1use std::sync::Arc;
16
17use tokio::sync::Semaphore;
18
19use crate::ast::{Command, Value};
20use crate::interpreter::ExecResult;
21use crate::tools::{ExecContext, ToolRegistry};
22
23use super::pipeline::{run_sequential_pipeline, run_sequential_pipeline_owned};
24
25#[derive(Debug, Clone)]
27pub struct ScatterOptions {
28 pub var_name: String,
30 pub limit: usize,
32}
33
34impl Default for ScatterOptions {
35 fn default() -> Self {
36 Self {
37 var_name: "ITEM".to_string(),
38 limit: 8,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct GatherOptions {
46 pub progress: bool,
48 pub first: usize,
50 pub format: String,
52}
53
54impl Default for GatherOptions {
55 fn default() -> Self {
56 Self {
57 progress: false,
58 first: 0,
59 format: "lines".to_string(),
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct ScatterResult {
67 pub item: String,
69 pub result: ExecResult,
71}
72
73pub struct ScatterGatherRunner {
75 tools: Arc<ToolRegistry>,
76}
77
78impl ScatterGatherRunner {
79 pub fn new(tools: Arc<ToolRegistry>) -> Self {
81 Self { tools }
82 }
83
84 pub async fn run(
93 &self,
94 pre_scatter: &[Command],
95 scatter_opts: ScatterOptions,
96 parallel: &[Command],
97 gather_opts: GatherOptions,
98 post_gather: &[Command],
99 ctx: &mut ExecContext,
100 ) -> ExecResult {
101 let input = if pre_scatter.is_empty() {
103 ctx.take_stdin().unwrap_or_default()
105 } else {
106 let result = run_sequential_pipeline(&self.tools, pre_scatter, ctx).await;
107 if !result.ok() {
108 return result;
109 }
110 result.out
111 };
112
113 let items = split_input(&input);
115 if items.is_empty() {
116 return ExecResult::success("");
117 }
118
119 let results = self
121 .run_parallel(&items, &scatter_opts, parallel, ctx)
122 .await;
123
124 let gathered = gather_results(&results, &gather_opts);
126
127 if post_gather.is_empty() {
129 ExecResult::success(gathered)
130 } else {
131 ctx.set_stdin(gathered);
132 run_sequential_pipeline(&self.tools, post_gather, ctx).await
133 }
134 }
135
136 async fn run_parallel(
138 &self,
139 items: &[String],
140 opts: &ScatterOptions,
141 commands: &[Command],
142 base_ctx: &ExecContext,
143 ) -> Vec<ScatterResult> {
144 let semaphore = Arc::new(Semaphore::new(opts.limit));
145 let tools = self.tools.clone();
146 let var_name = opts.var_name.clone();
147
148 let mut handles = Vec::with_capacity(items.len());
150
151 for item in items.iter().cloned() {
152 let permit = semaphore.clone().acquire_owned().await;
153 let tools = tools.clone();
154 let commands = commands.to_vec();
155 let var_name = var_name.clone();
156 let base_scope = base_ctx.scope.clone();
157 let backend = base_ctx.backend.clone();
158 let cwd = base_ctx.cwd.clone();
159
160 let handle = tokio::spawn(async move {
161 let _permit = permit; let mut scope = base_scope;
165 scope.set(&var_name, Value::String(item.clone()));
166
167 let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
168 ctx.set_cwd(cwd);
169
170 let result = run_sequential_pipeline_owned(tools, commands, &mut ctx).await;
172
173 ScatterResult { item, result }
174 });
175
176 handles.push(handle);
177 }
178
179 let mut results = Vec::with_capacity(handles.len());
181 for handle in handles {
182 match handle.await {
183 Ok(result) => results.push(result),
184 Err(e) => {
185 results.push(ScatterResult {
186 item: String::new(),
187 result: ExecResult::failure(1, format!("Task panicked: {}", e)),
188 });
189 }
190 }
191 }
192
193 results
194 }
195}
196
197fn split_input(input: &str) -> Vec<String> {
199 let trimmed = input.trim();
200
201 if trimmed.starts_with('[')
203 && let Ok(arr) = serde_json::from_str::<Vec<serde_json::Value>>(trimmed) {
204 return arr
205 .into_iter()
206 .map(|v| match v {
207 serde_json::Value::String(s) => s,
208 other => other.to_string(),
209 })
210 .collect();
211 }
212
213 trimmed
215 .lines()
216 .map(|s| s.to_string())
217 .filter(|s| !s.is_empty())
218 .collect()
219}
220
221fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> String {
223 let results_to_use = if opts.first > 0 && opts.first < results.len() {
224 &results[..opts.first]
225 } else {
226 results
227 };
228
229 if opts.format == "json" {
230 let json_results: Vec<serde_json::Value> = results_to_use
232 .iter()
233 .map(|r| {
234 serde_json::json!({
235 "item": r.item,
236 "ok": r.result.ok(),
237 "code": r.result.code,
238 "out": r.result.out.trim(),
239 "err": r.result.err.trim(),
240 })
241 })
242 .collect();
243
244 serde_json::to_string_pretty(&json_results).unwrap_or_default()
245 } else {
246 results_to_use
248 .iter()
249 .filter(|r| r.result.ok())
250 .map(|r| r.result.out.trim())
251 .collect::<Vec<_>>()
252 .join("\n")
253 }
254}
255
256pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
258 let mut opts = ScatterOptions::default();
259
260 if let Some(Value::String(name)) = args.named.get("as") {
261 opts.var_name = name.clone();
262 }
263
264 if let Some(Value::Int(n)) = args.named.get("limit") {
265 opts.limit = (*n).max(1) as usize;
266 }
267
268 opts
269}
270
271pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
273 let mut opts = GatherOptions::default();
274
275 if args.has_flag("progress") {
276 opts.progress = true;
277 }
278
279 if let Some(Value::Int(n)) = args.named.get("first") {
280 opts.first = (*n).max(0) as usize;
281 }
282
283 if let Some(Value::String(fmt)) = args.named.get("format") {
284 opts.format = fmt.clone();
285 }
286
287 opts
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_split_input_lines() {
296 let input = "one\ntwo\nthree\n";
297 let items = split_input(input);
298 assert_eq!(items, vec!["one", "two", "three"]);
299 }
300
301 #[test]
302 fn test_split_input_json_array() {
303 let input = r#"["a", "b", "c"]"#;
304 let items = split_input(input);
305 assert_eq!(items, vec!["a", "b", "c"]);
306 }
307
308 #[test]
309 fn test_split_input_json_mixed() {
310 let input = r#"[1, "two", true]"#;
311 let items = split_input(input);
312 assert_eq!(items, vec!["1", "two", "true"]);
313 }
314
315 #[test]
316 fn test_split_input_empty() {
317 let input = "";
318 let items = split_input(input);
319 assert!(items.is_empty());
320 }
321
322 #[test]
323 fn test_gather_results_lines() {
324 let results = vec![
325 ScatterResult {
326 item: "a".to_string(),
327 result: ExecResult::success("result_a"),
328 },
329 ScatterResult {
330 item: "b".to_string(),
331 result: ExecResult::success("result_b"),
332 },
333 ];
334
335 let opts = GatherOptions::default();
336 let output = gather_results(&results, &opts);
337 assert_eq!(output, "result_a\nresult_b");
338 }
339
340 #[test]
341 fn test_gather_results_json() {
342 let results = vec![ScatterResult {
343 item: "test".to_string(),
344 result: ExecResult::success("output"),
345 }];
346
347 let opts = GatherOptions {
348 format: "json".to_string(),
349 ..Default::default()
350 };
351 let output = gather_results(&results, &opts);
352 assert!(output.contains("\"item\": \"test\""));
353 assert!(output.contains("\"ok\": true"));
354 }
355
356 #[test]
357 fn test_gather_results_first_n() {
358 let results = vec![
359 ScatterResult {
360 item: "a".to_string(),
361 result: ExecResult::success("1"),
362 },
363 ScatterResult {
364 item: "b".to_string(),
365 result: ExecResult::success("2"),
366 },
367 ScatterResult {
368 item: "c".to_string(),
369 result: ExecResult::success("3"),
370 },
371 ];
372
373 let opts = GatherOptions {
374 first: 2,
375 ..Default::default()
376 };
377 let output = gather_results(&results, &opts);
378 assert_eq!(output, "1\n2");
379 }
380
381 #[test]
382 fn test_parse_scatter_options() {
383 use crate::tools::ToolArgs;
384
385 let mut args = ToolArgs::new();
386 args.named.insert("as".to_string(), Value::String("URL".to_string()));
387 args.named.insert("limit".to_string(), Value::Int(4));
388
389 let opts = parse_scatter_options(&args);
390 assert_eq!(opts.var_name, "URL");
391 assert_eq!(opts.limit, 4);
392 }
393
394 #[test]
395 fn test_parse_gather_options() {
396 use crate::tools::ToolArgs;
397
398 let mut args = ToolArgs::new();
399 args.named.insert("first".to_string(), Value::Int(5));
400 args.named.insert("format".to_string(), Value::String("json".to_string()));
401
402 let opts = parse_gather_options(&args);
403 assert_eq!(opts.first, 5);
404 assert_eq!(opts.format, "json");
405 }
406}