Skip to main content

agent_diva_agent/
subagent.rs

1//! Subagent management for background tasks
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6
7use anyhow::Result;
8use tokio::sync::RwLock;
9use tokio::task::JoinHandle;
10use tracing::{debug, error, info};
11use uuid::Uuid;
12
13use agent_diva_core::bus::{InboundMessage, MessageBus};
14use agent_diva_core::utils::truncate;
15use agent_diva_providers::base::{LLMProvider, Message};
16use agent_diva_tools::registry::ToolRegistry;
17use agent_diva_tools::{
18    filesystem::{ListDirTool, ReadFileTool, WriteFileTool},
19    shell::ExecTool,
20    web::{WebFetchTool, WebSearchTool},
21};
22
23use crate::tool_config::network::NetworkToolConfig;
24
25/// Subagent manager for background task execution.
26///
27/// Subagents are lightweight agent instances that run in the background
28/// to handle specific tasks. They share the same LLM provider but have
29/// isolated context and a focused system prompt.
30pub struct SubagentManager {
31    provider: Arc<dyn LLMProvider>,
32    workspace: PathBuf,
33    bus: MessageBus,
34    model: String,
35    network_config: Arc<RwLock<NetworkToolConfig>>,
36    exec_timeout: u64,
37    restrict_to_workspace: bool,
38    running_tasks: Arc<tokio::sync::Mutex<HashMap<String, JoinHandle<()>>>>,
39}
40
41impl SubagentManager {
42    /// Create a new subagent manager
43    #[allow(clippy::too_many_arguments)]
44    pub fn new(
45        provider: Arc<dyn LLMProvider>,
46        workspace: PathBuf,
47        bus: MessageBus,
48        model: Option<String>,
49        network_config: NetworkToolConfig,
50        exec_timeout: Option<u64>,
51        restrict_to_workspace: bool,
52    ) -> Self {
53        let model = model.unwrap_or_else(|| provider.get_default_model());
54        let exec_timeout = exec_timeout.unwrap_or(30);
55
56        Self {
57            provider,
58            workspace,
59            bus,
60            model,
61            network_config: Arc::new(RwLock::new(network_config)),
62            exec_timeout,
63            restrict_to_workspace,
64            running_tasks: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
65        }
66    }
67
68    pub async fn update_network_config(&self, network_config: NetworkToolConfig) {
69        let mut guard = self.network_config.write().await;
70        *guard = network_config;
71    }
72
73    /// Spawn a subagent to execute a task in the background.
74    ///
75    /// # Arguments
76    /// * `task` - The task description for the subagent
77    /// * `label` - Optional human-readable label for the task
78    /// * `origin_channel` - The channel to announce results to
79    /// * `origin_chat_id` - The chat ID to announce results to
80    ///
81    /// # Returns
82    /// Status message indicating the subagent was started
83    pub async fn spawn(
84        &self,
85        task: String,
86        label: Option<String>,
87        origin_channel: String,
88        origin_chat_id: String,
89    ) -> Result<String> {
90        let task_id = Uuid::new_v4().to_string()[..8].to_string();
91        let display_label = label.unwrap_or_else(|| {
92            if task.len() > 30 {
93                let mut end = 30;
94                while !task.is_char_boundary(end) {
95                    end -= 1;
96                }
97                format!("{}...", &task[..end])
98            } else {
99                task.clone()
100            }
101        });
102
103        let provider = Arc::clone(&self.provider);
104        let workspace = self.workspace.clone();
105        let bus = self.bus.clone();
106        let model = self.model.clone();
107        let network_config = self.network_config.read().await.clone();
108        let exec_timeout = self.exec_timeout;
109        let restrict_to_workspace = self.restrict_to_workspace;
110
111        let task_id_clone = task_id.clone();
112        let display_label_clone = display_label.clone();
113        let running_tasks = Arc::clone(&self.running_tasks);
114
115        // Create background task
116        let bg_task = tokio::spawn(async move {
117            Self::run_subagent(
118                task_id_clone.clone(),
119                task.clone(),
120                display_label_clone.clone(),
121                origin_channel,
122                origin_chat_id,
123                provider,
124                workspace,
125                bus.clone(),
126                model,
127                network_config,
128                exec_timeout,
129                restrict_to_workspace,
130            )
131            .await;
132
133            // Cleanup when done
134            let mut tasks = running_tasks.lock().await;
135            tasks.remove(&task_id_clone);
136        });
137
138        // Store the task handle
139        let mut tasks = self.running_tasks.lock().await;
140        tasks.insert(task_id.clone(), bg_task);
141        drop(tasks);
142
143        info!("Spawned subagent [{}]: {}", task_id, display_label);
144        Ok(format!(
145            "Subagent [{}] started (id: {}). I'll notify you when it completes.",
146            display_label, task_id
147        ))
148    }
149
150    /// Execute the subagent task and announce the result
151    #[allow(clippy::too_many_arguments)]
152    async fn run_subagent(
153        task_id: String,
154        task: String,
155        label: String,
156        origin_channel: String,
157        origin_chat_id: String,
158        provider: Arc<dyn LLMProvider>,
159        workspace: PathBuf,
160        bus: MessageBus,
161        model: String,
162        network_config: NetworkToolConfig,
163        exec_timeout: u64,
164        restrict_to_workspace: bool,
165    ) {
166        info!("Subagent [{}] starting task: {}", task_id, label);
167
168        let result = Self::execute_subagent_task(
169            &task_id,
170            &task,
171            &provider,
172            &workspace,
173            &model,
174            &network_config,
175            exec_timeout,
176            restrict_to_workspace,
177        )
178        .await;
179
180        let (final_result, status) = match result {
181            Ok(content) => {
182                info!("Subagent [{}] completed successfully", task_id);
183                (content, "ok")
184            }
185            Err(e) => {
186                let error_msg = format!("Error: {}", e);
187                error!("Subagent [{}] failed: {}", task_id, e);
188                (error_msg, "error")
189            }
190        };
191
192        Self::announce_result(
193            &task_id,
194            &label,
195            &task,
196            &final_result,
197            &origin_channel,
198            &origin_chat_id,
199            status,
200            &bus,
201        )
202        .await;
203    }
204
205    /// Execute the subagent task with LLM and tools
206    #[allow(clippy::too_many_arguments)]
207    async fn execute_subagent_task(
208        task_id: &str,
209        task: &str,
210        provider: &Arc<dyn LLMProvider>,
211        workspace: &Path,
212        model: &str,
213        network_config: &NetworkToolConfig,
214        _exec_timeout: u64,
215        restrict_to_workspace: bool,
216    ) -> Result<String> {
217        // Build subagent tools (no message tool, no spawn tool)
218        let mut tools = ToolRegistry::new();
219        let allowed_dir = if restrict_to_workspace {
220            Some(workspace.to_path_buf())
221        } else {
222            None
223        };
224
225        tools.register(Arc::new(ReadFileTool::new(allowed_dir.clone())));
226        tools.register(Arc::new(WriteFileTool::new(allowed_dir.clone())));
227        tools.register(Arc::new(ListDirTool::new(allowed_dir)));
228        tools.register(Arc::new(ExecTool::new()));
229        if network_config.web.search.enabled {
230            tools.register(Arc::new(WebSearchTool::with_provider_and_max_results(
231                network_config.web.search.provider.clone(),
232                network_config.web.search.api_key.clone(),
233                network_config.web.search.normalized_max_results(),
234            )));
235        }
236        if network_config.web.fetch.enabled {
237            tools.register(Arc::new(WebFetchTool::new()));
238        }
239
240        // Build messages with subagent-specific prompt
241        let system_prompt = Self::build_subagent_prompt(task, workspace);
242        let mut messages = vec![
243            Message::system(system_prompt),
244            Message::user(task.to_string()),
245        ];
246
247        // Run agent loop (limited iterations)
248        let max_iterations = 15;
249        let mut iteration = 0;
250        let mut final_result: Option<String> = None;
251
252        while iteration < max_iterations {
253            iteration += 1;
254
255            let response = provider
256                .chat(
257                    messages.clone(),
258                    Some(tools.get_definitions()),
259                    Some(model.to_string()),
260                    2000,
261                    0.7,
262                )
263                .await?;
264
265            if response.has_tool_calls() {
266                // Add assistant message with tool calls
267                messages.push(Message {
268                    role: "assistant".to_string(),
269                    content: response.content.clone().unwrap_or_default(),
270                    name: None,
271                    tool_call_id: None,
272                    tool_calls: Some(response.tool_calls.clone()),
273                    reasoning_content: response.reasoning_content.clone(),
274                    thinking_blocks: None,
275                });
276
277                // Execute tools
278                for tool_call in &response.tool_calls {
279                    let args_json = serde_json::to_value(&tool_call.arguments)?;
280                    let args_str = serde_json::to_string(&tool_call.arguments)?;
281                    debug!(
282                        "Subagent [{}] executing: {} with arguments: {}",
283                        task_id, tool_call.name, args_str
284                    );
285                    let result = tools.execute(&tool_call.name, args_json).await;
286                    messages.push(Message::tool(result, tool_call.id.clone()));
287                }
288            } else {
289                final_result = response.content;
290                break;
291            }
292        }
293
294        Ok(final_result
295            .unwrap_or_else(|| "Task completed but no final response was generated.".to_string()))
296    }
297
298    /// Announce the subagent result to the main agent via the message bus
299    #[allow(clippy::too_many_arguments)]
300    async fn announce_result(
301        task_id: &str,
302        label: &str,
303        task: &str,
304        result: &str,
305        origin_channel: &str,
306        origin_chat_id: &str,
307        status: &str,
308        bus: &MessageBus,
309    ) {
310        let status_text = if status == "ok" {
311            "completed successfully"
312        } else {
313            "failed"
314        };
315
316        let announce_content = format!(
317            "[Subagent '{}' {}]\n\nTask: {}\n\nResult:\n{}\n\nSummarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like \"subagent\" or task IDs.",
318            label, status_text, task, result
319        );
320
321        // Inject as system message to trigger main agent
322        // Use the origin channel/chat_id directly so the response routes back correctly
323        let msg = InboundMessage::new(origin_channel, "subagent", origin_chat_id, announce_content);
324
325        if let Err(e) = bus.publish_inbound(msg) {
326            error!("Failed to announce subagent result: {}", e);
327        }
328
329        debug!(
330            "Subagent [{}] announced result to {}:{}",
331            task_id, origin_channel, origin_chat_id
332        );
333    }
334
335    /// Build a focused system prompt for the subagent
336    fn build_subagent_prompt(task: &str, workspace: &Path) -> String {
337        let soul_summary = Self::build_identity_summary(workspace);
338        format!(
339            r#"# Subagent
340
341You are a subagent spawned by the main agent to complete a specific task.
342
343## Your Task
344{}
345
346## Inherited Identity
347{}
348
349## Rules
3501. Stay focused - complete only the assigned task, nothing else
3512. Your final response will be reported back to the main agent
3523. Do not initiate conversations or take on side tasks
3534. Be concise but informative in your findings
354
355## What You Can Do
356- Read and write files in the workspace
357- Execute shell commands
358- Search the web and fetch web pages
359- Complete the task thoroughly
360
361## What You Cannot Do
362- Send messages directly to users (no message tool available)
363- Spawn other subagents
364- Access the main agent's conversation history
365
366## Workspace
367Your workspace is at: {}
368
369When you have completed the task, provide a clear summary of your findings or actions."#,
370            task,
371            soul_summary,
372            workspace.display()
373        )
374    }
375
376    fn build_identity_summary(workspace: &Path) -> String {
377        let mut sections = Vec::new();
378        for file in ["SOUL.md", "IDENTITY.md", "USER.md"] {
379            let path = workspace.join(file);
380            let Ok(raw) = std::fs::read_to_string(path) else {
381                continue;
382            };
383            let trimmed = raw.trim();
384            if trimmed.is_empty() {
385                continue;
386            }
387            let content = if trimmed.chars().count() > 800 {
388                truncate(trimmed, 3200)
389            } else {
390                trimmed.to_string()
391            };
392            sections.push(format!("### {}\n{}", file, content));
393        }
394
395        if sections.is_empty() {
396            "No persisted soul identity found. Follow the task faithfully, remain concise, and preserve user intent.".to_string()
397        } else {
398            sections.join("\n\n")
399        }
400    }
401
402    /// Get the number of currently running subagents
403    pub async fn get_running_count(&self) -> usize {
404        let tasks = self.running_tasks.lock().await;
405        tasks.len()
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::SubagentManager;
412
413    #[test]
414    fn test_build_subagent_prompt_includes_identity_summary() {
415        let temp = tempfile::tempdir().unwrap();
416        std::fs::write(temp.path().join("SOUL.md"), "# Soul\n\nKeep concise.").unwrap();
417        std::fs::write(temp.path().join("IDENTITY.md"), "# Identity\n\nAgent Diva.").unwrap();
418        std::fs::write(
419            temp.path().join("USER.md"),
420            "# User\n\nPrefer direct replies.",
421        )
422        .unwrap();
423
424        let prompt = SubagentManager::build_subagent_prompt("analyze logs", temp.path());
425        assert!(prompt.contains("## Inherited Identity"));
426        assert!(prompt.contains("### SOUL.md"));
427        assert!(prompt.contains("### IDENTITY.md"));
428        assert!(prompt.contains("### USER.md"));
429    }
430
431    #[test]
432    fn test_build_subagent_prompt_fallback_without_identity_files() {
433        let temp = tempfile::tempdir().unwrap();
434        let prompt = SubagentManager::build_subagent_prompt("analyze logs", temp.path());
435        assert!(prompt.contains("No persisted soul identity found"));
436    }
437}