1use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::time::Duration;
18
19use tokio::sync::Semaphore;
20use tracing::Instrument;
21
22use crate::ast::{Command, Value};
23use crate::dispatch::CommandDispatcher;
24use crate::duration::parse_duration;
25use crate::interpreter::ExecResult;
26use crate::tools::{ExecContext, ToolRegistry};
27
28use super::pipeline::PipelineRunner;
29
30#[derive(Debug, Clone)]
32pub struct ScatterOptions {
33 pub var_name: String,
35 pub limit: usize,
37 pub timeout: Option<Duration>,
41}
42
43#[derive(Debug, Clone)]
45pub struct GatherOptions {
46 pub progress: bool,
48 pub first: usize,
50 pub format: String,
52}
53
54impl Default for ScatterOptions {
55 fn default() -> Self {
56 Self {
57 var_name: "ITEM".to_string(),
58 limit: 8,
59 timeout: None,
60 }
61 }
62}
63
64impl Default for GatherOptions {
65 fn default() -> Self {
66 Self {
67 progress: false,
68 first: 0,
69 format: "lines".to_string(),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct ScatterResult {
77 pub item: String,
79 pub result: ExecResult,
81 pub timed_out: bool,
83}
84
85pub struct ScatterGatherRunner {
92 tools: Arc<ToolRegistry>,
93 sequential_dispatcher: Arc<dyn CommandDispatcher>,
96}
97
98impl ScatterGatherRunner {
99 pub fn new(
104 tools: Arc<ToolRegistry>,
105 dispatcher: Arc<dyn CommandDispatcher>,
106 ) -> Self {
107 Self { tools, sequential_dispatcher: dispatcher }
108 }
109
110 #[tracing::instrument(level = "info", skip(self, pre_scatter, scatter_opts, parallel, gather_opts, post_gather, ctx), fields(item_count = tracing::field::Empty, parallelism = scatter_opts.limit))]
119 pub async fn run(
120 &self,
121 pre_scatter: &[Command],
122 scatter_opts: ScatterOptions,
123 parallel: &[Command],
124 gather_opts: GatherOptions,
125 post_gather: &[Command],
126 ctx: &mut ExecContext,
127 ) -> ExecResult {
128 let runner = PipelineRunner::new(self.tools.clone());
129
130 let (text, data) = if pre_scatter.is_empty() {
133 let data = ctx.take_stdin_data();
135 let text = ctx.take_stdin().unwrap_or_default();
136 (text, data)
137 } else {
138 let result = runner.run_sequential(pre_scatter, ctx, &*self.sequential_dispatcher).await;
139 if !result.ok() {
140 return result;
141 }
142 (result.text_out().into_owned(), result.data)
143 };
144
145 let items = match extract_items(data.as_ref(), &text) {
147 Ok(items) => items,
148 Err(msg) => return ExecResult::failure(1, msg),
149 };
150 if items.is_empty() {
151 return ExecResult::success("");
152 }
153
154 tracing::Span::current().record("item_count", items.len());
155
156 let results = self
158 .run_parallel(&items, &scatter_opts, parallel, ctx)
159 .await;
160
161 let GatherOutput {
163 text: gathered,
164 dropped_failures,
165 } = gather_results(&results, &gather_opts);
166
167 if !dropped_failures.is_empty() {
173 let err = format!(
174 "gather: {} task(s) failed and were omitted from line output: {} (use --json to capture per-task status)",
175 dropped_failures.len(),
176 dropped_failures.join(", ")
177 );
178 return ExecResult::from_output(1, gathered, err);
179 }
180
181 if post_gather.is_empty() {
183 ExecResult::success(gathered)
184 } else {
185 ctx.set_stdin(gathered);
186 runner.run_sequential(post_gather, ctx, &*self.sequential_dispatcher).await
187 }
188 }
189
190 #[tracing::instrument(level = "debug", skip(self, items, opts, commands, base_ctx), fields(worker_count = items.len()))]
199 async fn run_parallel(
200 &self,
201 items: &[String],
202 opts: &ScatterOptions,
203 commands: &[Command],
204 base_ctx: &ExecContext,
205 ) -> Vec<ScatterResult> {
206 let semaphore = Arc::new(Semaphore::new(opts.limit));
207 let tools = self.tools.clone();
208 let var_name = opts.var_name.clone();
209
210 let mut handles = Vec::with_capacity(items.len());
212
213 for item in items.iter().cloned() {
214 let permit = semaphore.clone().acquire_owned().await;
215 let tools = tools.clone();
216 let worker_dispatcher = self.sequential_dispatcher.fork_attached().await;
221 let commands = commands.to_vec();
222 let var_name = var_name.clone();
223 let base_scope = base_ctx.scope.clone();
224 let backend = base_ctx.backend.clone();
225 let cwd = base_ctx.cwd.clone();
226 let parent_token = base_ctx.cancel.clone();
227 let worker_token = parent_token.child_token();
228
229 let timed_out_flag = Arc::new(AtomicBool::new(false));
235 let timer_handle: Option<tokio::task::JoinHandle<()>> = opts.timeout.map(|d| {
236 let cancel = worker_token.clone();
237 let flag = timed_out_flag.clone();
238 tokio::spawn(async move {
239 tokio::time::sleep(d).await;
240 flag.store(true, Ordering::SeqCst);
241 cancel.cancel();
242 })
243 });
244 let timed_out_check = timed_out_flag.clone();
245
246 let item_label = if item.len() > 64 {
247 format!("{}...", &item[..64])
248 } else {
249 item.clone()
250 };
251 let worker_span = tracing::debug_span!("scatter_worker", item = %item_label);
252 let handle = tokio::spawn(crate::telemetry::bind_current_context(async move {
256 let _permit = permit; let mut scope = base_scope;
260 scope.set(&var_name, Value::String(item.clone()));
261
262 let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
263 ctx.set_cwd(cwd);
264 ctx.cancel = worker_token;
265
266 let runner = PipelineRunner::new(tools);
269 let result = runner.run_sequential(&commands, &mut ctx, &*worker_dispatcher).await;
270
271 if let Some(h) = timer_handle {
274 h.abort();
275 }
276
277 let timed_out = timed_out_check.load(Ordering::SeqCst);
278 ScatterResult { item, result, timed_out }
279 }.instrument(worker_span)));
280
281 handles.push(handle);
282 }
283
284 let mut results = Vec::with_capacity(handles.len());
286 for handle in handles {
287 match handle.await {
288 Ok(result) => results.push(result),
289 Err(e) => {
290 results.push(ScatterResult {
291 item: String::new(),
292 result: ExecResult::failure(1, format!("Task panicked: {}", e)),
293 timed_out: false,
294 });
295 }
296 }
297 }
298
299 results
300 }
301}
302
303pub fn extract_items(data: Option<&Value>, text: &str) -> Result<Vec<String>, String> {
310 if let Some(Value::Json(serde_json::Value::Array(arr))) = data {
312 return Ok(arr.iter().map(|v| match v {
313 serde_json::Value::String(s) => s.clone(),
314 other => other.to_string(),
315 }).collect());
316 }
317 if let Some(Value::String(s)) = data {
318 return Ok(vec![s.clone()]);
319 }
320
321 let trimmed = text.trim();
323 if trimmed.is_empty() {
324 return Ok(vec![]);
325 }
326
327 Ok(vec![trimmed.to_string()])
329}
330
331struct GatherOutput {
334 text: String,
335 dropped_failures: Vec<String>,
339}
340
341fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> GatherOutput {
349 let results_to_use = if opts.first > 0 && opts.first < results.len() {
350 &results[..opts.first]
351 } else {
352 results
353 };
354
355 if opts.format == "json" {
356 let json_results: Vec<serde_json::Value> = results_to_use
358 .iter()
359 .map(|r| {
360 serde_json::json!({
361 "item": r.item,
362 "ok": r.result.ok(),
363 "code": r.result.code,
364 "out": r.result.text_out().trim(),
365 "err": r.result.err.trim(),
366 "timed_out": r.timed_out,
367 })
368 })
369 .collect();
370
371 GatherOutput {
372 text: serde_json::to_string_pretty(&json_results).unwrap_or_default(),
373 dropped_failures: Vec::new(),
374 }
375 } else {
376 let text = results_to_use
381 .iter()
382 .filter(|r| r.result.ok())
383 .map(|r| r.result.text_out())
384 .map(|t| t.trim().to_string())
385 .collect::<Vec<_>>()
386 .join("\n");
387 let dropped_failures = results_to_use
388 .iter()
389 .filter(|r| !r.result.ok())
390 .map(|r| r.item.clone())
391 .collect();
392 GatherOutput {
393 text,
394 dropped_failures,
395 }
396 }
397}
398
399pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
401 let mut opts = ScatterOptions::default();
402
403 if let Some(Value::String(name)) = args.named.get("as") {
404 opts.var_name = name.clone();
405 }
406
407 if let Some(Value::Int(n)) = args.named.get("limit") {
408 let requested = *n;
409 let clamped = requested.clamp(1, SCATTER_LIMIT_MAX as i64);
410 if requested > SCATTER_LIMIT_MAX as i64 {
411 tracing::warn!(
412 target: "kaish::scatter",
413 requested = requested,
414 ceiling = SCATTER_LIMIT_MAX,
415 "scatter limit clamped to ceiling"
416 );
417 }
418 opts.limit = clamped as usize;
419 }
420
421 if let Some(Value::String(s)) = args.named.get("timeout") {
425 match parse_duration(s) {
426 Some(d) => opts.timeout = Some(d),
427 None => tracing::warn!(
428 target: "kaish::scatter",
429 value = %s,
430 "scatter --timeout: invalid duration (try: 30, 5s, 500ms, 2m, 1h)"
431 ),
432 }
433 } else if let Some(Value::Int(n)) = args.named.get("timeout") {
434 if *n >= 0 {
435 opts.timeout = Some(Duration::from_secs(*n as u64));
436 }
437 }
438
439 opts
440}
441
442pub const SCATTER_LIMIT_MAX: usize = 10_000;
446
447pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
449 let mut opts = GatherOptions::default();
450
451 if args.has_flag("progress") {
452 opts.progress = true;
453 }
454
455 if let Some(Value::Int(n)) = args.named.get("first") {
456 opts.first = (*n).max(0) as usize;
457 }
458
459 if let Some(Value::String(fmt)) = args.named.get("format") {
460 opts.format = fmt.clone();
461 }
462
463 opts
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_extract_items_structured_json_array() {
472 let data = Value::Json(serde_json::json!(["a", "b", "c"]));
473 let items = extract_items(Some(&data), "").unwrap();
474 assert_eq!(items, vec!["a", "b", "c"]);
475 }
476
477 #[test]
478 fn test_extract_items_structured_mixed_types() {
479 let data = Value::Json(serde_json::json!([1, "two", true]));
480 let items = extract_items(Some(&data), "").unwrap();
481 assert_eq!(items, vec!["1", "two", "true"]);
482 }
483
484 #[test]
485 fn test_extract_items_structured_string() {
486 let data = Value::String("single".into());
487 let items = extract_items(Some(&data), "").unwrap();
488 assert_eq!(items, vec!["single"]);
489 }
490
491 #[test]
492 fn test_extract_items_single_line_text() {
493 let items = extract_items(None, "hello").unwrap();
494 assert_eq!(items, vec!["hello"]);
495 }
496
497 #[test]
498 fn test_extract_items_empty() {
499 let items = extract_items(None, "").unwrap();
500 assert!(items.is_empty());
501 }
502
503 #[test]
504 fn test_extract_items_multiline_is_one_item() {
505 let items = extract_items(None, "one\ntwo\nthree").unwrap();
507 assert_eq!(items, vec!["one\ntwo\nthree"]);
508 }
509
510 #[test]
511 fn test_extract_items_structured_overrides_text() {
512 let data = Value::Json(serde_json::json!(["x", "y"]));
514 let items = extract_items(Some(&data), "ignored\ntext").unwrap();
515 assert_eq!(items, vec!["x", "y"]);
516 }
517
518 #[test]
519 fn test_gather_results_lines() {
520 let results = vec![
521 ScatterResult {
522 item: "a".to_string(),
523 result: ExecResult::success("result_a"),
524 timed_out: false,
525 },
526 ScatterResult {
527 item: "b".to_string(),
528 result: ExecResult::success("result_b"),
529 timed_out: false,
530 },
531 ];
532
533 let opts = GatherOptions::default();
534 let output = gather_results(&results, &opts);
535 assert_eq!(output.text, "result_a\nresult_b");
536 assert!(output.dropped_failures.is_empty());
537 }
538
539 #[test]
540 fn test_gather_results_lines_reports_dropped_failures() {
541 let results = vec![
544 ScatterResult {
545 item: "a".to_string(),
546 result: ExecResult::success("result_a"),
547 timed_out: false,
548 },
549 ScatterResult {
550 item: "b".to_string(),
551 result: ExecResult::failure(1, "boom"),
552 timed_out: false,
553 },
554 ];
555
556 let opts = GatherOptions::default();
557 let output = gather_results(&results, &opts);
558 assert_eq!(output.text, "result_a");
560 assert_eq!(output.dropped_failures, vec!["b".to_string()]);
561 }
562
563 #[test]
564 fn test_gather_results_json_keeps_failures_as_rows() {
565 let results = vec![ScatterResult {
567 item: "b".to_string(),
568 result: ExecResult::failure(2, "boom"),
569 timed_out: false,
570 }];
571 let opts = GatherOptions {
572 format: "json".to_string(),
573 ..Default::default()
574 };
575 let output = gather_results(&results, &opts);
576 assert!(output.dropped_failures.is_empty());
577 assert!(output.text.contains("\"ok\": false"));
578 assert!(output.text.contains("\"code\": 2"));
579 }
580
581 #[test]
582 fn test_gather_results_json() {
583 let results = vec![ScatterResult {
584 item: "test".to_string(),
585 result: ExecResult::success("output"),
586 timed_out: false,
587 }];
588
589 let opts = GatherOptions {
590 format: "json".to_string(),
591 ..Default::default()
592 };
593 let output = gather_results(&results, &opts);
594 assert!(output.text.contains("\"item\": \"test\""));
595 assert!(output.text.contains("\"ok\": true"));
596 }
597
598 #[test]
599 fn test_gather_results_first_n() {
600 let results = vec![
601 ScatterResult {
602 item: "a".to_string(),
603 result: ExecResult::success("1"),
604 timed_out: false,
605 },
606 ScatterResult {
607 item: "b".to_string(),
608 result: ExecResult::success("2"),
609 timed_out: false,
610 },
611 ScatterResult {
612 item: "c".to_string(),
613 result: ExecResult::success("3"),
614 timed_out: false,
615 },
616 ];
617
618 let opts = GatherOptions {
619 first: 2,
620 ..Default::default()
621 };
622 let output = gather_results(&results, &opts);
623 assert_eq!(output.text, "1\n2");
624 }
625
626 #[test]
627 fn test_parse_scatter_options() {
628 use crate::tools::ToolArgs;
629
630 let mut args = ToolArgs::new();
631 args.named.insert("as".to_string(), Value::String("URL".to_string()));
632 args.named.insert("limit".to_string(), Value::Int(4));
633
634 let opts = parse_scatter_options(&args);
635 assert_eq!(opts.var_name, "URL");
636 assert_eq!(opts.limit, 4);
637 }
638
639 #[test]
640 fn test_parse_gather_options() {
641 use crate::tools::ToolArgs;
642
643 let mut args = ToolArgs::new();
644 args.named.insert("first".to_string(), Value::Int(5));
645 args.named.insert("format".to_string(), Value::String("json".to_string()));
646
647 let opts = parse_gather_options(&args);
648 assert_eq!(opts.first, 5);
649 assert_eq!(opts.format, "json");
650 }
651
652 #[test]
653 fn scatter_limit_clamps_to_ceiling() {
654 use crate::tools::ToolArgs;
655
656 let mut args = ToolArgs::new();
657 args.named.insert("limit".to_string(), Value::Int(999_999));
658 let opts = parse_scatter_options(&args);
659 assert_eq!(opts.limit, SCATTER_LIMIT_MAX);
660 }
661
662 #[test]
663 fn scatter_limit_raises_zero_to_one() {
664 use crate::tools::ToolArgs;
665
666 let mut args = ToolArgs::new();
667 args.named.insert("limit".to_string(), Value::Int(0));
668 let opts = parse_scatter_options(&args);
669 assert_eq!(opts.limit, 1);
670 }
671
672 #[test]
673 fn scatter_limit_raises_negative_to_one() {
674 use crate::tools::ToolArgs;
675
676 let mut args = ToolArgs::new();
677 args.named.insert("limit".to_string(), Value::Int(-42));
678 let opts = parse_scatter_options(&args);
679 assert_eq!(opts.limit, 1);
680 }
681
682 #[test]
683 fn scatter_limit_preserves_valid_values() {
684 use crate::tools::ToolArgs;
685
686 let mut args = ToolArgs::new();
687 args.named.insert("limit".to_string(), Value::Int(500));
688 let opts = parse_scatter_options(&args);
689 assert_eq!(opts.limit, 500);
690 }
691}