1use std::sync::Arc;
16
17use tokio::sync::Semaphore;
18use tracing::Instrument;
19
20use crate::ast::{Command, Value};
21use crate::dispatch::CommandDispatcher;
22use crate::interpreter::ExecResult;
23use crate::tools::{ExecContext, ToolRegistry};
24
25use super::pipeline::PipelineRunner;
26
27#[derive(Debug, Clone)]
29pub struct ScatterOptions {
30 pub var_name: String,
32 pub limit: usize,
34}
35
36impl Default for ScatterOptions {
37 fn default() -> Self {
38 Self {
39 var_name: "ITEM".to_string(),
40 limit: 8,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct GatherOptions {
48 pub progress: bool,
50 pub first: usize,
52 pub format: String,
54}
55
56impl Default for GatherOptions {
57 fn default() -> Self {
58 Self {
59 progress: false,
60 first: 0,
61 format: "lines".to_string(),
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct ScatterResult {
69 pub item: String,
71 pub result: ExecResult,
73}
74
75pub struct ScatterGatherRunner {
81 tools: Arc<ToolRegistry>,
82 dispatcher: Arc<dyn CommandDispatcher>,
83}
84
85impl ScatterGatherRunner {
86 pub fn new(tools: Arc<ToolRegistry>, dispatcher: Arc<dyn CommandDispatcher>) -> Self {
91 Self { tools, dispatcher }
92 }
93
94 #[tracing::instrument(level = "info", skip(self, pre_scatter, scatter_opts, parallel, gather_opts, post_gather, ctx), fields(item_count = tracing::field::Empty, parallelism = scatter_opts.limit))]
103 pub async fn run(
104 &self,
105 pre_scatter: &[Command],
106 scatter_opts: ScatterOptions,
107 parallel: &[Command],
108 gather_opts: GatherOptions,
109 post_gather: &[Command],
110 ctx: &mut ExecContext,
111 ) -> ExecResult {
112 let runner = PipelineRunner::new(self.tools.clone());
113
114 let (text, data) = if pre_scatter.is_empty() {
117 let data = ctx.take_stdin_data();
119 let text = ctx.take_stdin().unwrap_or_default();
120 (text, data)
121 } else {
122 let result = runner.run_sequential(pre_scatter, ctx, &*self.dispatcher).await;
123 if !result.ok() {
124 return result;
125 }
126 (result.out, result.data)
127 };
128
129 let items = match extract_items(data.as_ref(), &text) {
131 Ok(items) => items,
132 Err(msg) => return ExecResult::failure(1, msg),
133 };
134 if items.is_empty() {
135 return ExecResult::success("");
136 }
137
138 tracing::Span::current().record("item_count", items.len());
139
140 let results = self
142 .run_parallel(&items, &scatter_opts, parallel, ctx)
143 .await;
144
145 let gathered = gather_results(&results, &gather_opts);
147
148 if post_gather.is_empty() {
150 ExecResult::success(gathered)
151 } else {
152 ctx.set_stdin(gathered);
153 runner.run_sequential(post_gather, ctx, &*self.dispatcher).await
154 }
155 }
156
157 #[tracing::instrument(level = "debug", skip(self, items, opts, commands, base_ctx), fields(worker_count = items.len()))]
170 async fn run_parallel(
171 &self,
172 items: &[String],
173 opts: &ScatterOptions,
174 commands: &[Command],
175 base_ctx: &ExecContext,
176 ) -> Vec<ScatterResult> {
177 let semaphore = Arc::new(Semaphore::new(opts.limit));
178 let tools = self.tools.clone();
179 let dispatcher = self.dispatcher.clone();
180 let var_name = opts.var_name.clone();
181
182 let mut handles = Vec::with_capacity(items.len());
184
185 for item in items.iter().cloned() {
186 let permit = semaphore.clone().acquire_owned().await;
187 let tools = tools.clone();
188 let dispatcher = dispatcher.clone();
189 let commands = commands.to_vec();
190 let var_name = var_name.clone();
191 let base_scope = base_ctx.scope.clone();
192 let backend = base_ctx.backend.clone();
193 let cwd = base_ctx.cwd.clone();
194
195 let item_label = if item.len() > 64 {
196 format!("{}...", &item[..64])
197 } else {
198 item.clone()
199 };
200 let worker_span = tracing::debug_span!("scatter_worker", item = %item_label);
201 let handle = tokio::spawn(async move {
202 let _permit = permit; let mut scope = base_scope;
206 scope.set(&var_name, Value::String(item.clone()));
207
208 let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
209 ctx.set_cwd(cwd);
210
211 let runner = PipelineRunner::new(tools);
214 let result = runner.run_sequential(&commands, &mut ctx, &*dispatcher).await;
215
216 ScatterResult { item, result }
217 }.instrument(worker_span));
218
219 handles.push(handle);
220 }
221
222 let mut results = Vec::with_capacity(handles.len());
224 for handle in handles {
225 match handle.await {
226 Ok(result) => results.push(result),
227 Err(e) => {
228 results.push(ScatterResult {
229 item: String::new(),
230 result: ExecResult::failure(1, format!("Task panicked: {}", e)),
231 });
232 }
233 }
234 }
235
236 results
237 }
238}
239
240pub fn extract_items(data: Option<&Value>, text: &str) -> Result<Vec<String>, String> {
247 if let Some(Value::Json(serde_json::Value::Array(arr))) = data {
249 return Ok(arr.iter().map(|v| match v {
250 serde_json::Value::String(s) => s.clone(),
251 other => other.to_string(),
252 }).collect());
253 }
254 if let Some(Value::String(s)) = data {
255 return Ok(vec![s.clone()]);
256 }
257
258 let trimmed = text.trim();
260 if trimmed.is_empty() {
261 return Ok(vec![]);
262 }
263
264 Ok(vec![trimmed.to_string()])
266}
267
268fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> String {
270 let results_to_use = if opts.first > 0 && opts.first < results.len() {
271 &results[..opts.first]
272 } else {
273 results
274 };
275
276 if opts.format == "json" {
277 let json_results: Vec<serde_json::Value> = results_to_use
279 .iter()
280 .map(|r| {
281 serde_json::json!({
282 "item": r.item,
283 "ok": r.result.ok(),
284 "code": r.result.code,
285 "out": r.result.out.trim(),
286 "err": r.result.err.trim(),
287 })
288 })
289 .collect();
290
291 serde_json::to_string_pretty(&json_results).unwrap_or_default()
292 } else {
293 results_to_use
295 .iter()
296 .filter(|r| r.result.ok())
297 .map(|r| r.result.out.trim())
298 .collect::<Vec<_>>()
299 .join("\n")
300 }
301}
302
303pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
305 let mut opts = ScatterOptions::default();
306
307 if let Some(Value::String(name)) = args.named.get("as") {
308 opts.var_name = name.clone();
309 }
310
311 if let Some(Value::Int(n)) = args.named.get("limit") {
312 opts.limit = (*n).max(1) as usize;
313 }
314
315 opts
316}
317
318pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
320 let mut opts = GatherOptions::default();
321
322 if args.has_flag("progress") {
323 opts.progress = true;
324 }
325
326 if let Some(Value::Int(n)) = args.named.get("first") {
327 opts.first = (*n).max(0) as usize;
328 }
329
330 if let Some(Value::String(fmt)) = args.named.get("format") {
331 opts.format = fmt.clone();
332 }
333
334 opts
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_extract_items_structured_json_array() {
343 let data = Value::Json(serde_json::json!(["a", "b", "c"]));
344 let items = extract_items(Some(&data), "").unwrap();
345 assert_eq!(items, vec!["a", "b", "c"]);
346 }
347
348 #[test]
349 fn test_extract_items_structured_mixed_types() {
350 let data = Value::Json(serde_json::json!([1, "two", true]));
351 let items = extract_items(Some(&data), "").unwrap();
352 assert_eq!(items, vec!["1", "two", "true"]);
353 }
354
355 #[test]
356 fn test_extract_items_structured_string() {
357 let data = Value::String("single".into());
358 let items = extract_items(Some(&data), "").unwrap();
359 assert_eq!(items, vec!["single"]);
360 }
361
362 #[test]
363 fn test_extract_items_single_line_text() {
364 let items = extract_items(None, "hello").unwrap();
365 assert_eq!(items, vec!["hello"]);
366 }
367
368 #[test]
369 fn test_extract_items_empty() {
370 let items = extract_items(None, "").unwrap();
371 assert!(items.is_empty());
372 }
373
374 #[test]
375 fn test_extract_items_multiline_is_one_item() {
376 let items = extract_items(None, "one\ntwo\nthree").unwrap();
378 assert_eq!(items, vec!["one\ntwo\nthree"]);
379 }
380
381 #[test]
382 fn test_extract_items_structured_overrides_text() {
383 let data = Value::Json(serde_json::json!(["x", "y"]));
385 let items = extract_items(Some(&data), "ignored\ntext").unwrap();
386 assert_eq!(items, vec!["x", "y"]);
387 }
388
389 #[test]
390 fn test_gather_results_lines() {
391 let results = vec![
392 ScatterResult {
393 item: "a".to_string(),
394 result: ExecResult::success("result_a"),
395 },
396 ScatterResult {
397 item: "b".to_string(),
398 result: ExecResult::success("result_b"),
399 },
400 ];
401
402 let opts = GatherOptions::default();
403 let output = gather_results(&results, &opts);
404 assert_eq!(output, "result_a\nresult_b");
405 }
406
407 #[test]
408 fn test_gather_results_json() {
409 let results = vec![ScatterResult {
410 item: "test".to_string(),
411 result: ExecResult::success("output"),
412 }];
413
414 let opts = GatherOptions {
415 format: "json".to_string(),
416 ..Default::default()
417 };
418 let output = gather_results(&results, &opts);
419 assert!(output.contains("\"item\": \"test\""));
420 assert!(output.contains("\"ok\": true"));
421 }
422
423 #[test]
424 fn test_gather_results_first_n() {
425 let results = vec![
426 ScatterResult {
427 item: "a".to_string(),
428 result: ExecResult::success("1"),
429 },
430 ScatterResult {
431 item: "b".to_string(),
432 result: ExecResult::success("2"),
433 },
434 ScatterResult {
435 item: "c".to_string(),
436 result: ExecResult::success("3"),
437 },
438 ];
439
440 let opts = GatherOptions {
441 first: 2,
442 ..Default::default()
443 };
444 let output = gather_results(&results, &opts);
445 assert_eq!(output, "1\n2");
446 }
447
448 #[test]
449 fn test_parse_scatter_options() {
450 use crate::tools::ToolArgs;
451
452 let mut args = ToolArgs::new();
453 args.named.insert("as".to_string(), Value::String("URL".to_string()));
454 args.named.insert("limit".to_string(), Value::Int(4));
455
456 let opts = parse_scatter_options(&args);
457 assert_eq!(opts.var_name, "URL");
458 assert_eq!(opts.limit, 4);
459 }
460
461 #[test]
462 fn test_parse_gather_options() {
463 use crate::tools::ToolArgs;
464
465 let mut args = ToolArgs::new();
466 args.named.insert("first".to_string(), Value::Int(5));
467 args.named.insert("format".to_string(), Value::String("json".to_string()));
468
469 let opts = parse_gather_options(&args);
470 assert_eq!(opts.first, 5);
471 assert_eq!(opts.format, "json");
472 }
473}