Skip to main content

kaish_kernel/scheduler/
scatter.rs

1//! Scatter/Gather — Parallel pipeline execution.
2//!
3//! Scatter splits input into items and runs the pipeline in parallel.
4//! Gather collects the parallel results.
5//!
6//! # Example
7//!
8//! ```text
9//! cat urls.txt | scatter | fetch url=${ITEM} | gather
10//! ```
11//!
12//! This reads URLs, then for each URL runs `fetch` in parallel,
13//! then collects all results.
14
15use std::sync::Arc;
16
17use tokio::sync::Semaphore;
18use tracing::Instrument;
19
20use crate::ast::{Command, Value};
21use crate::dispatch::CommandDispatcher;
22use crate::interpreter::ExecResult;
23use crate::tools::{ExecContext, ToolRegistry};
24
25use super::pipeline::PipelineRunner;
26
27/// Options for scatter operation.
28#[derive(Debug, Clone)]
29pub struct ScatterOptions {
30    /// Variable name to bind each item to (default: "ITEM").
31    pub var_name: String,
32    /// Maximum parallelism (default: 8).
33    pub limit: usize,
34}
35
36impl Default for ScatterOptions {
37    fn default() -> Self {
38        Self {
39            var_name: "ITEM".to_string(),
40            limit: 8,
41        }
42    }
43}
44
45/// Options for gather operation.
46#[derive(Debug, Clone)]
47pub struct GatherOptions {
48    /// Show progress indicator.
49    pub progress: bool,
50    /// Take first N results and cancel rest (0 = all).
51    pub first: usize,
52    /// Output format: "json" or "lines".
53    pub format: String,
54}
55
56impl Default for GatherOptions {
57    fn default() -> Self {
58        Self {
59            progress: false,
60            first: 0,
61            format: "lines".to_string(),
62        }
63    }
64}
65
66/// Result from a single scatter worker.
67#[derive(Debug, Clone)]
68pub struct ScatterResult {
69    /// The input item that was processed.
70    pub item: String,
71    /// The execution result.
72    pub result: ExecResult,
73}
74
75/// Runs scatter/gather pipelines.
76///
77/// The dispatcher is used for pre_scatter, post_gather, and parallel worker
78/// command execution. This enables scatter blocks to run user tools, scripts,
79/// and external commands — not just builtins.
80pub struct ScatterGatherRunner {
81    tools: Arc<ToolRegistry>,
82    dispatcher: Arc<dyn CommandDispatcher>,
83}
84
85impl ScatterGatherRunner {
86    /// Create a new scatter/gather runner with the given dispatcher.
87    ///
88    /// The dispatcher handles the full command resolution chain for all
89    /// pipeline stages (pre_scatter, parallel workers, post_gather).
90    pub fn new(tools: Arc<ToolRegistry>, dispatcher: Arc<dyn CommandDispatcher>) -> Self {
91        Self { tools, dispatcher }
92    }
93
94    /// Execute a scatter/gather pipeline.
95    ///
96    /// The pipeline is split into three parts:
97    /// - pre_scatter: commands before scatter
98    /// - parallel: commands between scatter and gather
99    /// - post_gather: commands after gather
100    ///
101    /// Returns the final result after all stages complete.
102    #[tracing::instrument(level = "info", skip(self, pre_scatter, scatter_opts, parallel, gather_opts, post_gather, ctx), fields(item_count = tracing::field::Empty, parallelism = scatter_opts.limit))]
103    pub async fn run(
104        &self,
105        pre_scatter: &[Command],
106        scatter_opts: ScatterOptions,
107        parallel: &[Command],
108        gather_opts: GatherOptions,
109        post_gather: &[Command],
110        ctx: &mut ExecContext,
111    ) -> ExecResult {
112        let runner = PipelineRunner::new(self.tools.clone());
113
114        // Run pre-scatter commands to get input.
115        // Uses run_sequential to avoid async recursion (scatter → run → scatter).
116        let (text, data) = if pre_scatter.is_empty() {
117            // Use existing stdin
118            let data = ctx.take_stdin_data();
119            let text = ctx.take_stdin().unwrap_or_default();
120            (text, data)
121        } else {
122            let result = runner.run_sequential(pre_scatter, ctx, &*self.dispatcher).await;
123            if !result.ok() {
124                return result;
125            }
126            (result.out, result.data)
127        };
128
129        // Extract items from structured data or text
130        let items = match extract_items(data.as_ref(), &text) {
131            Ok(items) => items,
132            Err(msg) => return ExecResult::failure(1, msg),
133        };
134        if items.is_empty() {
135            return ExecResult::success("");
136        }
137
138        tracing::Span::current().record("item_count", items.len());
139
140        // Run parallel stage
141        let results = self
142            .run_parallel(&items, &scatter_opts, parallel, ctx)
143            .await;
144
145        // Gather results
146        let gathered = gather_results(&results, &gather_opts);
147
148        // Run post-gather commands if any
149        if post_gather.is_empty() {
150            ExecResult::success(gathered)
151        } else {
152            ctx.set_stdin(gathered);
153            runner.run_sequential(post_gather, ctx, &*self.dispatcher).await
154        }
155    }
156
157    /// Run the parallel stage for all items.
158    ///
159    /// # Safety constraint
160    ///
161    /// Parallel workers share the dispatcher via `Arc`. The dispatcher MUST be
162    /// stateless (like `BackendDispatcher`) when used from parallel workers.
163    /// Using a stateful dispatcher (like `Kernel`, which writes to `self.scope`)
164    /// would cause data races — parallel workers would stomp each other's scope.
165    ///
166    /// The `Kernel` dispatcher is safe for pre_scatter and post_gather (sequential),
167    /// but scatter workers must use `BackendDispatcher` until per-worker kernel
168    /// instances are implemented.
169    #[tracing::instrument(level = "debug", skip(self, items, opts, commands, base_ctx), fields(worker_count = items.len()))]
170    async fn run_parallel(
171        &self,
172        items: &[String],
173        opts: &ScatterOptions,
174        commands: &[Command],
175        base_ctx: &ExecContext,
176    ) -> Vec<ScatterResult> {
177        let semaphore = Arc::new(Semaphore::new(opts.limit));
178        let tools = self.tools.clone();
179        let dispatcher = self.dispatcher.clone();
180        let var_name = opts.var_name.clone();
181
182        // Spawn parallel tasks
183        let mut handles = Vec::with_capacity(items.len());
184
185        for item in items.iter().cloned() {
186            let permit = semaphore.clone().acquire_owned().await;
187            let tools = tools.clone();
188            let dispatcher = dispatcher.clone();
189            let commands = commands.to_vec();
190            let var_name = var_name.clone();
191            let base_scope = base_ctx.scope.clone();
192            let backend = base_ctx.backend.clone();
193            let cwd = base_ctx.cwd.clone();
194
195            let item_label = if item.len() > 64 {
196                format!("{}...", &item[..64])
197            } else {
198                item.clone()
199            };
200            let worker_span = tracing::debug_span!("scatter_worker", item = %item_label);
201            let handle = tokio::spawn(async move {
202                let _permit = permit; // Hold permit until done
203
204                // Create context for this worker
205                let mut scope = base_scope;
206                scope.set(&var_name, Value::String(item.clone()));
207
208                let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
209                ctx.set_cwd(cwd);
210
211                // Run through PipelineRunner + dispatcher (full resolution chain).
212                // Uses run_sequential to avoid async recursion and infinite future size.
213                let runner = PipelineRunner::new(tools);
214                let result = runner.run_sequential(&commands, &mut ctx, &*dispatcher).await;
215
216                ScatterResult { item, result }
217            }.instrument(worker_span));
218
219            handles.push(handle);
220        }
221
222        // Collect results
223        let mut results = Vec::with_capacity(handles.len());
224        for handle in handles {
225            match handle.await {
226                Ok(result) => results.push(result),
227                Err(e) => {
228                    results.push(ScatterResult {
229                        item: String::new(),
230                        result: ExecResult::failure(1, format!("Task panicked: {}", e)),
231                    });
232                }
233            }
234        }
235
236        results
237    }
238}
239
240/// Extract items from structured data or text.
241///
242/// kaish does not split implicitly — this function requires structured data
243/// (JSON array from split/seq/glob/find) for multi-item input. Single-line
244/// text is treated as one item. Multi-line text without structured data is
245/// an error.
246pub fn extract_items(data: Option<&Value>, text: &str) -> Result<Vec<String>, String> {
247    // 1. Structured data (JSON array from split/seq/glob/find) — use it
248    if let Some(Value::Json(serde_json::Value::Array(arr))) = data {
249        return Ok(arr.iter().map(|v| match v {
250            serde_json::Value::String(s) => s.clone(),
251            other => other.to_string(),
252        }).collect());
253    }
254    if let Some(Value::String(s)) = data {
255        return Ok(vec![s.clone()]);
256    }
257
258    // 2. Empty — return empty
259    let trimmed = text.trim();
260    if trimmed.is_empty() {
261        return Ok(vec![]);
262    }
263
264    // 3. Raw text without structured data — one item (no implicit splitting)
265    Ok(vec![trimmed.to_string()])
266}
267
268/// Gather results into output string.
269fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> String {
270    let results_to_use = if opts.first > 0 && opts.first < results.len() {
271        &results[..opts.first]
272    } else {
273        results
274    };
275
276    if opts.format == "json" {
277        // Output as JSON array of objects
278        let json_results: Vec<serde_json::Value> = results_to_use
279            .iter()
280            .map(|r| {
281                serde_json::json!({
282                    "item": r.item,
283                    "ok": r.result.ok(),
284                    "code": r.result.code,
285                    "out": r.result.out.trim(),
286                    "err": r.result.err.trim(),
287                })
288            })
289            .collect();
290
291        serde_json::to_string_pretty(&json_results).unwrap_or_default()
292    } else {
293        // Output as lines (stdout from each, separated by newlines)
294        results_to_use
295            .iter()
296            .filter(|r| r.result.ok())
297            .map(|r| r.result.out.trim())
298            .collect::<Vec<_>>()
299            .join("\n")
300    }
301}
302
303/// Parse scatter options from tool args.
304pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
305    let mut opts = ScatterOptions::default();
306
307    if let Some(Value::String(name)) = args.named.get("as") {
308        opts.var_name = name.clone();
309    }
310
311    if let Some(Value::Int(n)) = args.named.get("limit") {
312        opts.limit = (*n).max(1) as usize;
313    }
314
315    opts
316}
317
318/// Parse gather options from tool args.
319pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
320    let mut opts = GatherOptions::default();
321
322    if args.has_flag("progress") {
323        opts.progress = true;
324    }
325
326    if let Some(Value::Int(n)) = args.named.get("first") {
327        opts.first = (*n).max(0) as usize;
328    }
329
330    if let Some(Value::String(fmt)) = args.named.get("format") {
331        opts.format = fmt.clone();
332    }
333
334    opts
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_extract_items_structured_json_array() {
343        let data = Value::Json(serde_json::json!(["a", "b", "c"]));
344        let items = extract_items(Some(&data), "").unwrap();
345        assert_eq!(items, vec!["a", "b", "c"]);
346    }
347
348    #[test]
349    fn test_extract_items_structured_mixed_types() {
350        let data = Value::Json(serde_json::json!([1, "two", true]));
351        let items = extract_items(Some(&data), "").unwrap();
352        assert_eq!(items, vec!["1", "two", "true"]);
353    }
354
355    #[test]
356    fn test_extract_items_structured_string() {
357        let data = Value::String("single".into());
358        let items = extract_items(Some(&data), "").unwrap();
359        assert_eq!(items, vec!["single"]);
360    }
361
362    #[test]
363    fn test_extract_items_single_line_text() {
364        let items = extract_items(None, "hello").unwrap();
365        assert_eq!(items, vec!["hello"]);
366    }
367
368    #[test]
369    fn test_extract_items_empty() {
370        let items = extract_items(None, "").unwrap();
371        assert!(items.is_empty());
372    }
373
374    #[test]
375    fn test_extract_items_multiline_is_one_item() {
376        // No implicit splitting — multi-line text is one item
377        let items = extract_items(None, "one\ntwo\nthree").unwrap();
378        assert_eq!(items, vec!["one\ntwo\nthree"]);
379    }
380
381    #[test]
382    fn test_extract_items_structured_overrides_text() {
383        // Structured data takes priority over text
384        let data = Value::Json(serde_json::json!(["x", "y"]));
385        let items = extract_items(Some(&data), "ignored\ntext").unwrap();
386        assert_eq!(items, vec!["x", "y"]);
387    }
388
389    #[test]
390    fn test_gather_results_lines() {
391        let results = vec![
392            ScatterResult {
393                item: "a".to_string(),
394                result: ExecResult::success("result_a"),
395            },
396            ScatterResult {
397                item: "b".to_string(),
398                result: ExecResult::success("result_b"),
399            },
400        ];
401
402        let opts = GatherOptions::default();
403        let output = gather_results(&results, &opts);
404        assert_eq!(output, "result_a\nresult_b");
405    }
406
407    #[test]
408    fn test_gather_results_json() {
409        let results = vec![ScatterResult {
410            item: "test".to_string(),
411            result: ExecResult::success("output"),
412        }];
413
414        let opts = GatherOptions {
415            format: "json".to_string(),
416            ..Default::default()
417        };
418        let output = gather_results(&results, &opts);
419        assert!(output.contains("\"item\": \"test\""));
420        assert!(output.contains("\"ok\": true"));
421    }
422
423    #[test]
424    fn test_gather_results_first_n() {
425        let results = vec![
426            ScatterResult {
427                item: "a".to_string(),
428                result: ExecResult::success("1"),
429            },
430            ScatterResult {
431                item: "b".to_string(),
432                result: ExecResult::success("2"),
433            },
434            ScatterResult {
435                item: "c".to_string(),
436                result: ExecResult::success("3"),
437            },
438        ];
439
440        let opts = GatherOptions {
441            first: 2,
442            ..Default::default()
443        };
444        let output = gather_results(&results, &opts);
445        assert_eq!(output, "1\n2");
446    }
447
448    #[test]
449    fn test_parse_scatter_options() {
450        use crate::tools::ToolArgs;
451
452        let mut args = ToolArgs::new();
453        args.named.insert("as".to_string(), Value::String("URL".to_string()));
454        args.named.insert("limit".to_string(), Value::Int(4));
455
456        let opts = parse_scatter_options(&args);
457        assert_eq!(opts.var_name, "URL");
458        assert_eq!(opts.limit, 4);
459    }
460
461    #[test]
462    fn test_parse_gather_options() {
463        use crate::tools::ToolArgs;
464
465        let mut args = ToolArgs::new();
466        args.named.insert("first".to_string(), Value::Int(5));
467        args.named.insert("format".to_string(), Value::String("json".to_string()));
468
469        let opts = parse_gather_options(&args);
470        assert_eq!(opts.first, 5);
471        assert_eq!(opts.format, "json");
472    }
473}