Skip to main content

kaish_kernel/scheduler/
scatter.rs

1//! Scatter/Gather — Parallel pipeline execution.
2//!
3//! Scatter splits input into items and runs the pipeline in parallel.
4//! Gather collects the parallel results.
5//!
6//! # Example
7//!
8//! ```text
9//! cat urls.txt | scatter | fetch url=${ITEM} | gather
10//! ```
11//!
12//! This reads URLs, then for each URL runs `fetch` in parallel,
13//! then collects all results.
14
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::time::Duration;
18
19use tokio::sync::Semaphore;
20use tracing::Instrument;
21
22use crate::ast::{Command, Value};
23use crate::dispatch::CommandDispatcher;
24use crate::duration::parse_duration;
25use crate::interpreter::ExecResult;
26use crate::tools::{ExecContext, ToolRegistry};
27
28use super::pipeline::PipelineRunner;
29
30/// Options for scatter operation.
31#[derive(Debug, Clone)]
32pub struct ScatterOptions {
33    /// Variable name to bind each item to (default: "ITEM").
34    pub var_name: String,
35    /// Maximum parallelism (default: 8).
36    pub limit: usize,
37    /// Per-worker timeout. When `Some`, each worker is cancelled after this
38    /// duration; the worker's external children get SIGTERM/SIGKILL and the
39    /// `ScatterResult.timed_out` flag is set.
40    pub timeout: Option<Duration>,
41}
42
43/// Options for gather operation.
44#[derive(Debug, Clone)]
45pub struct GatherOptions {
46    /// Show progress indicator.
47    pub progress: bool,
48    /// Take first N results and cancel rest (0 = all).
49    pub first: usize,
50    /// Output format: "json" or "lines".
51    pub format: String,
52}
53
54impl Default for ScatterOptions {
55    fn default() -> Self {
56        Self {
57            var_name: "ITEM".to_string(),
58            limit: 8,
59            timeout: None,
60        }
61    }
62}
63
64impl Default for GatherOptions {
65    fn default() -> Self {
66        Self {
67            progress: false,
68            first: 0,
69            format: "lines".to_string(),
70        }
71    }
72}
73
74/// Result from a single scatter worker.
75#[derive(Debug, Clone)]
76pub struct ScatterResult {
77    /// The input item that was processed.
78    pub item: String,
79    /// The execution result.
80    pub result: ExecResult,
81    /// Whether the worker was cancelled by the per-worker `--timeout`.
82    pub timed_out: bool,
83}
84
85/// Runs scatter/gather pipelines.
86///
87/// Uses a single dispatcher for sequential stages (pre_scatter, post_gather),
88/// and forks it per parallel worker via [`CommandDispatcher::fork`]. Each
89/// worker gets its own subkernel with snapshotted session state so they can
90/// run concurrently without racing on scope/cwd/aliases.
91pub struct ScatterGatherRunner {
92    tools: Arc<ToolRegistry>,
93    /// Full dispatch chain for sequential stages (pre_scatter, post_gather).
94    /// Parallel workers fork from this dispatcher.
95    sequential_dispatcher: Arc<dyn CommandDispatcher>,
96}
97
98impl ScatterGatherRunner {
99    /// Create a new scatter/gather runner.
100    ///
101    /// `dispatcher` drives sequential stages directly and serves as the fork
102    /// source for parallel workers.
103    pub fn new(
104        tools: Arc<ToolRegistry>,
105        dispatcher: Arc<dyn CommandDispatcher>,
106    ) -> Self {
107        Self { tools, sequential_dispatcher: dispatcher }
108    }
109
110    /// Execute a scatter/gather pipeline.
111    ///
112    /// The pipeline is split into three parts:
113    /// - pre_scatter: commands before scatter
114    /// - parallel: commands between scatter and gather
115    /// - post_gather: commands after gather
116    ///
117    /// Returns the final result after all stages complete.
118    #[tracing::instrument(level = "info", skip(self, pre_scatter, scatter_opts, parallel, gather_opts, post_gather, ctx), fields(item_count = tracing::field::Empty, parallelism = scatter_opts.limit))]
119    pub async fn run(
120        &self,
121        pre_scatter: &[Command],
122        scatter_opts: ScatterOptions,
123        parallel: &[Command],
124        gather_opts: GatherOptions,
125        post_gather: &[Command],
126        ctx: &mut ExecContext,
127    ) -> ExecResult {
128        let runner = PipelineRunner::new(self.tools.clone());
129
130        // Run pre-scatter commands to get input.
131        // Uses run_sequential to avoid async recursion (scatter → run → scatter).
132        let (text, data) = if pre_scatter.is_empty() {
133            // Use existing stdin
134            let data = ctx.take_stdin_data();
135            let text = ctx.take_stdin().unwrap_or_default();
136            (text, data)
137        } else {
138            let result = runner.run_sequential(pre_scatter, ctx, &*self.sequential_dispatcher).await;
139            if !result.ok() {
140                return result;
141            }
142            (result.text_out().into_owned(), result.data)
143        };
144
145        // Extract items from structured data or text
146        let items = match extract_items(data.as_ref(), &text) {
147            Ok(items) => items,
148            Err(msg) => return ExecResult::failure(1, msg),
149        };
150        if items.is_empty() {
151            return ExecResult::success("");
152        }
153
154        tracing::Span::current().record("item_count", items.len());
155
156        // Run parallel stage
157        let results = self
158            .run_parallel(&items, &scatter_opts, parallel, ctx)
159            .await;
160
161        // Gather results
162        let gathered = gather_results(&results, &gather_opts);
163
164        // Run post-gather commands if any
165        if post_gather.is_empty() {
166            ExecResult::success(gathered)
167        } else {
168            ctx.set_stdin(gathered);
169            runner.run_sequential(post_gather, ctx, &*self.sequential_dispatcher).await
170        }
171    }
172
173    /// Run the parallel stage for all items.
174    ///
175    /// Each worker gets its own forked dispatcher via
176    /// [`CommandDispatcher::fork`]. The fork snapshots per-session state
177    /// (scope, cwd, aliases, user tools) so workers can run concurrently
178    /// without racing. Forks are cheap (Scope is COW, plus a few Arc bumps),
179    /// and they unlock the full dispatch chain inside workers — user tools,
180    /// `.kai` scripts, and `$(...)` in args all work.
181    #[tracing::instrument(level = "debug", skip(self, items, opts, commands, base_ctx), fields(worker_count = items.len()))]
182    async fn run_parallel(
183        &self,
184        items: &[String],
185        opts: &ScatterOptions,
186        commands: &[Command],
187        base_ctx: &ExecContext,
188    ) -> Vec<ScatterResult> {
189        let semaphore = Arc::new(Semaphore::new(opts.limit));
190        let tools = self.tools.clone();
191        let var_name = opts.var_name.clone();
192
193        // Spawn parallel tasks
194        let mut handles = Vec::with_capacity(items.len());
195
196        for item in items.iter().cloned() {
197            let permit = semaphore.clone().acquire_owned().await;
198            let tools = tools.clone();
199            // Fork attached: the worker's cancel token is a child of the
200            // parent kernel's, so a parent cancel (request timeout, embedder
201            // Kernel::cancel) cascades into the worker and kills its
202            // external children via the wait_or_kill discipline.
203            let worker_dispatcher = self.sequential_dispatcher.fork_attached().await;
204            let commands = commands.to_vec();
205            let var_name = var_name.clone();
206            let base_scope = base_ctx.scope.clone();
207            let backend = base_ctx.backend.clone();
208            let cwd = base_ctx.cwd.clone();
209            let parent_token = base_ctx.cancel.clone();
210            let worker_token = parent_token.child_token();
211
212            // Per-worker timeout: spawn a delay task that cancels the worker's
213            // child token after `opts.timeout`. The cancel cascades into the
214            // worker's externals via the fork's cancel link. `timed_out_flag`
215            // distinguishes timeout from explicit parent cancellation when
216            // tagging ScatterResult.
217            let timed_out_flag = Arc::new(AtomicBool::new(false));
218            let timer_handle: Option<tokio::task::JoinHandle<()>> = opts.timeout.map(|d| {
219                let cancel = worker_token.clone();
220                let flag = timed_out_flag.clone();
221                tokio::spawn(async move {
222                    tokio::time::sleep(d).await;
223                    flag.store(true, Ordering::SeqCst);
224                    cancel.cancel();
225                })
226            });
227            let timed_out_check = timed_out_flag.clone();
228
229            let item_label = if item.len() > 64 {
230                format!("{}...", &item[..64])
231            } else {
232                item.clone()
233            };
234            let worker_span = tracing::debug_span!("scatter_worker", item = %item_label);
235            let handle = tokio::spawn(async move {
236                let _permit = permit; // Hold permit until done
237
238                // Create context for this worker
239                let mut scope = base_scope;
240                scope.set(&var_name, Value::String(item.clone()));
241
242                let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
243                ctx.set_cwd(cwd);
244                ctx.cancel = worker_token;
245
246                // Run through PipelineRunner + dispatcher (full resolution chain).
247                // Uses run_sequential to avoid async recursion and infinite future size.
248                let runner = PipelineRunner::new(tools);
249                let result = runner.run_sequential(&commands, &mut ctx, &*worker_dispatcher).await;
250
251                // Worker finished — abort the timer if still pending so it
252                // doesn't fire a now-pointless cancel and idle resources.
253                if let Some(h) = timer_handle {
254                    h.abort();
255                }
256
257                let timed_out = timed_out_check.load(Ordering::SeqCst);
258                ScatterResult { item, result, timed_out }
259            }.instrument(worker_span));
260
261            handles.push(handle);
262        }
263
264        // Collect results
265        let mut results = Vec::with_capacity(handles.len());
266        for handle in handles {
267            match handle.await {
268                Ok(result) => results.push(result),
269                Err(e) => {
270                    results.push(ScatterResult {
271                        item: String::new(),
272                        result: ExecResult::failure(1, format!("Task panicked: {}", e)),
273                        timed_out: false,
274                    });
275                }
276            }
277        }
278
279        results
280    }
281}
282
283/// Extract items from structured data or text.
284///
285/// kaish does not split implicitly — this function requires structured data
286/// (JSON array from split/seq/glob/find) for multi-item input. Single-line
287/// text is treated as one item. Multi-line text without structured data is
288/// an error.
289pub fn extract_items(data: Option<&Value>, text: &str) -> Result<Vec<String>, String> {
290    // 1. Structured data (JSON array from split/seq/glob/find) — use it
291    if let Some(Value::Json(serde_json::Value::Array(arr))) = data {
292        return Ok(arr.iter().map(|v| match v {
293            serde_json::Value::String(s) => s.clone(),
294            other => other.to_string(),
295        }).collect());
296    }
297    if let Some(Value::String(s)) = data {
298        return Ok(vec![s.clone()]);
299    }
300
301    // 2. Empty — return empty
302    let trimmed = text.trim();
303    if trimmed.is_empty() {
304        return Ok(vec![]);
305    }
306
307    // 3. Raw text without structured data — one item (no implicit splitting)
308    Ok(vec![trimmed.to_string()])
309}
310
311/// Gather results into output string.
312fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> String {
313    let results_to_use = if opts.first > 0 && opts.first < results.len() {
314        &results[..opts.first]
315    } else {
316        results
317    };
318
319    if opts.format == "json" {
320        // Output as JSON array of objects
321        let json_results: Vec<serde_json::Value> = results_to_use
322            .iter()
323            .map(|r| {
324                serde_json::json!({
325                    "item": r.item,
326                    "ok": r.result.ok(),
327                    "code": r.result.code,
328                    "out": r.result.text_out().trim(),
329                    "err": r.result.err.trim(),
330                    "timed_out": r.timed_out,
331                })
332            })
333            .collect();
334
335        serde_json::to_string_pretty(&json_results).unwrap_or_default()
336    } else {
337        // Output as lines (stdout from each, separated by newlines)
338        results_to_use
339            .iter()
340            .filter(|r| r.result.ok())
341            .map(|r| r.result.text_out())
342            .map(|t| t.trim().to_string())
343            .collect::<Vec<_>>()
344            .join("\n")
345    }
346}
347
348/// Parse scatter options from tool args.
349pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
350    let mut opts = ScatterOptions::default();
351
352    if let Some(Value::String(name)) = args.named.get("as") {
353        opts.var_name = name.clone();
354    }
355
356    if let Some(Value::Int(n)) = args.named.get("limit") {
357        let requested = *n;
358        let clamped = requested.clamp(1, SCATTER_LIMIT_MAX as i64);
359        if requested > SCATTER_LIMIT_MAX as i64 {
360            tracing::warn!(
361                target: "kaish::scatter",
362                requested = requested,
363                ceiling = SCATTER_LIMIT_MAX,
364                "scatter limit clamped to ceiling"
365            );
366        }
367        opts.limit = clamped as usize;
368    }
369
370    // --timeout DURATION: per-worker timeout. Accepts the same forms as the
371    // `timeout` builtin (30, 5s, 500ms, 2m, 1h). Invalid input is ignored
372    // with a warn so a typo doesn't silently disable cancellation.
373    if let Some(Value::String(s)) = args.named.get("timeout") {
374        match parse_duration(s) {
375            Some(d) => opts.timeout = Some(d),
376            None => tracing::warn!(
377                target: "kaish::scatter",
378                value = %s,
379                "scatter --timeout: invalid duration (try: 30, 5s, 500ms, 2m, 1h)"
380            ),
381        }
382    } else if let Some(Value::Int(n)) = args.named.get("timeout") {
383        if *n >= 0 {
384            opts.timeout = Some(Duration::from_secs(*n as u64));
385        }
386    }
387
388    opts
389}
390
391/// Upper bound on the concurrency `scatter limit=N` accepts. Users who
392/// ask for more get a `tracing::warn` and are clamped to this value —
393/// silent clamping would violate the "no silent fallbacks" rule.
394pub const SCATTER_LIMIT_MAX: usize = 10_000;
395
396/// Parse gather options from tool args.
397pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
398    let mut opts = GatherOptions::default();
399
400    if args.has_flag("progress") {
401        opts.progress = true;
402    }
403
404    if let Some(Value::Int(n)) = args.named.get("first") {
405        opts.first = (*n).max(0) as usize;
406    }
407
408    if let Some(Value::String(fmt)) = args.named.get("format") {
409        opts.format = fmt.clone();
410    }
411
412    opts
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_extract_items_structured_json_array() {
421        let data = Value::Json(serde_json::json!(["a", "b", "c"]));
422        let items = extract_items(Some(&data), "").unwrap();
423        assert_eq!(items, vec!["a", "b", "c"]);
424    }
425
426    #[test]
427    fn test_extract_items_structured_mixed_types() {
428        let data = Value::Json(serde_json::json!([1, "two", true]));
429        let items = extract_items(Some(&data), "").unwrap();
430        assert_eq!(items, vec!["1", "two", "true"]);
431    }
432
433    #[test]
434    fn test_extract_items_structured_string() {
435        let data = Value::String("single".into());
436        let items = extract_items(Some(&data), "").unwrap();
437        assert_eq!(items, vec!["single"]);
438    }
439
440    #[test]
441    fn test_extract_items_single_line_text() {
442        let items = extract_items(None, "hello").unwrap();
443        assert_eq!(items, vec!["hello"]);
444    }
445
446    #[test]
447    fn test_extract_items_empty() {
448        let items = extract_items(None, "").unwrap();
449        assert!(items.is_empty());
450    }
451
452    #[test]
453    fn test_extract_items_multiline_is_one_item() {
454        // No implicit splitting — multi-line text is one item
455        let items = extract_items(None, "one\ntwo\nthree").unwrap();
456        assert_eq!(items, vec!["one\ntwo\nthree"]);
457    }
458
459    #[test]
460    fn test_extract_items_structured_overrides_text() {
461        // Structured data takes priority over text
462        let data = Value::Json(serde_json::json!(["x", "y"]));
463        let items = extract_items(Some(&data), "ignored\ntext").unwrap();
464        assert_eq!(items, vec!["x", "y"]);
465    }
466
467    #[test]
468    fn test_gather_results_lines() {
469        let results = vec![
470            ScatterResult {
471                item: "a".to_string(),
472                result: ExecResult::success("result_a"),
473                timed_out: false,
474            },
475            ScatterResult {
476                item: "b".to_string(),
477                result: ExecResult::success("result_b"),
478                timed_out: false,
479            },
480        ];
481
482        let opts = GatherOptions::default();
483        let output = gather_results(&results, &opts);
484        assert_eq!(output, "result_a\nresult_b");
485    }
486
487    #[test]
488    fn test_gather_results_json() {
489        let results = vec![ScatterResult {
490            item: "test".to_string(),
491            result: ExecResult::success("output"),
492            timed_out: false,
493        }];
494
495        let opts = GatherOptions {
496            format: "json".to_string(),
497            ..Default::default()
498        };
499        let output = gather_results(&results, &opts);
500        assert!(output.contains("\"item\": \"test\""));
501        assert!(output.contains("\"ok\": true"));
502    }
503
504    #[test]
505    fn test_gather_results_first_n() {
506        let results = vec![
507            ScatterResult {
508                item: "a".to_string(),
509                result: ExecResult::success("1"),
510                timed_out: false,
511            },
512            ScatterResult {
513                item: "b".to_string(),
514                result: ExecResult::success("2"),
515                timed_out: false,
516            },
517            ScatterResult {
518                item: "c".to_string(),
519                result: ExecResult::success("3"),
520                timed_out: false,
521            },
522        ];
523
524        let opts = GatherOptions {
525            first: 2,
526            ..Default::default()
527        };
528        let output = gather_results(&results, &opts);
529        assert_eq!(output, "1\n2");
530    }
531
532    #[test]
533    fn test_parse_scatter_options() {
534        use crate::tools::ToolArgs;
535
536        let mut args = ToolArgs::new();
537        args.named.insert("as".to_string(), Value::String("URL".to_string()));
538        args.named.insert("limit".to_string(), Value::Int(4));
539
540        let opts = parse_scatter_options(&args);
541        assert_eq!(opts.var_name, "URL");
542        assert_eq!(opts.limit, 4);
543    }
544
545    #[test]
546    fn test_parse_gather_options() {
547        use crate::tools::ToolArgs;
548
549        let mut args = ToolArgs::new();
550        args.named.insert("first".to_string(), Value::Int(5));
551        args.named.insert("format".to_string(), Value::String("json".to_string()));
552
553        let opts = parse_gather_options(&args);
554        assert_eq!(opts.first, 5);
555        assert_eq!(opts.format, "json");
556    }
557
558    #[test]
559    fn scatter_limit_clamps_to_ceiling() {
560        use crate::tools::ToolArgs;
561
562        let mut args = ToolArgs::new();
563        args.named.insert("limit".to_string(), Value::Int(999_999));
564        let opts = parse_scatter_options(&args);
565        assert_eq!(opts.limit, SCATTER_LIMIT_MAX);
566    }
567
568    #[test]
569    fn scatter_limit_raises_zero_to_one() {
570        use crate::tools::ToolArgs;
571
572        let mut args = ToolArgs::new();
573        args.named.insert("limit".to_string(), Value::Int(0));
574        let opts = parse_scatter_options(&args);
575        assert_eq!(opts.limit, 1);
576    }
577
578    #[test]
579    fn scatter_limit_raises_negative_to_one() {
580        use crate::tools::ToolArgs;
581
582        let mut args = ToolArgs::new();
583        args.named.insert("limit".to_string(), Value::Int(-42));
584        let opts = parse_scatter_options(&args);
585        assert_eq!(opts.limit, 1);
586    }
587
588    #[test]
589    fn scatter_limit_preserves_valid_values() {
590        use crate::tools::ToolArgs;
591
592        let mut args = ToolArgs::new();
593        args.named.insert("limit".to_string(), Value::Int(500));
594        let opts = parse_scatter_options(&args);
595        assert_eq!(opts.limit, 500);
596    }
597}