Skip to main content

agent_diva_nano/
nano_loop.rs

1//! Lightweight agent loop implementation for agent-diva-nano.
2//!
3//! This module provides a simplified agent loop that allows complete control
4//! over tool registration, enabling fine-grained assembly of available tools.
5
6use agent_diva_core::bus::{AgentEvent, InboundMessage, MessageBus, OutboundMessage};
7use agent_diva_core::error_context::ErrorContext;
8use agent_diva_core::session::SessionManager;
9#[cfg(feature = "files")]
10use agent_diva_files::FileManager;
11use agent_diva_providers::{LLMProvider, LLMStreamEvent, ProviderEventStream, ToolCallRequest};
12use agent_diva_tooling::ToolRegistry;
13use std::collections::HashSet;
14use std::path::PathBuf;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::mpsc;
18use tracing::{debug, error, info, warn, Instrument};
19
20use crate::internal::context::{NanoContextBuilder, NanoSoulSettings};
21
22/// Configuration for the nano agent loop.
23#[derive(Clone)]
24pub struct NanoLoopConfig {
25    /// Maximum tool-call iterations per turn.
26    pub max_iterations: usize,
27    /// Memory window for context consolidation.
28    pub memory_window: usize,
29    /// Soul context settings.
30    pub soul_settings: NanoSoulSettings,
31    /// Notify on soul changes.
32    pub notify_on_soul_change: bool,
33}
34
35impl Default for NanoLoopConfig {
36    fn default() -> Self {
37        Self {
38            max_iterations: 20,
39            memory_window: 10,
40            soul_settings: NanoSoulSettings::default(),
41            notify_on_soul_change: true,
42        }
43    }
44}
45
46/// Lightweight agent loop with customizable tool registry.
47pub struct NanoAgentLoop {
48    bus: MessageBus,
49    provider: Arc<dyn LLMProvider>,
50    workspace: PathBuf,
51    model: String,
52    config: NanoLoopConfig,
53    sessions: SessionManager,
54    tools: ToolRegistry,
55    context: NanoContextBuilder,
56    #[cfg(feature = "files")]
57    file_manager: Arc<FileManager>,
58    cancelled_sessions: HashSet<String>,
59    runtime_control_rx: Option<mpsc::UnboundedReceiver<NanoRuntimeControlCommand>>,
60}
61
62/// Runtime control commands for nano agent loop.
63pub enum NanoRuntimeControlCommand {
64    /// Cancel a specific session.
65    CancelSession { chat_id: String },
66    /// Stop the agent loop entirely.
67    Stop,
68    /// Reload tools from assembly.
69    ReloadTools(ToolRegistry),
70}
71
72impl NanoAgentLoop {
73    /// Create a new nano agent loop with a pre-built tool registry.
74    pub async fn new(
75        bus: MessageBus,
76        provider: Arc<dyn LLMProvider>,
77        workspace: PathBuf,
78        model: Option<String>,
79        config: NanoLoopConfig,
80        tools: ToolRegistry,
81        #[cfg(feature = "files")] file_manager: Arc<FileManager>,
82    ) -> Result<Self, Box<dyn std::error::Error>> {
83        let model = model.unwrap_or_else(|| provider.get_default_model());
84        let context = NanoContextBuilder::new(workspace.clone())
85            .with_soul_settings(config.soul_settings.clone());
86        let sessions = SessionManager::new(workspace.clone());
87
88        Ok(Self {
89            bus,
90            provider,
91            workspace,
92            model,
93            config,
94            sessions,
95            tools,
96            context,
97            #[cfg(feature = "files")]
98            file_manager,
99            cancelled_sessions: HashSet::new(),
100            runtime_control_rx: None,
101        })
102    }
103
104    /// Create with runtime control channel.
105    pub fn with_runtime_control(
106        mut self,
107        rx: mpsc::UnboundedReceiver<NanoRuntimeControlCommand>,
108    ) -> Self {
109        self.runtime_control_rx = Some(rx);
110        self
111    }
112
113    /// Get the tool registry.
114    pub fn tools(&self) -> &ToolRegistry {
115        &self.tools
116    }
117
118    /// Get mutable tool registry for dynamic modification.
119    pub fn tools_mut(&mut self) -> &mut ToolRegistry {
120        &mut self.tools
121    }
122
123    /// Get the file manager.
124    #[cfg(feature = "files")]
125    pub fn file_manager(&self) -> Arc<FileManager> {
126        self.file_manager.clone()
127    }
128
129    /// Run the agent loop, processing messages from the bus.
130    pub async fn run(&mut self) -> Result<(), Box<dyn std::error::Error>> {
131        info!("Nano agent loop started");
132
133        let Some(mut inbound_rx) = self.bus.take_inbound_receiver().await else {
134            error!("Failed to take inbound receiver");
135            return Err("Inbound receiver already taken".into());
136        };
137
138        loop {
139            if let Some(control_rx) = self.runtime_control_rx.as_mut() {
140                tokio::select! {
141                    control = control_rx.recv() => {
142                        match control {
143                            Some(cmd) => {
144                                if self.handle_runtime_control(cmd) {
145                                    info!("Nano agent loop stopped via control command");
146                                    return Ok(());
147                                }
148                            }
149                            None => {
150                                info!("Runtime control channel closed");
151                                self.runtime_control_rx = None;
152                            }
153                        }
154                    }
155                    maybe_msg = inbound_rx.recv() => {
156                        match maybe_msg {
157                            Some(msg) => self.handle_inbound(msg).await,
158                            None => {
159                                info!("Message bus closed, stopping nano agent loop");
160                                break;
161                            }
162                        }
163                    }
164                }
165            } else {
166                match tokio::time::timeout(Duration::from_secs(1), inbound_rx.recv()).await {
167                    Ok(Some(msg)) => self.handle_inbound(msg).await,
168                    Ok(None) => {
169                        info!("Message bus closed, stopping nano agent loop");
170                        break;
171                    }
172                    Err(_) => continue,
173                }
174            }
175        }
176
177        info!("Nano agent loop stopped");
178        Ok(())
179    }
180
181    /// Handle runtime control command.
182    /// Returns true if the loop should stop.
183    fn handle_runtime_control(&mut self, cmd: NanoRuntimeControlCommand) -> bool {
184        match cmd {
185            NanoRuntimeControlCommand::CancelSession { chat_id } => {
186                let chat_id_clone = chat_id.clone();
187                self.cancelled_sessions.insert(chat_id);
188                info!("Session {} marked for cancellation", chat_id_clone);
189                false
190            }
191            NanoRuntimeControlCommand::Stop => true,
192            NanoRuntimeControlCommand::ReloadTools(new_registry) => {
193                self.tools = new_registry;
194                info!("Tools reloaded, now have {} tools", self.tools.len());
195                false
196            }
197        }
198    }
199
200    /// Handle an inbound message.
201    async fn handle_inbound(&mut self, msg: InboundMessage) {
202        debug!("Received message from {}:{}", msg.channel, msg.chat_id);
203        
204        if self.cancelled_sessions.contains(&msg.chat_id) {
205            self.cancelled_sessions.remove(&msg.chat_id);
206            self.emit_event(&msg, AgentEvent::Error {
207                message: "Session was cancelled".to_string(),
208            });
209            return;
210        }
211
212        let event_msg = msg.clone();
213        match self.process_inbound_message(msg).await {
214            Ok(Some(response)) => {
215                if let Err(e) = self.bus.publish_outbound(response) {
216                    error!("Failed to publish response: {}", e);
217                }
218            }
219            Ok(None) => debug!("No response needed"),
220            Err(e) => {
221                let error_message = format!("Failed to process message: {}", e);
222                let ctx = ErrorContext::new("handle_inbound", &error_message)
223                    .with_metadata("channel", event_msg.channel.clone())
224                    .with_metadata("chat_id", event_msg.chat_id.clone())
225                    .with_metadata("sender_id", event_msg.sender_id.clone());
226                error!("{}", ctx.to_detailed_string());
227                self.emit_error_event(&event_msg, error_message);
228            }
229        }
230    }
231
232    /// Process a single inbound message.
233    async fn process_inbound_message(
234        &mut self,
235        msg: InboundMessage,
236    ) -> Result<Option<OutboundMessage>, Box<dyn std::error::Error>> {
237        let trace_id = uuid::Uuid::new_v4().to_string();
238        let span = tracing::info_span!("NanoAgentSpan", trace_id = %trace_id);
239
240        self.process_turn(msg, trace_id).instrument(span).await
241    }
242
243    /// Process a turn of conversation.
244    async fn process_turn(
245        &mut self,
246        msg: InboundMessage,
247        trace_id: String,
248    ) -> Result<Option<OutboundMessage>, Box<dyn std::error::Error>> {
249        // Build context for the turn
250        let session_key = format!("{}:{}", msg.channel, msg.chat_id);
251        let session = self.sessions.get_or_create(&session_key);
252
253        // Build messages for LLM
254        let messages = self.context.build_messages(
255            &msg,
256            session,
257            &self.tools,
258            self.config.memory_window,
259        )?;
260
261        // Get tool definitions
262        let tool_defs = self.tools.get_definitions();
263        let tools_param = if tool_defs.is_empty() {
264            None
265        } else {
266            Some(tool_defs)
267        };
268
269        // Stream from provider
270        let stream = self.provider.chat_stream(
271            messages,
272            tools_param,
273            Some(self.model.clone()),
274            4096,
275            0.7,
276        ).await?;
277
278        // Process the stream and handle tool calls
279        self.process_stream(stream, msg, session_key, trace_id).await
280    }
281
282    /// Process streaming response from provider.
283    async fn process_stream(
284        &mut self,
285        stream: ProviderEventStream,
286        msg: InboundMessage,
287        session_key: String,
288        trace_id: String,
289    ) -> Result<Option<OutboundMessage>, Box<dyn std::error::Error>> {
290        use futures::StreamExt;
291        let mut stream = stream;
292        let mut full_content = String::new();
293        let mut reasoning_content = String::new();
294        let mut tool_calls: Vec<ToolCallRequest> = Vec::new();
295        let mut tool_call_accumulator: std::collections::HashMap<usize, (Option<String>, Option<String>, String)> = std::collections::HashMap::new();
296        let mut iteration_count = 0;
297
298        loop {
299            match tokio::time::timeout(Duration::from_secs(120), stream.next()).await {
300                Ok(Some(event)) => {
301                    match event {
302                        Ok(LLMStreamEvent::TextDelta(delta)) => {
303                            full_content.push_str(&delta);
304                            self.emit_event(&msg, AgentEvent::AssistantDelta { text: delta });
305                        }
306                        Ok(LLMStreamEvent::ReasoningDelta(delta)) => {
307                            reasoning_content.push_str(&delta);
308                        }
309                        Ok(LLMStreamEvent::ToolCallDelta { index, id, name, arguments_delta }) => {
310                            // Accumulate tool call deltas by index
311                            let entry = tool_call_accumulator.entry(index).or_insert((None, None, String::new()));
312                            if let Some(id) = id {
313                                entry.0 = Some(id);
314                            }
315                            if let Some(name) = name {
316                                entry.1 = Some(name);
317                            }
318                            if let Some(args) = arguments_delta {
319                                entry.2.push_str(&args);
320                            }
321                        }
322                        Ok(LLMStreamEvent::Completed(response)) => {
323                            // Build tool calls from accumulated deltas or response
324                            if tool_call_accumulator.is_empty() && !response.tool_calls.is_empty() {
325                                tool_calls = response.tool_calls.clone();
326                            } else {
327                                // Build from accumulator
328                                for (_, (id, name, args)) in tool_call_accumulator.drain() {
329                                    if let (Some(id), Some(name)) = (id, name) {
330                                        let arguments = serde_json::from_str(&args)
331                                            .unwrap_or(std::collections::HashMap::new());
332                                        tool_calls.push(ToolCallRequest {
333                                            id,
334                                            call_type: "function".to_string(),
335                                            name,
336                                            arguments,
337                                        });
338                                    }
339                                }
340                            }
341
342                            // Check if we have tool calls to execute
343                            if !tool_calls.is_empty() && iteration_count < self.config.max_iterations {
344                                iteration_count += 1;
345                                
346                                // Execute tool calls
347                                let tool_results = self.execute_tool_calls(&tool_calls, &msg).await;
348                                
349                                // Build next request with tool results
350                                // (This is simplified - full implementation would need proper context management)
351                                tool_calls.clear();
352                                tool_call_accumulator.clear();
353                                continue;
354                            }
355                            
356                            // Final response - use response content if available
357                            let final_content = response.content.clone().unwrap_or(full_content.clone());
358                            
359                            self.emit_event(&msg, AgentEvent::FinalResponse {
360                                content: final_content.clone(),
361                            });
362
363                            // Update session
364                            if let Some(session) = self.sessions.get(&session_key) {
365                                // Clone session to add message since we can't modify through &Session
366                                let mut session_clone = session.clone();
367                                session_clone.add_message("user", msg.content.clone());
368                                session_clone.add_message("assistant", final_content.clone());
369                                self.sessions.save(&session_clone)?;
370                            }
371
372                            let mut outbound = OutboundMessage::new(
373                                &msg.channel,
374                                &msg.chat_id,
375                                final_content,
376                            );
377                            if !reasoning_content.is_empty() {
378                                outbound.reasoning_content = Some(reasoning_content);
379                            }
380                            return Ok(Some(outbound));
381                        }
382                        Err(e) => {
383                            self.emit_error_event(&msg, e.to_string());
384                            return Err(e.into());
385                        }
386                    }
387                }
388                Ok(None) => break,
389                Err(_) => {
390                    warn!("Stream timeout for trace {}", trace_id);
391                    self.emit_error_event(&msg, "Stream timeout".to_string());
392                    return Err("Stream timeout".into());
393                }
394            }
395        }
396
397        Ok(None)
398    }
399
400    /// Execute tool calls and return results.
401    async fn execute_tool_calls(
402        &mut self,
403        tool_calls: &[ToolCallRequest],
404        msg: &InboundMessage,
405    ) -> Vec<(String, String)> {
406        let mut results = Vec::new();
407
408        for tc in tool_calls {
409            // Build args_preview from arguments
410            let args_preview = serde_json::to_string(&tc.arguments)
411                .unwrap_or_default()
412                .chars()
413                .take(100)
414                .collect();
415            
416            self.emit_event(&msg, AgentEvent::ToolCallStarted {
417                name: tc.name.clone(),
418                args_preview,
419                call_id: tc.id.clone(),
420            });
421
422            // Convert arguments HashMap to JSON Value
423            let params = serde_json::to_value(&tc.arguments).unwrap_or(serde_json::Value::Null);
424
425            let result = self.tools.execute(&tc.name, params).await;
426            let is_error = result.starts_with("Error");
427
428            self.emit_event(&msg, AgentEvent::ToolCallFinished {
429                name: tc.name.clone(),
430                result: result.clone(),
431                is_error,
432                call_id: tc.id.clone(),
433            });
434
435            results.push((tc.id.clone(), result));
436        }
437
438        results
439    }
440
441    /// Emit an event to the bus.
442    fn emit_event(&self, msg: &InboundMessage, event: AgentEvent) {
443        if let Err(e) = self.bus.publish_event(&msg.channel, &msg.chat_id, event) {
444            warn!("Failed to emit event: {}", e);
445        }
446    }
447
448    /// Emit an error event.
449    fn emit_error_event(&self, msg: &InboundMessage, message: String) {
450        self.emit_event(msg, AgentEvent::Error { message });
451    }
452
453    /// Process a message directly (for CLI or testing).
454    pub async fn process_direct(
455        &mut self,
456        content: impl Into<String>,
457        channel: impl Into<String>,
458        chat_id: impl Into<String>,
459    ) -> Result<String, Box<dyn std::error::Error>> {
460        let content = content.into();
461        let channel = channel.into();
462        let chat_id = chat_id.into();
463
464        let msg = InboundMessage::new(channel, "user", chat_id, content);
465
466        let response = self.process_inbound_message(msg).await?;
467        Ok(response
468            .map(|r| {
469                let content = r.content;
470                if let Some(reasoning) = r.reasoning_content {
471                    if !reasoning.is_empty() {
472                        return format!("{}\n\n{}", reasoning, content);
473                    }
474                }
475                content
476            })
477            .unwrap_or_default())
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use agent_diva_providers::{LLMResponse, Message, ProviderResult, LLMStreamEvent};
485    use async_trait::async_trait;
486    use futures::stream;
487
488    struct MockProvider;
489
490    #[async_trait]
491    impl LLMProvider for MockProvider {
492        async fn chat(
493            &self,
494            _messages: Vec<Message>,
495            _tools: Option<Vec<serde_json::Value>>,
496            _model: Option<String>,
497            _max_tokens: i32,
498            _temperature: f64,
499        ) -> ProviderResult<LLMResponse> {
500            Ok(LLMResponse {
501                content: Some("mock response".to_string()),
502                reasoning_content: None,
503                tool_calls: Vec::new(),
504                finish_reason: "stop".to_string(),
505                usage: std::collections::HashMap::new(),
506            })
507        }
508
509        async fn chat_stream(
510            &self,
511            _messages: Vec<Message>,
512            _tools: Option<Vec<serde_json::Value>>,
513            _model: Option<String>,
514            _max_tokens: i32,
515            _temperature: f64,
516        ) -> ProviderResult<ProviderEventStream> {
517            Ok(Box::pin(stream::iter(vec![
518                Ok(LLMStreamEvent::TextDelta("mock".to_string())),
519                Ok(LLMStreamEvent::Completed(LLMResponse {
520                    content: Some("mock".to_string()),
521                    reasoning_content: None,
522                    tool_calls: Vec::new(),
523                    finish_reason: "stop".to_string(),
524                    usage: std::collections::HashMap::new(),
525                })),
526            ])))
527        }
528
529        fn get_default_model(&self) -> String {
530            "mock-model".to_string()
531        }
532    }
533
534    #[tokio::test]
535    async fn test_nano_agent_loop_creation() {
536        let bus = MessageBus::new();
537        let provider = Arc::new(MockProvider);
538        let workspace = PathBuf::from("/tmp/test");
539        let tools = ToolRegistry::new();
540        
541        #[cfg(feature = "files")]
542        {
543            let storage_path = workspace.join(".agent-diva/files");
544            let file_config = agent_diva_files::FileConfig::with_path(&storage_path);
545            let file_manager = Arc::new(FileManager::new(file_config).await.unwrap());
546
547            let agent = NanoAgentLoop::new(
548                bus,
549                provider,
550                workspace,
551                None,
552                NanoLoopConfig::default(),
553                tools,
554                file_manager,
555            ).await;
556
557            assert!(agent.is_ok());
558            let agent = agent.unwrap();
559            assert_eq!(agent.config.max_iterations, 20);
560        }
561        
562        #[cfg(not(feature = "files"))]
563        {
564            let agent = NanoAgentLoop::new(
565                bus,
566                provider,
567                workspace,
568                None,
569                NanoLoopConfig::default(),
570                tools,
571            ).await;
572
573            assert!(agent.is_ok());
574            let agent = agent.unwrap();
575            assert_eq!(agent.config.max_iterations, 20);
576        }
577    }
578}