1use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6
7use anyhow::Result;
8use tokio::sync::RwLock;
9use tokio::task::JoinHandle;
10use tracing::{debug, error, info};
11use uuid::Uuid;
12
13use agent_diva_core::bus::{InboundMessage, MessageBus};
14use agent_diva_core::utils::truncate;
15use agent_diva_providers::base::{LLMProvider, Message};
16use agent_diva_tooling::ToolRegistry;
17
18use crate::tool_assembly::ToolAssembly;
19use crate::tool_config::builtin::BuiltInToolsConfig;
20use crate::tool_config::network::NetworkToolConfig;
21use agent_diva_core::config::MCPServerConfig;
22
23pub struct SubagentManager {
29 provider: Arc<dyn LLMProvider>,
30 workspace: PathBuf,
31 bus: MessageBus,
32 model: String,
33 builtin_tools: BuiltInToolsConfig,
34 network_config: Arc<RwLock<NetworkToolConfig>>,
35 exec_timeout: u64,
36 restrict_to_workspace: bool,
37 mcp_servers: Arc<RwLock<HashMap<String, MCPServerConfig>>>,
38 running_tasks: Arc<tokio::sync::Mutex<HashMap<String, JoinHandle<()>>>>,
39}
40
41impl SubagentManager {
42 #[allow(clippy::too_many_arguments)]
44 pub fn new(
45 provider: Arc<dyn LLMProvider>,
46 workspace: PathBuf,
47 bus: MessageBus,
48 model: Option<String>,
49 builtin_tools: BuiltInToolsConfig,
50 network_config: NetworkToolConfig,
51 exec_timeout: Option<u64>,
52 restrict_to_workspace: bool,
53 mcp_servers: HashMap<String, MCPServerConfig>,
54 ) -> Self {
55 let model = model.unwrap_or_else(|| provider.get_default_model());
56 let exec_timeout = exec_timeout.unwrap_or(30);
57
58 Self {
59 provider,
60 workspace,
61 bus,
62 model,
63 builtin_tools,
64 network_config: Arc::new(RwLock::new(network_config)),
65 exec_timeout,
66 restrict_to_workspace,
67 mcp_servers: Arc::new(RwLock::new(mcp_servers)),
68 running_tasks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
69 }
70 }
71
72 pub async fn update_network_config(&self, network_config: NetworkToolConfig) {
73 let mut guard = self.network_config.write().await;
74 *guard = network_config;
75 }
76
77 pub async fn update_mcp_servers(&self, mcp_servers: HashMap<String, MCPServerConfig>) {
78 let mut guard = self.mcp_servers.write().await;
79 *guard = mcp_servers;
80 }
81
82 pub async fn spawn(
93 &self,
94 task: String,
95 label: Option<String>,
96 origin_channel: String,
97 origin_chat_id: String,
98 ) -> Result<String> {
99 let task_id = Uuid::new_v4().to_string()[..8].to_string();
100 let display_label = label.unwrap_or_else(|| {
101 if task.len() > 30 {
102 let mut end = 30;
103 while !task.is_char_boundary(end) {
104 end -= 1;
105 }
106 format!("{}...", &task[..end])
107 } else {
108 task.clone()
109 }
110 });
111
112 let provider = Arc::clone(&self.provider);
113 let workspace = self.workspace.clone();
114 let bus = self.bus.clone();
115 let model = self.model.clone();
116 let builtin_tools = self.builtin_tools.clone();
117 let network_config = self.network_config.read().await.clone();
118 let exec_timeout = self.exec_timeout;
119 let restrict_to_workspace = self.restrict_to_workspace;
120 let mcp_servers = self.mcp_servers.read().await.clone();
121
122 let task_id_clone = task_id.clone();
123 let display_label_clone = display_label.clone();
124 let running_tasks = Arc::clone(&self.running_tasks);
125
126 let bg_task = tokio::spawn(async move {
128 Self::run_subagent(
129 task_id_clone.clone(),
130 task.clone(),
131 display_label_clone.clone(),
132 origin_channel,
133 origin_chat_id,
134 provider,
135 workspace,
136 bus.clone(),
137 model,
138 builtin_tools,
139 network_config,
140 exec_timeout,
141 restrict_to_workspace,
142 mcp_servers,
143 )
144 .await;
145
146 let mut tasks = running_tasks.lock().await;
148 tasks.remove(&task_id_clone);
149 });
150
151 let mut tasks = self.running_tasks.lock().await;
153 tasks.insert(task_id.clone(), bg_task);
154 drop(tasks);
155
156 info!("Spawned subagent [{}]: {}", task_id, display_label);
157 Ok(format!(
158 "Subagent [{}] started (id: {}). I'll notify you when it completes.",
159 display_label, task_id
160 ))
161 }
162
163 #[allow(clippy::too_many_arguments)]
165 async fn run_subagent(
166 task_id: String,
167 task: String,
168 label: String,
169 origin_channel: String,
170 origin_chat_id: String,
171 provider: Arc<dyn LLMProvider>,
172 workspace: PathBuf,
173 bus: MessageBus,
174 model: String,
175 builtin_tools: BuiltInToolsConfig,
176 network_config: NetworkToolConfig,
177 exec_timeout: u64,
178 restrict_to_workspace: bool,
179 mcp_servers: HashMap<String, MCPServerConfig>,
180 ) {
181 info!("Subagent [{}] starting task: {}", task_id, label);
182
183 let result = Self::execute_subagent_task(
184 &task_id,
185 &task,
186 &provider,
187 &workspace,
188 &model,
189 &builtin_tools,
190 &network_config,
191 exec_timeout,
192 restrict_to_workspace,
193 &mcp_servers,
194 )
195 .await;
196
197 let (final_result, status) = match result {
198 Ok(content) => {
199 info!("Subagent [{}] completed successfully", task_id);
200 (content, "ok")
201 }
202 Err(e) => {
203 let error_msg = format!("Error: {}", e);
204 error!("Subagent [{}] failed: {}", task_id, e);
205 (error_msg, "error")
206 }
207 };
208
209 Self::announce_result(
210 &task_id,
211 &label,
212 &task,
213 &final_result,
214 &origin_channel,
215 &origin_chat_id,
216 status,
217 &bus,
218 )
219 .await;
220 }
221
222 #[allow(clippy::too_many_arguments)]
224 async fn execute_subagent_task(
225 task_id: &str,
226 task: &str,
227 provider: &Arc<dyn LLMProvider>,
228 workspace: &Path,
229 model: &str,
230 builtin_tools: &BuiltInToolsConfig,
231 network_config: &NetworkToolConfig,
232 exec_timeout: u64,
233 restrict_to_workspace: bool,
234 mcp_servers: &HashMap<String, MCPServerConfig>,
235 ) -> Result<String> {
236 let tools: ToolRegistry = ToolAssembly::new(workspace.to_path_buf())
237 .builtin(builtin_tools.clone())
238 .with_network_config(network_config.clone())
239 .with_exec_timeout(exec_timeout)
240 .restrict_to_workspace(restrict_to_workspace)
241 .mcp_servers(mcp_servers.clone())
242 .build_subagent_registry();
243
244 let system_prompt = Self::build_subagent_prompt(task, workspace);
246 let mut messages = vec![
247 Message::system(system_prompt),
248 Message::user(task.to_string()),
249 ];
250
251 let max_iterations = 15;
253 let mut iteration = 0;
254 let mut final_result: Option<String> = None;
255
256 while iteration < max_iterations {
257 iteration += 1;
258
259 let response = provider
260 .chat(
261 messages.clone(),
262 Some(tools.get_definitions()),
263 Some(model.to_string()),
264 2000,
265 0.7,
266 )
267 .await?;
268
269 if response.has_tool_calls() {
270 messages.push(Message {
272 role: "assistant".to_string(),
273 content: response.content.clone().unwrap_or_default(),
274 name: None,
275 tool_call_id: None,
276 tool_calls: Some(response.tool_calls.clone()),
277 reasoning_content: response.reasoning_content.clone(),
278 thinking_blocks: None,
279 });
280
281 for tool_call in &response.tool_calls {
283 let args_json = serde_json::to_value(&tool_call.arguments)?;
284 let args_str = serde_json::to_string(&tool_call.arguments)?;
285 debug!(
286 "Subagent [{}] executing: {} with arguments: {}",
287 task_id, tool_call.name, args_str
288 );
289 let result = tools.execute(&tool_call.name, args_json).await;
290 messages.push(Message::tool(result, tool_call.id.clone()));
291 }
292 } else {
293 final_result = response.content;
294 break;
295 }
296 }
297
298 Ok(final_result
299 .unwrap_or_else(|| "Task completed but no final response was generated.".to_string()))
300 }
301
302 #[allow(clippy::too_many_arguments)]
304 async fn announce_result(
305 task_id: &str,
306 label: &str,
307 task: &str,
308 result: &str,
309 origin_channel: &str,
310 origin_chat_id: &str,
311 status: &str,
312 bus: &MessageBus,
313 ) {
314 let status_text = if status == "ok" {
315 "completed successfully"
316 } else {
317 "failed"
318 };
319
320 let announce_content = format!(
321 "[Subagent '{}' {}]\n\nTask: {}\n\nResult:\n{}\n\nSummarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like \"subagent\" or task IDs.",
322 label, status_text, task, result
323 );
324
325 let msg = InboundMessage::new(origin_channel, "subagent", origin_chat_id, announce_content);
328
329 if let Err(e) = bus.publish_inbound(msg) {
330 error!("Failed to announce subagent result: {}", e);
331 }
332
333 debug!(
334 "Subagent [{}] announced result to {}:{}",
335 task_id, origin_channel, origin_chat_id
336 );
337 }
338
339 fn build_subagent_prompt(task: &str, workspace: &Path) -> String {
341 let soul_summary = Self::build_identity_summary(workspace);
342 format!(
343 r#"# Subagent
344
345You are a subagent spawned by the main agent to complete a specific task.
346
347## Your Task
348{}
349
350## Inherited Identity
351{}
352
353## Rules
3541. Stay focused - complete only the assigned task, nothing else
3552. Your final response will be reported back to the main agent
3563. Do not initiate conversations or take on side tasks
3574. Be concise but informative in your findings
358
359## What You Can Do
360- Read and write files in the workspace
361- Execute shell commands
362- Search the web and fetch web pages
363- Complete the task thoroughly
364
365## What You Cannot Do
366- Send messages directly to users (no message tool available)
367- Spawn other subagents
368- Access the main agent's conversation history
369
370## Workspace
371Your workspace is at: {}
372
373When you have completed the task, provide a clear summary of your findings or actions."#,
374 task,
375 soul_summary,
376 workspace.display()
377 )
378 }
379
380 fn build_identity_summary(workspace: &Path) -> String {
381 let mut sections = Vec::new();
382 for file in ["SOUL.md", "IDENTITY.md", "USER.md"] {
383 let path = workspace.join(file);
384 let Ok(raw) = std::fs::read_to_string(path) else {
385 continue;
386 };
387 let trimmed = raw.trim();
388 if trimmed.is_empty() {
389 continue;
390 }
391 let content = if trimmed.chars().count() > 800 {
392 truncate(trimmed, 3200)
393 } else {
394 trimmed.to_string()
395 };
396 sections.push(format!("### {}\n{}", file, content));
397 }
398
399 if sections.is_empty() {
400 "No persisted soul identity found. Follow the task faithfully, remain concise, and preserve user intent.".to_string()
401 } else {
402 sections.join("\n\n")
403 }
404 }
405
406 pub async fn get_running_count(&self) -> usize {
408 let tasks = self.running_tasks.lock().await;
409 tasks.len()
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::SubagentManager;
416
417 #[test]
418 fn test_build_subagent_prompt_includes_identity_summary() {
419 let temp = tempfile::tempdir().unwrap();
420 std::fs::write(temp.path().join("SOUL.md"), "# Soul\n\nKeep concise.").unwrap();
421 std::fs::write(temp.path().join("IDENTITY.md"), "# Identity\n\nAgent Diva.").unwrap();
422 std::fs::write(
423 temp.path().join("USER.md"),
424 "# User\n\nPrefer direct replies.",
425 )
426 .unwrap();
427
428 let prompt = SubagentManager::build_subagent_prompt("analyze logs", temp.path());
429 assert!(prompt.contains("## Inherited Identity"));
430 assert!(prompt.contains("### SOUL.md"));
431 assert!(prompt.contains("### IDENTITY.md"));
432 assert!(prompt.contains("### USER.md"));
433 }
434
435 #[test]
436 fn test_build_subagent_prompt_fallback_without_identity_files() {
437 let temp = tempfile::tempdir().unwrap();
438 let prompt = SubagentManager::build_subagent_prompt("analyze logs", temp.path());
439 assert!(prompt.contains("No persisted soul identity found"));
440 }
441}