Skip to main content

oxi_agent/tools/
subagent.rs

1/// Subagent tool — delegate tasks to specialized agents
2/// Spawns a separate `oxi --mode json` process for each invocation,
3/// giving it an isolated context window.
4/// Supports three modes:
5///   - Single: { agent: "name", task: "..." }
6///   - Parallel: { tasks: [{ agent, task }, ...] }
7///   - Chain: { chain: [{ agent, task: "... {previous} ..." }, ...] }
8///     Agent definitions are markdown files with YAML frontmatter,
9///     discovered from `~/.oxi/agents/` (user) and `.oxi/agents/` (project).
10use super::{AgentTool, AgentToolResult, ProgressCallback, ToolContext, ToolError};
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use std::collections::HashMap;
15use std::path::{Path, PathBuf};
16use tokio::io::{AsyncBufReadExt, BufReader};
17use tokio::sync::oneshot;
18
19// ── Constants ──────────────────────────────────────────────────────────
20
21const MAX_PARALLEL_TASKS: usize = 8;
22const MAX_CONCURRENCY: usize = 4;
23
24// ── Progress callback type (reuse from tools.rs) ──────────────────────
25
26type ProgressFn = ProgressCallback;
27
28// ── Temp dir helper (no RAII — let OS clean up after subprocess exits) ──
29
30fn create_system_prompt_temp_dir(prefix: &str) -> Result<PathBuf, String> {
31    let path = std::env::temp_dir().join(format!("{}-{}", prefix, uuid::Uuid::new_v4()));
32    std::fs::create_dir_all(&path).map_err(|e| format!("Failed to create temp dir: {}", e))?;
33    Ok(path)
34}
35
36// ── Agent Discovery ────────────────────────────────────────────────────
37
38/// Agent scope for discovery.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40#[serde(rename_all = "snake_case")]
41#[derive(Default)]
42pub enum AgentScope {
43    /// Only user-level agents (~/.oxi/agents/)
44    #[default]
45    User,
46    /// Only project-level agents (.oxi/agents/)
47    Project,
48    /// Both user and project agents
49    Both,
50}
51
52/// A discovered agent definition.
53#[derive(Debug, Clone)]
54pub struct AgentConfig {
55    /// pub.
56    pub name: String,
57    /// pub.
58    pub description: String,
59    /// pub.
60    pub model: Option<String>,
61    /// pub.
62    pub tools: Option<Vec<String>>,
63    /// pub.
64    pub system_prompt: String,
65    /// pub.
66    pub source: String, // "user" or "project"
67}
68
69/// Discover agents from user and/or project directories.
70pub fn discover_agents(cwd: &Path, scope: AgentScope) -> Vec<AgentConfig> {
71    let mut agents = Vec::new();
72    let mut seen_names = std::collections::HashSet::new();
73
74    // User-level agents
75    if scope == AgentScope::User || scope == AgentScope::Both {
76        if let Some(home) = dirs::home_dir() {
77            let user_dir = home.join(".oxi").join("agents");
78            load_agents_from_dir(&user_dir, "user", &mut agents, &mut seen_names);
79        }
80    }
81
82    // Project-level agents (walk up to .git boundary)
83    if scope == AgentScope::Project || scope == AgentScope::Both {
84        if let Some(project_dir) = find_project_agents_dir(cwd) {
85            load_agents_from_dir(&project_dir, "project", &mut agents, &mut seen_names);
86        }
87    }
88
89    agents
90}
91
92/// Walk up from `cwd` to find `.oxi/agents/`.
93/// Stops at `.git` boundary (project root). Returns None if not found.
94fn find_project_agents_dir(cwd: &Path) -> Option<PathBuf> {
95    let mut current = cwd;
96    loop {
97        let candidate = current.join(".oxi").join("agents");
98        if candidate.is_dir() {
99            return Some(candidate);
100        }
101        // .git marks project root — don't go higher
102        if current.join(".git").exists() {
103            return None;
104        }
105        current = current.parent()?;
106    }
107}
108
109fn load_agents_from_dir(
110    dir: &Path,
111    source: &str,
112    agents: &mut Vec<AgentConfig>,
113    seen: &mut std::collections::HashSet<String>,
114) {
115    let entries = match std::fs::read_dir(dir) {
116        Ok(e) => e,
117        Err(_) => return,
118    };
119
120    for entry in entries.flatten() {
121        let path = entry.path();
122        if path.extension().and_then(|e| e.to_str()) != Some("md") {
123            continue;
124        }
125
126        let name = path
127            .file_stem()
128            .and_then(|s| s.to_str())
129            .unwrap_or("")
130            .to_string();
131
132        if name.is_empty() || seen.contains(&name) {
133            continue;
134        }
135
136        match parse_agent_file(&path) {
137            Ok(config) => {
138                seen.insert(name.clone());
139                let mut config = config;
140                config.source = source.to_string();
141                agents.push(config);
142            }
143            Err(e) => {
144                tracing::warn!("Failed to parse agent {}: {}", path.display(), e);
145            }
146        }
147    }
148}
149
150/// Parse an agent markdown file with optional YAML frontmatter.
151fn parse_agent_file(path: &Path) -> Result<AgentConfig, String> {
152    let content = std::fs::read_to_string(path).map_err(|e| format!("Failed to read: {}", e))?;
153
154    let (frontmatter, body) = parse_frontmatter(&content);
155
156    let name = frontmatter.get("name").cloned().unwrap_or_else(|| {
157        path.file_stem()
158            .and_then(|s| s.to_str())
159            .unwrap_or("unknown")
160            .to_string()
161    });
162
163    let description = frontmatter.get("description").cloned().unwrap_or_default();
164
165    let model = frontmatter.get("model").cloned();
166
167    let tools = frontmatter.get("tools").map(|s| {
168        s.split(',')
169            .map(|t| t.trim().to_string())
170            .filter(|t| !t.is_empty())
171            .collect()
172    });
173
174    Ok(AgentConfig {
175        name,
176        description,
177        model,
178        tools,
179        system_prompt: body.trim().to_string(),
180        source: String::new(),
181    })
182}
183
184/// Parse YAML frontmatter from markdown content.
185fn parse_frontmatter(content: &str) -> (HashMap<String, String>, String) {
186    let mut map = HashMap::new();
187    let trimmed = content.trim_start();
188    if !trimmed.starts_with("---") {
189        return (map, content.to_string());
190    }
191    let after_first = &trimmed[3..];
192    if let Some(end_idx) = after_first.find("\n---") {
193        let yaml = &after_first[..end_idx];
194        let body = after_first[end_idx + 4..].to_string();
195        for line in yaml.lines() {
196            let line = line.trim();
197            if line.is_empty() {
198                continue;
199            }
200            if let Some((key, value)) = line.split_once(':') {
201                map.insert(key.trim().to_string(), value.trim().to_string());
202            }
203        }
204        return (map, body);
205    }
206    (map, content.to_string())
207}
208
209// ── Result Types ───────────────────────────────────────────────────────
210
211#[derive(Debug, Clone, Serialize, Deserialize, Default)]
212/// UsageStats.
213pub struct UsageStats {
214    /// pub.
215    pub input_tokens: u64,
216    /// pub.
217    pub output_tokens: u64,
218    /// pub.
219    pub cache_read: u64,
220    /// pub.
221    pub cache_write: u64,
222    /// pub.
223    pub cost: f64,
224    /// pub.
225    pub turns: u32,
226}
227
228#[derive(Debug, Clone)]
229/// SingleResult.
230pub struct SingleResult {
231    /// pub.
232    pub agent: String,
233    /// pub.
234    pub agent_source: String,
235    /// pub.
236    pub task: String,
237    /// pub.
238    pub exit_code: i32,
239    /// pub.
240    pub output: String,
241    /// pub.
242    pub stderr: String,
243    /// pub.
244    pub usage: UsageStats,
245    /// pub.
246    pub model: Option<String>,
247    /// pub.
248    pub stop_reason: Option<String>,
249    /// pub.
250    pub error_message: Option<String>,
251    /// pub.
252    pub step: Option<usize>,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
256#[serde(rename_all = "snake_case")]
257/// SubagentMode.
258pub enum SubagentMode {
259    /// single variant.
260    Single,
261    /// parallel variant.
262    Parallel,
263    /// chain variant.
264    Chain,
265}
266
267#[derive(Debug, Clone)]
268/// SubagentDetails.
269pub struct SubagentDetails {
270    /// pub.
271    pub mode: SubagentMode,
272    /// pub.
273    pub results: Vec<SingleResult>,
274}
275
276// ── JSON line processing ───────────────────────────────────────────────
277
278fn process_json_line(
279    line: &str,
280    result: &mut SingleResult,
281    text: &mut String,
282    _on_progress: &Option<ProgressFn>,
283) {
284    let event: Value = match serde_json::from_str(line) {
285        Ok(v) => v,
286        Err(_) => return,
287    };
288    match event["type"].as_str().unwrap_or("") {
289        "text_delta" => {
290            if let Some(t) = event["text"].as_str() {
291                text.push_str(t);
292            }
293        }
294        "usage" => {
295            result.usage.input_tokens += event["input_tokens"].as_u64().unwrap_or(0);
296            result.usage.output_tokens += event["output_tokens"].as_u64().unwrap_or(0);
297            result.usage.turns += 1;
298        }
299        "complete" => {
300            result.stop_reason = Some("complete".to_string());
301        }
302        "error" => {
303            result.error_message = Some(
304                event["message"]
305                    .as_str()
306                    .unwrap_or("Unknown error")
307                    .to_string(),
308            );
309            result.stop_reason = Some("error".to_string());
310        }
311        _ => {}
312    }
313}
314
315// ── Process Execution ──────────────────────────────────────────────────
316
317/// Build command-line arguments for launching a subagent process.
318fn build_agent_args(agent: &AgentConfig, tmp_dir: &Path, task: &str) -> Vec<String> {
319    let mut args = vec!["--mode".to_string(), "json".to_string(), "-p".to_string()];
320
321    if let Some(ref model) = agent.model {
322        args.push("--model".to_string());
323        args.push(model.clone());
324    }
325
326    if let Some(ref agent_tools) = agent.tools {
327        if !agent_tools.is_empty() {
328            args.push("--tools".to_string());
329            args.push(agent_tools.join(","));
330        }
331    }
332
333    if !agent.system_prompt.is_empty()
334        && std::fs::write(tmp_dir.join("system_prompt.md"), &agent.system_prompt).is_ok()
335    {
336        args.push("--append-system-prompt".to_string());
337        args.push(
338            tmp_dir
339                .join("system_prompt.md")
340                .to_str()
341                .unwrap_or_default()
342                .to_string(),
343        );
344    }
345
346    args.push(format!("Task: {}", task));
347    args
348}
349
350/// Gracefully terminate a child process (SIGTERM → wait → SIGKILL).
351async fn terminate_child(
352    child: &mut tokio::process::Child,
353    stderr_handle: tokio::task::JoinHandle<String>,
354    result: &mut SingleResult,
355) {
356    #[cfg(unix)]
357    {
358        if let Some(pid) = child.id() {
359            unsafe {
360                libc::kill(pid as i32, libc::SIGTERM);
361            }
362        }
363        let deadline = tokio::time::sleep(std::time::Duration::from_secs(5));
364        tokio::pin!(deadline);
365        tokio::select! {
366            _ = &mut deadline => { let _ = child.start_kill(); }
367            _ = child.wait() => {}
368        }
369    }
370    #[cfg(not(unix))]
371    {
372        let _ = child.start_kill();
373        let _ = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await;
374    }
375
376    // Collect stderr with short timeout
377    let _ = tokio::time::timeout(std::time::Duration::from_secs(1), async {
378        if let Ok(err) = stderr_handle.await {
379            result.stderr = err;
380        }
381    })
382    .await;
383}
384
385/// Run a single agent process with abort support.
386#[allow(clippy::too_many_arguments)]
387async fn run_single_agent(
388    cwd: &Path,
389    agents: &[AgentConfig],
390    agent_name: &str,
391    task: &str,
392    agent_cwd: Option<&str>,
393    step: Option<usize>,
394    signal: Option<oneshot::Receiver<()>>,
395    on_progress: Option<ProgressFn>,
396    binary_path: &Path,
397) -> SingleResult {
398    let agent = match agents.iter().find(|a| a.name == agent_name) {
399        Some(a) => a,
400        None => {
401            let available = agents
402                .iter()
403                .map(|a| format!("\"{}\"", a.name))
404                .collect::<Vec<_>>()
405                .join(", ");
406            return SingleResult {
407                agent: agent_name.to_string(),
408                agent_source: "unknown".to_string(),
409                task: task.to_string(),
410                exit_code: 1,
411                output: String::new(),
412                stderr: format!(
413                    "Unknown agent: \"{}\". Available: {}",
414                    agent_name, available
415                ),
416                usage: UsageStats::default(),
417                model: None,
418                stop_reason: None,
419                error_message: Some(format!("Unknown agent: {}", agent_name)),
420                step,
421            };
422        }
423    };
424
425    let mut result = SingleResult {
426        agent: agent_name.to_string(),
427        agent_source: agent.source.clone(),
428        task: task.to_string(),
429        exit_code: 0,
430        output: String::new(),
431        stderr: String::new(),
432        usage: UsageStats::default(),
433        model: agent.model.clone(),
434        stop_reason: None,
435        error_message: None,
436        step,
437    };
438
439    // Notify progress
440    if let Some(ref cb) = on_progress {
441        cb(format!("[{}] running...", agent_name));
442    }
443
444    // Build command args
445    let tmp_dir = match create_system_prompt_temp_dir("oxi-subagent") {
446        Ok(tmp) => Some(tmp),
447        Err(e) => {
448            result.exit_code = 1;
449            result.stderr = e.clone();
450            result.error_message = Some(e);
451            return result;
452        }
453    };
454
455    let args = match tmp_dir {
456        Some(ref tmp) => build_agent_args(agent, tmp, task),
457        None => vec![
458            "--mode".to_string(),
459            "json".to_string(),
460            "-p".to_string(),
461            format!("Task: {}", task),
462        ],
463    };
464
465    let working_dir = agent_cwd
466        .map(PathBuf::from)
467        .unwrap_or_else(|| cwd.to_path_buf());
468
469    let mut cmd = tokio::process::Command::new(binary_path);
470    cmd.args(&args)
471        .current_dir(&working_dir)
472        .stdout(std::process::Stdio::piped())
473        .stderr(std::process::Stdio::piped())
474        .stdin(std::process::Stdio::null());
475
476    let mut child = match cmd.spawn() {
477        Ok(c) => c,
478        Err(e) => {
479            result.exit_code = 1;
480            result.stderr = format!("Failed to spawn: {}", e);
481            result.error_message = Some(format!("Failed to spawn: {}", e));
482            return result;
483        }
484    };
485
486    let stdout = child.stdout.take().expect("stdout piped but missing");
487    let stderr = child.stderr.take().expect("stderr piped but missing");
488
489    // Spawn stdout reader → channel
490    let (line_tx, mut line_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
491    let _reader_handle = tokio::spawn(async move {
492        let reader = BufReader::new(stdout);
493        let mut lines = reader.lines();
494        while let Ok(Some(line)) = lines.next_line().await {
495            if line_tx.send(line).is_err() {
496                break;
497            }
498        }
499    });
500
501    // Spawn stderr reader
502    let stderr_handle = tokio::spawn(async move {
503        let mut err = String::new();
504        let reader = BufReader::new(stderr);
505        let mut lines = reader.lines();
506        while let Ok(Some(line)) = lines.next_line().await {
507            err.push_str(&line);
508            err.push('\n');
509        }
510        err
511    });
512
513    // Main loop: select between stdout lines and abort signal
514    let mut final_text = String::new();
515    let mut signal_rx = signal;
516    let mut aborted = false;
517
518    loop {
519        tokio::select! {
520            line = line_rx.recv() => {
521                match line {
522                    Some(line) => {
523                        process_json_line(&line, &mut result, &mut final_text, &on_progress);
524                    }
525                    None => break, // stdout EOF
526                }
527            }
528            _ = async {
529                match &mut signal_rx {
530                    Some(rx) => { let _ = rx.await; }
531                    None => std::future::pending::<()>().await,
532                }
533            } => {
534                aborted = true;
535                break;
536            }
537        }
538    }
539
540    if aborted {
541        result.stop_reason = Some("aborted".into());
542        result.error_message = Some("Aborted by user".into());
543        terminate_child(&mut child, stderr_handle, &mut result).await;
544    } else {
545        // Normal completion
546        if let Ok(err_output) = stderr_handle.await {
547            result.stderr = err_output;
548        }
549        match child.wait().await {
550            Ok(status) => result.exit_code = status.code().unwrap_or(1),
551            Err(_) => result.exit_code = 1,
552        }
553    }
554
555    result.output = final_text;
556
557    if let Some(ref cb) = on_progress {
558        let status = if result.exit_code == 0 {
559            "done"
560        } else {
561            "failed"
562        };
563        cb(format!("[{}] {}", agent_name, status));
564    }
565
566    result
567}
568
569/// Run multiple tasks with concurrency limit.
570async fn run_parallel(
571    cwd: &Path,
572    agents: &[AgentConfig],
573    tasks: Vec<ParallelTask>,
574    binary_path: PathBuf,
575    on_progress: Option<ProgressFn>,
576) -> Vec<SingleResult> {
577    let n = tasks.len();
578    if n == 0 {
579        return vec![];
580    }
581
582    let limit = MAX_CONCURRENCY.min(n);
583    let indexed_tasks: Vec<(usize, ParallelTask)> = tasks.into_iter().enumerate().collect();
584    let mut all_results: Vec<Option<SingleResult>> = vec![None; n];
585
586    let mut i = 0;
587    while i < indexed_tasks.len() {
588        let end = (i + limit).min(indexed_tasks.len());
589        let chunk: Vec<_> = indexed_tasks[i..end].to_vec();
590        let mut handles = Vec::new();
591
592        for (idx, task) in chunk {
593            let agents = agents.to_vec();
594            let cwd = cwd.to_path_buf();
595            let bp = binary_path.clone();
596            let progress = on_progress.clone();
597
598            handles.push((
599                idx,
600                tokio::spawn(async move {
601                    run_single_agent(
602                        &cwd,
603                        &agents,
604                        &task.agent,
605                        &task.task,
606                        task.cwd.as_deref(),
607                        None,
608                        None,
609                        progress,
610                        &bp,
611                    )
612                    .await
613                }),
614            ));
615        }
616
617        for (idx, handle) in handles {
618            if let Ok(r) = handle.await {
619                all_results[idx] = Some(r);
620            }
621        }
622
623        i = end;
624    }
625
626    all_results
627        .into_iter()
628        .map(|r| {
629            r.unwrap_or_else(|| SingleResult {
630                agent: "unknown".to_string(),
631                agent_source: "unknown".to_string(),
632                task: "unknown".to_string(),
633                exit_code: 1,
634                output: String::new(),
635                stderr: "Task did not complete".to_string(),
636                usage: UsageStats::default(),
637                model: None,
638                stop_reason: Some("error".to_string()),
639                error_message: Some("Task did not complete".to_string()),
640                step: None,
641            })
642        })
643        .collect()
644}
645
646// ── Parameter Types ────────────────────────────────────────────────────
647
648#[derive(Debug, Deserialize, Clone)]
649struct ParallelTask {
650    agent: String,
651    task: String,
652    #[serde(default)]
653    cwd: Option<String>,
654}
655
656#[derive(Debug, Deserialize)]
657struct ChainStep {
658    agent: String,
659    task: String,
660    #[serde(default)]
661    cwd: Option<String>,
662}
663
664// ── Tool Implementation ────────────────────────────────────────────────
665
666/// SubagentTool.
667pub struct SubagentTool {
668    /// Explicit working directory override. If None, uses ToolContext.root() at runtime.
669    cwd: Option<PathBuf>,
670    binary_path: Option<PathBuf>,
671    progress_callback: parking_lot::Mutex<Option<ProgressFn>>,
672}
673
674impl Default for SubagentTool {
675    fn default() -> Self {
676        Self::new()
677    }
678}
679
680impl SubagentTool {
681    /// Create with no explicit root (uses ToolContext.root() at runtime).
682    pub fn new() -> Self {
683        Self {
684            cwd: None,
685            binary_path: None,
686            progress_callback: parking_lot::Mutex::new(None),
687        }
688    }
689
690    /// Create with an explicit working directory (overrides ToolContext).
691    pub fn with_cwd(cwd: impl Into<PathBuf>) -> Self {
692        Self {
693            cwd: Some(cwd.into()),
694            binary_path: None,
695            progress_callback: parking_lot::Mutex::new(None),
696        }
697    }
698
699    fn get_binary(&self) -> PathBuf {
700        self.binary_path
701            .clone()
702            .or_else(|| std::env::current_exe().ok())
703            .unwrap_or_else(|| PathBuf::from("oxi"))
704    }
705}
706
707#[async_trait]
708impl AgentTool for SubagentTool {
709    fn name(&self) -> &str {
710        "subagent"
711    }
712
713    fn label(&self) -> &str {
714        "Subagent"
715    }
716
717    fn description(&self) -> &str {
718        "Delegate tasks to specialized subagents with isolated context. \
719         Modes: single (agent + task), parallel (tasks array), chain (sequential with {previous} placeholder). \
720         Agents are discovered from ~/.oxi/agents/ (user) and .oxi/agents/ (project)."
721    }
722
723    fn parameters_schema(&self) -> Value {
724        json!({
725            "type": "object",
726            "properties": {
727                "agent": {
728                    "type": "string",
729                    "description": "Agent name for single mode"
730                },
731                "task": {
732                    "type": "string",
733                    "description": "Task to delegate (single mode)"
734                },
735                "tasks": {
736                    "type": "array",
737                    "description": "Array of {agent, task} for parallel execution (max 8)",
738                    "items": {
739                        "type": "object",
740                        "properties": {
741                            "agent": { "type": "string" },
742                            "task": { "type": "string" },
743                            "cwd": { "type": "string" }
744                        },
745                        "required": ["agent", "task"]
746                    }
747                },
748                "chain": {
749                    "type": "array",
750                    "description": "Array of {agent, task} for sequential execution. Use {previous} in task for prior output.",
751                    "items": {
752                        "type": "object",
753                        "properties": {
754                            "agent": { "type": "string" },
755                            "task": { "type": "string" },
756                            "cwd": { "type": "string" }
757                        },
758                        "required": ["agent", "task"]
759                    }
760                },
761                "agentScope": {
762                    "type": "string",
763                    "description": "Agent discovery scope: 'user' (default), 'project', or 'both'",
764                    "enum": ["user", "project", "both"],
765                    "default": "user"
766                },
767                "cwd": {
768                    "type": "string",
769                    "description": "Working directory for single mode"
770                }
771            }
772        })
773    }
774
775    fn on_progress(&self, callback: ProgressCallback) {
776        *self.progress_callback.lock() = Some(callback);
777    }
778
779    async fn execute(
780        &self,
781        _tool_call_id: &str,
782        params: Value,
783        signal: Option<oneshot::Receiver<()>>,
784        ctx: &ToolContext,
785    ) -> Result<AgentToolResult, ToolError> {
786        // Use explicit cwd if set, else ctx.root()
787        let effective_cwd = self.cwd.as_deref().unwrap_or(ctx.root());
788
789        let scope: AgentScope = params
790            .get("agentScope")
791            .and_then(|v| serde_json::from_value(v.clone()).ok())
792            .unwrap_or(AgentScope::User);
793
794        let agents = discover_agents(effective_cwd, scope);
795        let binary = self.get_binary();
796        let progress = self.progress_callback.lock().clone();
797
798        let has_chain = params["chain"]
799            .as_array()
800            .map(|a| !a.is_empty())
801            .unwrap_or(false);
802        let has_tasks = params["tasks"]
803            .as_array()
804            .map(|a| !a.is_empty())
805            .unwrap_or(false);
806        let has_single = params["agent"].is_string() && params["task"].is_string();
807
808        let mode_count = [has_chain, has_tasks, has_single]
809            .iter()
810            .filter(|&&x| x)
811            .count();
812
813        if mode_count != 1 {
814            let available = agents
815                .iter()
816                .map(|a| format!("{} ({})", a.name, a.source))
817                .collect::<Vec<_>>()
818                .join(", ");
819            return Ok(AgentToolResult::error(format!(
820                "Provide exactly one mode: agent+task, tasks, or chain.\nAvailable agents: {}",
821                if available.is_empty() {
822                    "none".to_string()
823                } else {
824                    available
825                }
826            )));
827        }
828
829        // ── Chain mode ──
830        if has_chain {
831            return execute_chain_mode(effective_cwd, &agents, params, &binary, progress, signal)
832                .await;
833        }
834
835        // ── Parallel mode ──
836        if has_tasks {
837            return execute_parallel_mode(effective_cwd, &agents, params, &binary, progress).await;
838        }
839
840        // ── Single mode ──
841        if has_single {
842            return execute_single_mode(effective_cwd, &agents, params, &binary, progress, signal)
843                .await;
844        }
845
846        Ok(AgentToolResult::error("Invalid parameters".to_string()))
847    }
848}
849
850/// Execute chain mode: sequential agents where each step can reference {previous} output.
851async fn execute_chain_mode(
852    cwd: &Path,
853    agents: &[AgentConfig],
854    params: Value,
855    binary: &Path,
856    progress: Option<ProgressFn>,
857    signal: Option<oneshot::Receiver<()>>,
858) -> Result<AgentToolResult, ToolError> {
859    let steps: Vec<ChainStep> = serde_json::from_value(params["chain"].clone())
860        .map_err(|e| format!("Invalid chain parameter: {}", e))?;
861    let total = steps.len();
862    let mut results = Vec::new();
863    let mut previous_output = String::new();
864    let mut abort_signal = signal;
865
866    for (i, step) in steps.into_iter().enumerate() {
867        let task = step.task.replace("{previous}", &previous_output);
868        let step_signal = if i == total - 1 {
869            abort_signal.take()
870        } else {
871            None
872        };
873
874        let result = run_single_agent(
875            cwd,
876            agents,
877            &step.agent,
878            &task,
879            step.cwd.as_deref(),
880            Some(i + 1),
881            step_signal,
882            progress.clone(),
883            binary,
884        )
885        .await;
886
887        let is_error = result.exit_code != 0
888            || result.stop_reason.as_deref() == Some("error")
889            || result.stop_reason.as_deref() == Some("aborted");
890
891        if is_error {
892            let agent_name = result.agent.clone();
893            let error_msg = result
894                .error_message
895                .clone()
896                .unwrap_or_else(|| result.stderr.clone());
897            results.push(result);
898            return Ok(AgentToolResult::error(format!(
899                "Chain stopped at step {}/{} ({}): {}",
900                i + 1,
901                total,
902                agent_name,
903                error_msg
904            )));
905        }
906
907        previous_output = result.output.clone();
908        results.push(result);
909    }
910
911    let output = results.last().map(|r| r.output.clone()).unwrap_or_default();
912    Ok(AgentToolResult::success(if output.is_empty() {
913        "(no output)".to_string()
914    } else {
915        output
916    })
917    .with_metadata(json!({
918        "mode": "chain",
919        "steps": results.len(),
920    })))
921}
922
923/// Execute parallel mode: multiple agents running concurrently.
924async fn execute_parallel_mode(
925    cwd: &Path,
926    agents: &[AgentConfig],
927    params: Value,
928    binary: &Path,
929    progress: Option<ProgressFn>,
930) -> Result<AgentToolResult, ToolError> {
931    let tasks: Vec<ParallelTask> = serde_json::from_value(params["tasks"].clone())
932        .map_err(|e| format!("Invalid tasks parameter: {}", e))?;
933
934    if tasks.len() > MAX_PARALLEL_TASKS {
935        return Ok(AgentToolResult::error(format!(
936            "Too many parallel tasks ({}). Max is {}.",
937            tasks.len(),
938            MAX_PARALLEL_TASKS
939        )));
940    }
941
942    let results = run_parallel(cwd, agents, tasks, binary.to_path_buf(), progress).await;
943
944    let success_count = results.iter().filter(|r| r.exit_code == 0).count();
945    let summaries: Vec<String> = results
946        .iter()
947        .map(|r| {
948            let _preview = truncate_output(&r.output, 100);
949            format!(
950                "[{}]: {}",
951                r.agent,
952                if r.exit_code == 0 {
953                    "completed"
954                } else {
955                    "failed"
956                },
957            )
958        })
959        .collect();
960
961    Ok(AgentToolResult::success(format!(
962        "Parallel: {}/{} succeeded\n\n{}",
963        success_count,
964        results.len(),
965        summaries.join("\n\n")
966    ))
967    .with_metadata(json!({
968        "mode": "parallel",
969        "results": results.iter().map(|r| json!({
970            "agent": r.agent,
971            "exit_code": r.exit_code,
972        })).collect::<Vec<_>>()
973    })))
974}
975
976/// Execute single mode: one agent, one task.
977async fn execute_single_mode(
978    cwd: &Path,
979    agents: &[AgentConfig],
980    params: Value,
981    binary: &Path,
982    progress: Option<ProgressFn>,
983    signal: Option<oneshot::Receiver<()>>,
984) -> Result<AgentToolResult, ToolError> {
985    let agent_name = params["agent"]
986        .as_str()
987        .ok_or("Missing required parameter: agent")?;
988    let task = params["task"]
989        .as_str()
990        .ok_or("Missing required parameter: task")?;
991    let agent_cwd = params["cwd"].as_str();
992
993    let result = run_single_agent(
994        cwd, agents, agent_name, task, agent_cwd, None, signal, progress, binary,
995    )
996    .await;
997
998    let is_error = result.exit_code != 0
999        || result.stop_reason.as_deref() == Some("error")
1000        || result.stop_reason.as_deref() == Some("aborted");
1001
1002    if is_error {
1003        let error_msg = result.error_message.as_deref().unwrap_or(&result.stderr);
1004        return Ok(AgentToolResult::error(format!(
1005            "Agent {}: {}",
1006            result.stop_reason.as_deref().unwrap_or("failed"),
1007            error_msg
1008        )));
1009    }
1010
1011    Ok(AgentToolResult::success(if result.output.is_empty() {
1012        "(no output)".to_string()
1013    } else {
1014        result.output.clone()
1015    })
1016    .with_metadata(json!({
1017        "mode": "single",
1018        "agent": result.agent,
1019        "source": result.agent_source,
1020        "usage": {
1021            "input_tokens": result.usage.input_tokens,
1022            "output_tokens": result.usage.output_tokens,
1023            "turns": result.usage.turns,
1024        },
1025    })))
1026}
1027
1028// ── Helpers ────────────────────────────────────────────────────────────
1029
1030fn truncate_output(text: &str, max_chars: usize) -> String {
1031    if text.len() <= max_chars {
1032        text.to_string()
1033    } else {
1034        format!("{}...", &text[..max_chars])
1035    }
1036}
1037
1038// ── Tests ──────────────────────────────────────────────────────────────
1039
1040#[cfg(test)]
1041mod tests {
1042    use super::*;
1043
1044    #[test]
1045    fn test_parse_frontmatter_with_yaml() {
1046        let content = "---\nname: scout\ndescription: Fast recon\nmodel: haiku\ntools: read, grep\n---\nYou are a scout agent.";
1047        let (fm, body) = parse_frontmatter(content);
1048        assert_eq!(fm.get("name").unwrap(), "scout");
1049        assert_eq!(fm.get("description").unwrap(), "Fast recon");
1050        assert_eq!(fm.get("model").unwrap(), "haiku");
1051        assert_eq!(fm.get("tools").unwrap(), "read, grep");
1052        assert!(body.trim().starts_with("You are a scout agent."));
1053    }
1054
1055    #[test]
1056    fn test_parse_frontmatter_no_yaml() {
1057        let content = "Just a plain system prompt.";
1058        let (fm, body) = parse_frontmatter(content);
1059        assert!(fm.is_empty());
1060        assert_eq!(body.trim(), "Just a plain system prompt.");
1061    }
1062
1063    #[test]
1064    fn test_parse_agent_file() {
1065        let tmp = tempfile::tempdir().unwrap();
1066        let file_path = tmp.path().join("scout.md");
1067        std::fs::write(
1068            &file_path,
1069            "---\nname: scout\ndescription: Fast recon\n---\nYou are a scout.",
1070        )
1071        .unwrap();
1072        let config = parse_agent_file(&file_path).unwrap();
1073        assert_eq!(config.name, "scout");
1074        assert_eq!(config.description, "Fast recon");
1075        assert_eq!(config.system_prompt, "You are a scout.");
1076    }
1077
1078    #[test]
1079    fn test_parse_agent_file_no_frontmatter() {
1080        let tmp = tempfile::tempdir().unwrap();
1081        let file_path = tmp.path().join("worker.md");
1082        std::fs::write(&file_path, "You are a worker agent.").unwrap();
1083        let config = parse_agent_file(&file_path).unwrap();
1084        assert_eq!(config.name, "worker");
1085        assert_eq!(config.system_prompt, "You are a worker agent.");
1086    }
1087
1088    #[test]
1089    fn test_discover_agents_empty_dir() {
1090        let tmp = tempfile::tempdir().unwrap();
1091        let agents = discover_agents(tmp.path(), AgentScope::User);
1092        assert!(agents.is_empty());
1093    }
1094
1095    #[test]
1096    fn test_discover_agents_with_files() {
1097        let tmp = tempfile::tempdir().unwrap();
1098        let agents_dir = tmp.path().join(".oxi").join("agents");
1099        std::fs::create_dir_all(&agents_dir).unwrap();
1100        std::fs::write(
1101            agents_dir.join("scout.md"),
1102            "---\nname: scout\ndescription: Recon\n---\nBe a scout.",
1103        )
1104        .unwrap();
1105        std::fs::write(
1106            agents_dir.join("worker.md"),
1107            "---\nname: worker\n---\nBe a worker.",
1108        )
1109        .unwrap();
1110        std::fs::write(agents_dir.join("ignore.txt"), "ignore me").unwrap();
1111        let agents = discover_agents(tmp.path(), AgentScope::Project);
1112        assert_eq!(agents.len(), 2);
1113        assert!(agents.iter().any(|a| a.name == "scout"));
1114        assert!(agents.iter().any(|a| a.name == "worker"));
1115    }
1116
1117    #[test]
1118    fn test_find_project_agents_dir() {
1119        let tmp = tempfile::tempdir().unwrap();
1120        let agents_dir = tmp.path().join(".oxi").join("agents");
1121        std::fs::create_dir_all(&agents_dir).unwrap();
1122        let git_dir = tmp.path().join(".git");
1123        std::fs::create_dir_all(&git_dir).unwrap();
1124        let sub = tmp.path().join("subdir");
1125        std::fs::create_dir_all(&sub).unwrap();
1126        // From subdirectory, should walk up to find .oxi/agents
1127        assert_eq!(find_project_agents_dir(&sub), Some(agents_dir));
1128    }
1129
1130    #[test]
1131    fn test_find_project_agents_dir_stops_at_git() {
1132        let tmp = tempfile::tempdir().unwrap();
1133        let git_dir = tmp.path().join(".git");
1134        std::fs::create_dir_all(&git_dir).unwrap();
1135        // No .oxi/agents, .git exists → None
1136        assert_eq!(find_project_agents_dir(tmp.path()), None);
1137    }
1138
1139    #[test]
1140    fn test_agent_scope_default() {
1141        assert_eq!(AgentScope::default(), AgentScope::User);
1142    }
1143
1144    #[test]
1145    fn test_tools_parsing() {
1146        let tmp = tempfile::tempdir().unwrap();
1147        let file_path = tmp.path().join("agent.md");
1148        std::fs::write(
1149            &file_path,
1150            "---\ntools: read, grep, find, ls\n---\nSystem prompt.",
1151        )
1152        .unwrap();
1153        let config = parse_agent_file(&file_path).unwrap();
1154        let tools = config.tools.unwrap();
1155        assert_eq!(tools, vec!["read", "grep", "find", "ls"]);
1156    }
1157
1158    #[test]
1159    fn test_schema_structure() {
1160        let tool = SubagentTool::new();
1161        let schema = tool.parameters_schema();
1162        assert_eq!(schema["type"], "object");
1163        assert!(schema["properties"]["agent"].is_object());
1164        assert!(schema["properties"]["tasks"].is_object());
1165        assert!(schema["properties"]["chain"].is_object());
1166        assert!(schema["properties"]["agentScope"].is_object());
1167    }
1168
1169    #[test]
1170    fn test_truncate_output() {
1171        assert_eq!(truncate_output("hello", 10), "hello");
1172        assert_eq!(truncate_output("hello world foo", 5), "hello...");
1173    }
1174
1175    #[test]
1176    fn test_process_json_line_text_delta() {
1177        let mut result = SingleResult {
1178            agent: "test".into(),
1179            agent_source: "user".into(),
1180            task: "t".into(),
1181            exit_code: 0,
1182            output: String::new(),
1183            stderr: String::new(),
1184            usage: UsageStats::default(),
1185            model: None,
1186            stop_reason: None,
1187            error_message: None,
1188            step: None,
1189        };
1190        let mut text = String::new();
1191        process_json_line(
1192            r#"{"type":"text_delta","text":"hello"}"#,
1193            &mut result,
1194            &mut text,
1195            &None,
1196        );
1197        assert_eq!(text, "hello");
1198    }
1199
1200    #[test]
1201    fn test_process_json_line_usage() {
1202        let mut result = SingleResult {
1203            agent: "test".into(),
1204            agent_source: "user".into(),
1205            task: "t".into(),
1206            exit_code: 0,
1207            output: String::new(),
1208            stderr: String::new(),
1209            usage: UsageStats::default(),
1210            model: None,
1211            stop_reason: None,
1212            error_message: None,
1213            step: None,
1214        };
1215        let mut text = String::new();
1216        process_json_line(
1217            r#"{"type":"usage","input_tokens":100,"output_tokens":50}"#,
1218            &mut result,
1219            &mut text,
1220            &None,
1221        );
1222        assert_eq!(result.usage.input_tokens, 100);
1223        assert_eq!(result.usage.output_tokens, 50);
1224        assert_eq!(result.usage.turns, 1);
1225    }
1226}