1use std::sync::Arc;
24
25use anyhow::Result;
26use async_trait::async_trait;
27use serde::Deserialize;
28use serde_json::json;
29use tokio::sync::mpsc;
30
31use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
32use crate::agent::sub_agent;
33use crate::agent::AgentEvent;
34use crate::config::Config;
35use crate::provider::LlmProvider;
36
37#[derive(Debug, Deserialize)]
40struct ParallelEditFile {
41 path: String,
42 instruction: String,
43}
44
45#[derive(Debug, Deserialize)]
46struct ParallelEditArgs {
47 files: Vec<ParallelEditFile>,
48 #[serde(default)]
55 contract: String,
56}
57
58pub struct ParallelEditTool {
59 pub provider: Arc<dyn LlmProvider>,
60 pub config: Config,
61 pub event_tx: mpsc::UnboundedSender<AgentEvent>,
62}
63
64#[async_trait]
65impl Tool for ParallelEditTool {
66 fn definition(&self) -> ToolDef {
67 ToolDef {
68 name: "parallel_edit_files",
69 description:
70 "Edit multiple INDEPENDENT files in parallel via fork sub-agents.\n\n\
71 Use ONLY when:\n\
72 - You have 2+ concrete files to edit, each with a clear instruction\n\
73 - Edits in different files don't depend on each other\n\
74 - You can express any cross-file invariants (shared trait/type/interface) in `contract`\n\n\
75 Do NOT use when:\n\
76 - You're still exploring or the edit isn't fully decided\n\
77 - Files have impl/decl splits that need coordinated edits (use sequential edit_file)\n\
78 - You want to read more files first (use read_file)\n\n\
79 Each sub-agent sees only its assigned file content + the contract you provide. \
80 Cross-file changes that aren't expressed in `contract` will be missed by the merge — \
81 the sub-agents cannot see each other's edits. After all sub-agents settle, the \
82 framework runs a build probe (cargo/npm/mvn/go) and surfaces compile errors so you \
83 can repair cross-file gaps."
84 .to_string(),
85 parameters: json!({
86 "type": "object",
87 "properties": {
88 "files": {
89 "type": "array",
90 "minItems": 2,
91 "maxItems": 12,
92 "items": {
93 "type": "object",
94 "properties": {
95 "path": {
96 "type": "string",
97 "description": "File path. Absolute, or relative to the working directory."
98 },
99 "instruction": {
100 "type": "string",
101 "description": "Concrete edit description for THIS file. Be specific: what to add/modify/remove and why. The sub-agent sees only this instruction + the file content + the contract — no other context."
102 }
103 },
104 "required": ["path", "instruction"]
105 }
106 },
107 "contract": {
108 "type": "string",
109 "description": "Cross-file invariants every sub-agent must honour: shared traits, type signatures, interface contracts, naming conventions. Empty if files are fully independent."
110 }
111 },
112 "required": ["files"]
113 }),
114 }
115 }
116
117 fn approval(&self, _args: &str) -> ApprovalRequirement {
118 ApprovalRequirement::AutoApprove
119 }
120
121 fn validate_args(&self, args: &str) -> std::result::Result<(), String> {
122 let parsed: ParallelEditArgs = serde_json::from_str(args).map_err(|e| {
123 format!(
124 "{} (parallel_edit_files arguments must be {{\"files\": [{{\"path\": \"…\", \"instruction\": \"…\"}}, …], \"contract\": \"…\"?}})",
125 e
126 )
127 })?;
128 if parsed.files.len() < 2 {
129 return Err(
130 "parallel_edit_files requires at least 2 files. For a single file, call edit_file directly."
131 .to_string(),
132 );
133 }
134 if parsed.files.len() > 12 {
135 return Err(format!(
136 "parallel_edit_files capped at 12 files; you sent {}. Split into smaller batches or run sequentially.",
137 parsed.files.len()
138 ));
139 }
140 for (i, f) in parsed.files.iter().enumerate() {
141 if f.path.trim().is_empty() {
142 return Err(format!("files[{}].path is empty", i));
143 }
144 if f.instruction.trim().is_empty() {
145 return Err(format!(
146 "files[{}].instruction is empty. Each file needs a concrete edit description; \
147 a sub-agent with no instruction will either fake an edit or burn its budget.",
148 i
149 ));
150 }
151 }
152 Ok(())
153 }
154
155 async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult> {
156 let parsed: ParallelEditArgs = serde_json::from_str(args)?;
157
158 let working_dir = ctx.working_dir.read().await.clone();
159 let registry = match ctx.tool_registry.as_ref() {
160 Some(r) => r.clone(),
161 None => {
162 return Ok(ToolResult {
167 call_id: String::new(),
168 output: "parallel_edit_files unavailable: tool registry not wired in this context."
169 .to_string(),
170 success: false,
171 });
172 }
173 };
174
175 let mut all_file_contents: Vec<(String, String)> = Vec::with_capacity(parsed.files.len());
179 for spec in &parsed.files {
180 let path = if std::path::Path::new(&spec.path).is_absolute() {
181 std::path::PathBuf::from(&spec.path)
182 } else {
183 working_dir.join(&spec.path)
184 };
185 let content = match tokio::fs::read_to_string(&path).await {
186 Ok(c) => c,
187 Err(e) => {
188 return Ok(ToolResult {
189 call_id: String::new(),
190 output: format!(
191 "Cannot read `{}`: {}. Aborted dispatch — fix the path or use a different approach.",
192 spec.path, e
193 ),
194 success: false,
195 });
196 }
197 };
198 all_file_contents.push((path.to_string_lossy().to_string(), content));
199 }
200
201 let mut tasks = Vec::with_capacity(parsed.files.len());
205 for i in 0..parsed.files.len() {
206 let mut siblings = String::new();
207 for (j, (sib_path, sib_content)) in all_file_contents.iter().enumerate() {
208 if i == j {
209 continue;
210 }
211 let short = std::path::Path::new(sib_path)
212 .file_name()
213 .map(|n| n.to_string_lossy().to_string())
214 .unwrap_or_else(|| sib_path.clone());
215 let skeleton: String =
216 sib_content.lines().take(30).collect::<Vec<_>>().join("\n");
217 siblings.push_str(&format!("### {}\n```\n{}\n```\n\n", short, skeleton));
218 }
219 tasks.push(sub_agent::SubAgentTask {
220 file_path: all_file_contents[i].0.clone(),
221 file_content: all_file_contents[i].1.clone(),
222 task_instruction: parsed.files[i].instruction.clone(),
223 contract: parsed.contract.clone(),
224 sibling_skeletons: siblings,
225 });
226 }
227
228 let paths: Vec<&str> = tasks.iter().map(|t| t.file_path.as_str()).collect();
234 let task_infos = build_task_infos_with_dedup(&paths);
235 let _ = self
236 .event_tx
237 .send(AgentEvent::SubAgentDispatchStart { tasks: task_infos });
238
239 let pool = sub_agent::SubAgentPool {
240 tasks,
241 max_concurrent: self.config.subagent.max_concurrent,
242 timeout_secs: self.config.subagent.timeout_secs,
243 };
244 let results = pool
245 .execute_all(
246 self.provider.clone(),
247 registry,
248 &self.config,
249 &working_dir,
250 &self.event_tx,
251 )
252 .await;
253 let _ = self.event_tx.send(AgentEvent::SubAgentDispatchEnd);
254
255 let ok_count = results.iter().filter(|r| r.success).count();
269 let fail_count = results.len() - ok_count;
270 let mut summary = format!(
271 "Sub-agents: {} ok, {} fail (of {})\n",
272 ok_count,
273 fail_count,
274 results.len(),
275 );
276 let mut all_success = fail_count == 0;
277 for r in &results {
278 let icon = if r.success { "✓" } else { "✗" };
279 let one_line = r.summary.lines().next().unwrap_or("").trim();
284 summary.push_str(&format!(
285 " {} {} ({}T) — {}\n",
286 icon, r.file_path, r.turns_used, one_line,
287 ));
288 if !r.success {
289 all_success = false;
290 for failure in &r.failures {
291 summary.push_str(&format!(" reason: {:?}\n", failure));
292 }
293 }
294 }
295
296 if let Some((cmd, build_dir)) = find_build_command(&working_dir) {
300 let mut build_cmd = tokio::process::Command::new("sh");
301 build_cmd.args(["-c", &cmd])
302 .current_dir(&build_dir);
303 crate::process_utils::suppress_console_window(&mut build_cmd);
304 let output = build_cmd.output().await;
305 if let Ok(out) = output {
306 let stdout = String::from_utf8_lossy(&out.stdout);
307 let stderr = String::from_utf8_lossy(&out.stderr);
308 let combined = format!("{}{}", stdout, stderr);
309 if !out.status.success() || combined.to_lowercase().contains("error") {
310 let err_lines: String =
311 combined.lines().take(15).collect::<Vec<_>>().join("\n");
312 summary.push_str(&format!(
313 "\n⚠ BUILD ERRORS after merge:\n{}\nFix these before proceeding.\n",
314 err_lines
315 ));
316 all_success = false;
317 } else {
318 summary.push_str("\n✓ Build verification passed.\n");
319 }
320 }
321 }
322
323 Ok(ToolResult {
324 call_id: String::new(),
325 output: summary,
326 success: all_success,
327 })
328 }
329}
330
331fn build_task_infos_with_dedup(paths: &[&str]) -> Vec<crate::agent::SubAgentTaskInfo> {
341 use std::collections::HashMap;
342 let mut counts: HashMap<&str, usize> = HashMap::new();
343 let mut seen: HashMap<&str, usize> = HashMap::new();
344 for p in paths {
345 *counts.entry(*p).or_insert(0) += 1;
346 }
347 paths
348 .iter()
349 .map(|p| {
350 let total = counts.get(*p).copied().unwrap_or(1);
351 let dedup_suffix = if total > 1 {
352 let n = seen.entry(*p).or_insert(0);
353 *n += 1;
354 format!(" (#{})", *n)
355 } else {
356 String::new()
357 };
358 crate::agent::SubAgentTaskInfo {
359 path: p.to_string(),
360 dedup_suffix,
361 }
362 })
363 .collect()
364}
365
366fn find_build_command(wd: &std::path::Path) -> Option<(String, std::path::PathBuf)> {
370 let markers: &[(&str, &str)] = &[
371 ("package.json", "npm run build 2>&1 | head -30"),
372 ("Cargo.toml", "cargo check 2>&1 | tail -20"),
373 ("pom.xml", "mvn compile -q 2>&1 | tail -20"),
374 ("go.mod", "go build ./... 2>&1 | tail -20"),
375 ];
376
377 for &(marker, cmd) in markers {
378 if wd.join(marker).exists() {
379 return Some((cmd.to_string(), wd.to_path_buf()));
380 }
381 }
382
383 if let Ok(entries) = std::fs::read_dir(wd) {
384 for entry in entries.flatten() {
385 if entry.file_type().map(|t| t.is_dir()).unwrap_or(false) {
386 let sub = entry.path();
387 let name = sub.file_name().unwrap_or_default().to_string_lossy();
388 if name.starts_with('.') || name == "node_modules" || name == "target" {
389 continue;
390 }
391 for &(marker, cmd) in markers {
392 if sub.join(marker).exists() {
393 return Some((cmd.to_string(), sub));
394 }
395 }
396 }
397 }
398 }
399
400 None
401}
402
403#[cfg(test)]
404mod validate_args_tests {
405 use super::*;
406 use crate::stream::StreamEvent;
407 use std::pin::Pin;
408 use tokio::sync::mpsc;
409
410 struct StubProvider;
413
414 impl LlmProvider for StubProvider {
415 fn chat_stream(
416 &self,
417 _messages: &[crate::conversation::message::Message],
418 _tools: Option<&[crate::tool::ToolDef]>,
419 ) -> anyhow::Result<
420 Pin<
421 Box<
422 dyn futures::Stream<Item = anyhow::Result<StreamEvent>> + Send,
423 >,
424 >,
425 > {
426 unimplemented!()
427 }
428 fn model_name(&self) -> &str {
429 "stub"
430 }
431 }
432
433 fn blank_config() -> Config {
434 Config {
435 default_provider: String::new(),
436 default_workdir: None,
437 providers: std::collections::HashMap::new(),
438 datalog: Default::default(),
439 auto_update: true,
440 notifications: Default::default(),
441 telemetry: Default::default(),
442 lsp: Default::default(),
443 auto_commit: false,
444 subagent: Default::default(),
445 vision_preprocessor_provider: None,
446 language: None,
447 ui: Default::default(),
448 plugin: Default::default(),
449 }
450 }
451
452 fn tool() -> ParallelEditTool {
453 let (tx, _rx) = mpsc::unbounded_channel();
454 ParallelEditTool {
455 provider: Arc::new(StubProvider),
456 config: blank_config(),
457 event_tx: tx,
458 }
459 }
460
461 #[test]
462 fn rejects_single_file_dispatch() {
463 let args = r#"{"files":[{"path":"a.rs","instruction":"edit"}]}"#;
468 let err = tool().validate_args(args).unwrap_err();
469 assert!(err.contains("at least 2 files"), "got: {}", err);
470 }
471
472 #[test]
473 fn rejects_empty_instruction() {
474 let args = r#"{"files":[
483 {"path":"a.rs","instruction":"add field"},
484 {"path":"b.rs","instruction":" "}
485 ]}"#;
486 let err = tool().validate_args(args).unwrap_err();
487 assert!(err.contains("instruction is empty"), "got: {}", err);
488 }
489
490 #[test]
491 fn rejects_empty_path() {
492 let args = r#"{"files":[
493 {"path":"","instruction":"edit"},
494 {"path":"b.rs","instruction":"edit"}
495 ]}"#;
496 let err = tool().validate_args(args).unwrap_err();
497 assert!(err.contains("path is empty"), "got: {}", err);
498 }
499
500 #[test]
501 fn rejects_more_than_twelve_files() {
502 let files: Vec<String> = (0..13)
507 .map(|i| format!(r#"{{"path":"f{}.rs","instruction":"edit"}}"#, i))
508 .collect();
509 let args = format!(r#"{{"files":[{}]}}"#, files.join(","));
510 let err = tool().validate_args(&args).unwrap_err();
511 assert!(err.contains("capped at 12"), "got: {}", err);
512 }
513
514 #[test]
515 fn accepts_valid_two_file_dispatch() {
516 let args = r#"{"files":[
517 {"path":"a.rs","instruction":"add field X"},
518 {"path":"b.rs","instruction":"wire X into Y"}
519 ],"contract":"X is a u32"}"#;
520 assert!(tool().validate_args(args).is_ok());
521 }
522
523 #[test]
524 fn accepts_minimal_args_without_contract() {
525 let args = r#"{"files":[
528 {"path":"a.rs","instruction":"add log"},
529 {"path":"b.rs","instruction":"add log"}
530 ]}"#;
531 assert!(tool().validate_args(args).is_ok());
532 }
533
534 #[test]
535 fn rejects_unparseable_json() {
536 let args = "not json at all";
537 let err = tool().validate_args(args).unwrap_err();
538 assert!(err.contains("parallel_edit_files arguments"), "got: {}", err);
539 }
540
541 #[test]
544 fn dedup_suffix_empty_for_unique_paths() {
545 let infos = super::build_task_infos_with_dedup(&[
546 "src/server/api.rs",
547 "src/client/mod.rs",
548 "src/server/mod.rs",
549 ]);
550 for i in &infos {
551 assert_eq!(i.dedup_suffix, "", "{} should be unique", i.path);
552 }
553 }
554
555 #[test]
556 fn dedup_suffix_numbers_repeats_in_order() {
557 let infos = super::build_task_infos_with_dedup(&[
558 "src/server/tunnel.rs",
559 "src/client/tunnel.rs",
560 "src/server/tunnel.rs",
561 "src/server/tunnel.rs",
562 ]);
563 assert_eq!(infos[0].dedup_suffix, " (#1)");
564 assert_eq!(infos[1].dedup_suffix, "");
565 assert_eq!(infos[2].dedup_suffix, " (#2)");
566 assert_eq!(infos[3].dedup_suffix, " (#3)");
567 }
568
569 #[test]
570 fn dedup_suffix_preserves_input_order() {
571 let paths = ["a.rs", "b.rs", "a.rs"];
575 let infos = super::build_task_infos_with_dedup(&paths);
576 assert_eq!(infos.len(), 3);
577 assert_eq!(infos[0].path, "a.rs");
578 assert_eq!(infos[1].path, "b.rs");
579 assert_eq!(infos[2].path, "a.rs");
580 }
581}