1use agent_diva_core::bus::{AgentEvent, InboundMessage, MessageBus, OutboundMessage};
4use agent_diva_core::config::MCPServerConfig;
5use agent_diva_core::cron::CronService;
6use agent_diva_core::error_context::ErrorContext;
7use agent_diva_core::session::SessionManager;
8use agent_diva_providers::LLMProvider;
9use agent_diva_tools::{
10 load_mcp_tools_sync, CronTool, EditFileTool, ExecTool, ListDirTool, ReadFileTool, SpawnTool,
11 ToolError, ToolRegistry, WriteFileTool,
12};
13use std::collections::{HashMap, HashSet, VecDeque};
14use std::path::PathBuf;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::sync::mpsc;
18use tracing::{debug, error, info};
19use uuid::Uuid;
20
21use crate::consolidation;
22use crate::context::{ContextBuilder, SoulContextSettings};
23use crate::runtime_control::RuntimeControlCommand;
24use crate::subagent::SubagentManager;
25use crate::tool_config::network::NetworkToolConfig;
26
27mod loop_runtime_control;
28mod loop_tools;
29mod loop_turn;
30
31#[derive(Clone)]
33pub struct ToolConfig {
34 pub network: NetworkToolConfig,
36 pub exec_timeout: u64,
38 pub restrict_to_workspace: bool,
40 pub mcp_servers: HashMap<String, MCPServerConfig>,
42 pub cron_service: Option<Arc<CronService>>,
44 pub soul_context: SoulContextSettings,
46 pub notify_on_soul_change: bool,
48 pub soul_governance: SoulGovernanceSettings,
50}
51
52impl Default for ToolConfig {
53 fn default() -> Self {
54 Self {
55 network: NetworkToolConfig::default(),
56 exec_timeout: 60,
57 restrict_to_workspace: false,
58 mcp_servers: HashMap::new(),
59 cron_service: None,
60 soul_context: SoulContextSettings::default(),
61 notify_on_soul_change: true,
62 soul_governance: SoulGovernanceSettings::default(),
63 }
64 }
65}
66
67#[derive(Clone, Debug)]
69pub struct SoulGovernanceSettings {
70 pub frequent_change_window_secs: u64,
72 pub frequent_change_threshold: usize,
74 pub boundary_confirmation_hint: bool,
76}
77
78impl Default for SoulGovernanceSettings {
79 fn default() -> Self {
80 Self {
81 frequent_change_window_secs: 600,
82 frequent_change_threshold: 3,
83 boundary_confirmation_hint: true,
84 }
85 }
86}
87
88pub struct AgentLoop {
90 bus: MessageBus,
91 provider: Arc<dyn LLMProvider>,
92 #[allow(dead_code)]
93 workspace: PathBuf,
94 #[allow(dead_code)]
95 model: String,
96 max_iterations: usize,
97 memory_window: usize,
98 context: ContextBuilder,
99 sessions: SessionManager,
100 tools: ToolRegistry,
101 subagent_manager: Arc<SubagentManager>,
102 runtime_control_rx: Option<mpsc::UnboundedReceiver<RuntimeControlCommand>>,
103 cancelled_sessions: HashSet<String>,
104 notify_on_soul_change: bool,
105 soul_governance: SoulGovernanceSettings,
106 soul_change_turns: VecDeque<Instant>,
107}
108
109impl AgentLoop {
110 pub fn new(
112 bus: MessageBus,
113 provider: Arc<dyn LLMProvider>,
114 workspace: PathBuf,
115 model: Option<String>,
116 max_iterations: Option<usize>,
117 ) -> Self {
118 let model = model.unwrap_or_else(|| provider.get_default_model());
119 let mut context = ContextBuilder::with_skills(workspace.clone(), None);
120 context.set_soul_settings(SoulContextSettings::default());
121 let sessions = SessionManager::new(workspace.clone());
122 let tools = ToolRegistry::new();
123
124 let subagent_manager = Arc::new(SubagentManager::new(
125 provider.clone(),
126 workspace.clone(),
127 bus.clone(),
128 Some(model.clone()),
129 NetworkToolConfig::default(),
130 None,
131 false,
132 ));
133
134 Self {
135 bus,
136 provider,
137 workspace,
138 model,
139 max_iterations: max_iterations.unwrap_or(20),
140 memory_window: consolidation::DEFAULT_MEMORY_WINDOW,
141 context,
142 sessions,
143 tools,
144 subagent_manager,
145 runtime_control_rx: None,
146 cancelled_sessions: HashSet::new(),
147 notify_on_soul_change: true,
148 soul_governance: SoulGovernanceSettings::default(),
149 soul_change_turns: VecDeque::new(),
150 }
151 }
152
153 pub fn with_tools(
155 bus: MessageBus,
156 provider: Arc<dyn LLMProvider>,
157 workspace: PathBuf,
158 model: Option<String>,
159 max_iterations: Option<usize>,
160 tool_config: ToolConfig,
161 runtime_control_rx: Option<mpsc::UnboundedReceiver<RuntimeControlCommand>>,
162 ) -> Self {
163 let model = model.unwrap_or_else(|| provider.get_default_model());
164 let mut context = ContextBuilder::with_skills(workspace.clone(), None);
165 context.set_soul_settings(tool_config.soul_context.clone());
166 let sessions = SessionManager::new(workspace.clone());
167 let mut tools = ToolRegistry::new();
168
169 let subagent_manager = Arc::new(SubagentManager::new(
170 provider.clone(),
171 workspace.clone(),
172 bus.clone(),
173 Some(model.clone()),
174 tool_config.network.clone(),
175 Some(tool_config.exec_timeout),
176 tool_config.restrict_to_workspace,
177 ));
178
179 let sm = subagent_manager.clone();
181 tools.register(Arc::new(SpawnTool::new(
182 move |task, label, channel, chat_id| {
183 let sm = sm.clone();
184 async move {
185 sm.spawn(task, label, channel, chat_id)
186 .await
187 .map_err(|e| ToolError::ExecutionFailed(e.to_string()))
188 }
189 },
190 )));
191
192 let allowed_dir = if tool_config.restrict_to_workspace {
194 Some(workspace.clone())
195 } else {
196 None
197 };
198 tools.register(Arc::new(ReadFileTool::new(allowed_dir.clone())));
199 tools.register(Arc::new(WriteFileTool::new(allowed_dir.clone())));
200 tools.register(Arc::new(EditFileTool::new(allowed_dir.clone())));
201 tools.register(Arc::new(ListDirTool::new(allowed_dir)));
202
203 tools.register(Arc::new(ExecTool::with_config(
205 tool_config.exec_timeout,
206 Some(workspace.clone()),
207 tool_config.restrict_to_workspace,
208 )));
209
210 Self::register_web_tools(&mut tools, &tool_config.network);
212
213 for mcp_tool in load_mcp_tools_sync(&tool_config.mcp_servers) {
215 tools.register(mcp_tool);
216 }
217
218 if let Some(cron_service) = tool_config.cron_service.clone() {
220 tools.register(Arc::new(CronTool::new(cron_service)));
221 }
222
223 Self {
224 bus,
225 provider,
226 workspace,
227 model,
228 max_iterations: max_iterations.unwrap_or(20),
229 memory_window: consolidation::DEFAULT_MEMORY_WINDOW,
230 context,
231 sessions,
232 tools,
233 subagent_manager,
234 runtime_control_rx,
235 cancelled_sessions: HashSet::new(),
236 notify_on_soul_change: tool_config.notify_on_soul_change,
237 soul_governance: tool_config.soul_governance,
238 soul_change_turns: VecDeque::new(),
239 }
240 }
241
242 pub async fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
244 info!("Agent loop started");
245
246 let Some(mut inbound_rx) = self.bus.take_inbound_receiver().await else {
248 error!("Failed to take inbound receiver");
249 return Err("Inbound receiver already taken".into());
250 };
251
252 loop {
253 if let Some(control_rx) = self.runtime_control_rx.as_mut() {
254 tokio::select! {
255 control = control_rx.recv() => {
256 match control {
257 Some(cmd) => self.handle_runtime_control_command(cmd).await,
258 None => {
259 info!("Runtime control channel closed");
260 self.runtime_control_rx = None;
261 }
262 }
263 }
264 maybe_msg = inbound_rx.recv() => {
265 match maybe_msg {
266 Some(msg) => self.handle_inbound(msg).await,
267 None => {
268 info!("Message bus closed, stopping agent loop");
269 break;
270 }
271 }
272 }
273 }
274 } else {
275 match tokio::time::timeout(std::time::Duration::from_secs(1), inbound_rx.recv())
276 .await
277 {
278 Ok(Some(msg)) => self.handle_inbound(msg).await,
279 Ok(None) => {
280 info!("Message bus closed, stopping agent loop");
281 break;
282 }
283 Err(_) => continue,
284 }
285 }
286 }
287
288 info!("Agent loop stopped");
289 Ok(())
290 }
291
292 async fn handle_inbound(&mut self, msg: InboundMessage) {
293 debug!("Received message from {}:{}", msg.channel, msg.chat_id);
294 let event_msg = msg.clone();
295 match self.process_inbound_message(msg, None).await {
296 Ok(Some(response)) => {
297 if let Err(e) = self.bus.publish_outbound(response) {
298 error!("Failed to publish response: {}", e);
299 }
300 }
301 Ok(None) => debug!("No response needed"),
302 Err(e) => {
303 let error_message = format!("Failed to process message: {}", e);
304 let ctx = ErrorContext::new("handle_inbound", &error_message)
305 .with_metadata("channel", event_msg.channel.clone())
306 .with_metadata("chat_id", event_msg.chat_id.clone())
307 .with_metadata("sender_id", event_msg.sender_id.clone());
308 error!("{}", ctx.to_detailed_string());
309 self.emit_error_event(&event_msg, None, error_message);
310 }
311 }
312 }
313
314 pub async fn process_inbound_message(
316 &mut self,
317 msg: InboundMessage,
318 event_tx: Option<&mpsc::UnboundedSender<AgentEvent>>,
319 ) -> Result<Option<OutboundMessage>, Box<dyn std::error::Error>> {
320 let trace_id = Uuid::new_v4().to_string();
321 use tracing::Instrument;
322 let span = tracing::info_span!("AgentSpan", trace_id = %trace_id);
323
324 self.process_inbound_message_inner(msg, event_tx, trace_id)
325 .instrument(span)
326 .await
327 }
328
329 pub async fn process_direct(
331 &mut self,
332 content: impl Into<String>,
333 _session_key: impl Into<String>,
334 channel: impl Into<String>,
335 chat_id: impl Into<String>,
336 ) -> Result<String, Box<dyn std::error::Error>> {
337 let content = content.into();
338 let channel = channel.into();
339 let chat_id = chat_id.into();
340
341 let msg = InboundMessage::new(channel, "user", chat_id, content);
342
343 let response = self.process_inbound_message(msg, None).await?;
344 Ok(response
345 .map(|r| {
346 let content = r.content;
347 if let Some(reasoning) = r.reasoning_content {
348 if !reasoning.is_empty() {
349 return format!("<think>\n{}\n</think>\n\n{}", reasoning, content);
350 }
351 }
352 content
353 })
354 .unwrap_or_default())
355 }
356
357 pub async fn process_direct_stream(
359 &mut self,
360 content: impl Into<String>,
361 _session_key: impl Into<String>,
362 channel: impl Into<String>,
363 chat_id: impl Into<String>,
364 event_tx: mpsc::UnboundedSender<AgentEvent>,
365 ) -> Result<String, Box<dyn std::error::Error>> {
366 let content = content.into();
367 let channel = channel.into();
368 let chat_id = chat_id.into();
369
370 let msg = InboundMessage::new(channel, "user", chat_id, content);
371
372 match self.process_inbound_message(msg, Some(&event_tx)).await {
373 Ok(response) => Ok(response.map(|r| r.content).unwrap_or_default()),
374 Err(err) => {
375 let _ = event_tx.send(AgentEvent::Error {
376 message: err.to_string(),
377 });
378 Err(err)
379 }
380 }
381 }
382
383 fn is_frequent_soul_change_turn(&mut self) -> bool {
384 let window = Duration::from_secs(self.soul_governance.frequent_change_window_secs.max(1));
385 let now = Instant::now();
386 self.soul_change_turns.push_back(now);
387 while let Some(front) = self.soul_change_turns.front().copied() {
388 if now.duration_since(front) > window {
389 self.soul_change_turns.pop_front();
390 } else {
391 break;
392 }
393 }
394 self.soul_change_turns.len() >= self.soul_governance.frequent_change_threshold.max(1)
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use agent_diva_providers::{
402 LLMResponse, LiteLLMClient, Message, ProviderError, ProviderEventStream, ProviderResult,
403 };
404 use async_trait::async_trait;
405 use futures::stream;
406 use tokio::time::{timeout, Duration};
407
408 struct FailingStreamProvider;
409
410 #[async_trait]
411 impl LLMProvider for FailingStreamProvider {
412 async fn chat(
413 &self,
414 _messages: Vec<Message>,
415 _tools: Option<Vec<serde_json::Value>>,
416 _model: Option<String>,
417 _max_tokens: i32,
418 _temperature: f64,
419 ) -> ProviderResult<LLMResponse> {
420 Err(ProviderError::ApiError(
421 "chat should not be used".to_string(),
422 ))
423 }
424
425 async fn chat_stream(
426 &self,
427 _messages: Vec<Message>,
428 _tools: Option<Vec<serde_json::Value>>,
429 _model: Option<String>,
430 _max_tokens: i32,
431 _temperature: f64,
432 ) -> ProviderResult<ProviderEventStream> {
433 Ok(Box::pin(stream::iter(vec![Err(ProviderError::ApiError(
434 "simulated stream failure".to_string(),
435 ))])))
436 }
437
438 fn get_default_model(&self) -> String {
439 "test-model".to_string()
440 }
441 }
442
443 #[tokio::test]
444 async fn test_agent_loop_creation() {
445 let bus = MessageBus::new();
446 let provider = Arc::new(LiteLLMClient::default());
447 let workspace = PathBuf::from("/tmp/test");
448 let agent = AgentLoop::new(bus, provider, workspace, None, None);
449 assert_eq!(agent.max_iterations, 20);
450 }
451
452 #[tokio::test]
453 async fn test_process_direct() {
454 let bus = MessageBus::new();
455 let provider = Arc::new(LiteLLMClient::default());
456 let temp_dir = tempfile::tempdir().unwrap();
457 let workspace = temp_dir.path().to_path_buf();
458
459 let mut agent = AgentLoop::new(bus, provider, workspace, None, Some(1));
460
461 let result = agent
463 .process_direct("Hello", "cli:test", "cli", "test")
464 .await;
465
466 assert!(result.is_err());
468 }
469
470 #[test]
471 fn test_soul_governance_defaults_are_non_zero() {
472 let cfg = SoulGovernanceSettings::default();
473 assert!(cfg.frequent_change_window_secs > 0);
474 assert!(cfg.frequent_change_threshold > 0);
475 }
476
477 #[tokio::test]
478 async fn test_handle_inbound_emits_error_event_on_provider_failure() {
479 let bus = MessageBus::new();
480 let mut event_rx = bus.subscribe_events();
481 let provider = Arc::new(FailingStreamProvider);
482 let temp_dir = tempfile::tempdir().unwrap();
483 let workspace = temp_dir.path().to_path_buf();
484
485 let mut agent = AgentLoop::new(bus.clone(), provider, workspace, None, Some(1));
486 let msg = InboundMessage::new("gui", "user", "chat-1", "Hello");
487
488 agent.handle_inbound(msg).await;
489
490 let error_event = timeout(Duration::from_secs(1), async {
491 loop {
492 let bus_event = event_rx.recv().await.unwrap();
493 if let AgentEvent::Error { message } = bus_event.event {
494 break (bus_event.channel, bus_event.chat_id, message);
495 }
496 }
497 })
498 .await
499 .expect("timed out waiting for error event");
500
501 assert_eq!(error_event.0, "gui");
502 assert_eq!(error_event.1, "chat-1");
503 assert!(error_event.2.contains("simulated stream failure"));
504 }
505}