Skip to main content

agent_diva_agent/
agent_loop.rs

1//! Agent loop: the core processing engine
2
3use agent_diva_core::bus::{AgentEvent, InboundMessage, MessageBus, OutboundMessage};
4use agent_diva_core::config::MCPServerConfig;
5use agent_diva_core::cron::CronService;
6use agent_diva_core::error_context::ErrorContext;
7use agent_diva_core::session::SessionManager;
8use agent_diva_memory::WorkspaceMemoryService;
9use agent_diva_providers::LLMProvider;
10use agent_diva_tools::{
11    load_mcp_tools_sync, CronTool, DiaryListTool, DiaryReadTool, EditFileTool, ExecTool,
12    ListDirTool, MemoryGetTool, MemoryRecallTool, MemorySearchTool, ReadFileTool, SpawnTool,
13    ToolError, ToolRegistry, WriteFileTool,
14};
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::path::PathBuf;
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tokio::sync::mpsc;
20use tracing::{debug, error, info};
21use uuid::Uuid;
22
23use crate::consolidation;
24use crate::context::{ContextBuilder, SoulContextSettings};
25use crate::runtime_control::RuntimeControlCommand;
26use crate::subagent::SubagentManager;
27use crate::tool_config::network::NetworkToolConfig;
28
29mod loop_runtime_control;
30mod loop_tools;
31mod loop_turn;
32
33/// Configuration for tool setup
34#[derive(Clone)]
35pub struct ToolConfig {
36    /// Network tool runtime config
37    pub network: NetworkToolConfig,
38    /// Shell execution timeout in seconds
39    pub exec_timeout: u64,
40    /// Whether to restrict file access to workspace
41    pub restrict_to_workspace: bool,
42    /// Configured MCP servers
43    pub mcp_servers: HashMap<String, MCPServerConfig>,
44    /// Optional cron service for scheduling tools
45    pub cron_service: Option<Arc<CronService>>,
46    /// Soul context settings
47    pub soul_context: SoulContextSettings,
48    /// Whether to append transparent notifications on soul updates
49    pub notify_on_soul_change: bool,
50    /// Governance behavior for soul evolution transparency
51    pub soul_governance: SoulGovernanceSettings,
52}
53
54impl Default for ToolConfig {
55    fn default() -> Self {
56        Self {
57            network: NetworkToolConfig::default(),
58            exec_timeout: 60,
59            restrict_to_workspace: false,
60            mcp_servers: HashMap::new(),
61            cron_service: None,
62            soul_context: SoulContextSettings::default(),
63            notify_on_soul_change: true,
64            soul_governance: SoulGovernanceSettings::default(),
65        }
66    }
67}
68
69/// Runtime soft-governance settings for soul evolution.
70#[derive(Clone, Debug)]
71pub struct SoulGovernanceSettings {
72    /// Rolling window in seconds for "frequent changes" hints.
73    pub frequent_change_window_secs: u64,
74    /// Minimum number of soul-changing turns in window to trigger hints.
75    pub frequent_change_threshold: usize,
76    /// Add a confirmation hint when SOUL.md changes.
77    pub boundary_confirmation_hint: bool,
78}
79
80impl Default for SoulGovernanceSettings {
81    fn default() -> Self {
82        Self {
83            frequent_change_window_secs: 600,
84            frequent_change_threshold: 3,
85            boundary_confirmation_hint: true,
86        }
87    }
88}
89
90/// The agent loop is the core processing engine
91pub struct AgentLoop {
92    bus: MessageBus,
93    provider: Arc<dyn LLMProvider>,
94    #[allow(dead_code)]
95    workspace: PathBuf,
96    #[allow(dead_code)]
97    model: String,
98    max_iterations: usize,
99    memory_window: usize,
100    context: ContextBuilder,
101    sessions: SessionManager,
102    tools: ToolRegistry,
103    memory_service: Arc<WorkspaceMemoryService>,
104    subagent_manager: Arc<SubagentManager>,
105    runtime_control_rx: Option<mpsc::UnboundedReceiver<RuntimeControlCommand>>,
106    cancelled_sessions: HashSet<String>,
107    notify_on_soul_change: bool,
108    soul_governance: SoulGovernanceSettings,
109    soul_change_turns: VecDeque<Instant>,
110}
111
112impl AgentLoop {
113    /// Create a new agent loop
114    pub fn new(
115        bus: MessageBus,
116        provider: Arc<dyn LLMProvider>,
117        workspace: PathBuf,
118        model: Option<String>,
119        max_iterations: Option<usize>,
120    ) -> Self {
121        let model = model.unwrap_or_else(|| provider.get_default_model());
122        let mut context = ContextBuilder::with_skills(workspace.clone(), None);
123        context.set_soul_settings(SoulContextSettings::default());
124        let sessions = SessionManager::new(workspace.clone());
125        let tools = ToolRegistry::new();
126        let memory_service = Arc::new(WorkspaceMemoryService::new(&workspace));
127
128        let subagent_manager = Arc::new(SubagentManager::new(
129            provider.clone(),
130            workspace.clone(),
131            bus.clone(),
132            Some(model.clone()),
133            NetworkToolConfig::default(),
134            None,
135            false,
136        ));
137
138        Self {
139            bus,
140            provider,
141            workspace,
142            model,
143            max_iterations: max_iterations.unwrap_or(20),
144            memory_window: consolidation::DEFAULT_MEMORY_WINDOW,
145            context,
146            sessions,
147            tools,
148            memory_service,
149            subagent_manager,
150            runtime_control_rx: None,
151            cancelled_sessions: HashSet::new(),
152            notify_on_soul_change: true,
153            soul_governance: SoulGovernanceSettings::default(),
154            soul_change_turns: VecDeque::new(),
155        }
156    }
157
158    /// Create a new agent loop with tool configuration
159    pub fn with_tools(
160        bus: MessageBus,
161        provider: Arc<dyn LLMProvider>,
162        workspace: PathBuf,
163        model: Option<String>,
164        max_iterations: Option<usize>,
165        tool_config: ToolConfig,
166        runtime_control_rx: Option<mpsc::UnboundedReceiver<RuntimeControlCommand>>,
167    ) -> Self {
168        let model = model.unwrap_or_else(|| provider.get_default_model());
169        let mut context = ContextBuilder::with_skills(workspace.clone(), None);
170        context.set_soul_settings(tool_config.soul_context.clone());
171        let sessions = SessionManager::new(workspace.clone());
172        let mut tools = ToolRegistry::new();
173        let memory_service = Arc::new(WorkspaceMemoryService::new(&workspace));
174
175        let subagent_manager = Arc::new(SubagentManager::new(
176            provider.clone(),
177            workspace.clone(),
178            bus.clone(),
179            Some(model.clone()),
180            tool_config.network.clone(),
181            Some(tool_config.exec_timeout),
182            tool_config.restrict_to_workspace,
183        ));
184
185        // Register spawn tool
186        let sm = subagent_manager.clone();
187        tools.register(Arc::new(SpawnTool::new(
188            move |task, label, channel, chat_id| {
189                let sm = sm.clone();
190                async move {
191                    sm.spawn(task, label, channel, chat_id)
192                        .await
193                        .map_err(|e| ToolError::ExecutionFailed(e.to_string()))
194                }
195            },
196        )));
197
198        // Register file system tools
199        let allowed_dir = if tool_config.restrict_to_workspace {
200            Some(workspace.clone())
201        } else {
202            None
203        };
204        tools.register(Arc::new(ReadFileTool::new(allowed_dir.clone())));
205        tools.register(Arc::new(WriteFileTool::new(allowed_dir.clone())));
206        tools.register(Arc::new(EditFileTool::new(allowed_dir.clone())));
207        tools.register(Arc::new(ListDirTool::new(allowed_dir)));
208        tools.register(Arc::new(MemoryRecallTool::new(memory_service.clone())));
209        tools.register(Arc::new(MemorySearchTool::new(memory_service.clone())));
210        tools.register(Arc::new(MemoryGetTool::new(memory_service.clone())));
211        tools.register(Arc::new(DiaryReadTool::new(memory_service.clone())));
212        tools.register(Arc::new(DiaryListTool::new(memory_service.clone())));
213
214        // Register shell tool
215        tools.register(Arc::new(ExecTool::with_config(
216            tool_config.exec_timeout,
217            Some(workspace.clone()),
218            tool_config.restrict_to_workspace,
219        )));
220
221        // Register web tools
222        Self::register_web_tools(&mut tools, &tool_config.network);
223
224        // Register MCP tools discovered from configured servers
225        for mcp_tool in load_mcp_tools_sync(&tool_config.mcp_servers) {
226            tools.register(mcp_tool);
227        }
228
229        // Register cron tool when scheduling is configured
230        if let Some(cron_service) = tool_config.cron_service.clone() {
231            tools.register(Arc::new(CronTool::new(cron_service)));
232        }
233
234        Self {
235            bus,
236            provider,
237            workspace,
238            model,
239            max_iterations: max_iterations.unwrap_or(20),
240            memory_window: consolidation::DEFAULT_MEMORY_WINDOW,
241            context,
242            sessions,
243            tools,
244            memory_service,
245            subagent_manager,
246            runtime_control_rx,
247            cancelled_sessions: HashSet::new(),
248            notify_on_soul_change: tool_config.notify_on_soul_change,
249            soul_governance: tool_config.soul_governance,
250            soul_change_turns: VecDeque::new(),
251        }
252    }
253
254    /// Run the agent loop, processing messages from the bus
255    pub async fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
256        info!("Agent loop started");
257
258        // Take the inbound receiver
259        let Some(mut inbound_rx) = self.bus.take_inbound_receiver().await else {
260            error!("Failed to take inbound receiver");
261            return Err("Inbound receiver already taken".into());
262        };
263
264        loop {
265            if let Some(control_rx) = self.runtime_control_rx.as_mut() {
266                tokio::select! {
267                    control = control_rx.recv() => {
268                        match control {
269                            Some(cmd) => self.handle_runtime_control_command(cmd).await,
270                            None => {
271                                info!("Runtime control channel closed");
272                                self.runtime_control_rx = None;
273                            }
274                        }
275                    }
276                    maybe_msg = inbound_rx.recv() => {
277                        match maybe_msg {
278                            Some(msg) => self.handle_inbound(msg).await,
279                            None => {
280                                info!("Message bus closed, stopping agent loop");
281                                break;
282                            }
283                        }
284                    }
285                }
286            } else {
287                match tokio::time::timeout(std::time::Duration::from_secs(1), inbound_rx.recv())
288                    .await
289                {
290                    Ok(Some(msg)) => self.handle_inbound(msg).await,
291                    Ok(None) => {
292                        info!("Message bus closed, stopping agent loop");
293                        break;
294                    }
295                    Err(_) => continue,
296                }
297            }
298        }
299
300        info!("Agent loop stopped");
301        Ok(())
302    }
303
304    async fn handle_inbound(&mut self, msg: InboundMessage) {
305        debug!("Received message from {}:{}", msg.channel, msg.chat_id);
306        let event_msg = msg.clone();
307        match self.process_inbound_message(msg, None).await {
308            Ok(Some(response)) => {
309                if let Err(e) = self.bus.publish_outbound(response) {
310                    error!("Failed to publish response: {}", e);
311                }
312            }
313            Ok(None) => debug!("No response needed"),
314            Err(e) => {
315                let error_message = format!("Failed to process message: {}", e);
316                let ctx = ErrorContext::new("handle_inbound", &error_message)
317                    .with_metadata("channel", event_msg.channel.clone())
318                    .with_metadata("chat_id", event_msg.chat_id.clone())
319                    .with_metadata("sender_id", event_msg.sender_id.clone());
320                error!("{}", ctx.to_detailed_string());
321                self.emit_error_event(&event_msg, None, error_message);
322            }
323        }
324    }
325
326    /// Process a single inbound message
327    pub async fn process_inbound_message(
328        &mut self,
329        msg: InboundMessage,
330        event_tx: Option<&mpsc::UnboundedSender<AgentEvent>>,
331    ) -> Result<Option<OutboundMessage>, Box<dyn std::error::Error>> {
332        let trace_id = Uuid::new_v4().to_string();
333        use tracing::Instrument;
334        let span = tracing::info_span!("AgentSpan", trace_id = %trace_id);
335
336        self.process_inbound_message_inner(msg, event_tx, trace_id)
337            .instrument(span)
338            .await
339    }
340
341    /// Process a message directly (for CLI or testing)
342    pub async fn process_direct(
343        &mut self,
344        content: impl Into<String>,
345        _session_key: impl Into<String>,
346        channel: impl Into<String>,
347        chat_id: impl Into<String>,
348    ) -> Result<String, Box<dyn std::error::Error>> {
349        let content = content.into();
350        let channel = channel.into();
351        let chat_id = chat_id.into();
352
353        let msg = InboundMessage::new(channel, "user", chat_id, content);
354
355        let response = self.process_inbound_message(msg, None).await?;
356        Ok(response
357            .map(|r| {
358                let content = r.content;
359                if let Some(reasoning) = r.reasoning_content {
360                    if !reasoning.is_empty() {
361                        return format!("<think>\n{}\n</think>\n\n{}", reasoning, content);
362                    }
363                }
364                content
365            })
366            .unwrap_or_default())
367    }
368
369    /// Process a message directly and emit streaming events for UI consumers.
370    pub async fn process_direct_stream(
371        &mut self,
372        content: impl Into<String>,
373        _session_key: impl Into<String>,
374        channel: impl Into<String>,
375        chat_id: impl Into<String>,
376        event_tx: mpsc::UnboundedSender<AgentEvent>,
377    ) -> Result<String, Box<dyn std::error::Error>> {
378        let content = content.into();
379        let channel = channel.into();
380        let chat_id = chat_id.into();
381
382        let msg = InboundMessage::new(channel, "user", chat_id, content);
383
384        match self.process_inbound_message(msg, Some(&event_tx)).await {
385            Ok(response) => Ok(response.map(|r| r.content).unwrap_or_default()),
386            Err(err) => {
387                let _ = event_tx.send(AgentEvent::Error {
388                    message: err.to_string(),
389                });
390                Err(err)
391            }
392        }
393    }
394
395    fn is_frequent_soul_change_turn(&mut self) -> bool {
396        let window = Duration::from_secs(self.soul_governance.frequent_change_window_secs.max(1));
397        let now = Instant::now();
398        self.soul_change_turns.push_back(now);
399        while let Some(front) = self.soul_change_turns.front().copied() {
400            if now.duration_since(front) > window {
401                self.soul_change_turns.pop_front();
402            } else {
403                break;
404            }
405        }
406        self.soul_change_turns.len() >= self.soul_governance.frequent_change_threshold.max(1)
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use agent_diva_providers::{
414        LLMResponse, LiteLLMClient, Message, ProviderError, ProviderEventStream, ProviderResult,
415    };
416    use async_trait::async_trait;
417    use futures::stream;
418    use tokio::time::{timeout, Duration};
419
420    struct FailingStreamProvider;
421
422    #[async_trait]
423    impl LLMProvider for FailingStreamProvider {
424        async fn chat(
425            &self,
426            _messages: Vec<Message>,
427            _tools: Option<Vec<serde_json::Value>>,
428            _model: Option<String>,
429            _max_tokens: i32,
430            _temperature: f64,
431        ) -> ProviderResult<LLMResponse> {
432            Err(ProviderError::ApiError(
433                "chat should not be used".to_string(),
434            ))
435        }
436
437        async fn chat_stream(
438            &self,
439            _messages: Vec<Message>,
440            _tools: Option<Vec<serde_json::Value>>,
441            _model: Option<String>,
442            _max_tokens: i32,
443            _temperature: f64,
444        ) -> ProviderResult<ProviderEventStream> {
445            Ok(Box::pin(stream::iter(vec![Err(ProviderError::ApiError(
446                "simulated stream failure".to_string(),
447            ))])))
448        }
449
450        fn get_default_model(&self) -> String {
451            "test-model".to_string()
452        }
453    }
454
455    #[tokio::test]
456    async fn test_agent_loop_creation() {
457        let bus = MessageBus::new();
458        let provider = Arc::new(LiteLLMClient::default());
459        let workspace = PathBuf::from("/tmp/test");
460        let agent = AgentLoop::new(bus, provider, workspace, None, None);
461        assert_eq!(agent.max_iterations, 20);
462    }
463
464    #[tokio::test]
465    async fn test_process_direct() {
466        let bus = MessageBus::new();
467        let provider = Arc::new(LiteLLMClient::default());
468        let temp_dir = tempfile::tempdir().unwrap();
469        let workspace = temp_dir.path().to_path_buf();
470
471        let mut agent = AgentLoop::new(bus, provider, workspace, None, Some(1));
472
473        // This will fail to connect to LLM, but tests the structure
474        let result = agent
475            .process_direct("Hello", "cli:test", "cli", "test")
476            .await;
477
478        // We expect an error since we don't have a real LLM connection
479        assert!(result.is_err());
480    }
481
482    #[test]
483    fn test_soul_governance_defaults_are_non_zero() {
484        let cfg = SoulGovernanceSettings::default();
485        assert!(cfg.frequent_change_window_secs > 0);
486        assert!(cfg.frequent_change_threshold > 0);
487    }
488
489    #[tokio::test]
490    async fn test_handle_inbound_emits_error_event_on_provider_failure() {
491        let bus = MessageBus::new();
492        let mut event_rx = bus.subscribe_events();
493        let provider = Arc::new(FailingStreamProvider);
494        let temp_dir = tempfile::tempdir().unwrap();
495        let workspace = temp_dir.path().to_path_buf();
496
497        let mut agent = AgentLoop::new(bus.clone(), provider, workspace, None, Some(1));
498        let msg = InboundMessage::new("gui", "user", "chat-1", "Hello");
499
500        agent.handle_inbound(msg).await;
501
502        let error_event = timeout(Duration::from_secs(1), async {
503            loop {
504                let bus_event = event_rx.recv().await.unwrap();
505                if let AgentEvent::Error { message } = bus_event.event {
506                    break (bus_event.channel, bus_event.chat_id, message);
507                }
508            }
509        })
510        .await
511        .expect("timed out waiting for error event");
512
513        assert_eq!(error_event.0, "gui");
514        assert_eq!(error_event.1, "chat-1");
515        assert!(error_event.2.contains("simulated stream failure"));
516    }
517}