Skip to main content

agent_base/tool/
subagent.rs

1use async_trait::async_trait;
2use serde_json::{json, Value};
3use tokio::sync::Mutex;
4
5use crate::engine::AgentRuntime;
6use crate::types::{AgentError, AgentEvent, AgentResult, SessionId};
7use super::{Tool, ToolContext, ToolControlFlow, ToolOutput};
8
9/// Sub-Agent session policy
10#[derive(Clone, Debug)]
11pub enum SubAgentSessionPolicy {
12    /// Create a new session per call (default)
13    Ephemeral,
14    /// Reuse the same session; sub-agent accumulates history
15    Persistent,
16}
17
18pub struct SubAgentTool {
19    name: &'static str,
20    description: &'static str,
21    sub_runtime: Mutex<AgentRuntime>,
22    sub_session_id: Mutex<Option<SessionId>>,
23    session_policy: SubAgentSessionPolicy,
24}
25
26impl SubAgentTool {
27    pub fn new(
28        name: &'static str,
29        description: &'static str,
30        mut sub_runtime: AgentRuntime,
31    ) -> Self {
32        let sub_session_id = sub_runtime.create_session();
33        Self {
34            name,
35            description,
36            sub_runtime: Mutex::new(sub_runtime),
37            sub_session_id: Mutex::new(Some(sub_session_id)),
38            session_policy: SubAgentSessionPolicy::Ephemeral,
39        }
40    }
41
42    pub fn with_persistent(
43        name: &'static str,
44        description: &'static str,
45        mut sub_runtime: AgentRuntime,
46    ) -> Self {
47        let sub_session_id = sub_runtime.create_session();
48        Self {
49            name,
50            description,
51            sub_runtime: Mutex::new(sub_runtime),
52            sub_session_id: Mutex::new(Some(sub_session_id)),
53            session_policy: SubAgentSessionPolicy::Persistent,
54        }
55    }
56}
57
58#[async_trait]
59impl Tool for SubAgentTool {
60    fn name(&self) -> &'static str {
61        self.name
62    }
63
64    fn definition(&self) -> Value {
65        json!({
66            "type": "function",
67            "function": {
68                "name": self.name,
69                "description": self.description,
70                "parameters": {
71                    "type": "object",
72                    "properties": {
73                        "task": {
74                            "type": "string",
75                            "description": "Task description to delegate to the sub-agent"
76                        }
77                    },
78                    "required": ["task"]
79                }
80            }
81        })
82    }
83
84    async fn call(&self, args: &Value, ctx: &ToolContext) -> AgentResult<ToolOutput> {
85        let task = args
86            .get("task")
87            .and_then(Value::as_str)
88            .ok_or_else(|| AgentError::ToolArgsInvalid {
89                name: self.name.to_string(),
90                raw: args.to_string(),
91            })?;
92
93        if task.is_empty() {
94            return Ok(ToolOutput {
95                summary: "Task description is empty, cannot execute".to_string(),
96                raw: None,
97                control_flow: ToolControlFlow::Break,
98                truncated: false,
99            });
100        }
101
102        let parent_event_bus = ctx.event_bus.clone();
103        let parent_session_id = ctx.session_id.clone();
104
105        let sub_session_id = match self.session_policy {
106            SubAgentSessionPolicy::Ephemeral => {
107                let mut runtime = self.sub_runtime.lock().await;
108                let new_id = runtime.create_session();
109                let mut sid_guard = self.sub_session_id.lock().await;
110                *sid_guard = Some(new_id.clone());
111                new_id
112            }
113            SubAgentSessionPolicy::Persistent => {
114                let sid_guard = self.sub_session_id.lock().await;
115                sid_guard.clone().expect("sub session not initialized")
116            }
117        };
118
119        let (events, _outcome) = {
120            let mut runtime = self.sub_runtime.lock().await;
121            runtime
122                .run_turn_stream(sub_session_id, task)
123                .await
124                .map_err(|e| AgentError::ToolExecution {
125                    name: self.name.to_string(),
126                    source: Box::new(e),
127                })?
128        };
129
130        let mut final_text = String::new();
131        for event in &events {
132            match event {
133                AgentEvent::TextDelta { text, .. } => {
134                    final_text.push_str(text);
135                }
136                _ => {}
137            }
138            let _ = parent_event_bus.send(AgentEvent::Custom {
139                session_id: parent_session_id.clone(),
140                payload: json!({
141                    "type": "subagent_event",
142                    "subagent": self.name,
143                    "event": event_to_value(event),
144                }),
145            });
146        }
147
148        let summary = if final_text.is_empty() {
149            format!("Sub-agent [{}] finished", self.name)
150        } else {
151            final_text
152        };
153
154        Ok(ToolOutput {
155            summary,
156            raw: None,
157            control_flow: ToolControlFlow::Continue,
158            truncated: false,
159        })
160    }
161}
162
163fn event_to_value(event: &AgentEvent) -> Value {
164    match event {
165        AgentEvent::TextDelta { text, .. } => json!({"type": "TextDelta", "text": text}),
166        AgentEvent::ThoughtDelta { text, .. } => json!({"type": "ThoughtDelta", "text": text}),
167        AgentEvent::ToolCallStarted { tool_name, args_json, .. } => {
168            json!({"type": "ToolCallStarted", "tool_name": tool_name, "args_json": args_json})
169        }
170        AgentEvent::ToolCallFinished { tool_name, summary, .. } => {
171            json!({"type": "ToolCallFinished", "tool_name": tool_name, "summary": summary})
172        }
173        AgentEvent::AwaitingApproval { request, .. } => {
174            json!({"type": "AwaitingApproval", "title": request.title})
175        }
176        AgentEvent::Checkpoint { .. } => json!({"type": "Checkpoint"}),
177        AgentEvent::RunFinished { .. } => json!({"type": "RunFinished"}),
178        AgentEvent::Custom { payload, .. } => json!({"type": "Custom", "payload": payload}),
179    }
180}