Skip to main content

kaish_kernel/scheduler/
scatter.rs

1//! Scatter/Gather — Parallel pipeline execution.
2//!
3//! Scatter splits input into items and runs the pipeline in parallel.
4//! Gather collects the parallel results.
5//!
6//! # Example
7//!
8//! ```text
9//! cat urls.txt | scatter | fetch url=${ITEM} | gather
10//! ```
11//!
12//! This reads URLs, then for each URL runs `fetch` in parallel,
13//! then collects all results.
14
15use std::sync::Arc;
16
17use tokio::sync::Semaphore;
18
19use crate::ast::{Command, Value};
20use crate::dispatch::CommandDispatcher;
21use crate::interpreter::ExecResult;
22use crate::tools::{ExecContext, ToolRegistry};
23
24use super::pipeline::PipelineRunner;
25
26/// Options for scatter operation.
27#[derive(Debug, Clone)]
28pub struct ScatterOptions {
29    /// Variable name to bind each item to (default: "ITEM").
30    pub var_name: String,
31    /// Maximum parallelism (default: 8).
32    pub limit: usize,
33}
34
35impl Default for ScatterOptions {
36    fn default() -> Self {
37        Self {
38            var_name: "ITEM".to_string(),
39            limit: 8,
40        }
41    }
42}
43
44/// Options for gather operation.
45#[derive(Debug, Clone)]
46pub struct GatherOptions {
47    /// Show progress indicator.
48    pub progress: bool,
49    /// Take first N results and cancel rest (0 = all).
50    pub first: usize,
51    /// Output format: "json" or "lines".
52    pub format: String,
53}
54
55impl Default for GatherOptions {
56    fn default() -> Self {
57        Self {
58            progress: false,
59            first: 0,
60            format: "lines".to_string(),
61        }
62    }
63}
64
65/// Result from a single scatter worker.
66#[derive(Debug, Clone)]
67pub struct ScatterResult {
68    /// The input item that was processed.
69    pub item: String,
70    /// The execution result.
71    pub result: ExecResult,
72}
73
74/// Runs scatter/gather pipelines.
75///
76/// The dispatcher is used for pre_scatter, post_gather, and parallel worker
77/// command execution. This enables scatter blocks to run user tools, scripts,
78/// and external commands — not just builtins.
79pub struct ScatterGatherRunner {
80    tools: Arc<ToolRegistry>,
81    dispatcher: Arc<dyn CommandDispatcher>,
82}
83
84impl ScatterGatherRunner {
85    /// Create a new scatter/gather runner with the given dispatcher.
86    ///
87    /// The dispatcher handles the full command resolution chain for all
88    /// pipeline stages (pre_scatter, parallel workers, post_gather).
89    pub fn new(tools: Arc<ToolRegistry>, dispatcher: Arc<dyn CommandDispatcher>) -> Self {
90        Self { tools, dispatcher }
91    }
92
93    /// Execute a scatter/gather pipeline.
94    ///
95    /// The pipeline is split into three parts:
96    /// - pre_scatter: commands before scatter
97    /// - parallel: commands between scatter and gather
98    /// - post_gather: commands after gather
99    ///
100    /// Returns the final result after all stages complete.
101    pub async fn run(
102        &self,
103        pre_scatter: &[Command],
104        scatter_opts: ScatterOptions,
105        parallel: &[Command],
106        gather_opts: GatherOptions,
107        post_gather: &[Command],
108        ctx: &mut ExecContext,
109    ) -> ExecResult {
110        let runner = PipelineRunner::new(self.tools.clone());
111
112        // Run pre-scatter commands to get input.
113        // Uses run_sequential to avoid async recursion (scatter → run → scatter).
114        let input = if pre_scatter.is_empty() {
115            // Use existing stdin
116            ctx.take_stdin().unwrap_or_default()
117        } else {
118            let result = runner.run_sequential(pre_scatter, ctx, &*self.dispatcher).await;
119            if !result.ok() {
120                return result;
121            }
122            result.out
123        };
124
125        // Split input into items
126        let items = split_input(&input);
127        if items.is_empty() {
128            return ExecResult::success("");
129        }
130
131        // Run parallel stage
132        let results = self
133            .run_parallel(&items, &scatter_opts, parallel, ctx)
134            .await;
135
136        // Gather results
137        let gathered = gather_results(&results, &gather_opts);
138
139        // Run post-gather commands if any
140        if post_gather.is_empty() {
141            ExecResult::success(gathered)
142        } else {
143            ctx.set_stdin(gathered);
144            runner.run_sequential(post_gather, ctx, &*self.dispatcher).await
145        }
146    }
147
148    /// Run the parallel stage for all items.
149    ///
150    /// # Safety constraint
151    ///
152    /// Parallel workers share the dispatcher via `Arc`. The dispatcher MUST be
153    /// stateless (like `BackendDispatcher`) when used from parallel workers.
154    /// Using a stateful dispatcher (like `Kernel`, which writes to `self.scope`)
155    /// would cause data races — parallel workers would stomp each other's scope.
156    ///
157    /// The `Kernel` dispatcher is safe for pre_scatter and post_gather (sequential),
158    /// but scatter workers must use `BackendDispatcher` until per-worker kernel
159    /// instances are implemented.
160    async fn run_parallel(
161        &self,
162        items: &[String],
163        opts: &ScatterOptions,
164        commands: &[Command],
165        base_ctx: &ExecContext,
166    ) -> Vec<ScatterResult> {
167        let semaphore = Arc::new(Semaphore::new(opts.limit));
168        let tools = self.tools.clone();
169        let dispatcher = self.dispatcher.clone();
170        let var_name = opts.var_name.clone();
171
172        // Spawn parallel tasks
173        let mut handles = Vec::with_capacity(items.len());
174
175        for item in items.iter().cloned() {
176            let permit = semaphore.clone().acquire_owned().await;
177            let tools = tools.clone();
178            let dispatcher = dispatcher.clone();
179            let commands = commands.to_vec();
180            let var_name = var_name.clone();
181            let base_scope = base_ctx.scope.clone();
182            let backend = base_ctx.backend.clone();
183            let cwd = base_ctx.cwd.clone();
184
185            let handle = tokio::spawn(async move {
186                let _permit = permit; // Hold permit until done
187
188                // Create context for this worker
189                let mut scope = base_scope;
190                scope.set(&var_name, Value::String(item.clone()));
191
192                let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
193                ctx.set_cwd(cwd);
194
195                // Run through PipelineRunner + dispatcher (full resolution chain).
196                // Uses run_sequential to avoid async recursion and infinite future size.
197                let runner = PipelineRunner::new(tools);
198                let result = runner.run_sequential(&commands, &mut ctx, &*dispatcher).await;
199
200                ScatterResult { item, result }
201            });
202
203            handles.push(handle);
204        }
205
206        // Collect results
207        let mut results = Vec::with_capacity(handles.len());
208        for handle in handles {
209            match handle.await {
210                Ok(result) => results.push(result),
211                Err(e) => {
212                    results.push(ScatterResult {
213                        item: String::new(),
214                        result: ExecResult::failure(1, format!("Task panicked: {}", e)),
215                    });
216                }
217            }
218        }
219
220        results
221    }
222}
223
224/// Split input into items (by newlines or JSON array).
225fn split_input(input: &str) -> Vec<String> {
226    let trimmed = input.trim();
227
228    // Try to parse as JSON array first
229    if trimmed.starts_with('[')
230        && let Ok(arr) = serde_json::from_str::<Vec<serde_json::Value>>(trimmed) {
231            return arr
232                .into_iter()
233                .map(|v| match v {
234                    serde_json::Value::String(s) => s,
235                    other => other.to_string(),
236                })
237                .collect();
238        }
239
240    // Fall back to line splitting
241    trimmed
242        .lines()
243        .map(|s| s.to_string())
244        .filter(|s| !s.is_empty())
245        .collect()
246}
247
248/// Gather results into output string.
249fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> String {
250    let results_to_use = if opts.first > 0 && opts.first < results.len() {
251        &results[..opts.first]
252    } else {
253        results
254    };
255
256    if opts.format == "json" {
257        // Output as JSON array of objects
258        let json_results: Vec<serde_json::Value> = results_to_use
259            .iter()
260            .map(|r| {
261                serde_json::json!({
262                    "item": r.item,
263                    "ok": r.result.ok(),
264                    "code": r.result.code,
265                    "out": r.result.out.trim(),
266                    "err": r.result.err.trim(),
267                })
268            })
269            .collect();
270
271        serde_json::to_string_pretty(&json_results).unwrap_or_default()
272    } else {
273        // Output as lines (stdout from each, separated by newlines)
274        results_to_use
275            .iter()
276            .filter(|r| r.result.ok())
277            .map(|r| r.result.out.trim())
278            .collect::<Vec<_>>()
279            .join("\n")
280    }
281}
282
283/// Parse scatter options from tool args.
284pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
285    let mut opts = ScatterOptions::default();
286
287    if let Some(Value::String(name)) = args.named.get("as") {
288        opts.var_name = name.clone();
289    }
290
291    if let Some(Value::Int(n)) = args.named.get("limit") {
292        opts.limit = (*n).max(1) as usize;
293    }
294
295    opts
296}
297
298/// Parse gather options from tool args.
299pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
300    let mut opts = GatherOptions::default();
301
302    if args.has_flag("progress") {
303        opts.progress = true;
304    }
305
306    if let Some(Value::Int(n)) = args.named.get("first") {
307        opts.first = (*n).max(0) as usize;
308    }
309
310    if let Some(Value::String(fmt)) = args.named.get("format") {
311        opts.format = fmt.clone();
312    }
313
314    opts
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_split_input_lines() {
323        let input = "one\ntwo\nthree\n";
324        let items = split_input(input);
325        assert_eq!(items, vec!["one", "two", "three"]);
326    }
327
328    #[test]
329    fn test_split_input_json_array() {
330        let input = r#"["a", "b", "c"]"#;
331        let items = split_input(input);
332        assert_eq!(items, vec!["a", "b", "c"]);
333    }
334
335    #[test]
336    fn test_split_input_json_mixed() {
337        let input = r#"[1, "two", true]"#;
338        let items = split_input(input);
339        assert_eq!(items, vec!["1", "two", "true"]);
340    }
341
342    #[test]
343    fn test_split_input_empty() {
344        let input = "";
345        let items = split_input(input);
346        assert!(items.is_empty());
347    }
348
349    #[test]
350    fn test_gather_results_lines() {
351        let results = vec![
352            ScatterResult {
353                item: "a".to_string(),
354                result: ExecResult::success("result_a"),
355            },
356            ScatterResult {
357                item: "b".to_string(),
358                result: ExecResult::success("result_b"),
359            },
360        ];
361
362        let opts = GatherOptions::default();
363        let output = gather_results(&results, &opts);
364        assert_eq!(output, "result_a\nresult_b");
365    }
366
367    #[test]
368    fn test_gather_results_json() {
369        let results = vec![ScatterResult {
370            item: "test".to_string(),
371            result: ExecResult::success("output"),
372        }];
373
374        let opts = GatherOptions {
375            format: "json".to_string(),
376            ..Default::default()
377        };
378        let output = gather_results(&results, &opts);
379        assert!(output.contains("\"item\": \"test\""));
380        assert!(output.contains("\"ok\": true"));
381    }
382
383    #[test]
384    fn test_gather_results_first_n() {
385        let results = vec![
386            ScatterResult {
387                item: "a".to_string(),
388                result: ExecResult::success("1"),
389            },
390            ScatterResult {
391                item: "b".to_string(),
392                result: ExecResult::success("2"),
393            },
394            ScatterResult {
395                item: "c".to_string(),
396                result: ExecResult::success("3"),
397            },
398        ];
399
400        let opts = GatherOptions {
401            first: 2,
402            ..Default::default()
403        };
404        let output = gather_results(&results, &opts);
405        assert_eq!(output, "1\n2");
406    }
407
408    #[test]
409    fn test_parse_scatter_options() {
410        use crate::tools::ToolArgs;
411
412        let mut args = ToolArgs::new();
413        args.named.insert("as".to_string(), Value::String("URL".to_string()));
414        args.named.insert("limit".to_string(), Value::Int(4));
415
416        let opts = parse_scatter_options(&args);
417        assert_eq!(opts.var_name, "URL");
418        assert_eq!(opts.limit, 4);
419    }
420
421    #[test]
422    fn test_parse_gather_options() {
423        use crate::tools::ToolArgs;
424
425        let mut args = ToolArgs::new();
426        args.named.insert("first".to_string(), Value::Int(5));
427        args.named.insert("format".to_string(), Value::String("json".to_string()));
428
429        let opts = parse_gather_options(&args);
430        assert_eq!(opts.first, 5);
431        assert_eq!(opts.format, "json");
432    }
433}