1use std::sync::Arc;
16
17use tokio::sync::Semaphore;
18
19use crate::ast::{Command, Value};
20use crate::dispatch::CommandDispatcher;
21use crate::interpreter::ExecResult;
22use crate::tools::{ExecContext, ToolRegistry};
23
24use super::pipeline::PipelineRunner;
25
26#[derive(Debug, Clone)]
28pub struct ScatterOptions {
29 pub var_name: String,
31 pub limit: usize,
33}
34
35impl Default for ScatterOptions {
36 fn default() -> Self {
37 Self {
38 var_name: "ITEM".to_string(),
39 limit: 8,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct GatherOptions {
47 pub progress: bool,
49 pub first: usize,
51 pub format: String,
53}
54
55impl Default for GatherOptions {
56 fn default() -> Self {
57 Self {
58 progress: false,
59 first: 0,
60 format: "lines".to_string(),
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct ScatterResult {
68 pub item: String,
70 pub result: ExecResult,
72}
73
74pub struct ScatterGatherRunner {
80 tools: Arc<ToolRegistry>,
81 dispatcher: Arc<dyn CommandDispatcher>,
82}
83
84impl ScatterGatherRunner {
85 pub fn new(tools: Arc<ToolRegistry>, dispatcher: Arc<dyn CommandDispatcher>) -> Self {
90 Self { tools, dispatcher }
91 }
92
93 pub async fn run(
102 &self,
103 pre_scatter: &[Command],
104 scatter_opts: ScatterOptions,
105 parallel: &[Command],
106 gather_opts: GatherOptions,
107 post_gather: &[Command],
108 ctx: &mut ExecContext,
109 ) -> ExecResult {
110 let runner = PipelineRunner::new(self.tools.clone());
111
112 let input = if pre_scatter.is_empty() {
115 ctx.take_stdin().unwrap_or_default()
117 } else {
118 let result = runner.run_sequential(pre_scatter, ctx, &*self.dispatcher).await;
119 if !result.ok() {
120 return result;
121 }
122 result.out
123 };
124
125 let items = split_input(&input);
127 if items.is_empty() {
128 return ExecResult::success("");
129 }
130
131 let results = self
133 .run_parallel(&items, &scatter_opts, parallel, ctx)
134 .await;
135
136 let gathered = gather_results(&results, &gather_opts);
138
139 if post_gather.is_empty() {
141 ExecResult::success(gathered)
142 } else {
143 ctx.set_stdin(gathered);
144 runner.run_sequential(post_gather, ctx, &*self.dispatcher).await
145 }
146 }
147
148 async fn run_parallel(
161 &self,
162 items: &[String],
163 opts: &ScatterOptions,
164 commands: &[Command],
165 base_ctx: &ExecContext,
166 ) -> Vec<ScatterResult> {
167 let semaphore = Arc::new(Semaphore::new(opts.limit));
168 let tools = self.tools.clone();
169 let dispatcher = self.dispatcher.clone();
170 let var_name = opts.var_name.clone();
171
172 let mut handles = Vec::with_capacity(items.len());
174
175 for item in items.iter().cloned() {
176 let permit = semaphore.clone().acquire_owned().await;
177 let tools = tools.clone();
178 let dispatcher = dispatcher.clone();
179 let commands = commands.to_vec();
180 let var_name = var_name.clone();
181 let base_scope = base_ctx.scope.clone();
182 let backend = base_ctx.backend.clone();
183 let cwd = base_ctx.cwd.clone();
184
185 let handle = tokio::spawn(async move {
186 let _permit = permit; let mut scope = base_scope;
190 scope.set(&var_name, Value::String(item.clone()));
191
192 let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
193 ctx.set_cwd(cwd);
194
195 let runner = PipelineRunner::new(tools);
198 let result = runner.run_sequential(&commands, &mut ctx, &*dispatcher).await;
199
200 ScatterResult { item, result }
201 });
202
203 handles.push(handle);
204 }
205
206 let mut results = Vec::with_capacity(handles.len());
208 for handle in handles {
209 match handle.await {
210 Ok(result) => results.push(result),
211 Err(e) => {
212 results.push(ScatterResult {
213 item: String::new(),
214 result: ExecResult::failure(1, format!("Task panicked: {}", e)),
215 });
216 }
217 }
218 }
219
220 results
221 }
222}
223
224fn split_input(input: &str) -> Vec<String> {
226 let trimmed = input.trim();
227
228 if trimmed.starts_with('[')
230 && let Ok(arr) = serde_json::from_str::<Vec<serde_json::Value>>(trimmed) {
231 return arr
232 .into_iter()
233 .map(|v| match v {
234 serde_json::Value::String(s) => s,
235 other => other.to_string(),
236 })
237 .collect();
238 }
239
240 trimmed
242 .lines()
243 .map(|s| s.to_string())
244 .filter(|s| !s.is_empty())
245 .collect()
246}
247
248fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> String {
250 let results_to_use = if opts.first > 0 && opts.first < results.len() {
251 &results[..opts.first]
252 } else {
253 results
254 };
255
256 if opts.format == "json" {
257 let json_results: Vec<serde_json::Value> = results_to_use
259 .iter()
260 .map(|r| {
261 serde_json::json!({
262 "item": r.item,
263 "ok": r.result.ok(),
264 "code": r.result.code,
265 "out": r.result.out.trim(),
266 "err": r.result.err.trim(),
267 })
268 })
269 .collect();
270
271 serde_json::to_string_pretty(&json_results).unwrap_or_default()
272 } else {
273 results_to_use
275 .iter()
276 .filter(|r| r.result.ok())
277 .map(|r| r.result.out.trim())
278 .collect::<Vec<_>>()
279 .join("\n")
280 }
281}
282
283pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
285 let mut opts = ScatterOptions::default();
286
287 if let Some(Value::String(name)) = args.named.get("as") {
288 opts.var_name = name.clone();
289 }
290
291 if let Some(Value::Int(n)) = args.named.get("limit") {
292 opts.limit = (*n).max(1) as usize;
293 }
294
295 opts
296}
297
298pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
300 let mut opts = GatherOptions::default();
301
302 if args.has_flag("progress") {
303 opts.progress = true;
304 }
305
306 if let Some(Value::Int(n)) = args.named.get("first") {
307 opts.first = (*n).max(0) as usize;
308 }
309
310 if let Some(Value::String(fmt)) = args.named.get("format") {
311 opts.format = fmt.clone();
312 }
313
314 opts
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_split_input_lines() {
323 let input = "one\ntwo\nthree\n";
324 let items = split_input(input);
325 assert_eq!(items, vec!["one", "two", "three"]);
326 }
327
328 #[test]
329 fn test_split_input_json_array() {
330 let input = r#"["a", "b", "c"]"#;
331 let items = split_input(input);
332 assert_eq!(items, vec!["a", "b", "c"]);
333 }
334
335 #[test]
336 fn test_split_input_json_mixed() {
337 let input = r#"[1, "two", true]"#;
338 let items = split_input(input);
339 assert_eq!(items, vec!["1", "two", "true"]);
340 }
341
342 #[test]
343 fn test_split_input_empty() {
344 let input = "";
345 let items = split_input(input);
346 assert!(items.is_empty());
347 }
348
349 #[test]
350 fn test_gather_results_lines() {
351 let results = vec![
352 ScatterResult {
353 item: "a".to_string(),
354 result: ExecResult::success("result_a"),
355 },
356 ScatterResult {
357 item: "b".to_string(),
358 result: ExecResult::success("result_b"),
359 },
360 ];
361
362 let opts = GatherOptions::default();
363 let output = gather_results(&results, &opts);
364 assert_eq!(output, "result_a\nresult_b");
365 }
366
367 #[test]
368 fn test_gather_results_json() {
369 let results = vec![ScatterResult {
370 item: "test".to_string(),
371 result: ExecResult::success("output"),
372 }];
373
374 let opts = GatherOptions {
375 format: "json".to_string(),
376 ..Default::default()
377 };
378 let output = gather_results(&results, &opts);
379 assert!(output.contains("\"item\": \"test\""));
380 assert!(output.contains("\"ok\": true"));
381 }
382
383 #[test]
384 fn test_gather_results_first_n() {
385 let results = vec![
386 ScatterResult {
387 item: "a".to_string(),
388 result: ExecResult::success("1"),
389 },
390 ScatterResult {
391 item: "b".to_string(),
392 result: ExecResult::success("2"),
393 },
394 ScatterResult {
395 item: "c".to_string(),
396 result: ExecResult::success("3"),
397 },
398 ];
399
400 let opts = GatherOptions {
401 first: 2,
402 ..Default::default()
403 };
404 let output = gather_results(&results, &opts);
405 assert_eq!(output, "1\n2");
406 }
407
408 #[test]
409 fn test_parse_scatter_options() {
410 use crate::tools::ToolArgs;
411
412 let mut args = ToolArgs::new();
413 args.named.insert("as".to_string(), Value::String("URL".to_string()));
414 args.named.insert("limit".to_string(), Value::Int(4));
415
416 let opts = parse_scatter_options(&args);
417 assert_eq!(opts.var_name, "URL");
418 assert_eq!(opts.limit, 4);
419 }
420
421 #[test]
422 fn test_parse_gather_options() {
423 use crate::tools::ToolArgs;
424
425 let mut args = ToolArgs::new();
426 args.named.insert("first".to_string(), Value::Int(5));
427 args.named.insert("format".to_string(), Value::String("json".to_string()));
428
429 let opts = parse_gather_options(&args);
430 assert_eq!(opts.first, 5);
431 assert_eq!(opts.format, "json");
432 }
433}