1use agent_diva_core::bus::{AgentEvent, InboundMessage, MessageBus, OutboundMessage};
7use agent_diva_core::error_context::ErrorContext;
8use agent_diva_core::session::SessionManager;
9#[cfg(feature = "files")]
10use agent_diva_files::FileManager;
11use agent_diva_providers::{LLMProvider, LLMStreamEvent, ProviderEventStream, ToolCallRequest};
12use agent_diva_tooling::ToolRegistry;
13use std::collections::HashSet;
14use std::path::PathBuf;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::mpsc;
18use tracing::{debug, error, info, warn, Instrument};
19
20use crate::internal::context::{NanoContextBuilder, NanoSoulSettings};
21
22#[derive(Clone)]
24pub struct NanoLoopConfig {
25 pub max_iterations: usize,
27 pub memory_window: usize,
29 pub soul_settings: NanoSoulSettings,
31 pub notify_on_soul_change: bool,
33}
34
35impl Default for NanoLoopConfig {
36 fn default() -> Self {
37 Self {
38 max_iterations: 20,
39 memory_window: 10,
40 soul_settings: NanoSoulSettings::default(),
41 notify_on_soul_change: true,
42 }
43 }
44}
45
46pub struct NanoAgentLoop {
48 bus: MessageBus,
49 provider: Arc<dyn LLMProvider>,
50 workspace: PathBuf,
51 model: String,
52 config: NanoLoopConfig,
53 sessions: SessionManager,
54 tools: ToolRegistry,
55 context: NanoContextBuilder,
56 #[cfg(feature = "files")]
57 file_manager: Arc<FileManager>,
58 cancelled_sessions: HashSet<String>,
59 runtime_control_rx: Option<mpsc::UnboundedReceiver<NanoRuntimeControlCommand>>,
60}
61
62pub enum NanoRuntimeControlCommand {
64 CancelSession { chat_id: String },
66 Stop,
68 ReloadTools(ToolRegistry),
70}
71
72impl NanoAgentLoop {
73 pub async fn new(
75 bus: MessageBus,
76 provider: Arc<dyn LLMProvider>,
77 workspace: PathBuf,
78 model: Option<String>,
79 config: NanoLoopConfig,
80 tools: ToolRegistry,
81 #[cfg(feature = "files")] file_manager: Arc<FileManager>,
82 ) -> Result<Self, Box<dyn std::error::Error>> {
83 let model = model.unwrap_or_else(|| provider.get_default_model());
84 let context = NanoContextBuilder::new(workspace.clone())
85 .with_soul_settings(config.soul_settings.clone());
86 let sessions = SessionManager::new(workspace.clone());
87
88 Ok(Self {
89 bus,
90 provider,
91 workspace,
92 model,
93 config,
94 sessions,
95 tools,
96 context,
97 #[cfg(feature = "files")]
98 file_manager,
99 cancelled_sessions: HashSet::new(),
100 runtime_control_rx: None,
101 })
102 }
103
104 pub fn with_runtime_control(
106 mut self,
107 rx: mpsc::UnboundedReceiver<NanoRuntimeControlCommand>,
108 ) -> Self {
109 self.runtime_control_rx = Some(rx);
110 self
111 }
112
113 pub fn tools(&self) -> &ToolRegistry {
115 &self.tools
116 }
117
118 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
120 &mut self.tools
121 }
122
123 #[cfg(feature = "files")]
125 pub fn file_manager(&self) -> Arc<FileManager> {
126 self.file_manager.clone()
127 }
128
129 pub async fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
131 info!("Nano agent loop started");
132
133 let Some(mut inbound_rx) = self.bus.take_inbound_receiver().await else {
134 error!("Failed to take inbound receiver");
135 return Err("Inbound receiver already taken".into());
136 };
137
138 loop {
139 if let Some(control_rx) = self.runtime_control_rx.as_mut() {
140 tokio::select! {
141 control = control_rx.recv() => {
142 match control {
143 Some(cmd) => {
144 if self.handle_runtime_control(cmd) {
145 info!("Nano agent loop stopped via control command");
146 return Ok(());
147 }
148 }
149 None => {
150 info!("Runtime control channel closed");
151 self.runtime_control_rx = None;
152 }
153 }
154 }
155 maybe_msg = inbound_rx.recv() => {
156 match maybe_msg {
157 Some(msg) => self.handle_inbound(msg).await,
158 None => {
159 info!("Message bus closed, stopping nano agent loop");
160 break;
161 }
162 }
163 }
164 }
165 } else {
166 match tokio::time::timeout(Duration::from_secs(1), inbound_rx.recv()).await {
167 Ok(Some(msg)) => self.handle_inbound(msg).await,
168 Ok(None) => {
169 info!("Message bus closed, stopping nano agent loop");
170 break;
171 }
172 Err(_) => continue,
173 }
174 }
175 }
176
177 info!("Nano agent loop stopped");
178 Ok(())
179 }
180
181 fn handle_runtime_control(&mut self, cmd: NanoRuntimeControlCommand) -> bool {
184 match cmd {
185 NanoRuntimeControlCommand::CancelSession { chat_id } => {
186 let chat_id_clone = chat_id.clone();
187 self.cancelled_sessions.insert(chat_id);
188 info!("Session {} marked for cancellation", chat_id_clone);
189 false
190 }
191 NanoRuntimeControlCommand::Stop => true,
192 NanoRuntimeControlCommand::ReloadTools(new_registry) => {
193 self.tools = new_registry;
194 info!("Tools reloaded, now have {} tools", self.tools.len());
195 false
196 }
197 }
198 }
199
200 async fn handle_inbound(&mut self, msg: InboundMessage) {
202 debug!("Received message from {}:{}", msg.channel, msg.chat_id);
203
204 if self.cancelled_sessions.contains(&msg.chat_id) {
205 self.cancelled_sessions.remove(&msg.chat_id);
206 self.emit_event(&msg, AgentEvent::Error {
207 message: "Session was cancelled".to_string(),
208 });
209 return;
210 }
211
212 let event_msg = msg.clone();
213 match self.process_inbound_message(msg).await {
214 Ok(Some(response)) => {
215 if let Err(e) = self.bus.publish_outbound(response) {
216 error!("Failed to publish response: {}", e);
217 }
218 }
219 Ok(None) => debug!("No response needed"),
220 Err(e) => {
221 let error_message = format!("Failed to process message: {}", e);
222 let ctx = ErrorContext::new("handle_inbound", &error_message)
223 .with_metadata("channel", event_msg.channel.clone())
224 .with_metadata("chat_id", event_msg.chat_id.clone())
225 .with_metadata("sender_id", event_msg.sender_id.clone());
226 error!("{}", ctx.to_detailed_string());
227 self.emit_error_event(&event_msg, error_message);
228 }
229 }
230 }
231
232 async fn process_inbound_message(
234 &mut self,
235 msg: InboundMessage,
236 ) -> Result<Option<OutboundMessage>, Box<dyn std::error::Error>> {
237 let trace_id = uuid::Uuid::new_v4().to_string();
238 let span = tracing::info_span!("NanoAgentSpan", trace_id = %trace_id);
239
240 self.process_turn(msg, trace_id).instrument(span).await
241 }
242
243 async fn process_turn(
245 &mut self,
246 msg: InboundMessage,
247 trace_id: String,
248 ) -> Result<Option<OutboundMessage>, Box<dyn std::error::Error>> {
249 let session_key = format!("{}:{}", msg.channel, msg.chat_id);
251 let session = self.sessions.get_or_create(&session_key);
252
253 let messages = self.context.build_messages(
255 &msg,
256 session,
257 &self.tools,
258 self.config.memory_window,
259 )?;
260
261 let tool_defs = self.tools.get_definitions();
263 let tools_param = if tool_defs.is_empty() {
264 None
265 } else {
266 Some(tool_defs)
267 };
268
269 let stream = self.provider.chat_stream(
271 messages,
272 tools_param,
273 Some(self.model.clone()),
274 4096,
275 0.7,
276 ).await?;
277
278 self.process_stream(stream, msg, session_key, trace_id).await
280 }
281
282 async fn process_stream(
284 &mut self,
285 stream: ProviderEventStream,
286 msg: InboundMessage,
287 session_key: String,
288 trace_id: String,
289 ) -> Result<Option<OutboundMessage>, Box<dyn std::error::Error>> {
290 use futures::StreamExt;
291 let mut stream = stream;
292 let mut full_content = String::new();
293 let mut reasoning_content = String::new();
294 let mut tool_calls: Vec<ToolCallRequest> = Vec::new();
295 let mut tool_call_accumulator: std::collections::HashMap<usize, (Option<String>, Option<String>, String)> = std::collections::HashMap::new();
296 let mut iteration_count = 0;
297
298 loop {
299 match tokio::time::timeout(Duration::from_secs(120), stream.next()).await {
300 Ok(Some(event)) => {
301 match event {
302 Ok(LLMStreamEvent::TextDelta(delta)) => {
303 full_content.push_str(&delta);
304 self.emit_event(&msg, AgentEvent::AssistantDelta { text: delta });
305 }
306 Ok(LLMStreamEvent::ReasoningDelta(delta)) => {
307 reasoning_content.push_str(&delta);
308 }
309 Ok(LLMStreamEvent::ToolCallDelta { index, id, name, arguments_delta }) => {
310 let entry = tool_call_accumulator.entry(index).or_insert((None, None, String::new()));
312 if let Some(id) = id {
313 entry.0 = Some(id);
314 }
315 if let Some(name) = name {
316 entry.1 = Some(name);
317 }
318 if let Some(args) = arguments_delta {
319 entry.2.push_str(&args);
320 }
321 }
322 Ok(LLMStreamEvent::Completed(response)) => {
323 if tool_call_accumulator.is_empty() && !response.tool_calls.is_empty() {
325 tool_calls = response.tool_calls.clone();
326 } else {
327 for (_, (id, name, args)) in tool_call_accumulator.drain() {
329 if let (Some(id), Some(name)) = (id, name) {
330 let arguments = serde_json::from_str(&args)
331 .unwrap_or(std::collections::HashMap::new());
332 tool_calls.push(ToolCallRequest {
333 id,
334 call_type: "function".to_string(),
335 name,
336 arguments,
337 });
338 }
339 }
340 }
341
342 if !tool_calls.is_empty() && iteration_count < self.config.max_iterations {
344 iteration_count += 1;
345
346 let tool_results = self.execute_tool_calls(&tool_calls, &msg).await;
348
349 tool_calls.clear();
352 tool_call_accumulator.clear();
353 continue;
354 }
355
356 let final_content = response.content.clone().unwrap_or(full_content.clone());
358
359 self.emit_event(&msg, AgentEvent::FinalResponse {
360 content: final_content.clone(),
361 });
362
363 if let Some(session) = self.sessions.get(&session_key) {
365 let mut session_clone = session.clone();
367 session_clone.add_message("user", msg.content.clone());
368 session_clone.add_message("assistant", final_content.clone());
369 self.sessions.save(&session_clone)?;
370 }
371
372 let mut outbound = OutboundMessage::new(
373 &msg.channel,
374 &msg.chat_id,
375 final_content,
376 );
377 if !reasoning_content.is_empty() {
378 outbound.reasoning_content = Some(reasoning_content);
379 }
380 return Ok(Some(outbound));
381 }
382 Err(e) => {
383 self.emit_error_event(&msg, e.to_string());
384 return Err(e.into());
385 }
386 }
387 }
388 Ok(None) => break,
389 Err(_) => {
390 warn!("Stream timeout for trace {}", trace_id);
391 self.emit_error_event(&msg, "Stream timeout".to_string());
392 return Err("Stream timeout".into());
393 }
394 }
395 }
396
397 Ok(None)
398 }
399
400 async fn execute_tool_calls(
402 &mut self,
403 tool_calls: &[ToolCallRequest],
404 msg: &InboundMessage,
405 ) -> Vec<(String, String)> {
406 let mut results = Vec::new();
407
408 for tc in tool_calls {
409 let args_preview = serde_json::to_string(&tc.arguments)
411 .unwrap_or_default()
412 .chars()
413 .take(100)
414 .collect();
415
416 self.emit_event(&msg, AgentEvent::ToolCallStarted {
417 name: tc.name.clone(),
418 args_preview,
419 call_id: tc.id.clone(),
420 });
421
422 let params = serde_json::to_value(&tc.arguments).unwrap_or(serde_json::Value::Null);
424
425 let result = self.tools.execute(&tc.name, params).await;
426 let is_error = result.starts_with("Error");
427
428 self.emit_event(&msg, AgentEvent::ToolCallFinished {
429 name: tc.name.clone(),
430 result: result.clone(),
431 is_error,
432 call_id: tc.id.clone(),
433 });
434
435 results.push((tc.id.clone(), result));
436 }
437
438 results
439 }
440
441 fn emit_event(&self, msg: &InboundMessage, event: AgentEvent) {
443 if let Err(e) = self.bus.publish_event(&msg.channel, &msg.chat_id, event) {
444 warn!("Failed to emit event: {}", e);
445 }
446 }
447
448 fn emit_error_event(&self, msg: &InboundMessage, message: String) {
450 self.emit_event(msg, AgentEvent::Error { message });
451 }
452
453 pub async fn process_direct(
455 &mut self,
456 content: impl Into<String>,
457 channel: impl Into<String>,
458 chat_id: impl Into<String>,
459 ) -> Result<String, Box<dyn std::error::Error>> {
460 let content = content.into();
461 let channel = channel.into();
462 let chat_id = chat_id.into();
463
464 let msg = InboundMessage::new(channel, "user", chat_id, content);
465
466 let response = self.process_inbound_message(msg).await?;
467 Ok(response
468 .map(|r| {
469 let content = r.content;
470 if let Some(reasoning) = r.reasoning_content {
471 if !reasoning.is_empty() {
472 return format!("{}\n\n{}", reasoning, content);
473 }
474 }
475 content
476 })
477 .unwrap_or_default())
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use agent_diva_providers::{LLMResponse, Message, ProviderResult, LLMStreamEvent};
485 use async_trait::async_trait;
486 use futures::stream;
487
488 struct MockProvider;
489
490 #[async_trait]
491 impl LLMProvider for MockProvider {
492 async fn chat(
493 &self,
494 _messages: Vec<Message>,
495 _tools: Option<Vec<serde_json::Value>>,
496 _model: Option<String>,
497 _max_tokens: i32,
498 _temperature: f64,
499 ) -> ProviderResult<LLMResponse> {
500 Ok(LLMResponse {
501 content: Some("mock response".to_string()),
502 reasoning_content: None,
503 tool_calls: Vec::new(),
504 finish_reason: "stop".to_string(),
505 usage: std::collections::HashMap::new(),
506 })
507 }
508
509 async fn chat_stream(
510 &self,
511 _messages: Vec<Message>,
512 _tools: Option<Vec<serde_json::Value>>,
513 _model: Option<String>,
514 _max_tokens: i32,
515 _temperature: f64,
516 ) -> ProviderResult<ProviderEventStream> {
517 Ok(Box::pin(stream::iter(vec![
518 Ok(LLMStreamEvent::TextDelta("mock".to_string())),
519 Ok(LLMStreamEvent::Completed(LLMResponse {
520 content: Some("mock".to_string()),
521 reasoning_content: None,
522 tool_calls: Vec::new(),
523 finish_reason: "stop".to_string(),
524 usage: std::collections::HashMap::new(),
525 })),
526 ])))
527 }
528
529 fn get_default_model(&self) -> String {
530 "mock-model".to_string()
531 }
532 }
533
534 #[tokio::test]
535 async fn test_nano_agent_loop_creation() {
536 let bus = MessageBus::new();
537 let provider = Arc::new(MockProvider);
538 let workspace = PathBuf::from("/tmp/test");
539 let tools = ToolRegistry::new();
540
541 #[cfg(feature = "files")]
542 {
543 let storage_path = workspace.join(".agent-diva/files");
544 let file_config = agent_diva_files::FileConfig::with_path(&storage_path);
545 let file_manager = Arc::new(FileManager::new(file_config).await.unwrap());
546
547 let agent = NanoAgentLoop::new(
548 bus,
549 provider,
550 workspace,
551 None,
552 NanoLoopConfig::default(),
553 tools,
554 file_manager,
555 ).await;
556
557 assert!(agent.is_ok());
558 let agent = agent.unwrap();
559 assert_eq!(agent.config.max_iterations, 20);
560 }
561
562 #[cfg(not(feature = "files"))]
563 {
564 let agent = NanoAgentLoop::new(
565 bus,
566 provider,
567 workspace,
568 None,
569 NanoLoopConfig::default(),
570 tools,
571 ).await;
572
573 assert!(agent.is_ok());
574 let agent = agent.unwrap();
575 assert_eq!(agent.config.max_iterations, 20);
576 }
577 }
578}