Skip to main content

atomcode_core/agent/
sub_agent.rs

1//! Sub-agent parallel execution for multi-file tasks.
2//!
3//! Each SubAgent handles one file with its own Conversation + TurnRunner,
4//! running in parallel via tokio::JoinSet. This keeps each sub-agent's
5//! context small (~3-4K tokens) so weak models perform well.
6
7use std::sync::Arc;
8use std::time::Duration;
9
10use tokio::sync::mpsc;
11use tokio_util::sync::CancellationToken;
12
13use crate::config::Config;
14use crate::conversation::Conversation;
15use crate::provider::LlmProvider;
16use crate::tool::{ToolContext, ToolRegistry};
17use crate::turn::event::{TurnEvent, TurnResult};
18use crate::turn::permission::{AutoPermissionDecider, AutoPermissionMode};
19use crate::turn::runner::TurnRunner;
20
21/// Tunable knobs for the resilience layer of `SubAgentTask::execute`.
22/// Wired from `Config::subagent` at the call site; defaults match
23/// `SubAgentConfig::default()`.
24#[derive(Debug, Clone)]
25pub struct ResilienceConfig {
26    /// Starting per-task turn budget.
27    pub initial_turns: usize,
28    /// Hard cap regardless of progress signals.
29    pub max_turns: usize,
30    /// Minimum turns to run before honoring budget exhaustion (so a
31    /// single bad turn can't end the sub-agent prematurely).
32    pub min_turns: usize,
33    /// Budget bonus when a turn produced a successful edit.
34    pub edit_bonus: usize,
35    /// Budget penalty when no_edit_runs ≥ idle_threshold.
36    pub idle_penalty: usize,
37    /// Number of consecutive no-edit turns before idle penalty applies.
38    pub idle_threshold: usize,
39    /// Number of consecutive no-edit turns that triggers early kill
40    /// (NoProgress failure).
41    pub idle_kill_threshold: usize,
42    /// Max in-loop retries for stream-timeout class failures (network).
43    pub max_call_retries: usize,
44    /// Reads of the assigned file (with zero successful edits) that
45    /// triggers the hallucination nudge.
46    pub hallucination_read_threshold: usize,
47}
48
49impl Default for ResilienceConfig {
50    fn default() -> Self {
51        Self {
52            initial_turns: 4,
53            max_turns: 12,
54            min_turns: 2,
55            edit_bonus: 2,
56            idle_penalty: 1,
57            idle_threshold: 2,
58            idle_kill_threshold: 4,
59            max_call_retries: 1,
60            hallucination_read_threshold: 3,
61        }
62    }
63}
64
65/// Per-execute progress signals. Updated each turn by `scan_turn_signals`.
66/// Drives the adaptive budget calculation and hallucination nudge.
67#[derive(Debug, Default, Clone)]
68struct ProgressTracker {
69    /// All files that received at least one successful edit.
70    edited_files: std::collections::HashSet<String>,
71    /// Turn index of most-recent successful edit.
72    last_edit_turn: Option<usize>,
73    /// Consecutive turns with zero successful edits.
74    no_edit_runs: usize,
75    /// Per-file read-call counts (regardless of success/failure).
76    read_count: std::collections::HashMap<String, usize>,
77    /// How many stream-timeout retries have fired so far.
78    timeouts: usize,
79    /// How many hallucination nudges have been injected.
80    hallucination_nudges_sent: usize,
81}
82
83impl ProgressTracker {
84    fn observe_turn(
85        &mut self,
86        turn_idx: usize,
87        edited: &[String],
88        reads: &[String],
89    ) {
90        if !edited.is_empty() {
91            for f in edited {
92                self.edited_files.insert(f.clone());
93            }
94            self.last_edit_turn = Some(turn_idx);
95            self.no_edit_runs = 0;
96        } else {
97            self.no_edit_runs += 1;
98        }
99        for f in reads {
100            *self.read_count.entry(f.clone()).or_default() += 1;
101        }
102    }
103
104    fn budget_adjustment(&self, cfg: &ResilienceConfig) -> i32 {
105        let mut delta = 0_i32;
106        if self.last_edit_turn.is_some() {
107            delta += cfg.edit_bonus as i32;
108        }
109        if self.no_edit_runs >= cfg.idle_threshold {
110            delta -= cfg.idle_penalty as i32;
111        }
112        delta
113    }
114
115    fn hallucination_detected(
116        &self,
117        assigned_file: &str,
118        cfg: &ResilienceConfig,
119    ) -> Option<String> {
120        let count = self.read_count.get(assigned_file).copied().unwrap_or(0);
121        if count >= cfg.hallucination_read_threshold && self.edited_files.is_empty() {
122            Some(format!(
123                "You have read `{}` {} times without editing. \
124                 Stop reading; the file content is already in your prompt above. \
125                 Call edit_file NOW with old_string + new_string.",
126                assigned_file, count
127            ))
128        } else {
129            None
130        }
131    }
132}
133
134/// Walk new messages added during one turn (slice `&messages[prev_len..]`)
135/// and extract:
136///  - `edited`: files whose `edit_file` / `search_replace` call returned
137///    `success=true`. `search_replace` has no `file_path` so we record an
138///    empty string to mark "an edit occurred" (still counts toward
139///    last_edit_turn). Failed edits are excluded.
140///  - `reads`: every `read_file` call's `file_path`, regardless of result.
141fn scan_turn_signals(
142    messages: &[crate::conversation::message::Message],
143    prev_len: usize,
144) -> (Vec<String> /* edited */, Vec<String> /* reads */) {
145    use crate::conversation::message::MessageContent;
146
147    // First pass: collect (call_id → tool_name + file_path)
148    let mut call_meta: std::collections::HashMap<String, (String, String)> =
149        std::collections::HashMap::new();
150    for msg in &messages[prev_len..] {
151        if let MessageContent::AssistantWithToolCalls { tool_calls, .. } = &msg.content {
152            for tc in tool_calls {
153                let path = serde_json::from_str::<serde_json::Value>(&tc.arguments)
154                    .ok()
155                    .and_then(|v| {
156                        v.get("file_path")
157                            .and_then(|x| x.as_str())
158                            .map(str::to_string)
159                    })
160                    .unwrap_or_default();
161                call_meta.insert(tc.id.clone(), (tc.name.clone(), path));
162            }
163        }
164    }
165
166    // Second pass: pair tool_results with calls
167    let mut edited = Vec::new();
168    let mut reads = Vec::new();
169    for msg in &messages[prev_len..] {
170        if let MessageContent::ToolResult(r) = &msg.content {
171            if let Some((name, path)) = call_meta.get(&r.call_id) {
172                match name.as_str() {
173                    "edit_file" if r.success => edited.push(path.clone()),
174                    "search_replace" if r.success => edited.push(path.clone()),
175                    "read_file" => reads.push(path.clone()),
176                    _ => {}
177                }
178            }
179        }
180    }
181
182    (edited, reads)
183}
184
185/// Conservative classifier: was this runner-level error a network /
186/// transport hiccup (worth one retry) or a logic error (no point
187/// retrying)? Substring-matches a small known set.
188fn is_stream_timeout(err: &str) -> bool {
189    let lo = err.to_lowercase();
190    lo.contains("stream timeout")
191        || lo.contains("first token timeout")
192        || lo.contains("connection reset")
193        || lo.contains("eof")
194}
195
196/// Wrap a single `runner.run` call with up-to-N retries for stream-timeout
197/// class failures. Returns the final `TurnResult` plus a counter of how
198/// many retries fired (so the caller can bump `tracker.timeouts`). On
199/// retry, partial conversation state from the failed attempt is rolled
200/// back so the second attempt sends a clean prompt.
201async fn run_turn_with_retry(
202    runner: &mut TurnRunner,
203    conversation: &mut Conversation,
204    system_prompt: &str,
205    event_tx: &mpsc::UnboundedSender<TurnEvent>,
206    cancel: CancellationToken,
207    max_retries: usize,
208) -> (TurnResult, usize /* timeouts_fired */) {
209    let mut timeouts_fired = 0usize;
210    for attempt in 0..=max_retries {
211        let pre_msg_count = conversation.messages.len();
212        let result = runner
213            .run(conversation, system_prompt, event_tx, cancel.clone())
214            .await;
215        match &result {
216            TurnResult::Failed(err) if is_stream_timeout(err) && attempt < max_retries => {
217                timeouts_fired += 1;
218                // Roll conversation back to pre-attempt state so retry
219                // sends a clean prompt instead of a half-filled assistant
220                // message.
221                conversation.messages.truncate(pre_msg_count);
222                conversation.clear_stream_buffer();
223                continue;
224            }
225            _ => return (result, timeouts_fired),
226        }
227    }
228    unreachable!("run_turn_with_retry loop must exit via the inner return")
229}
230
231/// Construct a human-readable summary of what the sub-agent did.
232/// Replaces the previous "first 200 chars of last_text" approach with
233/// a compact, signal-aware multi-part line.
234fn build_summary(
235    assigned: &str,
236    tracker: &ProgressTracker,
237    last_text: &str,
238) -> String {
239    let mut parts: Vec<String> = Vec::new();
240    if tracker.edited_files.is_empty() {
241        parts.push(format!("Did not edit `{}`", assigned));
242    } else {
243        let edited: Vec<&str> = tracker.edited_files.iter().map(|s| s.as_str()).collect();
244        parts.push(format!(
245            "Edited {} file(s): {}",
246            edited.len(),
247            edited.join(", ")
248        ));
249    }
250    if tracker.timeouts > 0 {
251        parts.push(format!("{} timeout(s) recovered", tracker.timeouts));
252    }
253    if tracker.hallucination_nudges_sent > 0 {
254        parts.push(format!(
255            "{} hallucination nudge(s) sent",
256            tracker.hallucination_nudges_sent
257        ));
258    }
259    if !last_text.is_empty() {
260        let snippet: String = last_text.chars().take(120).collect();
261        parts.push(format!("model said: {}", snippet));
262    }
263    parts.join(" · ")
264}
265
266/// A single sub-agent task: one file to modify.
267pub struct SubAgentTask {
268    pub file_path: String,
269    pub file_content: String,
270    pub task_instruction: String,
271    pub contract: String,
272    pub sibling_skeletons: String,
273}
274
275/// Structured reason a sub-agent ended in failure. Replaces the old
276/// `errors: Vec<String>` so callers can match on discriminant instead
277/// of substring-matching on free text.
278#[derive(Debug, Clone)]
279pub enum SubAgentFailure {
280    /// Stream-timeout-class network failure that survived one in-loop retry.
281    StreamTimeoutAfterRetry,
282    /// Model read the same file ≥ `hallucination_read_threshold` times
283    /// without producing a successful edit, AND the recovery turn after
284    /// the nudge also failed to edit.
285    HallucinationLoop { reads: usize, file: String },
286    /// `no_edit_runs` reached `idle_kill_threshold`. May or may not be
287    /// preceded by a hallucination nudge.
288    NoProgress { idle_turns: usize },
289    /// Loop exited because turn budget was exhausted and zero edits had
290    /// landed. (If edits landed, exit is treated as success.)
291    BudgetExhaustedNoEdits,
292    /// Pool wall-time wrapper (default 5min) tripped.
293    SubAgentTimeout5min,
294    /// Provider returned a non-timeout-class error (network, 4xx/5xx).
295    ProviderError(String),
296    /// `tokio::JoinSet` reported the task panicked or was cancelled at
297    /// the runtime level.
298    JoinError(String),
299    /// User pressed Ctrl+C / cancellation token tripped.
300    Cancelled,
301}
302
303/// Per-task instrumentation snapshot. Populated as `ProgressTracker`
304/// observes turns; surfaced on `SubAgentResult` so the parent agent
305/// (and operators reading datalog) can diagnose without re-deriving
306/// from raw conversation history.
307#[derive(Debug, Clone, Default)]
308pub struct Diagnostic {
309    pub edited_files: Vec<String>,
310    pub read_counts: std::collections::HashMap<String, usize>,
311    pub timeouts: usize,
312    pub hallucination_nudges_sent: usize,
313    pub final_budget: usize,
314    pub turns_used: usize,
315}
316
317/// Result of a sub-agent execution.
318#[derive(Debug, Clone)]
319pub struct SubAgentResult {
320    pub file_path: String,
321    pub success: bool,
322    pub turns_used: usize,
323    pub summary: String,
324    pub failures: Vec<SubAgentFailure>,
325    pub diagnostic: Diagnostic,
326}
327
328/// Tool wrapper that delegates to an inner `ReadFileTool` but rejects any
329/// `read_file` whose `file_path` arg differs from `assigned_file`. Used by
330/// `filter_tools_for_subagent` to keep sub-agents from drifting into
331/// sibling exploration.
332struct ScopedReadFile {
333    inner: Arc<dyn crate::tool::Tool>,
334    assigned_file: String,
335}
336
337#[async_trait::async_trait]
338impl crate::tool::Tool for ScopedReadFile {
339    fn definition(&self) -> crate::tool::ToolDef {
340        // Delegate; the LLM sees the same schema as a normal read_file.
341        self.inner.definition()
342    }
343
344    fn approval(&self, args: &str) -> crate::tool::ApprovalRequirement {
345        self.inner.approval(args)
346    }
347
348    fn approval_with_context(
349        &self,
350        args: &str,
351        ctx: &crate::tool::ToolContext,
352    ) -> crate::tool::ApprovalRequirement {
353        self.inner.approval_with_context(args, ctx)
354    }
355
356    fn validate_args(&self, args: &str) -> std::result::Result<(), String> {
357        // First: inner schema check.
358        self.inner.validate_args(args)?;
359        // Second: scope check. Parse args to peek at file_path.
360        let parsed: serde_json::Value = serde_json::from_str(args)
361            .map_err(|e| format!("scope check parse: {e}"))?;
362        let path = parsed
363            .get("file_path")
364            .and_then(|v| v.as_str())
365            .unwrap_or("");
366        if path != self.assigned_file {
367            return Err(format!(
368                "Sub-agent only reads its assigned file `{}`. \
369                 Sibling content is in your prompt's skeleton section; \
370                 do not call read_file for path `{}`.",
371                self.assigned_file, path
372            ));
373        }
374        Ok(())
375    }
376
377    async fn execute(
378        &self,
379        args: &str,
380        ctx: &crate::tool::ToolContext,
381    ) -> anyhow::Result<crate::tool::ToolResult> {
382        self.inner.execute(args, ctx).await
383    }
384}
385
386/// Build a sub-agent ToolRegistry by selecting only whitelisted tools from
387/// the parent. `read_file` is wrapped in `ScopedReadFile` so it can only
388/// read `assigned_file`. All other tools (bash, web_*, change_dir, glob,
389/// list_directory, write_file, grep) are absent from the result —
390/// the runner's "tool not registered" path returns a structured error to
391/// the model, which routes back through the LLM as a re-think signal.
392///
393/// Async because the parent registry's lock is `tokio::sync::RwLock`. Called
394/// from `SubAgentTask::execute` (Task 8 wiring) which is itself async.
395async fn filter_tools_for_subagent(
396    parent: &ToolRegistry,
397    assigned_file: &str,
398) -> ToolRegistry {
399    let filtered = ToolRegistry::new();
400    for (name, tool) in parent.iter().await {
401        match name.as_str() {
402            "read_file" => {
403                let scoped = ScopedReadFile {
404                    inner: tool,
405                    assigned_file: assigned_file.to_string(),
406                };
407                filtered
408                    .register_arc("read_file".to_string(), Arc::new(scoped))
409                    .await;
410            }
411            "edit_file" | "search_replace" => {
412                filtered.register_arc(name, tool).await;
413            }
414            _ => {} // blacklist by omission
415        }
416    }
417    filtered
418}
419
420impl SubAgentTask {
421    /// Execute this sub-agent task with its own Conversation + TurnRunner.
422    /// Runs up to `max_turns` LLM round-trips. Auto-approves all tools.
423    pub async fn execute(
424        &self,
425        provider: Arc<dyn LlmProvider>,
426        tools: Arc<ToolRegistry>,
427        config: &Config,
428        working_dir: &std::path::Path,
429        max_turns: usize,
430    ) -> SubAgentResult {
431        // 1. Build minimal system prompt
432        let rules = crate::config::prompt_sections::build_rules();
433        let vue_warning = if self.file_path.ends_with(".vue") || self.file_path.ends_with(".svelte")
434        {
435            "\nCRITICAL: This is a Vue SFC. Edit <script> and <template> in SEPARATE edit_file calls. \
436             Use old_string/new_string for each edit. Keep each edit focused on one region."
437        } else {
438            ""
439        };
440
441        let system_prompt = format!(
442            "{}\n\n## SUB-AGENT RULES\n\
443             You are a sub-agent. Your ONLY job: edit `{}`.\n\
444             The file content is provided below — do NOT read_file, you already have it.\n\
445             Call edit_file IMMEDIATELY on your first turn. Do NOT analyze, summarize, or plan.\n\
446             Use old_string/new_string to find and replace text. One edit per call.\n\
447             You are responsible for ONE file only. Ignore other files.{}",
448            rules, self.file_path, vue_warning,
449        );
450
451        // 2. Create fresh Conversation with injected context
452        let mut conversation = Conversation::new();
453        let user_message = format!(
454            "## Task\n{}\n\n## Contract\n{}\n\n## File: {}\n```\n{}\n```\n\n## Sibling files (skeleton)\n{}",
455            self.task_instruction,
456            self.contract,
457            self.file_path,
458            self.file_content,
459            self.sibling_skeletons,
460        );
461        conversation.add_user_message(&user_message);
462
463        // 3. Create isolated ToolContext + TurnRunner
464        let tool_ctx = ToolContext::new(working_dir.to_path_buf());
465        let permission = Box::new(AutoPermissionDecider::new(AutoPermissionMode::BypassAll));
466
467        // Pick the same ctx strategy the parent AgentLoop would. Sub-agents
468        // run on the same provider, so `for_provider` returns the matching
469        // builder (DefaultCtx / OllamaCtx / future per-model strategies).
470        // Falls back to a synthetic 128K-window config if the provider name
471        // isn't in the config — matches AgentLoop::new's fallback.
472        let build_ctx = match config.providers.get(&config.default_provider) {
473            Some(pc) => crate::ctx::for_provider(pc),
474            None => crate::ctx::for_provider(&crate::config::provider::ProviderConfig {
475                provider_type: String::new(),
476                api_key: None,
477                model: String::new(),
478                base_url: None,
479                system_prompt: None,
480                user_agent: None,
481                context_window: 128_000,
482                max_tokens: None,
483                thinking_type: None,
484                thinking_keep: None,
485                reasoning_history: None,
486                thinking_enabled: None,
487                thinking_budget: None,
488                skip_tls_verify: false,
489                ephemeral: true,
490
491}),
492        };
493
494        // Sandbox: filter parent tools to the sub-agent whitelist
495        // (edit_file, search_replace, scoped read_file). Hands the
496        // filtered registry to the runner so blacklisted tools (bash,
497        // web_*, glob, list_directory, change_dir, write_file, grep)
498        // are absent — the runner returns "tool not registered" to the
499        // model, which routes it back via re-think.
500        let sandboxed_tools =
501            Arc::new(filter_tools_for_subagent(&tools, &self.file_path).await);
502
503        let hooks = crate::hook::json_config::load_hooks_config(working_dir);
504        let mut runner = TurnRunner {
505            provider,
506            tools: sandboxed_tools,
507            context: tool_ctx,
508            config: config.clone(),
509            ctx: build_ctx,
510            permission,
511            recently_edited_files: Vec::new(),
512            hook_executor: std::sync::Arc::new(
513                crate::hook::executor::HookExecutor::new(hooks)
514            ),
515            loop_guard: Default::default(),
516        };
517
518        // 4. Event channel (we drain but don't forward — sub-agent is silent)
519        let (event_tx, mut event_rx) = mpsc::unbounded_channel::<TurnEvent>();
520        let cancel = CancellationToken::new();
521
522        // 5. Run loop with resilience layer
523        let res_cfg = ResilienceConfig::default();
524        let mut tracker = ProgressTracker::default();
525        let cap = max_turns.min(res_cfg.max_turns);
526        let mut dynamic_budget = res_cfg.initial_turns as i32;
527        let mut last_text = String::new();
528        let mut failures: Vec<SubAgentFailure> = Vec::new();
529        let mut turns_used = 0usize;
530
531        for turn in 0..cap {
532            // 1. Idle kill check (hard exit, no recovery)
533            if tracker.no_edit_runs >= res_cfg.idle_kill_threshold {
534                failures.push(SubAgentFailure::NoProgress {
535                    idle_turns: tracker.no_edit_runs,
536                });
537                break;
538            }
539
540            // 2. Pre-turn hallucination check
541            if let Some(nudge) = tracker.hallucination_detected(&self.file_path, &res_cfg) {
542                conversation.add_user_message(&nudge);
543                tracker.hallucination_nudges_sent += 1;
544                // Grace turn so nudge has recovery room before budget check
545                dynamic_budget += 1;
546            }
547
548            // 3. Budget exhaustion check
549            if turn as i32 >= dynamic_budget && turn >= res_cfg.min_turns {
550                if tracker.edited_files.is_empty() {
551                    failures.push(SubAgentFailure::BudgetExhaustedNoEdits);
552                }
553                break;
554            }
555
556            // 4. Run turn with retry (1× for stream-timeout class)
557            let pre_msg_count = conversation.messages.len();
558            let (result, timeouts_fired) = run_turn_with_retry(
559                &mut runner,
560                &mut conversation,
561                &system_prompt,
562                &event_tx,
563                cancel.clone(),
564                res_cfg.max_call_retries,
565            )
566            .await;
567            tracker.timeouts += timeouts_fired;
568            turns_used = turn + 1;
569
570            // Drain any UI events the runner emitted (we don't forward them).
571            while event_rx.try_recv().is_ok() {}
572
573            // 5. Process result
574            match result {
575                TurnResult::Responded { text, .. } => {
576                    last_text = text;
577                    break;
578                }
579                TurnResult::UsedTools { text, .. } => {
580                    if let Some(t) = text {
581                        last_text = t;
582                    }
583                    let (edited, reads) =
584                        scan_turn_signals(&conversation.messages, pre_msg_count);
585                    tracker.observe_turn(turn, &edited, &reads);
586                    let delta = tracker.budget_adjustment(&res_cfg);
587                    dynamic_budget = (dynamic_budget + delta)
588                        .max(res_cfg.min_turns as i32)
589                        .min(res_cfg.max_turns as i32);
590                }
591                TurnResult::Failed(err) if is_stream_timeout(&err) => {
592                    failures.push(SubAgentFailure::StreamTimeoutAfterRetry);
593                    break;
594                }
595                TurnResult::Failed(err) => {
596                    failures.push(SubAgentFailure::ProviderError(err));
597                    break;
598                }
599                TurnResult::Cancelled => {
600                    failures.push(SubAgentFailure::Cancelled);
601                    break;
602                }
603            }
604        }
605
606        // 6. Build diagnostic + summary
607        let diagnostic = Diagnostic {
608            edited_files: tracker.edited_files.iter().cloned().collect(),
609            read_counts: tracker.read_count.clone(),
610            timeouts: tracker.timeouts,
611            hallucination_nudges_sent: tracker.hallucination_nudges_sent,
612            final_budget: dynamic_budget.max(0) as usize,
613            turns_used,
614        };
615        let success = !tracker.edited_files.is_empty() && failures.is_empty();
616        let summary = build_summary(&self.file_path, &tracker, &last_text);
617
618        SubAgentResult {
619            file_path: self.file_path.clone(),
620            success,
621            turns_used,
622            summary,
623            failures,
624            diagnostic,
625        }
626    }
627}
628
629/// Pool that runs multiple SubAgentTasks in parallel with concurrency limits.
630pub struct SubAgentPool {
631    pub tasks: Vec<SubAgentTask>,
632    pub max_concurrent: usize,
633    pub timeout_secs: u64,
634}
635
636impl SubAgentPool {
637    pub fn new(tasks: Vec<SubAgentTask>) -> Self {
638        Self {
639            tasks,
640            max_concurrent: 3,
641            timeout_secs: 300,
642        }
643    }
644
645    /// Execute all tasks in parallel, streaming progress events.
646    pub async fn execute_all(
647        self,
648        provider: Arc<dyn LlmProvider>,
649        tools: Arc<ToolRegistry>,
650        config: &Config,
651        working_dir: &std::path::Path,
652        event_tx: &tokio::sync::mpsc::UnboundedSender<super::AgentEvent>,
653    ) -> Vec<SubAgentResult> {
654        use tokio::task::JoinSet;
655
656        let timeout = Duration::from_secs(self.timeout_secs);
657        let total = self.tasks.len();
658        let mut results: Vec<SubAgentResult> = Vec::with_capacity(total);
659
660        // Process in batches of max_concurrent. Index is preserved across
661        // batches via the outer `task_idx` counter — UI uses it to find
662        // each task's display slot, so it must match the original
663        // dispatch order regardless of which batch the task lands in.
664        let mut chunks = self.tasks.into_iter().enumerate().peekable();
665        while chunks.peek().is_some() {
666            let batch: Vec<(usize, SubAgentTask)> =
667                (&mut chunks).take(self.max_concurrent).collect();
668            let mut set = JoinSet::new();
669
670            for (task_idx, task) in batch {
671                let provider = provider.clone();
672                let tools = tools.clone();
673                let config = config.clone();
674                let working_dir = working_dir.to_path_buf();
675                let tx = event_tx.clone();
676                let file_path_for_err = task.file_path.clone();
677
678                set.spawn(async move {
679                    let _ = tx.send(super::AgentEvent::SubAgentTaskStarted { index: task_idx });
680                    let start = std::time::Instant::now();
681
682                    let result = tokio::time::timeout(
683                        timeout,
684                        task.execute(provider, tools, &config, &working_dir, 5),
685                    )
686                    .await;
687
688                    let elapsed_ms = start.elapsed().as_millis() as u64;
689                    match &result {
690                        Ok(r) => {
691                            if r.success {
692                                let _ = tx.send(super::AgentEvent::SubAgentTaskDone {
693                                    index: task_idx,
694                                    elapsed_ms,
695                                    turns: r.turns_used,
696                                    summary: r.summary.clone(),
697                                });
698                            } else {
699                                let _ = tx.send(super::AgentEvent::SubAgentTaskFailed {
700                                    index: task_idx,
701                                    elapsed_ms,
702                                    turns: r.turns_used,
703                                    reason: r.summary.clone(),
704                                });
705                            }
706                        }
707                        Err(_) => {
708                            let _ = tx.send(super::AgentEvent::SubAgentTaskFailed {
709                                index: task_idx,
710                                elapsed_ms,
711                                turns: 0,
712                                reason: "timeout".to_string(),
713                            });
714                        }
715                    }
716                    (file_path_for_err, result)
717                });
718            }
719
720            while let Some(join_result) = set.join_next().await {
721                match join_result {
722                    Ok((_, Ok(result))) => results.push(result),
723                    Ok((name, Err(_timeout))) => {
724                        results.push(SubAgentResult {
725                            file_path: name,
726                            success: false,
727                            turns_used: 0,
728                            summary: "Timed out".to_string(),
729                            failures: vec![SubAgentFailure::SubAgentTimeout5min],
730                            diagnostic: Diagnostic::default(),
731                        });
732                    }
733                    Err(join_err) => {
734                        results.push(SubAgentResult {
735                            file_path: "unknown".to_string(),
736                            success: false,
737                            turns_used: 0,
738                            summary: "Task panicked".to_string(),
739                            failures: vec![SubAgentFailure::JoinError(format!("{}", join_err))],
740                            diagnostic: Diagnostic::default(),
741                        });
742                    }
743                }
744            }
745        }
746
747        results
748    }
749}
750
751#[cfg(test)]
752mod tests {
753    use super::*;
754
755    #[test]
756    fn sub_agent_pool_creation() {
757        let pool = SubAgentPool::new(vec![
758            SubAgentTask {
759                file_path: "TopBar.vue".to_string(),
760                file_content: "<template>...</template>".to_string(),
761                task_instruction: "美化样式".to_string(),
762                contract: "emit('toggleSidebar')".to_string(),
763                sibling_skeletons: "App.vue: ...".to_string(),
764            },
765            SubAgentTask {
766                file_path: "Sidebar.vue".to_string(),
767                file_content: "<template>...</template>".to_string(),
768                task_instruction: "美化样式".to_string(),
769                contract: "props: { collapsed: Boolean }".to_string(),
770                sibling_skeletons: "App.vue: ...".to_string(),
771            },
772        ]);
773        assert_eq!(pool.tasks.len(), 2);
774        assert_eq!(pool.max_concurrent, 3);
775        assert_eq!(pool.timeout_secs, 300);
776    }
777
778    #[test]
779    fn scoped_read_file_rejects_sibling_path() {
780        use crate::tool::Tool;
781        let inner = Arc::new(crate::tool::read::ReadFileTool) as Arc<dyn Tool>;
782        let scoped = ScopedReadFile {
783            inner,
784            assigned_file: "/work/a.rs".to_string(),
785        };
786        let err = scoped
787            .validate_args(r#"{"file_path":"/work/b.rs"}"#)
788            .unwrap_err();
789        assert!(
790            err.contains("only reads its assigned file"),
791            "expected scope rejection, got: {err}"
792        );
793    }
794
795    #[test]
796    fn scoped_read_file_allows_assigned_path() {
797        use crate::tool::Tool;
798        let inner = Arc::new(crate::tool::read::ReadFileTool) as Arc<dyn Tool>;
799        let scoped = ScopedReadFile {
800            inner,
801            assigned_file: "/work/a.rs".to_string(),
802        };
803        assert!(
804            scoped
805                .validate_args(r#"{"file_path":"/work/a.rs"}"#)
806                .is_ok(),
807            "assigned file must pass"
808        );
809    }
810
811    #[tokio::test]
812    async fn filter_tools_for_subagent_keeps_only_whitelisted() {
813        let parent = make_full_tool_registry();
814        let filtered = filter_tools_for_subagent(&parent, "/work/a.rs").await;
815        let names = collect_tool_names(&filtered).await;
816        // Allowed:
817        assert!(names.contains(&"edit_file".to_string()));
818        assert!(names.contains(&"search_replace".to_string()));
819        assert!(names.contains(&"read_file".to_string()));
820        // Blocked:
821        assert!(!names.contains(&"bash".to_string()));
822        assert!(!names.contains(&"web_fetch".to_string()));
823        assert!(!names.contains(&"glob".to_string()));
824        assert!(!names.contains(&"list_directory".to_string()));
825        assert!(!names.contains(&"change_dir".to_string()));
826    }
827
828    /// Test helper: build a registry with all common tools.
829    fn make_full_tool_registry() -> ToolRegistry {
830        let mut r = ToolRegistry::new();
831        r.register_sync(Box::new(crate::tool::read::ReadFileTool));
832        r.register_sync(Box::new(crate::tool::write::WriteFileTool));
833        r.register_sync(Box::new(crate::tool::edit::EditFileTool));
834        r.register_sync(Box::new(crate::tool::bash::BashTool));
835        r.register_sync(Box::new(crate::tool::cd::CdTool));
836        r.register_sync(Box::new(crate::tool::grep::GrepTool));
837        r.register_sync(Box::new(crate::tool::glob::GlobTool));
838        r.register_sync(Box::new(crate::tool::list_dir::ListDirTool));
839        r.register_sync(Box::new(crate::tool::web_fetch::WebFetchTool));
840        r.register_sync(Box::new(crate::tool::search_replace::SearchReplaceTool));
841        r
842    }
843
844    /// Test helper: collect tool names from a registry via the public iter API.
845    async fn collect_tool_names(r: &ToolRegistry) -> Vec<String> {
846        r.iter().await.map(|(name, _)| name).collect()
847    }
848
849    #[test]
850    fn is_stream_timeout_matches_known_phrases() {
851        assert!(is_stream_timeout("stream timeout after 60s"));
852        assert!(is_stream_timeout("First token timeout"));
853        assert!(is_stream_timeout("connection reset by peer"));
854        assert!(is_stream_timeout("Unexpected EOF"));
855        // Case insensitive
856        assert!(is_stream_timeout("STREAM TIMEOUT"));
857    }
858
859    #[test]
860    fn is_stream_timeout_rejects_other_errors() {
861        assert!(!is_stream_timeout("401 Unauthorized"));
862        assert!(!is_stream_timeout("missing field `content`"));
863        assert!(!is_stream_timeout("Tool 'foo' was denied by the user"));
864        assert!(!is_stream_timeout(""));
865    }
866
867    use crate::conversation::message::{Message, MessageContent, Role};
868    use crate::tool::{ToolCall, ToolResult};
869
870    fn make_assistant_with_tool_call(call_id: &str, name: &str, args: &str) -> Message {
871        Message {
872            role: Role::Assistant,
873            content: MessageContent::AssistantWithToolCalls {
874                text: None,
875                tool_calls: vec![ToolCall {
876                    id: call_id.into(),
877                    name: name.into(),
878                    arguments: args.into(),
879                }],
880                reasoning_content: None,
881                thinking_blocks: Vec::new(),
882            },
883        }
884    }
885
886    fn make_tool_result(call_id: &str, success: bool, output: &str) -> Message {
887        Message {
888            role: Role::Tool,
889            content: MessageContent::ToolResult(ToolResult {
890                call_id: call_id.into(),
891                output: output.into(),
892                success,
893            }),
894        }
895    }
896
897    #[test]
898    fn scan_signals_counts_successful_edit_only() {
899        let msgs = vec![
900            make_assistant_with_tool_call("c1", "edit_file", r#"{"file_path":"/a.rs"}"#),
901            make_tool_result("c1", true, "Edited /a.rs"),
902        ];
903        let (edited, reads) = scan_turn_signals(&msgs, 0);
904        assert_eq!(edited, vec!["/a.rs".to_string()]);
905        assert!(reads.is_empty());
906    }
907
908    #[test]
909    fn scan_signals_failed_edit_not_counted() {
910        let msgs = vec![
911            make_assistant_with_tool_call("c1", "edit_file", r#"{"file_path":"/a.rs"}"#),
912            make_tool_result("c1", false, "old_string not found"),
913        ];
914        let (edited, _reads) = scan_turn_signals(&msgs, 0);
915        assert!(edited.is_empty(), "failed edits must not count");
916    }
917
918    #[test]
919    fn scan_signals_counts_read_regardless_of_success() {
920        let msgs = vec![
921            make_assistant_with_tool_call("c1", "read_file", r#"{"file_path":"/a.rs"}"#),
922            make_tool_result("c1", true, "..."),
923        ];
924        let (edited, reads) = scan_turn_signals(&msgs, 0);
925        assert!(edited.is_empty());
926        assert_eq!(reads, vec!["/a.rs".to_string()]);
927    }
928
929    #[test]
930    fn scan_signals_counts_search_replace_as_edit() {
931        let msgs = vec![
932            make_assistant_with_tool_call(
933                "c1",
934                "search_replace",
935                r#"{"search":"a","replace":"b"}"#,
936            ),
937            make_tool_result("c1", true, "modified 3 files"),
938        ];
939        let (edited, _reads) = scan_turn_signals(&msgs, 0);
940        // search_replace has no file_path; we record empty string to mark
941        // "an edit occurred" (still counts toward last_edit_turn).
942        assert_eq!(edited.len(), 1);
943    }
944
945    #[test]
946    fn scan_signals_respects_prev_len_offset() {
947        let msgs = vec![
948            make_assistant_with_tool_call("c0", "read_file", r#"{"file_path":"/a.rs"}"#),
949            make_tool_result("c0", true, "..."),
950            make_assistant_with_tool_call("c1", "edit_file", r#"{"file_path":"/a.rs"}"#),
951            make_tool_result("c1", true, "Edited"),
952        ];
953        // Look at only the second pair (prev_len=2)
954        let (edited, reads) = scan_turn_signals(&msgs, 2);
955        assert_eq!(edited, vec!["/a.rs".to_string()]);
956        assert!(reads.is_empty());
957    }
958
959    #[test]
960    fn progress_tracker_increments_on_successful_edit() {
961        let mut t = ProgressTracker::default();
962        t.observe_turn(0, &["a.rs".into()], &[]);
963        assert_eq!(t.last_edit_turn, Some(0));
964        assert_eq!(t.no_edit_runs, 0);
965        assert!(t.edited_files.contains("a.rs"));
966    }
967
968    #[test]
969    fn progress_tracker_failed_edit_doesnt_count() {
970        // Failed edit means scan_turn_signals returns it as nothing.
971        // The contract is "edited slice = SUCCESSFUL edits only".
972        let mut t = ProgressTracker::default();
973        t.observe_turn(0, &[], &[]);
974        assert_eq!(t.last_edit_turn, None);
975        assert_eq!(t.no_edit_runs, 1);
976    }
977
978    #[test]
979    fn progress_tracker_idle_runs_reset_on_edit() {
980        let mut t = ProgressTracker::default();
981        t.observe_turn(0, &[], &[]);
982        t.observe_turn(1, &[], &[]);
983        assert_eq!(t.no_edit_runs, 2);
984        t.observe_turn(2, &["a.rs".into()], &[]);
985        assert_eq!(t.no_edit_runs, 0);
986    }
987
988    #[test]
989    fn hallucination_detected_at_3_reads_no_edit() {
990        let cfg = ResilienceConfig::default();
991        let mut t = ProgressTracker::default();
992        t.observe_turn(0, &[], &["a.rs".into()]);
993        t.observe_turn(1, &[], &["a.rs".into()]);
994        assert!(t.hallucination_detected("a.rs", &cfg).is_none());
995        t.observe_turn(2, &[], &["a.rs".into()]);
996        let nudge = t.hallucination_detected("a.rs", &cfg);
997        assert!(nudge.is_some());
998        assert!(nudge.unwrap().contains("Stop reading"));
999    }
1000
1001    #[test]
1002    fn hallucination_not_detected_when_already_edited() {
1003        let cfg = ResilienceConfig::default();
1004        let mut t = ProgressTracker::default();
1005        t.observe_turn(0, &["a.rs".into()], &["a.rs".into()]);
1006        t.observe_turn(1, &[], &["a.rs".into()]);
1007        t.observe_turn(2, &[], &["a.rs".into()]);
1008        // 3 reads but already 1 edit — no nudge
1009        assert!(t.hallucination_detected("a.rs", &cfg).is_none());
1010    }
1011
1012    #[test]
1013    fn budget_adjustment_combines_signals() {
1014        let cfg = ResilienceConfig::default();
1015        let mut t = ProgressTracker::default();
1016
1017        // No edits, no idle → 0
1018        assert_eq!(t.budget_adjustment(&cfg), 0);
1019
1020        // Edit happened, no idle → +edit_bonus
1021        t.observe_turn(0, &["a.rs".into()], &[]);
1022        assert_eq!(t.budget_adjustment(&cfg), cfg.edit_bonus as i32);
1023
1024        // After 2 idle turns → -idle_penalty (last_edit still set, so still +bonus)
1025        t.observe_turn(1, &[], &[]);
1026        t.observe_turn(2, &[], &[]);
1027        let delta = t.budget_adjustment(&cfg);
1028        assert_eq!(
1029            delta,
1030            cfg.edit_bonus as i32 - cfg.idle_penalty as i32,
1031            "edit happened earlier (+bonus) AND idle threshold hit (-penalty)"
1032        );
1033    }
1034
1035    #[test]
1036    fn resilience_config_default_values_sensible() {
1037        let cfg = ResilienceConfig::default();
1038        assert_eq!(cfg.initial_turns, 4);
1039        assert_eq!(cfg.max_turns, 12);
1040        assert_eq!(cfg.min_turns, 2);
1041        assert_eq!(cfg.edit_bonus, 2);
1042        assert_eq!(cfg.idle_penalty, 1);
1043        assert_eq!(cfg.idle_threshold, 2);
1044        assert_eq!(cfg.idle_kill_threshold, 4);
1045        assert_eq!(cfg.max_call_retries, 1);
1046        assert_eq!(cfg.hallucination_read_threshold, 3);
1047    }
1048}