Skip to main content

kaish_kernel/scheduler/
scatter.rs

1//! Scatter/Gather — Parallel pipeline execution.
2//!
3//! Scatter splits input into items and runs the pipeline in parallel.
4//! Gather collects the parallel results.
5//!
6//! # Example
7//!
8//! ```text
9//! cat urls.txt | scatter | fetch url=${ITEM} | gather
10//! ```
11//!
12//! This reads URLs, then for each URL runs `fetch` in parallel,
13//! then collects all results.
14
15use std::sync::Arc;
16
17use tokio::sync::Semaphore;
18
19use crate::ast::{Command, Value};
20use crate::interpreter::ExecResult;
21use crate::tools::{ExecContext, ToolRegistry};
22
23use super::pipeline::{run_sequential_pipeline, run_sequential_pipeline_owned};
24
25/// Options for scatter operation.
26#[derive(Debug, Clone)]
27pub struct ScatterOptions {
28    /// Variable name to bind each item to (default: "ITEM").
29    pub var_name: String,
30    /// Maximum parallelism (default: 8).
31    pub limit: usize,
32}
33
34impl Default for ScatterOptions {
35    fn default() -> Self {
36        Self {
37            var_name: "ITEM".to_string(),
38            limit: 8,
39        }
40    }
41}
42
43/// Options for gather operation.
44#[derive(Debug, Clone)]
45pub struct GatherOptions {
46    /// Show progress indicator.
47    pub progress: bool,
48    /// Take first N results and cancel rest (0 = all).
49    pub first: usize,
50    /// Output format: "json" or "lines".
51    pub format: String,
52}
53
54impl Default for GatherOptions {
55    fn default() -> Self {
56        Self {
57            progress: false,
58            first: 0,
59            format: "lines".to_string(),
60        }
61    }
62}
63
64/// Result from a single scatter worker.
65#[derive(Debug, Clone)]
66pub struct ScatterResult {
67    /// The input item that was processed.
68    pub item: String,
69    /// The execution result.
70    pub result: ExecResult,
71}
72
73/// Runs scatter/gather pipelines.
74pub struct ScatterGatherRunner {
75    tools: Arc<ToolRegistry>,
76}
77
78impl ScatterGatherRunner {
79    /// Create a new scatter/gather runner.
80    pub fn new(tools: Arc<ToolRegistry>) -> Self {
81        Self { tools }
82    }
83
84    /// Execute a scatter/gather pipeline.
85    ///
86    /// The pipeline is split into three parts:
87    /// - pre_scatter: commands before scatter
88    /// - parallel: commands between scatter and gather
89    /// - post_gather: commands after gather
90    ///
91    /// Returns the final result after all stages complete.
92    pub async fn run(
93        &self,
94        pre_scatter: &[Command],
95        scatter_opts: ScatterOptions,
96        parallel: &[Command],
97        gather_opts: GatherOptions,
98        post_gather: &[Command],
99        ctx: &mut ExecContext,
100    ) -> ExecResult {
101        // Run pre-scatter commands to get input
102        let input = if pre_scatter.is_empty() {
103            // Use existing stdin
104            ctx.take_stdin().unwrap_or_default()
105        } else {
106            let result = run_sequential_pipeline(&self.tools, pre_scatter, ctx).await;
107            if !result.ok() {
108                return result;
109            }
110            result.out
111        };
112
113        // Split input into items
114        let items = split_input(&input);
115        if items.is_empty() {
116            return ExecResult::success("");
117        }
118
119        // Run parallel stage
120        let results = self
121            .run_parallel(&items, &scatter_opts, parallel, ctx)
122            .await;
123
124        // Gather results
125        let gathered = gather_results(&results, &gather_opts);
126
127        // Run post-gather commands if any
128        if post_gather.is_empty() {
129            ExecResult::success(gathered)
130        } else {
131            ctx.set_stdin(gathered);
132            run_sequential_pipeline(&self.tools, post_gather, ctx).await
133        }
134    }
135
136    /// Run the parallel stage for all items.
137    async fn run_parallel(
138        &self,
139        items: &[String],
140        opts: &ScatterOptions,
141        commands: &[Command],
142        base_ctx: &ExecContext,
143    ) -> Vec<ScatterResult> {
144        let semaphore = Arc::new(Semaphore::new(opts.limit));
145        let tools = self.tools.clone();
146        let var_name = opts.var_name.clone();
147
148        // Spawn parallel tasks
149        let mut handles = Vec::with_capacity(items.len());
150
151        for item in items.iter().cloned() {
152            let permit = semaphore.clone().acquire_owned().await;
153            let tools = tools.clone();
154            let commands = commands.to_vec();
155            let var_name = var_name.clone();
156            let base_scope = base_ctx.scope.clone();
157            let backend = base_ctx.backend.clone();
158            let cwd = base_ctx.cwd.clone();
159
160            let handle = tokio::spawn(async move {
161                let _permit = permit; // Hold permit until done
162
163                // Create context for this worker
164                let mut scope = base_scope;
165                scope.set(&var_name, Value::String(item.clone()));
166
167                let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
168                ctx.set_cwd(cwd);
169
170                // Run the commands using sequential runner (no nested scatter/gather)
171                let result = run_sequential_pipeline_owned(tools, commands, &mut ctx).await;
172
173                ScatterResult { item, result }
174            });
175
176            handles.push(handle);
177        }
178
179        // Collect results
180        let mut results = Vec::with_capacity(handles.len());
181        for handle in handles {
182            match handle.await {
183                Ok(result) => results.push(result),
184                Err(e) => {
185                    results.push(ScatterResult {
186                        item: String::new(),
187                        result: ExecResult::failure(1, format!("Task panicked: {}", e)),
188                    });
189                }
190            }
191        }
192
193        results
194    }
195}
196
197/// Split input into items (by newlines or JSON array).
198fn split_input(input: &str) -> Vec<String> {
199    let trimmed = input.trim();
200
201    // Try to parse as JSON array first
202    if trimmed.starts_with('[')
203        && let Ok(arr) = serde_json::from_str::<Vec<serde_json::Value>>(trimmed) {
204            return arr
205                .into_iter()
206                .map(|v| match v {
207                    serde_json::Value::String(s) => s,
208                    other => other.to_string(),
209                })
210                .collect();
211        }
212
213    // Fall back to line splitting
214    trimmed
215        .lines()
216        .map(|s| s.to_string())
217        .filter(|s| !s.is_empty())
218        .collect()
219}
220
221/// Gather results into output string.
222fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> String {
223    let results_to_use = if opts.first > 0 && opts.first < results.len() {
224        &results[..opts.first]
225    } else {
226        results
227    };
228
229    if opts.format == "json" {
230        // Output as JSON array of objects
231        let json_results: Vec<serde_json::Value> = results_to_use
232            .iter()
233            .map(|r| {
234                serde_json::json!({
235                    "item": r.item,
236                    "ok": r.result.ok(),
237                    "code": r.result.code,
238                    "out": r.result.out.trim(),
239                    "err": r.result.err.trim(),
240                })
241            })
242            .collect();
243
244        serde_json::to_string_pretty(&json_results).unwrap_or_default()
245    } else {
246        // Output as lines (stdout from each, separated by newlines)
247        results_to_use
248            .iter()
249            .filter(|r| r.result.ok())
250            .map(|r| r.result.out.trim())
251            .collect::<Vec<_>>()
252            .join("\n")
253    }
254}
255
256/// Parse scatter options from tool args.
257pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
258    let mut opts = ScatterOptions::default();
259
260    if let Some(Value::String(name)) = args.named.get("as") {
261        opts.var_name = name.clone();
262    }
263
264    if let Some(Value::Int(n)) = args.named.get("limit") {
265        opts.limit = (*n).max(1) as usize;
266    }
267
268    opts
269}
270
271/// Parse gather options from tool args.
272pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
273    let mut opts = GatherOptions::default();
274
275    if args.has_flag("progress") {
276        opts.progress = true;
277    }
278
279    if let Some(Value::Int(n)) = args.named.get("first") {
280        opts.first = (*n).max(0) as usize;
281    }
282
283    if let Some(Value::String(fmt)) = args.named.get("format") {
284        opts.format = fmt.clone();
285    }
286
287    opts
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_split_input_lines() {
296        let input = "one\ntwo\nthree\n";
297        let items = split_input(input);
298        assert_eq!(items, vec!["one", "two", "three"]);
299    }
300
301    #[test]
302    fn test_split_input_json_array() {
303        let input = r#"["a", "b", "c"]"#;
304        let items = split_input(input);
305        assert_eq!(items, vec!["a", "b", "c"]);
306    }
307
308    #[test]
309    fn test_split_input_json_mixed() {
310        let input = r#"[1, "two", true]"#;
311        let items = split_input(input);
312        assert_eq!(items, vec!["1", "two", "true"]);
313    }
314
315    #[test]
316    fn test_split_input_empty() {
317        let input = "";
318        let items = split_input(input);
319        assert!(items.is_empty());
320    }
321
322    #[test]
323    fn test_gather_results_lines() {
324        let results = vec![
325            ScatterResult {
326                item: "a".to_string(),
327                result: ExecResult::success("result_a"),
328            },
329            ScatterResult {
330                item: "b".to_string(),
331                result: ExecResult::success("result_b"),
332            },
333        ];
334
335        let opts = GatherOptions::default();
336        let output = gather_results(&results, &opts);
337        assert_eq!(output, "result_a\nresult_b");
338    }
339
340    #[test]
341    fn test_gather_results_json() {
342        let results = vec![ScatterResult {
343            item: "test".to_string(),
344            result: ExecResult::success("output"),
345        }];
346
347        let opts = GatherOptions {
348            format: "json".to_string(),
349            ..Default::default()
350        };
351        let output = gather_results(&results, &opts);
352        assert!(output.contains("\"item\": \"test\""));
353        assert!(output.contains("\"ok\": true"));
354    }
355
356    #[test]
357    fn test_gather_results_first_n() {
358        let results = vec![
359            ScatterResult {
360                item: "a".to_string(),
361                result: ExecResult::success("1"),
362            },
363            ScatterResult {
364                item: "b".to_string(),
365                result: ExecResult::success("2"),
366            },
367            ScatterResult {
368                item: "c".to_string(),
369                result: ExecResult::success("3"),
370            },
371        ];
372
373        let opts = GatherOptions {
374            first: 2,
375            ..Default::default()
376        };
377        let output = gather_results(&results, &opts);
378        assert_eq!(output, "1\n2");
379    }
380
381    #[test]
382    fn test_parse_scatter_options() {
383        use crate::tools::ToolArgs;
384
385        let mut args = ToolArgs::new();
386        args.named.insert("as".to_string(), Value::String("URL".to_string()));
387        args.named.insert("limit".to_string(), Value::Int(4));
388
389        let opts = parse_scatter_options(&args);
390        assert_eq!(opts.var_name, "URL");
391        assert_eq!(opts.limit, 4);
392    }
393
394    #[test]
395    fn test_parse_gather_options() {
396        use crate::tools::ToolArgs;
397
398        let mut args = ToolArgs::new();
399        args.named.insert("first".to_string(), Value::Int(5));
400        args.named.insert("format".to_string(), Value::String("json".to_string()));
401
402        let opts = parse_gather_options(&args);
403        assert_eq!(opts.first, 5);
404        assert_eq!(opts.format, "json");
405    }
406}