Skip to main content

kaish_kernel/scheduler/
scatter.rs

1//! Scatter/Gather — Parallel pipeline execution.
2//!
3//! Scatter splits input into items and runs the pipeline in parallel.
4//! Gather collects the parallel results.
5//!
6//! # Example
7//!
8//! ```text
9//! cat urls.txt | scatter | fetch url=${ITEM} | gather
10//! ```
11//!
12//! This reads URLs, then for each URL runs `fetch` in parallel,
13//! then collects all results.
14
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::time::Duration;
18
19use tokio::sync::Semaphore;
20use tracing::Instrument;
21
22use crate::ast::{Command, Value};
23use crate::dispatch::CommandDispatcher;
24use crate::duration::parse_duration;
25use crate::interpreter::ExecResult;
26use crate::tools::{ExecContext, ToolRegistry};
27
28use super::pipeline::PipelineRunner;
29
30/// Options for scatter operation.
31#[derive(Debug, Clone)]
32pub struct ScatterOptions {
33    /// Variable name to bind each item to (default: "ITEM").
34    pub var_name: String,
35    /// Maximum parallelism (default: 8).
36    pub limit: usize,
37    /// Per-worker timeout. When `Some`, each worker is cancelled after this
38    /// duration; the worker's external children get SIGTERM/SIGKILL and the
39    /// `ScatterResult.timed_out` flag is set.
40    pub timeout: Option<Duration>,
41}
42
43/// Options for gather operation.
44#[derive(Debug, Clone)]
45pub struct GatherOptions {
46    /// Show progress indicator.
47    pub progress: bool,
48    /// Take first N results and cancel rest (0 = all).
49    pub first: usize,
50    /// Output format: "json" or "lines".
51    pub format: String,
52}
53
54impl Default for ScatterOptions {
55    fn default() -> Self {
56        Self {
57            var_name: "ITEM".to_string(),
58            limit: 8,
59            timeout: None,
60        }
61    }
62}
63
64impl Default for GatherOptions {
65    fn default() -> Self {
66        Self {
67            progress: false,
68            first: 0,
69            format: "lines".to_string(),
70        }
71    }
72}
73
74/// Result from a single scatter worker.
75#[derive(Debug, Clone)]
76pub struct ScatterResult {
77    /// The input item that was processed.
78    pub item: String,
79    /// The execution result.
80    pub result: ExecResult,
81    /// Whether the worker was cancelled by the per-worker `--timeout`.
82    pub timed_out: bool,
83}
84
85/// Runs scatter/gather pipelines.
86///
87/// Uses a single dispatcher for sequential stages (pre_scatter, post_gather),
88/// and forks it per parallel worker via [`CommandDispatcher::fork`]. Each
89/// worker gets its own subkernel with snapshotted session state so they can
90/// run concurrently without racing on scope/cwd/aliases.
91pub struct ScatterGatherRunner {
92    tools: Arc<ToolRegistry>,
93    /// Full dispatch chain for sequential stages (pre_scatter, post_gather).
94    /// Parallel workers fork from this dispatcher.
95    sequential_dispatcher: Arc<dyn CommandDispatcher>,
96}
97
98impl ScatterGatherRunner {
99    /// Create a new scatter/gather runner.
100    ///
101    /// `dispatcher` drives sequential stages directly and serves as the fork
102    /// source for parallel workers.
103    pub fn new(
104        tools: Arc<ToolRegistry>,
105        dispatcher: Arc<dyn CommandDispatcher>,
106    ) -> Self {
107        Self { tools, sequential_dispatcher: dispatcher }
108    }
109
110    /// Execute a scatter/gather pipeline.
111    ///
112    /// The pipeline is split into three parts:
113    /// - pre_scatter: commands before scatter
114    /// - parallel: commands between scatter and gather
115    /// - post_gather: commands after gather
116    ///
117    /// Returns the final result after all stages complete.
118    #[tracing::instrument(level = "info", skip(self, pre_scatter, scatter_opts, parallel, gather_opts, post_gather, ctx), fields(item_count = tracing::field::Empty, parallelism = scatter_opts.limit))]
119    pub async fn run(
120        &self,
121        pre_scatter: &[Command],
122        scatter_opts: ScatterOptions,
123        parallel: &[Command],
124        gather_opts: GatherOptions,
125        post_gather: &[Command],
126        ctx: &mut ExecContext,
127    ) -> ExecResult {
128        let runner = PipelineRunner::new(self.tools.clone());
129
130        // Run pre-scatter commands to get input.
131        // Uses run_sequential to avoid async recursion (scatter → run → scatter).
132        let (text, data) = if pre_scatter.is_empty() {
133            // Use existing stdin
134            let data = ctx.take_stdin_data();
135            let text = ctx.take_stdin().unwrap_or_default();
136            (text, data)
137        } else {
138            let result = runner.run_sequential(pre_scatter, ctx, &*self.sequential_dispatcher).await;
139            if !result.ok() {
140                return result;
141            }
142            (result.text_out().into_owned(), result.data)
143        };
144
145        // Extract items from structured data or text
146        let items = match extract_items(data.as_ref(), &text) {
147            Ok(items) => items,
148            Err(msg) => return ExecResult::failure(1, msg),
149        };
150        if items.is_empty() {
151            return ExecResult::success("");
152        }
153
154        tracing::Span::current().record("item_count", items.len());
155
156        // Run parallel stage
157        let results = self
158            .run_parallel(&items, &scatter_opts, parallel, ctx)
159            .await;
160
161        // Gather results
162        let GatherOutput {
163            text: gathered,
164            dropped_failures,
165        } = gather_results(&results, &gather_opts);
166
167        // The line format can't carry a failed worker as a row. Rather than
168        // silently omit it (data corruption — the caller sees fewer rows than
169        // items scattered), fail loud: a non-zero exit plus an err naming the
170        // failed items. Feeding the truncated set into post-gather would
171        // propagate the corruption, so we short-circuit before running it.
172        if !dropped_failures.is_empty() {
173            let err = format!(
174                "gather: {} task(s) failed and were omitted from line output: {} (use --json to capture per-task status)",
175                dropped_failures.len(),
176                dropped_failures.join(", ")
177            );
178            return ExecResult::from_output(1, gathered, err);
179        }
180
181        // Run post-gather commands if any
182        if post_gather.is_empty() {
183            ExecResult::success(gathered)
184        } else {
185            ctx.set_stdin(gathered);
186            runner.run_sequential(post_gather, ctx, &*self.sequential_dispatcher).await
187        }
188    }
189
190    /// Run the parallel stage for all items.
191    ///
192    /// Each worker gets its own forked dispatcher via
193    /// [`CommandDispatcher::fork`]. The fork snapshots per-session state
194    /// (scope, cwd, aliases, user tools) so workers can run concurrently
195    /// without racing. Forks are cheap (Scope is COW, plus a few Arc bumps),
196    /// and they unlock the full dispatch chain inside workers — user tools,
197    /// `.kai` scripts, and `$(...)` in args all work.
198    #[tracing::instrument(level = "debug", skip(self, items, opts, commands, base_ctx), fields(worker_count = items.len()))]
199    async fn run_parallel(
200        &self,
201        items: &[String],
202        opts: &ScatterOptions,
203        commands: &[Command],
204        base_ctx: &ExecContext,
205    ) -> Vec<ScatterResult> {
206        let semaphore = Arc::new(Semaphore::new(opts.limit));
207        let tools = self.tools.clone();
208        let var_name = opts.var_name.clone();
209
210        // Spawn parallel tasks
211        let mut handles = Vec::with_capacity(items.len());
212
213        for item in items.iter().cloned() {
214            let permit = semaphore.clone().acquire_owned().await;
215            let tools = tools.clone();
216            // Fork attached: the worker's cancel token is a child of the
217            // parent kernel's, so a parent cancel (request timeout, embedder
218            // Kernel::cancel) cascades into the worker and kills its
219            // external children via the wait_or_kill discipline.
220            let worker_dispatcher = self.sequential_dispatcher.fork_attached().await;
221            let commands = commands.to_vec();
222            let var_name = var_name.clone();
223            let base_scope = base_ctx.scope.clone();
224            let backend = base_ctx.backend.clone();
225            let cwd = base_ctx.cwd.clone();
226            let parent_token = base_ctx.cancel.clone();
227            let worker_token = parent_token.child_token();
228
229            // Per-worker timeout: spawn a delay task that cancels the worker's
230            // child token after `opts.timeout`. The cancel cascades into the
231            // worker's externals via the fork's cancel link. `timed_out_flag`
232            // distinguishes timeout from explicit parent cancellation when
233            // tagging ScatterResult.
234            let timed_out_flag = Arc::new(AtomicBool::new(false));
235            let timer_handle: Option<tokio::task::JoinHandle<()>> = opts.timeout.map(|d| {
236                let cancel = worker_token.clone();
237                let flag = timed_out_flag.clone();
238                tokio::spawn(async move {
239                    tokio::time::sleep(d).await;
240                    flag.store(true, Ordering::SeqCst);
241                    cancel.cancel();
242                })
243            });
244            let timed_out_check = timed_out_flag.clone();
245
246            let item_label = if item.len() > 64 {
247                format!("{}...", &item[..64])
248            } else {
249                item.clone()
250            };
251            let worker_span = tracing::debug_span!("scatter_worker", item = %item_label);
252            // Propagate the embedder's trace context across the spawn boundary so
253            // each worker's spans stay in the same trace. `.instrument` below
254            // provides the tracing parent; this provides the OTel parent.
255            let handle = tokio::spawn(crate::telemetry::bind_current_context(async move {
256                let _permit = permit; // Hold permit until done
257
258                // Create context for this worker
259                let mut scope = base_scope;
260                scope.set(&var_name, Value::String(item.clone()));
261
262                let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
263                ctx.set_cwd(cwd);
264                ctx.cancel = worker_token;
265
266                // Run through PipelineRunner + dispatcher (full resolution chain).
267                // Uses run_sequential to avoid async recursion and infinite future size.
268                let runner = PipelineRunner::new(tools);
269                let result = runner.run_sequential(&commands, &mut ctx, &*worker_dispatcher).await;
270
271                // Worker finished — abort the timer if still pending so it
272                // doesn't fire a now-pointless cancel and idle resources.
273                if let Some(h) = timer_handle {
274                    h.abort();
275                }
276
277                let timed_out = timed_out_check.load(Ordering::SeqCst);
278                ScatterResult { item, result, timed_out }
279            }.instrument(worker_span)));
280
281            handles.push(handle);
282        }
283
284        // Collect results
285        let mut results = Vec::with_capacity(handles.len());
286        for handle in handles {
287            match handle.await {
288                Ok(result) => results.push(result),
289                Err(e) => {
290                    results.push(ScatterResult {
291                        item: String::new(),
292                        result: ExecResult::failure(1, format!("Task panicked: {}", e)),
293                        timed_out: false,
294                    });
295                }
296            }
297        }
298
299        results
300    }
301}
302
303/// Extract items from structured data or text.
304///
305/// kaish does not split implicitly — this function requires structured data
306/// (JSON array from split/seq/glob/find) for multi-item input. Single-line
307/// text is treated as one item. Multi-line text without structured data is
308/// an error.
309pub fn extract_items(data: Option<&Value>, text: &str) -> Result<Vec<String>, String> {
310    // 1. Structured data (JSON array from split/seq/glob/find) — use it
311    if let Some(Value::Json(serde_json::Value::Array(arr))) = data {
312        return Ok(arr.iter().map(|v| match v {
313            serde_json::Value::String(s) => s.clone(),
314            other => other.to_string(),
315        }).collect());
316    }
317    if let Some(Value::String(s)) = data {
318        return Ok(vec![s.clone()]);
319    }
320
321    // 2. Empty — return empty
322    let trimmed = text.trim();
323    if trimmed.is_empty() {
324        return Ok(vec![]);
325    }
326
327    // 3. Raw text without structured data — one item (no implicit splitting)
328    Ok(vec![trimmed.to_string()])
329}
330
331/// Rendered gather output plus the names of any failed tasks that the
332/// line format could not represent as a row.
333struct GatherOutput {
334    text: String,
335    /// Items whose worker failed and were omitted from `text`. Only the
336    /// line format populates this — the JSON format carries every task as a
337    /// row with an explicit `"ok"` field, so nothing is dropped there.
338    dropped_failures: Vec<String>,
339}
340
341/// Gather results into output string.
342///
343/// The JSON format emits every task as a row (`"ok"` discriminates success
344/// from failure). The line format can only carry stdout, so it returns the
345/// successful rows in `text` and reports the failed items in
346/// `dropped_failures` — the caller (`run`) turns that into a loud non-zero
347/// exit rather than letting the failures vanish (see `docs/issues.md`).
348fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> GatherOutput {
349    let results_to_use = if opts.first > 0 && opts.first < results.len() {
350        &results[..opts.first]
351    } else {
352        results
353    };
354
355    if opts.format == "json" {
356        // Output as JSON array of objects
357        let json_results: Vec<serde_json::Value> = results_to_use
358            .iter()
359            .map(|r| {
360                serde_json::json!({
361                    "item": r.item,
362                    "ok": r.result.ok(),
363                    "code": r.result.code,
364                    "out": r.result.text_out().trim(),
365                    "err": r.result.err.trim(),
366                    "timed_out": r.timed_out,
367                })
368            })
369            .collect();
370
371        GatherOutput {
372            text: serde_json::to_string_pretty(&json_results).unwrap_or_default(),
373            dropped_failures: Vec::new(),
374        }
375    } else {
376        // Output as lines (stdout from each successful worker, separated by
377        // newlines). Failed workers can't be represented as a stdout row, so
378        // we collect their items and let `run` fail loud instead of dropping
379        // them silently.
380        let text = results_to_use
381            .iter()
382            .filter(|r| r.result.ok())
383            .map(|r| r.result.text_out())
384            .map(|t| t.trim().to_string())
385            .collect::<Vec<_>>()
386            .join("\n");
387        let dropped_failures = results_to_use
388            .iter()
389            .filter(|r| !r.result.ok())
390            .map(|r| r.item.clone())
391            .collect();
392        GatherOutput {
393            text,
394            dropped_failures,
395        }
396    }
397}
398
399/// Parse scatter options from tool args.
400pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
401    let mut opts = ScatterOptions::default();
402
403    if let Some(Value::String(name)) = args.named.get("as") {
404        opts.var_name = name.clone();
405    }
406
407    if let Some(Value::Int(n)) = args.named.get("limit") {
408        let requested = *n;
409        let clamped = requested.clamp(1, SCATTER_LIMIT_MAX as i64);
410        if requested > SCATTER_LIMIT_MAX as i64 {
411            tracing::warn!(
412                target: "kaish::scatter",
413                requested = requested,
414                ceiling = SCATTER_LIMIT_MAX,
415                "scatter limit clamped to ceiling"
416            );
417        }
418        opts.limit = clamped as usize;
419    }
420
421    // --timeout DURATION: per-worker timeout. Accepts the same forms as the
422    // `timeout` builtin (30, 5s, 500ms, 2m, 1h). Invalid input is ignored
423    // with a warn so a typo doesn't silently disable cancellation.
424    if let Some(Value::String(s)) = args.named.get("timeout") {
425        match parse_duration(s) {
426            Some(d) => opts.timeout = Some(d),
427            None => tracing::warn!(
428                target: "kaish::scatter",
429                value = %s,
430                "scatter --timeout: invalid duration (try: 30, 5s, 500ms, 2m, 1h)"
431            ),
432        }
433    } else if let Some(Value::Int(n)) = args.named.get("timeout") {
434        if *n >= 0 {
435            opts.timeout = Some(Duration::from_secs(*n as u64));
436        }
437    }
438
439    opts
440}
441
442/// Upper bound on the concurrency `scatter --limit N` accepts. Users who
443/// ask for more get a `tracing::warn` and are clamped to this value —
444/// silent clamping would violate the "no silent fallbacks" rule.
445pub const SCATTER_LIMIT_MAX: usize = 10_000;
446
447/// Parse gather options from tool args.
448pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
449    let mut opts = GatherOptions::default();
450
451    if args.has_flag("progress") {
452        opts.progress = true;
453    }
454
455    if let Some(Value::Int(n)) = args.named.get("first") {
456        opts.first = (*n).max(0) as usize;
457    }
458
459    if let Some(Value::String(fmt)) = args.named.get("format") {
460        opts.format = fmt.clone();
461    }
462
463    opts
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_extract_items_structured_json_array() {
472        let data = Value::Json(serde_json::json!(["a", "b", "c"]));
473        let items = extract_items(Some(&data), "").unwrap();
474        assert_eq!(items, vec!["a", "b", "c"]);
475    }
476
477    #[test]
478    fn test_extract_items_structured_mixed_types() {
479        let data = Value::Json(serde_json::json!([1, "two", true]));
480        let items = extract_items(Some(&data), "").unwrap();
481        assert_eq!(items, vec!["1", "two", "true"]);
482    }
483
484    #[test]
485    fn test_extract_items_structured_string() {
486        let data = Value::String("single".into());
487        let items = extract_items(Some(&data), "").unwrap();
488        assert_eq!(items, vec!["single"]);
489    }
490
491    #[test]
492    fn test_extract_items_single_line_text() {
493        let items = extract_items(None, "hello").unwrap();
494        assert_eq!(items, vec!["hello"]);
495    }
496
497    #[test]
498    fn test_extract_items_empty() {
499        let items = extract_items(None, "").unwrap();
500        assert!(items.is_empty());
501    }
502
503    #[test]
504    fn test_extract_items_multiline_is_one_item() {
505        // No implicit splitting — multi-line text is one item
506        let items = extract_items(None, "one\ntwo\nthree").unwrap();
507        assert_eq!(items, vec!["one\ntwo\nthree"]);
508    }
509
510    #[test]
511    fn test_extract_items_structured_overrides_text() {
512        // Structured data takes priority over text
513        let data = Value::Json(serde_json::json!(["x", "y"]));
514        let items = extract_items(Some(&data), "ignored\ntext").unwrap();
515        assert_eq!(items, vec!["x", "y"]);
516    }
517
518    #[test]
519    fn test_gather_results_lines() {
520        let results = vec![
521            ScatterResult {
522                item: "a".to_string(),
523                result: ExecResult::success("result_a"),
524                timed_out: false,
525            },
526            ScatterResult {
527                item: "b".to_string(),
528                result: ExecResult::success("result_b"),
529                timed_out: false,
530            },
531        ];
532
533        let opts = GatherOptions::default();
534        let output = gather_results(&results, &opts);
535        assert_eq!(output.text, "result_a\nresult_b");
536        assert!(output.dropped_failures.is_empty());
537    }
538
539    #[test]
540    fn test_gather_results_lines_reports_dropped_failures() {
541        // A failed worker must not vanish from line output: it is reported in
542        // `dropped_failures` so the caller can fail loud (docs/issues.md).
543        let results = vec![
544            ScatterResult {
545                item: "a".to_string(),
546                result: ExecResult::success("result_a"),
547                timed_out: false,
548            },
549            ScatterResult {
550                item: "b".to_string(),
551                result: ExecResult::failure(1, "boom"),
552                timed_out: false,
553            },
554        ];
555
556        let opts = GatherOptions::default();
557        let output = gather_results(&results, &opts);
558        // Successful rows still render; the failure is reported, not dropped.
559        assert_eq!(output.text, "result_a");
560        assert_eq!(output.dropped_failures, vec!["b".to_string()]);
561    }
562
563    #[test]
564    fn test_gather_results_json_keeps_failures_as_rows() {
565        // JSON carries failures as rows (ok: false), so it drops nothing.
566        let results = vec![ScatterResult {
567            item: "b".to_string(),
568            result: ExecResult::failure(2, "boom"),
569            timed_out: false,
570        }];
571        let opts = GatherOptions {
572            format: "json".to_string(),
573            ..Default::default()
574        };
575        let output = gather_results(&results, &opts);
576        assert!(output.dropped_failures.is_empty());
577        assert!(output.text.contains("\"ok\": false"));
578        assert!(output.text.contains("\"code\": 2"));
579    }
580
581    #[test]
582    fn test_gather_results_json() {
583        let results = vec![ScatterResult {
584            item: "test".to_string(),
585            result: ExecResult::success("output"),
586            timed_out: false,
587        }];
588
589        let opts = GatherOptions {
590            format: "json".to_string(),
591            ..Default::default()
592        };
593        let output = gather_results(&results, &opts);
594        assert!(output.text.contains("\"item\": \"test\""));
595        assert!(output.text.contains("\"ok\": true"));
596    }
597
598    #[test]
599    fn test_gather_results_first_n() {
600        let results = vec![
601            ScatterResult {
602                item: "a".to_string(),
603                result: ExecResult::success("1"),
604                timed_out: false,
605            },
606            ScatterResult {
607                item: "b".to_string(),
608                result: ExecResult::success("2"),
609                timed_out: false,
610            },
611            ScatterResult {
612                item: "c".to_string(),
613                result: ExecResult::success("3"),
614                timed_out: false,
615            },
616        ];
617
618        let opts = GatherOptions {
619            first: 2,
620            ..Default::default()
621        };
622        let output = gather_results(&results, &opts);
623        assert_eq!(output.text, "1\n2");
624    }
625
626    #[test]
627    fn test_parse_scatter_options() {
628        use crate::tools::ToolArgs;
629
630        let mut args = ToolArgs::new();
631        args.named.insert("as".to_string(), Value::String("URL".to_string()));
632        args.named.insert("limit".to_string(), Value::Int(4));
633
634        let opts = parse_scatter_options(&args);
635        assert_eq!(opts.var_name, "URL");
636        assert_eq!(opts.limit, 4);
637    }
638
639    #[test]
640    fn test_parse_gather_options() {
641        use crate::tools::ToolArgs;
642
643        let mut args = ToolArgs::new();
644        args.named.insert("first".to_string(), Value::Int(5));
645        args.named.insert("format".to_string(), Value::String("json".to_string()));
646
647        let opts = parse_gather_options(&args);
648        assert_eq!(opts.first, 5);
649        assert_eq!(opts.format, "json");
650    }
651
652    #[test]
653    fn scatter_limit_clamps_to_ceiling() {
654        use crate::tools::ToolArgs;
655
656        let mut args = ToolArgs::new();
657        args.named.insert("limit".to_string(), Value::Int(999_999));
658        let opts = parse_scatter_options(&args);
659        assert_eq!(opts.limit, SCATTER_LIMIT_MAX);
660    }
661
662    #[test]
663    fn scatter_limit_raises_zero_to_one() {
664        use crate::tools::ToolArgs;
665
666        let mut args = ToolArgs::new();
667        args.named.insert("limit".to_string(), Value::Int(0));
668        let opts = parse_scatter_options(&args);
669        assert_eq!(opts.limit, 1);
670    }
671
672    #[test]
673    fn scatter_limit_raises_negative_to_one() {
674        use crate::tools::ToolArgs;
675
676        let mut args = ToolArgs::new();
677        args.named.insert("limit".to_string(), Value::Int(-42));
678        let opts = parse_scatter_options(&args);
679        assert_eq!(opts.limit, 1);
680    }
681
682    #[test]
683    fn scatter_limit_preserves_valid_values() {
684        use crate::tools::ToolArgs;
685
686        let mut args = ToolArgs::new();
687        args.named.insert("limit".to_string(), Value::Int(500));
688        let opts = parse_scatter_options(&args);
689        assert_eq!(opts.limit, 500);
690    }
691}