Skip to main content

agent_diva_agent/
subagent.rs

1//! Subagent management for background tasks
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6
7use anyhow::Result;
8use tokio::sync::RwLock;
9use tokio::task::JoinHandle;
10use tracing::{debug, error, info};
11use uuid::Uuid;
12
13use agent_diva_core::bus::{InboundMessage, MessageBus};
14use agent_diva_core::utils::truncate;
15use agent_diva_providers::base::{LLMProvider, Message};
16use agent_diva_tooling::ToolRegistry;
17
18use crate::tool_assembly::ToolAssembly;
19use crate::tool_config::builtin::BuiltInToolsConfig;
20use crate::tool_config::network::NetworkToolConfig;
21use agent_diva_core::config::MCPServerConfig;
22
23/// Subagent manager for background task execution.
24///
25/// Subagents are lightweight agent instances that run in the background
26/// to handle specific tasks. They share the same LLM provider but have
27/// isolated context and a focused system prompt.
28pub struct SubagentManager {
29    provider: Arc<dyn LLMProvider>,
30    workspace: PathBuf,
31    bus: MessageBus,
32    model: String,
33    builtin_tools: BuiltInToolsConfig,
34    network_config: Arc<RwLock<NetworkToolConfig>>,
35    exec_timeout: u64,
36    restrict_to_workspace: bool,
37    mcp_servers: Arc<RwLock<HashMap<String, MCPServerConfig>>>,
38    running_tasks: Arc<tokio::sync::Mutex<HashMap<String, JoinHandle<()>>>>,
39}
40
41impl SubagentManager {
42    /// Create a new subagent manager
43    #[allow(clippy::too_many_arguments)]
44    pub fn new(
45        provider: Arc<dyn LLMProvider>,
46        workspace: PathBuf,
47        bus: MessageBus,
48        model: Option<String>,
49        builtin_tools: BuiltInToolsConfig,
50        network_config: NetworkToolConfig,
51        exec_timeout: Option<u64>,
52        restrict_to_workspace: bool,
53        mcp_servers: HashMap<String, MCPServerConfig>,
54    ) -> Self {
55        let model = model.unwrap_or_else(|| provider.get_default_model());
56        let exec_timeout = exec_timeout.unwrap_or(30);
57
58        Self {
59            provider,
60            workspace,
61            bus,
62            model,
63            builtin_tools,
64            network_config: Arc::new(RwLock::new(network_config)),
65            exec_timeout,
66            restrict_to_workspace,
67            mcp_servers: Arc::new(RwLock::new(mcp_servers)),
68            running_tasks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
69        }
70    }
71
72    pub async fn update_network_config(&self, network_config: NetworkToolConfig) {
73        let mut guard = self.network_config.write().await;
74        *guard = network_config;
75    }
76
77    pub async fn update_mcp_servers(&self, mcp_servers: HashMap<String, MCPServerConfig>) {
78        let mut guard = self.mcp_servers.write().await;
79        *guard = mcp_servers;
80    }
81
82    /// Spawn a subagent to execute a task in the background.
83    ///
84    /// # Arguments
85    /// * `task` - The task description for the subagent
86    /// * `label` - Optional human-readable label for the task
87    /// * `origin_channel` - The channel to announce results to
88    /// * `origin_chat_id` - The chat ID to announce results to
89    ///
90    /// # Returns
91    /// Status message indicating the subagent was started
92    pub async fn spawn(
93        &self,
94        task: String,
95        label: Option<String>,
96        origin_channel: String,
97        origin_chat_id: String,
98    ) -> Result<String> {
99        let task_id = Uuid::new_v4().to_string()[..8].to_string();
100        let display_label = label.unwrap_or_else(|| {
101            if task.len() > 30 {
102                let mut end = 30;
103                while !task.is_char_boundary(end) {
104                    end -= 1;
105                }
106                format!("{}...", &task[..end])
107            } else {
108                task.clone()
109            }
110        });
111
112        let provider = Arc::clone(&self.provider);
113        let workspace = self.workspace.clone();
114        let bus = self.bus.clone();
115        let model = self.model.clone();
116        let builtin_tools = self.builtin_tools.clone();
117        let network_config = self.network_config.read().await.clone();
118        let exec_timeout = self.exec_timeout;
119        let restrict_to_workspace = self.restrict_to_workspace;
120        let mcp_servers = self.mcp_servers.read().await.clone();
121
122        let task_id_clone = task_id.clone();
123        let display_label_clone = display_label.clone();
124        let running_tasks = Arc::clone(&self.running_tasks);
125
126        // Create background task
127        let bg_task = tokio::spawn(async move {
128            Self::run_subagent(
129                task_id_clone.clone(),
130                task.clone(),
131                display_label_clone.clone(),
132                origin_channel,
133                origin_chat_id,
134                provider,
135                workspace,
136                bus.clone(),
137                model,
138                builtin_tools,
139                network_config,
140                exec_timeout,
141                restrict_to_workspace,
142                mcp_servers,
143            )
144            .await;
145
146            // Cleanup when done
147            let mut tasks = running_tasks.lock().await;
148            tasks.remove(&task_id_clone);
149        });
150
151        // Store the task handle
152        let mut tasks = self.running_tasks.lock().await;
153        tasks.insert(task_id.clone(), bg_task);
154        drop(tasks);
155
156        info!("Spawned subagent [{}]: {}", task_id, display_label);
157        Ok(format!(
158            "Subagent [{}] started (id: {}). I'll notify you when it completes.",
159            display_label, task_id
160        ))
161    }
162
163    /// Execute the subagent task and announce the result
164    #[allow(clippy::too_many_arguments)]
165    async fn run_subagent(
166        task_id: String,
167        task: String,
168        label: String,
169        origin_channel: String,
170        origin_chat_id: String,
171        provider: Arc<dyn LLMProvider>,
172        workspace: PathBuf,
173        bus: MessageBus,
174        model: String,
175        builtin_tools: BuiltInToolsConfig,
176        network_config: NetworkToolConfig,
177        exec_timeout: u64,
178        restrict_to_workspace: bool,
179        mcp_servers: HashMap<String, MCPServerConfig>,
180    ) {
181        info!("Subagent [{}] starting task: {}", task_id, label);
182
183        let result = Self::execute_subagent_task(
184            &task_id,
185            &task,
186            &provider,
187            &workspace,
188            &model,
189            &builtin_tools,
190            &network_config,
191            exec_timeout,
192            restrict_to_workspace,
193            &mcp_servers,
194        )
195        .await;
196
197        let (final_result, status) = match result {
198            Ok(content) => {
199                info!("Subagent [{}] completed successfully", task_id);
200                (content, "ok")
201            }
202            Err(e) => {
203                let error_msg = format!("Error: {}", e);
204                error!("Subagent [{}] failed: {}", task_id, e);
205                (error_msg, "error")
206            }
207        };
208
209        Self::announce_result(
210            &task_id,
211            &label,
212            &task,
213            &final_result,
214            &origin_channel,
215            &origin_chat_id,
216            status,
217            &bus,
218        )
219        .await;
220    }
221
222    /// Execute the subagent task with LLM and tools
223    #[allow(clippy::too_many_arguments)]
224    async fn execute_subagent_task(
225        task_id: &str,
226        task: &str,
227        provider: &Arc<dyn LLMProvider>,
228        workspace: &Path,
229        model: &str,
230        builtin_tools: &BuiltInToolsConfig,
231        network_config: &NetworkToolConfig,
232        exec_timeout: u64,
233        restrict_to_workspace: bool,
234        mcp_servers: &HashMap<String, MCPServerConfig>,
235    ) -> Result<String> {
236        let tools: ToolRegistry = ToolAssembly::new(workspace.to_path_buf())
237            .builtin(builtin_tools.clone())
238            .with_network_config(network_config.clone())
239            .with_exec_timeout(exec_timeout)
240            .restrict_to_workspace(restrict_to_workspace)
241            .mcp_servers(mcp_servers.clone())
242            .build_subagent_registry();
243
244        // Build messages with subagent-specific prompt
245        let system_prompt = Self::build_subagent_prompt(task, workspace);
246        let mut messages = vec![
247            Message::system(system_prompt),
248            Message::user(task.to_string()),
249        ];
250
251        // Run agent loop (limited iterations)
252        let max_iterations = 15;
253        let mut iteration = 0;
254        let mut final_result: Option<String> = None;
255
256        while iteration < max_iterations {
257            iteration += 1;
258
259            let response = provider
260                .chat(
261                    messages.clone(),
262                    Some(tools.get_definitions()),
263                    Some(model.to_string()),
264                    2000,
265                    0.7,
266                )
267                .await?;
268
269            if response.has_tool_calls() {
270                // Add assistant message with tool calls
271                messages.push(Message {
272                    role: "assistant".to_string(),
273                    content: response.content.clone().unwrap_or_default(),
274                    name: None,
275                    tool_call_id: None,
276                    tool_calls: Some(response.tool_calls.clone()),
277                    reasoning_content: response.reasoning_content.clone(),
278                    thinking_blocks: None,
279                });
280
281                // Execute tools
282                for tool_call in &response.tool_calls {
283                    let args_json = serde_json::to_value(&tool_call.arguments)?;
284                    let args_str = serde_json::to_string(&tool_call.arguments)?;
285                    debug!(
286                        "Subagent [{}] executing: {} with arguments: {}",
287                        task_id, tool_call.name, args_str
288                    );
289                    let result = tools.execute(&tool_call.name, args_json).await;
290                    messages.push(Message::tool(result, tool_call.id.clone()));
291                }
292            } else {
293                final_result = response.content;
294                break;
295            }
296        }
297
298        Ok(final_result
299            .unwrap_or_else(|| "Task completed but no final response was generated.".to_string()))
300    }
301
302    /// Announce the subagent result to the main agent via the message bus
303    #[allow(clippy::too_many_arguments)]
304    async fn announce_result(
305        task_id: &str,
306        label: &str,
307        task: &str,
308        result: &str,
309        origin_channel: &str,
310        origin_chat_id: &str,
311        status: &str,
312        bus: &MessageBus,
313    ) {
314        let status_text = if status == "ok" {
315            "completed successfully"
316        } else {
317            "failed"
318        };
319
320        let announce_content = format!(
321            "[Subagent '{}' {}]\n\nTask: {}\n\nResult:\n{}\n\nSummarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like \"subagent\" or task IDs.",
322            label, status_text, task, result
323        );
324
325        // Inject as system message to trigger main agent
326        // Use the origin channel/chat_id directly so the response routes back correctly
327        let msg = InboundMessage::new(origin_channel, "subagent", origin_chat_id, announce_content);
328
329        if let Err(e) = bus.publish_inbound(msg) {
330            error!("Failed to announce subagent result: {}", e);
331        }
332
333        debug!(
334            "Subagent [{}] announced result to {}:{}",
335            task_id, origin_channel, origin_chat_id
336        );
337    }
338
339    /// Build a focused system prompt for the subagent
340    fn build_subagent_prompt(task: &str, workspace: &Path) -> String {
341        let soul_summary = Self::build_identity_summary(workspace);
342        format!(
343            r#"# Subagent
344
345You are a subagent spawned by the main agent to complete a specific task.
346
347## Your Task
348{}
349
350## Inherited Identity
351{}
352
353## Rules
3541. Stay focused - complete only the assigned task, nothing else
3552. Your final response will be reported back to the main agent
3563. Do not initiate conversations or take on side tasks
3574. Be concise but informative in your findings
358
359## What You Can Do
360- Read and write files in the workspace
361- Execute shell commands
362- Search the web and fetch web pages
363- Complete the task thoroughly
364
365## What You Cannot Do
366- Send messages directly to users (no message tool available)
367- Spawn other subagents
368- Access the main agent's conversation history
369
370## Workspace
371Your workspace is at: {}
372
373When you have completed the task, provide a clear summary of your findings or actions."#,
374            task,
375            soul_summary,
376            workspace.display()
377        )
378    }
379
380    fn build_identity_summary(workspace: &Path) -> String {
381        let mut sections = Vec::new();
382        for file in ["SOUL.md", "IDENTITY.md", "USER.md"] {
383            let path = workspace.join(file);
384            let Ok(raw) = std::fs::read_to_string(path) else {
385                continue;
386            };
387            let trimmed = raw.trim();
388            if trimmed.is_empty() {
389                continue;
390            }
391            let content = if trimmed.chars().count() > 800 {
392                truncate(trimmed, 3200)
393            } else {
394                trimmed.to_string()
395            };
396            sections.push(format!("### {}\n{}", file, content));
397        }
398
399        if sections.is_empty() {
400            "No persisted soul identity found. Follow the task faithfully, remain concise, and preserve user intent.".to_string()
401        } else {
402            sections.join("\n\n")
403        }
404    }
405
406    /// Get the number of currently running subagents
407    pub async fn get_running_count(&self) -> usize {
408        let tasks = self.running_tasks.lock().await;
409        tasks.len()
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::SubagentManager;
416
417    #[test]
418    fn test_build_subagent_prompt_includes_identity_summary() {
419        let temp = tempfile::tempdir().unwrap();
420        std::fs::write(temp.path().join("SOUL.md"), "# Soul\n\nKeep concise.").unwrap();
421        std::fs::write(temp.path().join("IDENTITY.md"), "# Identity\n\nAgent Diva.").unwrap();
422        std::fs::write(
423            temp.path().join("USER.md"),
424            "# User\n\nPrefer direct replies.",
425        )
426        .unwrap();
427
428        let prompt = SubagentManager::build_subagent_prompt("analyze logs", temp.path());
429        assert!(prompt.contains("## Inherited Identity"));
430        assert!(prompt.contains("### SOUL.md"));
431        assert!(prompt.contains("### IDENTITY.md"));
432        assert!(prompt.contains("### USER.md"));
433    }
434
435    #[test]
436    fn test_build_subagent_prompt_fallback_without_identity_files() {
437        let temp = tempfile::tempdir().unwrap();
438        let prompt = SubagentManager::build_subagent_prompt("analyze logs", temp.path());
439        assert!(prompt.contains("No persisted soul identity found"));
440    }
441}