Skip to main content

agent_diva_agent/
agent_loop.rs

1//! Agent loop: the core processing engine
2
3use agent_diva_core::bus::{AgentEvent, InboundMessage, MessageBus, OutboundMessage};
4use agent_diva_core::config::MCPServerConfig;
5use agent_diva_core::cron::CronService;
6use agent_diva_core::error_context::ErrorContext;
7use agent_diva_core::session::SessionManager;
8use agent_diva_providers::LLMProvider;
9use agent_diva_tools::{
10    load_mcp_tools_sync, CronTool, EditFileTool, ExecTool, ListDirTool, ReadFileTool, SpawnTool,
11    ToolError, ToolRegistry, WriteFileTool,
12};
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::path::PathBuf;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::sync::mpsc;
18use tracing::{debug, error, info};
19use uuid::Uuid;
20
21use crate::consolidation;
22use crate::context::{ContextBuilder, SoulContextSettings};
23use crate::runtime_control::RuntimeControlCommand;
24use crate::subagent::SubagentManager;
25use crate::tool_config::network::NetworkToolConfig;
26
27mod loop_runtime_control;
28mod loop_tools;
29mod loop_turn;
30
31/// Configuration for tool setup
32#[derive(Clone)]
33pub struct ToolConfig {
34    /// Network tool runtime config
35    pub network: NetworkToolConfig,
36    /// Shell execution timeout in seconds
37    pub exec_timeout: u64,
38    /// Whether to restrict file access to workspace
39    pub restrict_to_workspace: bool,
40    /// Configured MCP servers
41    pub mcp_servers: HashMap<String, MCPServerConfig>,
42    /// Optional cron service for scheduling tools
43    pub cron_service: Option<Arc<CronService>>,
44    /// Soul context settings
45    pub soul_context: SoulContextSettings,
46    /// Whether to append transparent notifications on soul updates
47    pub notify_on_soul_change: bool,
48    /// Governance behavior for soul evolution transparency
49    pub soul_governance: SoulGovernanceSettings,
50}
51
52impl Default for ToolConfig {
53    fn default() -> Self {
54        Self {
55            network: NetworkToolConfig::default(),
56            exec_timeout: 60,
57            restrict_to_workspace: false,
58            mcp_servers: HashMap::new(),
59            cron_service: None,
60            soul_context: SoulContextSettings::default(),
61            notify_on_soul_change: true,
62            soul_governance: SoulGovernanceSettings::default(),
63        }
64    }
65}
66
67/// Runtime soft-governance settings for soul evolution.
68#[derive(Clone, Debug)]
69pub struct SoulGovernanceSettings {
70    /// Rolling window in seconds for "frequent changes" hints.
71    pub frequent_change_window_secs: u64,
72    /// Minimum number of soul-changing turns in window to trigger hints.
73    pub frequent_change_threshold: usize,
74    /// Add a confirmation hint when SOUL.md changes.
75    pub boundary_confirmation_hint: bool,
76}
77
78impl Default for SoulGovernanceSettings {
79    fn default() -> Self {
80        Self {
81            frequent_change_window_secs: 600,
82            frequent_change_threshold: 3,
83            boundary_confirmation_hint: true,
84        }
85    }
86}
87
88/// The agent loop is the core processing engine
89pub struct AgentLoop {
90    bus: MessageBus,
91    provider: Arc<dyn LLMProvider>,
92    #[allow(dead_code)]
93    workspace: PathBuf,
94    #[allow(dead_code)]
95    model: String,
96    max_iterations: usize,
97    memory_window: usize,
98    context: ContextBuilder,
99    sessions: SessionManager,
100    tools: ToolRegistry,
101    subagent_manager: Arc<SubagentManager>,
102    runtime_control_rx: Option<mpsc::UnboundedReceiver<RuntimeControlCommand>>,
103    cancelled_sessions: HashSet<String>,
104    notify_on_soul_change: bool,
105    soul_governance: SoulGovernanceSettings,
106    soul_change_turns: VecDeque<Instant>,
107}
108
109impl AgentLoop {
110    /// Create a new agent loop
111    pub fn new(
112        bus: MessageBus,
113        provider: Arc<dyn LLMProvider>,
114        workspace: PathBuf,
115        model: Option<String>,
116        max_iterations: Option<usize>,
117    ) -> Self {
118        let model = model.unwrap_or_else(|| provider.get_default_model());
119        let mut context = ContextBuilder::with_skills(workspace.clone(), None);
120        context.set_soul_settings(SoulContextSettings::default());
121        let sessions = SessionManager::new(workspace.clone());
122        let tools = ToolRegistry::new();
123
124        let subagent_manager = Arc::new(SubagentManager::new(
125            provider.clone(),
126            workspace.clone(),
127            bus.clone(),
128            Some(model.clone()),
129            NetworkToolConfig::default(),
130            None,
131            false,
132        ));
133
134        Self {
135            bus,
136            provider,
137            workspace,
138            model,
139            max_iterations: max_iterations.unwrap_or(20),
140            memory_window: consolidation::DEFAULT_MEMORY_WINDOW,
141            context,
142            sessions,
143            tools,
144            subagent_manager,
145            runtime_control_rx: None,
146            cancelled_sessions: HashSet::new(),
147            notify_on_soul_change: true,
148            soul_governance: SoulGovernanceSettings::default(),
149            soul_change_turns: VecDeque::new(),
150        }
151    }
152
153    /// Create a new agent loop with tool configuration
154    pub fn with_tools(
155        bus: MessageBus,
156        provider: Arc<dyn LLMProvider>,
157        workspace: PathBuf,
158        model: Option<String>,
159        max_iterations: Option<usize>,
160        tool_config: ToolConfig,
161        runtime_control_rx: Option<mpsc::UnboundedReceiver<RuntimeControlCommand>>,
162    ) -> Self {
163        let model = model.unwrap_or_else(|| provider.get_default_model());
164        let mut context = ContextBuilder::with_skills(workspace.clone(), None);
165        context.set_soul_settings(tool_config.soul_context.clone());
166        let sessions = SessionManager::new(workspace.clone());
167        let mut tools = ToolRegistry::new();
168
169        let subagent_manager = Arc::new(SubagentManager::new(
170            provider.clone(),
171            workspace.clone(),
172            bus.clone(),
173            Some(model.clone()),
174            tool_config.network.clone(),
175            Some(tool_config.exec_timeout),
176            tool_config.restrict_to_workspace,
177        ));
178
179        // Register spawn tool
180        let sm = subagent_manager.clone();
181        tools.register(Arc::new(SpawnTool::new(
182            move |task, label, channel, chat_id| {
183                let sm = sm.clone();
184                async move {
185                    sm.spawn(task, label, channel, chat_id)
186                        .await
187                        .map_err(|e| ToolError::ExecutionFailed(e.to_string()))
188                }
189            },
190        )));
191
192        // Register file system tools
193        let allowed_dir = if tool_config.restrict_to_workspace {
194            Some(workspace.clone())
195        } else {
196            None
197        };
198        tools.register(Arc::new(ReadFileTool::new(allowed_dir.clone())));
199        tools.register(Arc::new(WriteFileTool::new(allowed_dir.clone())));
200        tools.register(Arc::new(EditFileTool::new(allowed_dir.clone())));
201        tools.register(Arc::new(ListDirTool::new(allowed_dir)));
202
203        // Register shell tool
204        tools.register(Arc::new(ExecTool::with_config(
205            tool_config.exec_timeout,
206            Some(workspace.clone()),
207            tool_config.restrict_to_workspace,
208        )));
209
210        // Register web tools
211        Self::register_web_tools(&mut tools, &tool_config.network);
212
213        // Register MCP tools discovered from configured servers
214        for mcp_tool in load_mcp_tools_sync(&tool_config.mcp_servers) {
215            tools.register(mcp_tool);
216        }
217
218        // Register cron tool when scheduling is configured
219        if let Some(cron_service) = tool_config.cron_service.clone() {
220            tools.register(Arc::new(CronTool::new(cron_service)));
221        }
222
223        Self {
224            bus,
225            provider,
226            workspace,
227            model,
228            max_iterations: max_iterations.unwrap_or(20),
229            memory_window: consolidation::DEFAULT_MEMORY_WINDOW,
230            context,
231            sessions,
232            tools,
233            subagent_manager,
234            runtime_control_rx,
235            cancelled_sessions: HashSet::new(),
236            notify_on_soul_change: tool_config.notify_on_soul_change,
237            soul_governance: tool_config.soul_governance,
238            soul_change_turns: VecDeque::new(),
239        }
240    }
241
242    /// Run the agent loop, processing messages from the bus
243    pub async fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
244        info!("Agent loop started");
245
246        // Take the inbound receiver
247        let Some(mut inbound_rx) = self.bus.take_inbound_receiver().await else {
248            error!("Failed to take inbound receiver");
249            return Err("Inbound receiver already taken".into());
250        };
251
252        loop {
253            if let Some(control_rx) = self.runtime_control_rx.as_mut() {
254                tokio::select! {
255                    control = control_rx.recv() => {
256                        match control {
257                            Some(cmd) => self.handle_runtime_control_command(cmd).await,
258                            None => {
259                                info!("Runtime control channel closed");
260                                self.runtime_control_rx = None;
261                            }
262                        }
263                    }
264                    maybe_msg = inbound_rx.recv() => {
265                        match maybe_msg {
266                            Some(msg) => self.handle_inbound(msg).await,
267                            None => {
268                                info!("Message bus closed, stopping agent loop");
269                                break;
270                            }
271                        }
272                    }
273                }
274            } else {
275                match tokio::time::timeout(std::time::Duration::from_secs(1), inbound_rx.recv())
276                    .await
277                {
278                    Ok(Some(msg)) => self.handle_inbound(msg).await,
279                    Ok(None) => {
280                        info!("Message bus closed, stopping agent loop");
281                        break;
282                    }
283                    Err(_) => continue,
284                }
285            }
286        }
287
288        info!("Agent loop stopped");
289        Ok(())
290    }
291
292    async fn handle_inbound(&mut self, msg: InboundMessage) {
293        debug!("Received message from {}:{}", msg.channel, msg.chat_id);
294        let event_msg = msg.clone();
295        match self.process_inbound_message(msg, None).await {
296            Ok(Some(response)) => {
297                if let Err(e) = self.bus.publish_outbound(response) {
298                    error!("Failed to publish response: {}", e);
299                }
300            }
301            Ok(None) => debug!("No response needed"),
302            Err(e) => {
303                let error_message = format!("Failed to process message: {}", e);
304                let ctx = ErrorContext::new("handle_inbound", &error_message)
305                    .with_metadata("channel", event_msg.channel.clone())
306                    .with_metadata("chat_id", event_msg.chat_id.clone())
307                    .with_metadata("sender_id", event_msg.sender_id.clone());
308                error!("{}", ctx.to_detailed_string());
309                self.emit_error_event(&event_msg, None, error_message);
310            }
311        }
312    }
313
314    /// Process a single inbound message
315    pub async fn process_inbound_message(
316        &mut self,
317        msg: InboundMessage,
318        event_tx: Option<&mpsc::UnboundedSender<AgentEvent>>,
319    ) -> Result<Option<OutboundMessage>, Box<dyn std::error::Error>> {
320        let trace_id = Uuid::new_v4().to_string();
321        use tracing::Instrument;
322        let span = tracing::info_span!("AgentSpan", trace_id = %trace_id);
323
324        self.process_inbound_message_inner(msg, event_tx, trace_id)
325            .instrument(span)
326            .await
327    }
328
329    /// Process a message directly (for CLI or testing)
330    pub async fn process_direct(
331        &mut self,
332        content: impl Into<String>,
333        _session_key: impl Into<String>,
334        channel: impl Into<String>,
335        chat_id: impl Into<String>,
336    ) -> Result<String, Box<dyn std::error::Error>> {
337        let content = content.into();
338        let channel = channel.into();
339        let chat_id = chat_id.into();
340
341        let msg = InboundMessage::new(channel, "user", chat_id, content);
342
343        let response = self.process_inbound_message(msg, None).await?;
344        Ok(response
345            .map(|r| {
346                let content = r.content;
347                if let Some(reasoning) = r.reasoning_content {
348                    if !reasoning.is_empty() {
349                        return format!("<think>\n{}\n</think>\n\n{}", reasoning, content);
350                    }
351                }
352                content
353            })
354            .unwrap_or_default())
355    }
356
357    /// Process a message directly and emit streaming events for UI consumers.
358    pub async fn process_direct_stream(
359        &mut self,
360        content: impl Into<String>,
361        _session_key: impl Into<String>,
362        channel: impl Into<String>,
363        chat_id: impl Into<String>,
364        event_tx: mpsc::UnboundedSender<AgentEvent>,
365    ) -> Result<String, Box<dyn std::error::Error>> {
366        let content = content.into();
367        let channel = channel.into();
368        let chat_id = chat_id.into();
369
370        let msg = InboundMessage::new(channel, "user", chat_id, content);
371
372        match self.process_inbound_message(msg, Some(&event_tx)).await {
373            Ok(response) => Ok(response.map(|r| r.content).unwrap_or_default()),
374            Err(err) => {
375                let _ = event_tx.send(AgentEvent::Error {
376                    message: err.to_string(),
377                });
378                Err(err)
379            }
380        }
381    }
382
383    fn is_frequent_soul_change_turn(&mut self) -> bool {
384        let window = Duration::from_secs(self.soul_governance.frequent_change_window_secs.max(1));
385        let now = Instant::now();
386        self.soul_change_turns.push_back(now);
387        while let Some(front) = self.soul_change_turns.front().copied() {
388            if now.duration_since(front) > window {
389                self.soul_change_turns.pop_front();
390            } else {
391                break;
392            }
393        }
394        self.soul_change_turns.len() >= self.soul_governance.frequent_change_threshold.max(1)
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use agent_diva_providers::{
402        LLMResponse, LiteLLMClient, Message, ProviderError, ProviderEventStream, ProviderResult,
403    };
404    use async_trait::async_trait;
405    use futures::stream;
406    use tokio::time::{timeout, Duration};
407
408    struct FailingStreamProvider;
409
410    #[async_trait]
411    impl LLMProvider for FailingStreamProvider {
412        async fn chat(
413            &self,
414            _messages: Vec<Message>,
415            _tools: Option<Vec<serde_json::Value>>,
416            _model: Option<String>,
417            _max_tokens: i32,
418            _temperature: f64,
419        ) -> ProviderResult<LLMResponse> {
420            Err(ProviderError::ApiError(
421                "chat should not be used".to_string(),
422            ))
423        }
424
425        async fn chat_stream(
426            &self,
427            _messages: Vec<Message>,
428            _tools: Option<Vec<serde_json::Value>>,
429            _model: Option<String>,
430            _max_tokens: i32,
431            _temperature: f64,
432        ) -> ProviderResult<ProviderEventStream> {
433            Ok(Box::pin(stream::iter(vec![Err(ProviderError::ApiError(
434                "simulated stream failure".to_string(),
435            ))])))
436        }
437
438        fn get_default_model(&self) -> String {
439            "test-model".to_string()
440        }
441    }
442
443    #[tokio::test]
444    async fn test_agent_loop_creation() {
445        let bus = MessageBus::new();
446        let provider = Arc::new(LiteLLMClient::default());
447        let workspace = PathBuf::from("/tmp/test");
448        let agent = AgentLoop::new(bus, provider, workspace, None, None);
449        assert_eq!(agent.max_iterations, 20);
450    }
451
452    #[tokio::test]
453    async fn test_process_direct() {
454        let bus = MessageBus::new();
455        let provider = Arc::new(LiteLLMClient::default());
456        let temp_dir = tempfile::tempdir().unwrap();
457        let workspace = temp_dir.path().to_path_buf();
458
459        let mut agent = AgentLoop::new(bus, provider, workspace, None, Some(1));
460
461        // This will fail to connect to LLM, but tests the structure
462        let result = agent
463            .process_direct("Hello", "cli:test", "cli", "test")
464            .await;
465
466        // We expect an error since we don't have a real LLM connection
467        assert!(result.is_err());
468    }
469
470    #[test]
471    fn test_soul_governance_defaults_are_non_zero() {
472        let cfg = SoulGovernanceSettings::default();
473        assert!(cfg.frequent_change_window_secs > 0);
474        assert!(cfg.frequent_change_threshold > 0);
475    }
476
477    #[tokio::test]
478    async fn test_handle_inbound_emits_error_event_on_provider_failure() {
479        let bus = MessageBus::new();
480        let mut event_rx = bus.subscribe_events();
481        let provider = Arc::new(FailingStreamProvider);
482        let temp_dir = tempfile::tempdir().unwrap();
483        let workspace = temp_dir.path().to_path_buf();
484
485        let mut agent = AgentLoop::new(bus.clone(), provider, workspace, None, Some(1));
486        let msg = InboundMessage::new("gui", "user", "chat-1", "Hello");
487
488        agent.handle_inbound(msg).await;
489
490        let error_event = timeout(Duration::from_secs(1), async {
491            loop {
492                let bus_event = event_rx.recv().await.unwrap();
493                if let AgentEvent::Error { message } = bus_event.event {
494                    break (bus_event.channel, bus_event.chat_id, message);
495                }
496            }
497        })
498        .await
499        .expect("timed out waiting for error event");
500
501        assert_eq!(error_event.0, "gui");
502        assert_eq!(error_event.1, "chat-1");
503        assert!(error_event.2.contains("simulated stream failure"));
504    }
505}