1use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6
7use anyhow::Result;
8use tokio::sync::RwLock;
9use tokio::task::JoinHandle;
10use tracing::{debug, error, info};
11use uuid::Uuid;
12
13use agent_diva_core::bus::{InboundMessage, MessageBus};
14use agent_diva_core::utils::truncate;
15use agent_diva_providers::base::{LLMProvider, Message};
16use agent_diva_tools::registry::ToolRegistry;
17use agent_diva_tools::{
18 filesystem::{ListDirTool, ReadFileTool, WriteFileTool},
19 shell::ExecTool,
20 web::{WebFetchTool, WebSearchTool},
21};
22
23use crate::tool_config::network::NetworkToolConfig;
24
25pub struct SubagentManager {
31 provider: Arc<dyn LLMProvider>,
32 workspace: PathBuf,
33 bus: MessageBus,
34 model: String,
35 network_config: Arc<RwLock<NetworkToolConfig>>,
36 exec_timeout: u64,
37 restrict_to_workspace: bool,
38 running_tasks: Arc<tokio::sync::Mutex<HashMap<String, JoinHandle<()>>>>,
39}
40
41impl SubagentManager {
42 #[allow(clippy::too_many_arguments)]
44 pub fn new(
45 provider: Arc<dyn LLMProvider>,
46 workspace: PathBuf,
47 bus: MessageBus,
48 model: Option<String>,
49 network_config: NetworkToolConfig,
50 exec_timeout: Option<u64>,
51 restrict_to_workspace: bool,
52 ) -> Self {
53 let model = model.unwrap_or_else(|| provider.get_default_model());
54 let exec_timeout = exec_timeout.unwrap_or(30);
55
56 Self {
57 provider,
58 workspace,
59 bus,
60 model,
61 network_config: Arc::new(RwLock::new(network_config)),
62 exec_timeout,
63 restrict_to_workspace,
64 running_tasks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
65 }
66 }
67
68 pub async fn update_network_config(&self, network_config: NetworkToolConfig) {
69 let mut guard = self.network_config.write().await;
70 *guard = network_config;
71 }
72
73 pub async fn spawn(
84 &self,
85 task: String,
86 label: Option<String>,
87 origin_channel: String,
88 origin_chat_id: String,
89 ) -> Result<String> {
90 let task_id = Uuid::new_v4().to_string()[..8].to_string();
91 let display_label = label.unwrap_or_else(|| {
92 if task.len() > 30 {
93 let mut end = 30;
94 while !task.is_char_boundary(end) {
95 end -= 1;
96 }
97 format!("{}...", &task[..end])
98 } else {
99 task.clone()
100 }
101 });
102
103 let provider = Arc::clone(&self.provider);
104 let workspace = self.workspace.clone();
105 let bus = self.bus.clone();
106 let model = self.model.clone();
107 let network_config = self.network_config.read().await.clone();
108 let exec_timeout = self.exec_timeout;
109 let restrict_to_workspace = self.restrict_to_workspace;
110
111 let task_id_clone = task_id.clone();
112 let display_label_clone = display_label.clone();
113 let running_tasks = Arc::clone(&self.running_tasks);
114
115 let bg_task = tokio::spawn(async move {
117 Self::run_subagent(
118 task_id_clone.clone(),
119 task.clone(),
120 display_label_clone.clone(),
121 origin_channel,
122 origin_chat_id,
123 provider,
124 workspace,
125 bus.clone(),
126 model,
127 network_config,
128 exec_timeout,
129 restrict_to_workspace,
130 )
131 .await;
132
133 let mut tasks = running_tasks.lock().await;
135 tasks.remove(&task_id_clone);
136 });
137
138 let mut tasks = self.running_tasks.lock().await;
140 tasks.insert(task_id.clone(), bg_task);
141 drop(tasks);
142
143 info!("Spawned subagent [{}]: {}", task_id, display_label);
144 Ok(format!(
145 "Subagent [{}] started (id: {}). I'll notify you when it completes.",
146 display_label, task_id
147 ))
148 }
149
150 #[allow(clippy::too_many_arguments)]
152 async fn run_subagent(
153 task_id: String,
154 task: String,
155 label: String,
156 origin_channel: String,
157 origin_chat_id: String,
158 provider: Arc<dyn LLMProvider>,
159 workspace: PathBuf,
160 bus: MessageBus,
161 model: String,
162 network_config: NetworkToolConfig,
163 exec_timeout: u64,
164 restrict_to_workspace: bool,
165 ) {
166 info!("Subagent [{}] starting task: {}", task_id, label);
167
168 let result = Self::execute_subagent_task(
169 &task_id,
170 &task,
171 &provider,
172 &workspace,
173 &model,
174 &network_config,
175 exec_timeout,
176 restrict_to_workspace,
177 )
178 .await;
179
180 let (final_result, status) = match result {
181 Ok(content) => {
182 info!("Subagent [{}] completed successfully", task_id);
183 (content, "ok")
184 }
185 Err(e) => {
186 let error_msg = format!("Error: {}", e);
187 error!("Subagent [{}] failed: {}", task_id, e);
188 (error_msg, "error")
189 }
190 };
191
192 Self::announce_result(
193 &task_id,
194 &label,
195 &task,
196 &final_result,
197 &origin_channel,
198 &origin_chat_id,
199 status,
200 &bus,
201 )
202 .await;
203 }
204
205 #[allow(clippy::too_many_arguments)]
207 async fn execute_subagent_task(
208 task_id: &str,
209 task: &str,
210 provider: &Arc<dyn LLMProvider>,
211 workspace: &Path,
212 model: &str,
213 network_config: &NetworkToolConfig,
214 _exec_timeout: u64,
215 restrict_to_workspace: bool,
216 ) -> Result<String> {
217 let mut tools = ToolRegistry::new();
219 let allowed_dir = if restrict_to_workspace {
220 Some(workspace.to_path_buf())
221 } else {
222 None
223 };
224
225 tools.register(Arc::new(ReadFileTool::new(allowed_dir.clone())));
226 tools.register(Arc::new(WriteFileTool::new(allowed_dir.clone())));
227 tools.register(Arc::new(ListDirTool::new(allowed_dir)));
228 tools.register(Arc::new(ExecTool::new()));
229 if network_config.web.search.enabled {
230 tools.register(Arc::new(WebSearchTool::with_provider_and_max_results(
231 network_config.web.search.provider.clone(),
232 network_config.web.search.api_key.clone(),
233 network_config.web.search.normalized_max_results(),
234 )));
235 }
236 if network_config.web.fetch.enabled {
237 tools.register(Arc::new(WebFetchTool::new()));
238 }
239
240 let system_prompt = Self::build_subagent_prompt(task, workspace);
242 let mut messages = vec![
243 Message::system(system_prompt),
244 Message::user(task.to_string()),
245 ];
246
247 let max_iterations = 15;
249 let mut iteration = 0;
250 let mut final_result: Option<String> = None;
251
252 while iteration < max_iterations {
253 iteration += 1;
254
255 let response = provider
256 .chat(
257 messages.clone(),
258 Some(tools.get_definitions()),
259 Some(model.to_string()),
260 2000,
261 0.7,
262 )
263 .await?;
264
265 if response.has_tool_calls() {
266 messages.push(Message {
268 role: "assistant".to_string(),
269 content: response.content.clone().unwrap_or_default(),
270 name: None,
271 tool_call_id: None,
272 tool_calls: Some(response.tool_calls.clone()),
273 reasoning_content: response.reasoning_content.clone(),
274 thinking_blocks: None,
275 });
276
277 for tool_call in &response.tool_calls {
279 let args_json = serde_json::to_value(&tool_call.arguments)?;
280 let args_str = serde_json::to_string(&tool_call.arguments)?;
281 debug!(
282 "Subagent [{}] executing: {} with arguments: {}",
283 task_id, tool_call.name, args_str
284 );
285 let result = tools.execute(&tool_call.name, args_json).await;
286 messages.push(Message::tool(result, tool_call.id.clone()));
287 }
288 } else {
289 final_result = response.content;
290 break;
291 }
292 }
293
294 Ok(final_result
295 .unwrap_or_else(|| "Task completed but no final response was generated.".to_string()))
296 }
297
298 #[allow(clippy::too_many_arguments)]
300 async fn announce_result(
301 task_id: &str,
302 label: &str,
303 task: &str,
304 result: &str,
305 origin_channel: &str,
306 origin_chat_id: &str,
307 status: &str,
308 bus: &MessageBus,
309 ) {
310 let status_text = if status == "ok" {
311 "completed successfully"
312 } else {
313 "failed"
314 };
315
316 let announce_content = format!(
317 "[Subagent '{}' {}]\n\nTask: {}\n\nResult:\n{}\n\nSummarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like \"subagent\" or task IDs.",
318 label, status_text, task, result
319 );
320
321 let msg = InboundMessage::new(origin_channel, "subagent", origin_chat_id, announce_content);
324
325 if let Err(e) = bus.publish_inbound(msg) {
326 error!("Failed to announce subagent result: {}", e);
327 }
328
329 debug!(
330 "Subagent [{}] announced result to {}:{}",
331 task_id, origin_channel, origin_chat_id
332 );
333 }
334
335 fn build_subagent_prompt(task: &str, workspace: &Path) -> String {
337 let soul_summary = Self::build_identity_summary(workspace);
338 format!(
339 r#"# Subagent
340
341You are a subagent spawned by the main agent to complete a specific task.
342
343## Your Task
344{}
345
346## Inherited Identity
347{}
348
349## Rules
3501. Stay focused - complete only the assigned task, nothing else
3512. Your final response will be reported back to the main agent
3523. Do not initiate conversations or take on side tasks
3534. Be concise but informative in your findings
354
355## What You Can Do
356- Read and write files in the workspace
357- Execute shell commands
358- Search the web and fetch web pages
359- Complete the task thoroughly
360
361## What You Cannot Do
362- Send messages directly to users (no message tool available)
363- Spawn other subagents
364- Access the main agent's conversation history
365
366## Workspace
367Your workspace is at: {}
368
369When you have completed the task, provide a clear summary of your findings or actions."#,
370 task,
371 soul_summary,
372 workspace.display()
373 )
374 }
375
376 fn build_identity_summary(workspace: &Path) -> String {
377 let mut sections = Vec::new();
378 for file in ["SOUL.md", "IDENTITY.md", "USER.md"] {
379 let path = workspace.join(file);
380 let Ok(raw) = std::fs::read_to_string(path) else {
381 continue;
382 };
383 let trimmed = raw.trim();
384 if trimmed.is_empty() {
385 continue;
386 }
387 let content = if trimmed.chars().count() > 800 {
388 truncate(trimmed, 3200)
389 } else {
390 trimmed.to_string()
391 };
392 sections.push(format!("### {}\n{}", file, content));
393 }
394
395 if sections.is_empty() {
396 "No persisted soul identity found. Follow the task faithfully, remain concise, and preserve user intent.".to_string()
397 } else {
398 sections.join("\n\n")
399 }
400 }
401
402 pub async fn get_running_count(&self) -> usize {
404 let tasks = self.running_tasks.lock().await;
405 tasks.len()
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::SubagentManager;
412
413 #[test]
414 fn test_build_subagent_prompt_includes_identity_summary() {
415 let temp = tempfile::tempdir().unwrap();
416 std::fs::write(temp.path().join("SOUL.md"), "# Soul\n\nKeep concise.").unwrap();
417 std::fs::write(temp.path().join("IDENTITY.md"), "# Identity\n\nAgent Diva.").unwrap();
418 std::fs::write(
419 temp.path().join("USER.md"),
420 "# User\n\nPrefer direct replies.",
421 )
422 .unwrap();
423
424 let prompt = SubagentManager::build_subagent_prompt("analyze logs", temp.path());
425 assert!(prompt.contains("## Inherited Identity"));
426 assert!(prompt.contains("### SOUL.md"));
427 assert!(prompt.contains("### IDENTITY.md"));
428 assert!(prompt.contains("### USER.md"));
429 }
430
431 #[test]
432 fn test_build_subagent_prompt_fallback_without_identity_files() {
433 let temp = tempfile::tempdir().unwrap();
434 let prompt = SubagentManager::build_subagent_prompt("analyze logs", temp.path());
435 assert!(prompt.contains("No persisted soul identity found"));
436 }
437}