1use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::time::Duration;
18
19use tokio::sync::Semaphore;
20use tracing::Instrument;
21
22use crate::ast::{Command, Value};
23use crate::dispatch::CommandDispatcher;
24use crate::duration::parse_duration;
25use crate::interpreter::ExecResult;
26use crate::tools::{ExecContext, ToolRegistry};
27
28use super::pipeline::PipelineRunner;
29
30#[derive(Debug, Clone)]
32pub struct ScatterOptions {
33 pub var_name: String,
35 pub limit: usize,
37 pub timeout: Option<Duration>,
41}
42
43#[derive(Debug, Clone)]
45pub struct GatherOptions {
46 pub progress: bool,
48 pub first: usize,
50 pub format: String,
52}
53
54impl Default for ScatterOptions {
55 fn default() -> Self {
56 Self {
57 var_name: "ITEM".to_string(),
58 limit: 8,
59 timeout: None,
60 }
61 }
62}
63
64impl Default for GatherOptions {
65 fn default() -> Self {
66 Self {
67 progress: false,
68 first: 0,
69 format: "lines".to_string(),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct ScatterResult {
77 pub item: String,
79 pub result: ExecResult,
81 pub timed_out: bool,
83}
84
85pub struct ScatterGatherRunner {
92 tools: Arc<ToolRegistry>,
93 sequential_dispatcher: Arc<dyn CommandDispatcher>,
96}
97
98impl ScatterGatherRunner {
99 pub fn new(
104 tools: Arc<ToolRegistry>,
105 dispatcher: Arc<dyn CommandDispatcher>,
106 ) -> Self {
107 Self { tools, sequential_dispatcher: dispatcher }
108 }
109
110 #[tracing::instrument(level = "info", skip(self, pre_scatter, scatter_opts, parallel, gather_opts, post_gather, ctx), fields(item_count = tracing::field::Empty, parallelism = scatter_opts.limit))]
119 pub async fn run(
120 &self,
121 pre_scatter: &[Command],
122 scatter_opts: ScatterOptions,
123 parallel: &[Command],
124 gather_opts: GatherOptions,
125 post_gather: &[Command],
126 ctx: &mut ExecContext,
127 ) -> ExecResult {
128 let runner = PipelineRunner::new(self.tools.clone());
129
130 let (text, data) = if pre_scatter.is_empty() {
133 let data = ctx.take_stdin_data();
135 let text = ctx.take_stdin().unwrap_or_default();
136 (text, data)
137 } else {
138 let result = runner.run_sequential(pre_scatter, ctx, &*self.sequential_dispatcher).await;
139 if !result.ok() {
140 return result;
141 }
142 (result.text_out().into_owned(), result.data)
143 };
144
145 let items = match extract_items(data.as_ref(), &text) {
147 Ok(items) => items,
148 Err(msg) => return ExecResult::failure(1, msg),
149 };
150 if items.is_empty() {
151 return ExecResult::success("");
152 }
153
154 tracing::Span::current().record("item_count", items.len());
155
156 let results = self
158 .run_parallel(&items, &scatter_opts, parallel, ctx)
159 .await;
160
161 let gathered = gather_results(&results, &gather_opts);
163
164 if post_gather.is_empty() {
166 ExecResult::success(gathered)
167 } else {
168 ctx.set_stdin(gathered);
169 runner.run_sequential(post_gather, ctx, &*self.sequential_dispatcher).await
170 }
171 }
172
173 #[tracing::instrument(level = "debug", skip(self, items, opts, commands, base_ctx), fields(worker_count = items.len()))]
182 async fn run_parallel(
183 &self,
184 items: &[String],
185 opts: &ScatterOptions,
186 commands: &[Command],
187 base_ctx: &ExecContext,
188 ) -> Vec<ScatterResult> {
189 let semaphore = Arc::new(Semaphore::new(opts.limit));
190 let tools = self.tools.clone();
191 let var_name = opts.var_name.clone();
192
193 let mut handles = Vec::with_capacity(items.len());
195
196 for item in items.iter().cloned() {
197 let permit = semaphore.clone().acquire_owned().await;
198 let tools = tools.clone();
199 let worker_dispatcher = self.sequential_dispatcher.fork_attached().await;
204 let commands = commands.to_vec();
205 let var_name = var_name.clone();
206 let base_scope = base_ctx.scope.clone();
207 let backend = base_ctx.backend.clone();
208 let cwd = base_ctx.cwd.clone();
209 let parent_token = base_ctx.cancel.clone();
210 let worker_token = parent_token.child_token();
211
212 let timed_out_flag = Arc::new(AtomicBool::new(false));
218 let timer_handle: Option<tokio::task::JoinHandle<()>> = opts.timeout.map(|d| {
219 let cancel = worker_token.clone();
220 let flag = timed_out_flag.clone();
221 tokio::spawn(async move {
222 tokio::time::sleep(d).await;
223 flag.store(true, Ordering::SeqCst);
224 cancel.cancel();
225 })
226 });
227 let timed_out_check = timed_out_flag.clone();
228
229 let item_label = if item.len() > 64 {
230 format!("{}...", &item[..64])
231 } else {
232 item.clone()
233 };
234 let worker_span = tracing::debug_span!("scatter_worker", item = %item_label);
235 let handle = tokio::spawn(async move {
236 let _permit = permit; let mut scope = base_scope;
240 scope.set(&var_name, Value::String(item.clone()));
241
242 let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
243 ctx.set_cwd(cwd);
244 ctx.cancel = worker_token;
245
246 let runner = PipelineRunner::new(tools);
249 let result = runner.run_sequential(&commands, &mut ctx, &*worker_dispatcher).await;
250
251 if let Some(h) = timer_handle {
254 h.abort();
255 }
256
257 let timed_out = timed_out_check.load(Ordering::SeqCst);
258 ScatterResult { item, result, timed_out }
259 }.instrument(worker_span));
260
261 handles.push(handle);
262 }
263
264 let mut results = Vec::with_capacity(handles.len());
266 for handle in handles {
267 match handle.await {
268 Ok(result) => results.push(result),
269 Err(e) => {
270 results.push(ScatterResult {
271 item: String::new(),
272 result: ExecResult::failure(1, format!("Task panicked: {}", e)),
273 timed_out: false,
274 });
275 }
276 }
277 }
278
279 results
280 }
281}
282
283pub fn extract_items(data: Option<&Value>, text: &str) -> Result<Vec<String>, String> {
290 if let Some(Value::Json(serde_json::Value::Array(arr))) = data {
292 return Ok(arr.iter().map(|v| match v {
293 serde_json::Value::String(s) => s.clone(),
294 other => other.to_string(),
295 }).collect());
296 }
297 if let Some(Value::String(s)) = data {
298 return Ok(vec![s.clone()]);
299 }
300
301 let trimmed = text.trim();
303 if trimmed.is_empty() {
304 return Ok(vec![]);
305 }
306
307 Ok(vec![trimmed.to_string()])
309}
310
311fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> String {
313 let results_to_use = if opts.first > 0 && opts.first < results.len() {
314 &results[..opts.first]
315 } else {
316 results
317 };
318
319 if opts.format == "json" {
320 let json_results: Vec<serde_json::Value> = results_to_use
322 .iter()
323 .map(|r| {
324 serde_json::json!({
325 "item": r.item,
326 "ok": r.result.ok(),
327 "code": r.result.code,
328 "out": r.result.text_out().trim(),
329 "err": r.result.err.trim(),
330 "timed_out": r.timed_out,
331 })
332 })
333 .collect();
334
335 serde_json::to_string_pretty(&json_results).unwrap_or_default()
336 } else {
337 results_to_use
339 .iter()
340 .filter(|r| r.result.ok())
341 .map(|r| r.result.text_out())
342 .map(|t| t.trim().to_string())
343 .collect::<Vec<_>>()
344 .join("\n")
345 }
346}
347
348pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
350 let mut opts = ScatterOptions::default();
351
352 if let Some(Value::String(name)) = args.named.get("as") {
353 opts.var_name = name.clone();
354 }
355
356 if let Some(Value::Int(n)) = args.named.get("limit") {
357 let requested = *n;
358 let clamped = requested.clamp(1, SCATTER_LIMIT_MAX as i64);
359 if requested > SCATTER_LIMIT_MAX as i64 {
360 tracing::warn!(
361 target: "kaish::scatter",
362 requested = requested,
363 ceiling = SCATTER_LIMIT_MAX,
364 "scatter limit clamped to ceiling"
365 );
366 }
367 opts.limit = clamped as usize;
368 }
369
370 if let Some(Value::String(s)) = args.named.get("timeout") {
374 match parse_duration(s) {
375 Some(d) => opts.timeout = Some(d),
376 None => tracing::warn!(
377 target: "kaish::scatter",
378 value = %s,
379 "scatter --timeout: invalid duration (try: 30, 5s, 500ms, 2m, 1h)"
380 ),
381 }
382 } else if let Some(Value::Int(n)) = args.named.get("timeout") {
383 if *n >= 0 {
384 opts.timeout = Some(Duration::from_secs(*n as u64));
385 }
386 }
387
388 opts
389}
390
391pub const SCATTER_LIMIT_MAX: usize = 10_000;
395
396pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
398 let mut opts = GatherOptions::default();
399
400 if args.has_flag("progress") {
401 opts.progress = true;
402 }
403
404 if let Some(Value::Int(n)) = args.named.get("first") {
405 opts.first = (*n).max(0) as usize;
406 }
407
408 if let Some(Value::String(fmt)) = args.named.get("format") {
409 opts.format = fmt.clone();
410 }
411
412 opts
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn test_extract_items_structured_json_array() {
421 let data = Value::Json(serde_json::json!(["a", "b", "c"]));
422 let items = extract_items(Some(&data), "").unwrap();
423 assert_eq!(items, vec!["a", "b", "c"]);
424 }
425
426 #[test]
427 fn test_extract_items_structured_mixed_types() {
428 let data = Value::Json(serde_json::json!([1, "two", true]));
429 let items = extract_items(Some(&data), "").unwrap();
430 assert_eq!(items, vec!["1", "two", "true"]);
431 }
432
433 #[test]
434 fn test_extract_items_structured_string() {
435 let data = Value::String("single".into());
436 let items = extract_items(Some(&data), "").unwrap();
437 assert_eq!(items, vec!["single"]);
438 }
439
440 #[test]
441 fn test_extract_items_single_line_text() {
442 let items = extract_items(None, "hello").unwrap();
443 assert_eq!(items, vec!["hello"]);
444 }
445
446 #[test]
447 fn test_extract_items_empty() {
448 let items = extract_items(None, "").unwrap();
449 assert!(items.is_empty());
450 }
451
452 #[test]
453 fn test_extract_items_multiline_is_one_item() {
454 let items = extract_items(None, "one\ntwo\nthree").unwrap();
456 assert_eq!(items, vec!["one\ntwo\nthree"]);
457 }
458
459 #[test]
460 fn test_extract_items_structured_overrides_text() {
461 let data = Value::Json(serde_json::json!(["x", "y"]));
463 let items = extract_items(Some(&data), "ignored\ntext").unwrap();
464 assert_eq!(items, vec!["x", "y"]);
465 }
466
467 #[test]
468 fn test_gather_results_lines() {
469 let results = vec![
470 ScatterResult {
471 item: "a".to_string(),
472 result: ExecResult::success("result_a"),
473 timed_out: false,
474 },
475 ScatterResult {
476 item: "b".to_string(),
477 result: ExecResult::success("result_b"),
478 timed_out: false,
479 },
480 ];
481
482 let opts = GatherOptions::default();
483 let output = gather_results(&results, &opts);
484 assert_eq!(output, "result_a\nresult_b");
485 }
486
487 #[test]
488 fn test_gather_results_json() {
489 let results = vec![ScatterResult {
490 item: "test".to_string(),
491 result: ExecResult::success("output"),
492 timed_out: false,
493 }];
494
495 let opts = GatherOptions {
496 format: "json".to_string(),
497 ..Default::default()
498 };
499 let output = gather_results(&results, &opts);
500 assert!(output.contains("\"item\": \"test\""));
501 assert!(output.contains("\"ok\": true"));
502 }
503
504 #[test]
505 fn test_gather_results_first_n() {
506 let results = vec![
507 ScatterResult {
508 item: "a".to_string(),
509 result: ExecResult::success("1"),
510 timed_out: false,
511 },
512 ScatterResult {
513 item: "b".to_string(),
514 result: ExecResult::success("2"),
515 timed_out: false,
516 },
517 ScatterResult {
518 item: "c".to_string(),
519 result: ExecResult::success("3"),
520 timed_out: false,
521 },
522 ];
523
524 let opts = GatherOptions {
525 first: 2,
526 ..Default::default()
527 };
528 let output = gather_results(&results, &opts);
529 assert_eq!(output, "1\n2");
530 }
531
532 #[test]
533 fn test_parse_scatter_options() {
534 use crate::tools::ToolArgs;
535
536 let mut args = ToolArgs::new();
537 args.named.insert("as".to_string(), Value::String("URL".to_string()));
538 args.named.insert("limit".to_string(), Value::Int(4));
539
540 let opts = parse_scatter_options(&args);
541 assert_eq!(opts.var_name, "URL");
542 assert_eq!(opts.limit, 4);
543 }
544
545 #[test]
546 fn test_parse_gather_options() {
547 use crate::tools::ToolArgs;
548
549 let mut args = ToolArgs::new();
550 args.named.insert("first".to_string(), Value::Int(5));
551 args.named.insert("format".to_string(), Value::String("json".to_string()));
552
553 let opts = parse_gather_options(&args);
554 assert_eq!(opts.first, 5);
555 assert_eq!(opts.format, "json");
556 }
557
558 #[test]
559 fn scatter_limit_clamps_to_ceiling() {
560 use crate::tools::ToolArgs;
561
562 let mut args = ToolArgs::new();
563 args.named.insert("limit".to_string(), Value::Int(999_999));
564 let opts = parse_scatter_options(&args);
565 assert_eq!(opts.limit, SCATTER_LIMIT_MAX);
566 }
567
568 #[test]
569 fn scatter_limit_raises_zero_to_one() {
570 use crate::tools::ToolArgs;
571
572 let mut args = ToolArgs::new();
573 args.named.insert("limit".to_string(), Value::Int(0));
574 let opts = parse_scatter_options(&args);
575 assert_eq!(opts.limit, 1);
576 }
577
578 #[test]
579 fn scatter_limit_raises_negative_to_one() {
580 use crate::tools::ToolArgs;
581
582 let mut args = ToolArgs::new();
583 args.named.insert("limit".to_string(), Value::Int(-42));
584 let opts = parse_scatter_options(&args);
585 assert_eq!(opts.limit, 1);
586 }
587
588 #[test]
589 fn scatter_limit_preserves_valid_values() {
590 use crate::tools::ToolArgs;
591
592 let mut args = ToolArgs::new();
593 args.named.insert("limit".to_string(), Value::Int(500));
594 let opts = parse_scatter_options(&args);
595 assert_eq!(opts.limit, 500);
596 }
597}