use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::Instrument;
use crate::ast::{Command, Value};
use crate::dispatch::CommandDispatcher;
use crate::duration::parse_duration;
use crate::interpreter::ExecResult;
use crate::tools::{ExecContext, ToolRegistry};
use super::pipeline::PipelineRunner;
#[derive(Debug, Clone)]
pub struct ScatterOptions {
pub var_name: String,
pub limit: usize,
pub timeout: Option<Duration>,
}
#[derive(Debug, Clone)]
pub struct GatherOptions {
pub progress: bool,
pub first: usize,
pub format: String,
}
impl Default for ScatterOptions {
fn default() -> Self {
Self {
var_name: "ITEM".to_string(),
limit: 8,
timeout: None,
}
}
}
impl Default for GatherOptions {
fn default() -> Self {
Self {
progress: false,
first: 0,
format: "lines".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct ScatterResult {
pub item: String,
pub result: ExecResult,
pub timed_out: bool,
}
pub struct ScatterGatherRunner {
tools: Arc<ToolRegistry>,
sequential_dispatcher: Arc<dyn CommandDispatcher>,
}
impl ScatterGatherRunner {
pub fn new(
tools: Arc<ToolRegistry>,
dispatcher: Arc<dyn CommandDispatcher>,
) -> Self {
Self { tools, sequential_dispatcher: dispatcher }
}
#[tracing::instrument(level = "info", skip(self, pre_scatter, scatter_opts, parallel, gather_opts, post_gather, ctx), fields(item_count = tracing::field::Empty, parallelism = scatter_opts.limit))]
pub async fn run(
&self,
pre_scatter: &[Command],
scatter_opts: ScatterOptions,
parallel: &[Command],
gather_opts: GatherOptions,
post_gather: &[Command],
ctx: &mut ExecContext,
) -> ExecResult {
let runner = PipelineRunner::new(self.tools.clone());
let (text, data) = if pre_scatter.is_empty() {
let data = ctx.take_stdin_data();
let text = ctx.take_stdin().unwrap_or_default();
(text, data)
} else {
let result = runner.run_sequential(pre_scatter, ctx, &*self.sequential_dispatcher).await;
if !result.ok() {
return result;
}
(result.text_out().into_owned(), result.data)
};
let items = match extract_items(data.as_ref(), &text) {
Ok(items) => items,
Err(msg) => return ExecResult::failure(1, msg),
};
if items.is_empty() {
return ExecResult::success("");
}
tracing::Span::current().record("item_count", items.len());
let results = self
.run_parallel(&items, &scatter_opts, parallel, ctx)
.await;
let gathered = gather_results(&results, &gather_opts);
if post_gather.is_empty() {
ExecResult::success(gathered)
} else {
ctx.set_stdin(gathered);
runner.run_sequential(post_gather, ctx, &*self.sequential_dispatcher).await
}
}
#[tracing::instrument(level = "debug", skip(self, items, opts, commands, base_ctx), fields(worker_count = items.len()))]
async fn run_parallel(
&self,
items: &[String],
opts: &ScatterOptions,
commands: &[Command],
base_ctx: &ExecContext,
) -> Vec<ScatterResult> {
let semaphore = Arc::new(Semaphore::new(opts.limit));
let tools = self.tools.clone();
let var_name = opts.var_name.clone();
let mut handles = Vec::with_capacity(items.len());
for item in items.iter().cloned() {
let permit = semaphore.clone().acquire_owned().await;
let tools = tools.clone();
let worker_dispatcher = self.sequential_dispatcher.fork_attached().await;
let commands = commands.to_vec();
let var_name = var_name.clone();
let base_scope = base_ctx.scope.clone();
let backend = base_ctx.backend.clone();
let cwd = base_ctx.cwd.clone();
let parent_token = base_ctx.cancel.clone();
let worker_token = parent_token.child_token();
let timed_out_flag = Arc::new(AtomicBool::new(false));
let timer_handle: Option<tokio::task::JoinHandle<()>> = opts.timeout.map(|d| {
let cancel = worker_token.clone();
let flag = timed_out_flag.clone();
tokio::spawn(async move {
tokio::time::sleep(d).await;
flag.store(true, Ordering::SeqCst);
cancel.cancel();
})
});
let timed_out_check = timed_out_flag.clone();
let item_label = if item.len() > 64 {
format!("{}...", &item[..64])
} else {
item.clone()
};
let worker_span = tracing::debug_span!("scatter_worker", item = %item_label);
let handle = tokio::spawn(async move {
let _permit = permit;
let mut scope = base_scope;
scope.set(&var_name, Value::String(item.clone()));
let mut ctx = ExecContext::with_backend_and_scope(backend, scope);
ctx.set_cwd(cwd);
ctx.cancel = worker_token;
let runner = PipelineRunner::new(tools);
let result = runner.run_sequential(&commands, &mut ctx, &*worker_dispatcher).await;
if let Some(h) = timer_handle {
h.abort();
}
let timed_out = timed_out_check.load(Ordering::SeqCst);
ScatterResult { item, result, timed_out }
}.instrument(worker_span));
handles.push(handle);
}
let mut results = Vec::with_capacity(handles.len());
for handle in handles {
match handle.await {
Ok(result) => results.push(result),
Err(e) => {
results.push(ScatterResult {
item: String::new(),
result: ExecResult::failure(1, format!("Task panicked: {}", e)),
timed_out: false,
});
}
}
}
results
}
}
pub fn extract_items(data: Option<&Value>, text: &str) -> Result<Vec<String>, String> {
if let Some(Value::Json(serde_json::Value::Array(arr))) = data {
return Ok(arr.iter().map(|v| match v {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
}).collect());
}
if let Some(Value::String(s)) = data {
return Ok(vec![s.clone()]);
}
let trimmed = text.trim();
if trimmed.is_empty() {
return Ok(vec![]);
}
Ok(vec![trimmed.to_string()])
}
fn gather_results(results: &[ScatterResult], opts: &GatherOptions) -> String {
let results_to_use = if opts.first > 0 && opts.first < results.len() {
&results[..opts.first]
} else {
results
};
if opts.format == "json" {
let json_results: Vec<serde_json::Value> = results_to_use
.iter()
.map(|r| {
serde_json::json!({
"item": r.item,
"ok": r.result.ok(),
"code": r.result.code,
"out": r.result.text_out().trim(),
"err": r.result.err.trim(),
"timed_out": r.timed_out,
})
})
.collect();
serde_json::to_string_pretty(&json_results).unwrap_or_default()
} else {
results_to_use
.iter()
.filter(|r| r.result.ok())
.map(|r| r.result.text_out())
.map(|t| t.trim().to_string())
.collect::<Vec<_>>()
.join("\n")
}
}
pub fn parse_scatter_options(args: &crate::tools::ToolArgs) -> ScatterOptions {
let mut opts = ScatterOptions::default();
if let Some(Value::String(name)) = args.named.get("as") {
opts.var_name = name.clone();
}
if let Some(Value::Int(n)) = args.named.get("limit") {
let requested = *n;
let clamped = requested.clamp(1, SCATTER_LIMIT_MAX as i64);
if requested > SCATTER_LIMIT_MAX as i64 {
tracing::warn!(
target: "kaish::scatter",
requested = requested,
ceiling = SCATTER_LIMIT_MAX,
"scatter limit clamped to ceiling"
);
}
opts.limit = clamped as usize;
}
if let Some(Value::String(s)) = args.named.get("timeout") {
match parse_duration(s) {
Some(d) => opts.timeout = Some(d),
None => tracing::warn!(
target: "kaish::scatter",
value = %s,
"scatter --timeout: invalid duration (try: 30, 5s, 500ms, 2m, 1h)"
),
}
} else if let Some(Value::Int(n)) = args.named.get("timeout") {
if *n >= 0 {
opts.timeout = Some(Duration::from_secs(*n as u64));
}
}
opts
}
pub const SCATTER_LIMIT_MAX: usize = 10_000;
pub fn parse_gather_options(args: &crate::tools::ToolArgs) -> GatherOptions {
let mut opts = GatherOptions::default();
if args.has_flag("progress") {
opts.progress = true;
}
if let Some(Value::Int(n)) = args.named.get("first") {
opts.first = (*n).max(0) as usize;
}
if let Some(Value::String(fmt)) = args.named.get("format") {
opts.format = fmt.clone();
}
opts
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_items_structured_json_array() {
let data = Value::Json(serde_json::json!(["a", "b", "c"]));
let items = extract_items(Some(&data), "").unwrap();
assert_eq!(items, vec!["a", "b", "c"]);
}
#[test]
fn test_extract_items_structured_mixed_types() {
let data = Value::Json(serde_json::json!([1, "two", true]));
let items = extract_items(Some(&data), "").unwrap();
assert_eq!(items, vec!["1", "two", "true"]);
}
#[test]
fn test_extract_items_structured_string() {
let data = Value::String("single".into());
let items = extract_items(Some(&data), "").unwrap();
assert_eq!(items, vec!["single"]);
}
#[test]
fn test_extract_items_single_line_text() {
let items = extract_items(None, "hello").unwrap();
assert_eq!(items, vec!["hello"]);
}
#[test]
fn test_extract_items_empty() {
let items = extract_items(None, "").unwrap();
assert!(items.is_empty());
}
#[test]
fn test_extract_items_multiline_is_one_item() {
let items = extract_items(None, "one\ntwo\nthree").unwrap();
assert_eq!(items, vec!["one\ntwo\nthree"]);
}
#[test]
fn test_extract_items_structured_overrides_text() {
let data = Value::Json(serde_json::json!(["x", "y"]));
let items = extract_items(Some(&data), "ignored\ntext").unwrap();
assert_eq!(items, vec!["x", "y"]);
}
#[test]
fn test_gather_results_lines() {
let results = vec![
ScatterResult {
item: "a".to_string(),
result: ExecResult::success("result_a"),
timed_out: false,
},
ScatterResult {
item: "b".to_string(),
result: ExecResult::success("result_b"),
timed_out: false,
},
];
let opts = GatherOptions::default();
let output = gather_results(&results, &opts);
assert_eq!(output, "result_a\nresult_b");
}
#[test]
fn test_gather_results_json() {
let results = vec![ScatterResult {
item: "test".to_string(),
result: ExecResult::success("output"),
timed_out: false,
}];
let opts = GatherOptions {
format: "json".to_string(),
..Default::default()
};
let output = gather_results(&results, &opts);
assert!(output.contains("\"item\": \"test\""));
assert!(output.contains("\"ok\": true"));
}
#[test]
fn test_gather_results_first_n() {
let results = vec![
ScatterResult {
item: "a".to_string(),
result: ExecResult::success("1"),
timed_out: false,
},
ScatterResult {
item: "b".to_string(),
result: ExecResult::success("2"),
timed_out: false,
},
ScatterResult {
item: "c".to_string(),
result: ExecResult::success("3"),
timed_out: false,
},
];
let opts = GatherOptions {
first: 2,
..Default::default()
};
let output = gather_results(&results, &opts);
assert_eq!(output, "1\n2");
}
#[test]
fn test_parse_scatter_options() {
use crate::tools::ToolArgs;
let mut args = ToolArgs::new();
args.named.insert("as".to_string(), Value::String("URL".to_string()));
args.named.insert("limit".to_string(), Value::Int(4));
let opts = parse_scatter_options(&args);
assert_eq!(opts.var_name, "URL");
assert_eq!(opts.limit, 4);
}
#[test]
fn test_parse_gather_options() {
use crate::tools::ToolArgs;
let mut args = ToolArgs::new();
args.named.insert("first".to_string(), Value::Int(5));
args.named.insert("format".to_string(), Value::String("json".to_string()));
let opts = parse_gather_options(&args);
assert_eq!(opts.first, 5);
assert_eq!(opts.format, "json");
}
#[test]
fn scatter_limit_clamps_to_ceiling() {
use crate::tools::ToolArgs;
let mut args = ToolArgs::new();
args.named.insert("limit".to_string(), Value::Int(999_999));
let opts = parse_scatter_options(&args);
assert_eq!(opts.limit, SCATTER_LIMIT_MAX);
}
#[test]
fn scatter_limit_raises_zero_to_one() {
use crate::tools::ToolArgs;
let mut args = ToolArgs::new();
args.named.insert("limit".to_string(), Value::Int(0));
let opts = parse_scatter_options(&args);
assert_eq!(opts.limit, 1);
}
#[test]
fn scatter_limit_raises_negative_to_one() {
use crate::tools::ToolArgs;
let mut args = ToolArgs::new();
args.named.insert("limit".to_string(), Value::Int(-42));
let opts = parse_scatter_options(&args);
assert_eq!(opts.limit, 1);
}
#[test]
fn scatter_limit_preserves_valid_values() {
use crate::tools::ToolArgs;
let mut args = ToolArgs::new();
args.named.insert("limit".to_string(), Value::Int(500));
let opts = parse_scatter_options(&args);
assert_eq!(opts.limit, 500);
}
}