1use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::time::Duration;
18
19use tokio::sync::Semaphore;
20use tracing::Instrument;
21
22use crate::ast::{Command, Value};
23use crate::dispatch::CommandDispatcher;
24use crate::duration::parse_duration;
25use crate::interpreter::ExecResult;
26use crate::tools::{ExecContext, ToolRegistry};
27
28use super::pipeline::PipelineRunner;
29
30#[derive(Debug, Clone)]
32pub struct ScatterOptions {
33 pub var_name: String,
35 pub limit: usize,
37 pub timeout: Option<Duration>,
41}
42
43#[derive(Debug, Clone)]
45pub struct GatherOptions {
46 pub progress: bool,
48 pub first: usize,
50 pub format: String,
52}
53
54impl Default for ScatterOptions {
55 fn default() -> Self {
56 Self {
57 var_name: "ITEM".to_string(),
58 limit: 8,
59 timeout: None,
60 }
61 }
62}
63
64impl Default for GatherOptions {
65 fn default() -> Self {
66 Self {
67 progress: false,
68 first: 0,
69 format: "lines".to_string(),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct ScatterResult {
77 pub item: String,
79 pub result: ExecResult,
81 pub timed_out: bool,
83}
84
85pub struct ScatterGatherRunner {
92 tools: Arc<ToolRegistry>,
93 sequential_dispatcher: Arc<dyn CommandDispatcher>,
96}
97
98impl ScatterGatherRunner {
99 pub fn new(
104 tools: Arc<ToolRegistry>,
105 dispatcher: Arc<dyn CommandDispatcher>,
106 ) -> Self {
107 Self { tools, sequential_dispatcher: dispatcher }
108 }
109
110 #[tracing::instrument(level = "info", skip(self, pre_scatter, scatter_opts, parallel, gather_opts, post_gather, ctx), fields(item_count = tracing::field::Empty, parallelism = scatter_opts.limit))]
119 pub async fn run(
120 &self,
121 pre_scatter: &[Command],
122 scatter_opts: ScatterOptions,
123 parallel: &[Command],
124 gather_opts: GatherOptions,
125 post_gather: &[Command],
126 ctx: &mut ExecContext,
127 ) -> ExecResult {
128 let runner = PipelineRunner::new(self.tools.clone());
129
130 let (text, data) = if pre_scatter.is_empty() {
133 let data = ctx.take_stdin_data();
135 let text = ctx.take_stdin().unwrap_or_default();
136 (text, data)
137 } else {
138 let result = runner.run_sequential(pre_scatter, ctx, &*self.sequential_dispatcher).await;
139 if !result.ok() {
140 return result;
141 }
142 (result.text_out().into_owned(), result.data)
143 };
144
145 let items = match extract_items(data.as_ref(), &text) {
147 Ok(items) => items,
148 Err(msg) => return ExecResult::failure(1, msg),
149 };
150 if items.is_empty() {
151 return ExecResult::success("");
152 }
153
154 tracing::Span::current().record("item_count", items.len());
155
156 let results = self
158 .run_parallel(&items, &scatter_opts, parallel, ctx)
159 .await;
160
161 let GatherOutput {
163 text: gathered,
164 dropped_failures,
165 } = gather_results(&results, &gather_opts);
166
167 if !dropped_failures.is_empty() {
173 let err = format!(
174 "gather: {} task(s) failed and were omitted from line output: {} (use --json to capture per-task status)",
175 dropped_failures.len(),
176 dropped_failures.join(", ")
177 );
178 return ExecResult::from_output(1, gathered, err);
179 }
180
181 if post_gather.is_empty() {
183 ExecResult::success(gathered)
184 } else {
185 ctx.set_stdin(gathered);
186 runner.run_sequential(post_gather, ctx, &*self.sequential_dispatcher).await
187 }
188 }
189
190 #[tracing::instrument(level = "debug", skip(self, items, opts, commands, base_ctx), fields(worker_count = items.len()))]
199 async fn run_parallel(
200 &self,
201 items: &[String],
202 opts: &ScatterOptions,
203 commands: &[Command],
204 base_ctx: &ExecContext,
205 ) -> Vec<ScatterResult> {
206 let semaphore = Arc::new(Semaphore::new(opts.limit));
207 let tools = self.tools.clone();
208 let var_name = opts.var_name.clone();
209
210 let mut handles = Vec::with_capacity(items.len());
212
213 for item in items.iter().cloned() {
214 let permit = semaphore.clone().acquire_owned().await;
215 let tools = tools.clone();
216 let worker_dispatcher = self.sequential_dispatcher.fork_attached().await;
221 let commands = commands.to_vec();
222 let var_name = var_name.clone();
223 let base_scope = base_ctx.scope.clone();
224 let backend = base_ctx.backend.clone();
225 let cwd = base_ctx.cwd.clone();
226 let parent_token = base_ctx.cancel.clone();
227 let worker_token = parent_token.child_token();
228
229 let timed_out_flag = Arc::new(AtomicBool::new(false));
235 let timer_handle: Option<tokio::task::JoinHandle<()>> = opts.timeout.map(|d| {
236 let cancel = worker_token.clone();
237 let flag = timed_out_flag.clone();
238 tokio::spawn(async move {
239 tokio::time::sleep(d).await;
240 flag.store(true, Ordering::SeqCst);
241 cancel.cancel();
242 })
243 });
244 let timed_out_check = timed_out_flag.clone();
245
246 let item_label = if item.len() > 64 {
247 format!("{}...", &item[..64])
248 } else {
249 item.clone()
250 };
251 let worker_span = tracing::debug_span!("scatter_worker", item = %item_label);
252 let handle = tokio::spawn(crate::telemetry::bind_current_context(async move {
256 let _permit = permit; let mut scope = base_scope;
260 scope.set(&var_name, Value::String(item.clone()));
261
262 let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
263 ctx.set_cwd(cwd);
264 ctx.cancel = worker_token;
265
266 let runner = PipelineRunner::new(tools);
269 let result = runner.run_sequential(&commands, &mut ctx, &*worker_dispatcher).await;
270
271 if let Some(h) = timer_handle {
274 h.abort();
275 }
276
277 let timed_out = timed_out_check.load(Ordering::SeqCst);
278 ScatterResult { item, result, timed_out }
279 }.instrument(worker_span)));
280
281 handles.push(handle);
282 }
283
284 let mut results = Vec::with_capacity(handles.len());
286 for handle in handles {
287 match handle.await {
288 Ok(result) => results.push(result),
289 Err(e) => {
290 results.push(ScatterResult {
291 item: String::new(),
292 result: ExecResult::failure(1, format!("Task panicked: {}", e)),
293 timed_out: false,
294 });
295 }
296 }
297 }
298
299 results
300 }
301}
302
303pub fn extract_items(data: Option<&Value>, text: &str) -> Result<Vec<String>, String> {
312 if let Some(Value::Json(serde_json::Value::Array(arr))) = data {
314 return Ok(arr.iter().map(|v| match v {
315 serde_json::Value::String(s) => s.clone(),
316 other => other.to_string(),
317 }).collect());
318 }
319 if let Some(Value::String(s)) = data {
320 return Ok(vec![s.clone()]);
321 }
322
323 let trimmed = text.trim_end_matches(['\n', '\r']);
325 if trimmed.is_empty() {
326 return Ok(vec![]);
327 }
328 Ok(trimmed
329 .split('\n')
330 .map(|line| line.trim_end_matches('\r').to_string())
331 .collect())
332}
333
334struct GatherOutput {
337 text: String,
338 dropped_failures: Vec<String>,
342}
343
344fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> GatherOutput {
352 let results_to_use = if opts.first > 0 && opts.first < results.len() {
353 &results[..opts.first]
354 } else {
355 results
356 };
357
358 if opts.format == "json" {
359 let json_results: Vec<serde_json::Value> = results_to_use
361 .iter()
362 .map(|r| {
363 serde_json::json!({
364 "item": r.item,
365 "ok": r.result.ok(),
366 "code": r.result.code,
367 "out": r.result.text_out().trim(),
368 "err": r.result.err.trim(),
369 "timed_out": r.timed_out,
370 })
371 })
372 .collect();
373
374 GatherOutput {
375 text: serde_json::to_string_pretty(&json_results).unwrap_or_default(),
376 dropped_failures: Vec::new(),
377 }
378 } else {
379 let text = results_to_use
384 .iter()
385 .filter(|r| r.result.ok())
386 .map(|r| r.result.text_out())
387 .map(|t| t.trim().to_string())
388 .collect::<Vec<_>>()
389 .join("\n");
390 let dropped_failures = results_to_use
391 .iter()
392 .filter(|r| !r.result.ok())
393 .map(|r| r.item.clone())
394 .collect();
395 GatherOutput {
396 text,
397 dropped_failures,
398 }
399 }
400}
401
402pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
404 let mut opts = ScatterOptions::default();
405
406 if let Some(Value::String(name)) = args.named.get("as") {
407 opts.var_name = name.clone();
408 }
409
410 if let Some(Value::Int(n)) = args.named.get("limit") {
411 let requested = *n;
412 let clamped = requested.clamp(1, SCATTER_LIMIT_MAX as i64);
413 if requested > SCATTER_LIMIT_MAX as i64 {
414 tracing::warn!(
415 target: "kaish::scatter",
416 requested = requested,
417 ceiling = SCATTER_LIMIT_MAX,
418 "scatter limit clamped to ceiling"
419 );
420 }
421 opts.limit = clamped as usize;
422 }
423
424 if let Some(Value::String(s)) = args.named.get("timeout") {
428 match parse_duration(s) {
429 Some(d) => opts.timeout = Some(d),
430 None => tracing::warn!(
431 target: "kaish::scatter",
432 value = %s,
433 "scatter --timeout: invalid duration (try: 30, 5s, 500ms, 2m, 1h)"
434 ),
435 }
436 } else if let Some(Value::Int(n)) = args.named.get("timeout") {
437 if *n >= 0 {
438 opts.timeout = Some(Duration::from_secs(*n as u64));
439 }
440 }
441
442 opts
443}
444
445pub const SCATTER_LIMIT_MAX: usize = 10_000;
449
450pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
452 let mut opts = GatherOptions::default();
453
454 if args.has_flag("progress") {
455 opts.progress = true;
456 }
457
458 if let Some(Value::Int(n)) = args.named.get("first") {
459 opts.first = (*n).max(0) as usize;
460 }
461
462 if let Some(Value::String(fmt)) = args.named.get("format") {
463 opts.format = fmt.clone();
464 }
465
466 opts
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn test_extract_items_structured_json_array() {
475 let data = Value::Json(serde_json::json!(["a", "b", "c"]));
476 let items = extract_items(Some(&data), "").unwrap();
477 assert_eq!(items, vec!["a", "b", "c"]);
478 }
479
480 #[test]
481 fn test_extract_items_structured_mixed_types() {
482 let data = Value::Json(serde_json::json!([1, "two", true]));
483 let items = extract_items(Some(&data), "").unwrap();
484 assert_eq!(items, vec!["1", "two", "true"]);
485 }
486
487 #[test]
488 fn test_extract_items_structured_string() {
489 let data = Value::String("single".into());
490 let items = extract_items(Some(&data), "").unwrap();
491 assert_eq!(items, vec!["single"]);
492 }
493
494 #[test]
495 fn test_extract_items_single_line_text() {
496 let items = extract_items(None, "hello").unwrap();
497 assert_eq!(items, vec!["hello"]);
498 }
499
500 #[test]
501 fn test_extract_items_empty() {
502 let items = extract_items(None, "").unwrap();
503 assert!(items.is_empty());
504 }
505
506 #[test]
507 fn test_extract_items_multiline_fans_out_per_line() {
508 let items = extract_items(None, "one\ntwo\nthree").unwrap();
511 assert_eq!(items, vec!["one", "two", "three"]);
512 }
513
514 #[test]
515 fn test_extract_items_trailing_newline_no_phantom_item() {
516 let items = extract_items(None, "one\ntwo\n").unwrap();
518 assert_eq!(items, vec!["one", "two"]);
519 }
520
521 #[test]
522 fn test_extract_items_crlf_per_line() {
523 let items = extract_items(None, "one\r\ntwo\r\n").unwrap();
525 assert_eq!(items, vec!["one", "two"]);
526 }
527
528 #[test]
529 fn test_extract_items_interior_blank_line_preserved() {
530 let items = extract_items(None, "a\n\nb").unwrap();
532 assert_eq!(items, vec!["a", "", "b"]);
533 }
534
535 #[test]
536 fn test_extract_items_whitespace_within_line_not_split() {
537 let items = extract_items(None, "a b\nc d").unwrap();
539 assert_eq!(items, vec!["a b", "c d"]);
540 }
541
542 #[test]
543 fn test_extract_items_only_newlines_is_empty() {
544 let items = extract_items(None, "\n\n").unwrap();
545 assert!(items.is_empty());
546 }
547
548 #[test]
549 fn test_extract_items_structured_overrides_text() {
550 let data = Value::Json(serde_json::json!(["x", "y"]));
552 let items = extract_items(Some(&data), "ignored\ntext").unwrap();
553 assert_eq!(items, vec!["x", "y"]);
554 }
555
556 #[test]
557 fn test_gather_results_lines() {
558 let results = vec![
559 ScatterResult {
560 item: "a".to_string(),
561 result: ExecResult::success("result_a"),
562 timed_out: false,
563 },
564 ScatterResult {
565 item: "b".to_string(),
566 result: ExecResult::success("result_b"),
567 timed_out: false,
568 },
569 ];
570
571 let opts = GatherOptions::default();
572 let output = gather_results(&results, &opts);
573 assert_eq!(output.text, "result_a\nresult_b");
574 assert!(output.dropped_failures.is_empty());
575 }
576
577 #[test]
578 fn test_gather_results_lines_reports_dropped_failures() {
579 let results = vec![
582 ScatterResult {
583 item: "a".to_string(),
584 result: ExecResult::success("result_a"),
585 timed_out: false,
586 },
587 ScatterResult {
588 item: "b".to_string(),
589 result: ExecResult::failure(1, "boom"),
590 timed_out: false,
591 },
592 ];
593
594 let opts = GatherOptions::default();
595 let output = gather_results(&results, &opts);
596 assert_eq!(output.text, "result_a");
598 assert_eq!(output.dropped_failures, vec!["b".to_string()]);
599 }
600
601 #[test]
602 fn test_gather_results_json_keeps_failures_as_rows() {
603 let results = vec![ScatterResult {
605 item: "b".to_string(),
606 result: ExecResult::failure(2, "boom"),
607 timed_out: false,
608 }];
609 let opts = GatherOptions {
610 format: "json".to_string(),
611 ..Default::default()
612 };
613 let output = gather_results(&results, &opts);
614 assert!(output.dropped_failures.is_empty());
615 assert!(output.text.contains("\"ok\": false"));
616 assert!(output.text.contains("\"code\": 2"));
617 }
618
619 #[test]
620 fn test_gather_results_json() {
621 let results = vec![ScatterResult {
622 item: "test".to_string(),
623 result: ExecResult::success("output"),
624 timed_out: false,
625 }];
626
627 let opts = GatherOptions {
628 format: "json".to_string(),
629 ..Default::default()
630 };
631 let output = gather_results(&results, &opts);
632 assert!(output.text.contains("\"item\": \"test\""));
633 assert!(output.text.contains("\"ok\": true"));
634 }
635
636 #[test]
637 fn test_gather_results_first_n() {
638 let results = vec![
639 ScatterResult {
640 item: "a".to_string(),
641 result: ExecResult::success("1"),
642 timed_out: false,
643 },
644 ScatterResult {
645 item: "b".to_string(),
646 result: ExecResult::success("2"),
647 timed_out: false,
648 },
649 ScatterResult {
650 item: "c".to_string(),
651 result: ExecResult::success("3"),
652 timed_out: false,
653 },
654 ];
655
656 let opts = GatherOptions {
657 first: 2,
658 ..Default::default()
659 };
660 let output = gather_results(&results, &opts);
661 assert_eq!(output.text, "1\n2");
662 }
663
664 #[test]
665 fn test_parse_scatter_options() {
666 use crate::tools::ToolArgs;
667
668 let mut args = ToolArgs::new();
669 args.named.insert("as".to_string(), Value::String("URL".to_string()));
670 args.named.insert("limit".to_string(), Value::Int(4));
671
672 let opts = parse_scatter_options(&args);
673 assert_eq!(opts.var_name, "URL");
674 assert_eq!(opts.limit, 4);
675 }
676
677 #[test]
678 fn test_parse_gather_options() {
679 use crate::tools::ToolArgs;
680
681 let mut args = ToolArgs::new();
682 args.named.insert("first".to_string(), Value::Int(5));
683 args.named.insert("format".to_string(), Value::String("json".to_string()));
684
685 let opts = parse_gather_options(&args);
686 assert_eq!(opts.first, 5);
687 assert_eq!(opts.format, "json");
688 }
689
690 #[test]
691 fn scatter_limit_clamps_to_ceiling() {
692 use crate::tools::ToolArgs;
693
694 let mut args = ToolArgs::new();
695 args.named.insert("limit".to_string(), Value::Int(999_999));
696 let opts = parse_scatter_options(&args);
697 assert_eq!(opts.limit, SCATTER_LIMIT_MAX);
698 }
699
700 #[test]
701 fn scatter_limit_raises_zero_to_one() {
702 use crate::tools::ToolArgs;
703
704 let mut args = ToolArgs::new();
705 args.named.insert("limit".to_string(), Value::Int(0));
706 let opts = parse_scatter_options(&args);
707 assert_eq!(opts.limit, 1);
708 }
709
710 #[test]
711 fn scatter_limit_raises_negative_to_one() {
712 use crate::tools::ToolArgs;
713
714 let mut args = ToolArgs::new();
715 args.named.insert("limit".to_string(), Value::Int(-42));
716 let opts = parse_scatter_options(&args);
717 assert_eq!(opts.limit, 1);
718 }
719
720 #[test]
721 fn scatter_limit_preserves_valid_values() {
722 use crate::tools::ToolArgs;
723
724 let mut args = ToolArgs::new();
725 args.named.insert("limit".to_string(), Value::Int(500));
726 let opts = parse_scatter_options(&args);
727 assert_eq!(opts.limit, 500);
728 }
729}