Skip to main content

agent_diva_nano/
agent.rs

1//! Agent creation and management for agent-diva-nano.
2
3use crate::{NanoConfig, NanoError};
4use crate::tool_assembly::{ToolAssembly, BuiltInToolsConfig};
5use crate::nano_loop::{NanoAgentLoop, NanoLoopConfig, NanoRuntimeControlCommand};
6use crate::internal::context::NanoSoulSettings;
7use crate::internal::provider::{build_provider, build_tool_config};
8use agent_diva_agent::{AgentLoop, AgentLoopToolSet};
9use agent_diva_core::bus::{AgentEvent, InboundMessage, MessageBus};
10#[cfg(feature = "files")]
11use agent_diva_files::{FileManager, FileConfig};
12use agent_diva_providers::DynamicProvider;
13use agent_diva_tooling::{Tool, ToolRegistry};
14use std::path::PathBuf;
15use std::sync::Arc;
16use tokio::sync::mpsc;
17use tokio::task::JoinHandle;
18use tracing::{error, info};
19
20/// Agent loop mode selection.
21#[derive(Debug, Clone, Default)]
22pub enum AgentLoopMode {
23    /// Use agent-diva-agent's AgentLoop (default).
24    /// Tools are configured through ToolConfig.
25    #[default]
26    Standard,
27    /// Use nano's lightweight NanoAgentLoop.
28    /// Tools are configured through ToolAssembly with full control.
29    Nano,
30}
31
32/// A running agent instance.
33///
34/// Create with [`Agent::new`](Agent::new) and the builder pattern,
35/// then call [`start`](Agent::start) to run the background loop.
36pub struct Agent {
37    bus: MessageBus,
38    provider: Arc<DynamicProvider>,
39    mode: AgentLoopMode,
40    /// For standard mode: tool configuration
41    tool_config: Option<agent_diva_agent::ToolConfig>,
42    /// For nano mode: pre-built tool registry
43    tool_registry: Option<ToolRegistry>,
44    nano_loop_config: Option<NanoLoopConfig>,
45    workspace: PathBuf,
46    model: String,
47    max_iterations: usize,
48    #[cfg(feature = "files")]
49    file_manager: Arc<FileManager>,
50    runtime_control_tx: Option<mpsc::UnboundedSender<NanoRuntimeControlCommand>>,
51    agent_handle: Option<JoinHandle<()>>,
52    outbound_handle: Option<JoinHandle<()>>,
53}
54
55/// Builder for configuring an [`Agent`].
56pub struct AgentBuilder {
57    config: NanoConfig,
58    custom_tools: Vec<Arc<dyn Tool>>,
59    tool_assembly: Option<ToolAssembly>,
60    mode: AgentLoopMode,
61    system_prompt: Option<String>,
62}
63
64impl Agent {
65    /// Start configuring a new agent with default settings.
66    pub fn new(config: NanoConfig) -> AgentBuilder {
67        AgentBuilder {
68            config,
69            custom_tools: Vec::new(),
70            tool_assembly: None,
71            mode: AgentLoopMode::default(),
72            system_prompt: None,
73        }
74    }
75
76    /// Start the background agent loop.
77    pub async fn start(&mut self) -> Result<(), NanoError> {
78        if self.agent_handle.is_some() {
79            return Err(NanoError::Other("Agent already started".to_string()));
80        }
81
82        let bus = self.bus.clone();
83        let provider: Arc<dyn agent_diva_providers::LLMProvider> = self.provider.clone();
84        let model = self.model.clone();
85        let workspace = self.workspace.clone();
86        let max_iterations = self.max_iterations;
87        #[cfg(feature = "files")]
88        let file_manager = self.file_manager.clone();
89
90        let (runtime_control_tx, runtime_control_rx) = mpsc::unbounded_channel();
91        self.runtime_control_tx = Some(runtime_control_tx);
92
93        match self.mode {
94            AgentLoopMode::Standard => {
95                let tool_config = self.tool_config.clone().unwrap_or_default();
96                let registry = self
97                    .tool_registry
98                    .as_ref()
99                    .map(clone_registry)
100                    .unwrap_or_default();
101                
102                #[cfg(not(feature = "files"))]
103                {
104                    return Err(NanoError::Other("Standard mode requires 'files' feature. Use Nano mode or enable 'files' feature.".to_string()));
105                }
106                
107                #[cfg(feature = "files")]
108                {
109                    let mut agent_loop = AgentLoop::with_toolset(
110                        bus.clone(),
111                        provider,
112                        workspace,
113                        Some(model),
114                        Some(max_iterations),
115                        AgentLoopToolSet {
116                            registry,
117                            config: tool_config,
118                        },
119                        None, // No runtime control for standard mode (different type)
120                        file_manager,
121                    ).await.map_err(|e| NanoError::Other(e.to_string()))?;
122
123                    let agent_handle = tokio::spawn(async move {
124                        info!("Agent loop (standard) starting");
125                        if let Err(e) = agent_loop.run().await {
126                            error!("Agent loop error: {}", e);
127                        }
128                        info!("Agent loop (standard) stopped");
129                    });
130                    self.agent_handle = Some(agent_handle);
131                }
132            }
133            AgentLoopMode::Nano => {
134                let tool_registry = self
135                    .tool_registry
136                    .as_ref()
137                    .map(clone_registry)
138                    .unwrap_or_default();
139                let nano_config = self.nano_loop_config.clone().unwrap_or_default();
140
141                let mut nano_loop = NanoAgentLoop::new(
142                    bus.clone(),
143                    provider,
144                    workspace,
145                    Some(model),
146                    nano_config,
147                    tool_registry,
148                    #[cfg(feature = "files")]
149                    file_manager,
150                ).await.map_err(|e| NanoError::Other(e.to_string()))?;
151
152                nano_loop = nano_loop.with_runtime_control(runtime_control_rx);
153
154                let agent_handle = tokio::spawn(async move {
155                    info!("Nano agent loop starting");
156                    if let Err(e) = nano_loop.run().await {
157                        error!("Nano agent loop error: {}", e);
158                    }
159                    info!("Nano agent loop stopped");
160                });
161                self.agent_handle = Some(agent_handle);
162            }
163        }
164
165        let bus_for_outbound = self.bus.clone();
166        let outbound_handle = tokio::spawn(async move {
167            bus_for_outbound.dispatch_outbound_loop().await;
168        });
169        self.outbound_handle = Some(outbound_handle);
170
171        Ok(())
172    }
173
174    /// Send a message and wait for the complete text response.
175    pub async fn send(&self, message: impl Into<String>) -> Result<String, NanoError> {
176        let content = message.into();
177        let channel = "nano";
178        let chat_id = "default";
179
180        let mut event_rx = self.bus.subscribe_events();
181
182        let inbound = InboundMessage::new(channel, "user", chat_id, content);
183        self.bus
184            .publish_inbound(inbound)
185            .map_err(|e| NanoError::Agent(e.to_string()))?;
186
187        let mut full_response = String::new();
188        loop {
189            match tokio::time::timeout(
190                std::time::Duration::from_secs(300),
191                event_rx.recv(),
192            )
193            .await
194            {
195                Ok(Ok(bus_event)) => {
196                    if bus_event.channel != channel || bus_event.chat_id != chat_id {
197                        continue;
198                    }
199                    match bus_event.event {
200                        AgentEvent::AssistantDelta { text } => full_response.push_str(&text),
201                        AgentEvent::FinalResponse { content } => {
202                            full_response = content;
203                            break;
204                        }
205                        AgentEvent::Error { message } => {
206                            return Err(NanoError::Agent(message));
207                        }
208                        _ => {}
209                    }
210                }
211                Ok(Err(_)) => break,
212                Err(_) => return Err(NanoError::Timeout),
213            }
214        }
215
216        Ok(full_response)
217    }
218
219    /// Send a message and return a channel that receives all agent events.
220    pub async fn send_stream(
221        &self,
222        message: impl Into<String>,
223    ) -> Result<mpsc::UnboundedReceiver<AgentEvent>, NanoError> {
224        let content = message.into();
225        let channel = "nano";
226        let chat_id = "default";
227
228        let mut event_rx = self.bus.subscribe_events();
229
230        let inbound = InboundMessage::new(channel, "user", chat_id, content);
231        self.bus
232            .publish_inbound(inbound)
233            .map_err(|e| NanoError::Agent(e.to_string()))?;
234
235        let (tx, rx) = mpsc::unbounded_channel::<AgentEvent>();
236
237        tokio::spawn(async move {
238            loop {
239                match tokio::time::timeout(
240                    std::time::Duration::from_secs(300),
241                    event_rx.recv(),
242                )
243                .await
244                {
245                    Ok(Ok(bus_event)) => {
246                        if bus_event.channel != channel || bus_event.chat_id != chat_id {
247                            continue;
248                        }
249                        let is_final = matches!(
250                            bus_event.event,
251                            AgentEvent::FinalResponse { .. } | AgentEvent::Error { .. }
252                        );
253                        if tx.send(bus_event.event).is_err() {
254                            break;
255                        }
256                        if is_final {
257                            break;
258                        }
259                    }
260                    _ => break,
261                }
262            }
263        });
264
265        Ok(rx)
266    }
267
268    /// Dynamically reload tools (only works in Nano mode).
269    pub fn reload_tools(&self, registry: ToolRegistry) -> Result<(), NanoError> {
270        if let Some(ref tx) = self.runtime_control_tx {
271            tx.send(NanoRuntimeControlCommand::ReloadTools(registry))
272                .map_err(|e| NanoError::Other(e.to_string()))?;
273            Ok(())
274        } else {
275            Err(NanoError::Other("Runtime control not available (either agent not started or using Standard mode)".to_string()))
276        }
277    }
278
279    /// Cancel a specific session (only works in Nano mode).
280    pub fn cancel_session(&self, chat_id: impl Into<String>) -> Result<(), NanoError> {
281        if let Some(ref tx) = self.runtime_control_tx {
282            tx.send(NanoRuntimeControlCommand::CancelSession { chat_id: chat_id.into() })
283                .map_err(|e| NanoError::Other(e.to_string()))?;
284            Ok(())
285        } else {
286            Err(NanoError::Other("Runtime control not available".to_string()))
287        }
288    }
289
290    /// Stop the background agent loop.
291    pub async fn stop(&mut self) {
292        // Send stop command if in Nano mode
293        if let Some(ref tx) = self.runtime_control_tx {
294            let _ = tx.send(NanoRuntimeControlCommand::Stop);
295        }
296
297        if let Some(handle) = self.agent_handle.take() {
298            handle.abort();
299            let _ = handle.await;
300        }
301        if let Some(handle) = self.outbound_handle.take() {
302            handle.abort();
303            let _ = handle.await;
304        }
305        self.bus.stop().await;
306    }
307}
308
309impl AgentBuilder {
310    /// Set the model identifier.
311    pub fn model(mut self, model: impl Into<String>) -> Self {
312        self.config.model = model.into();
313        self
314    }
315
316    /// Set the API key.
317    pub fn api_key(mut self, key: impl Into<String>) -> Self {
318        self.config.api_key = key.into();
319        self
320    }
321
322    /// Set a custom API base URL.
323    pub fn api_base(mut self, base: impl Into<String>) -> Self {
324        self.config.api_base = Some(base.into());
325        self
326    }
327
328    /// Set the workspace directory.
329    pub fn workspace(mut self, path: impl Into<PathBuf>) -> Self {
330        self.config.workspace = path.into();
331        self
332    }
333
334    /// Set the maximum number of tool iterations.
335    pub fn max_iterations(mut self, n: usize) -> Self {
336        self.config.max_iterations = n;
337        self
338    }
339
340    /// Set the agent loop mode.
341    /// - `Standard`: Use agent-diva-agent's AgentLoop (default)
342    /// - `Nano`: Use nano's lightweight NanoAgentLoop with full tool control
343    pub fn mode(mut self, mode: AgentLoopMode) -> Self {
344        self.mode = mode;
345        self
346    }
347
348    /// Use nano mode for full tool control.
349    pub fn nano_mode(self) -> Self {
350        self.mode(AgentLoopMode::Nano)
351    }
352
353    /// Use standard mode (agent-diva-agent's AgentLoop).
354    pub fn standard_mode(self) -> Self {
355        self.mode(AgentLoopMode::Standard)
356    }
357
358    /// Add a custom tool.
359    /// In Standard mode, these will be added to the tool registry.
360    /// In Nano mode, use `with_tool_assembly` for more control.
361    pub fn with_tool(mut self, tool: Arc<dyn Tool>) -> Self {
362        self.custom_tools.push(tool);
363        self
364    }
365
366    /// Set a custom ToolAssembly for Nano mode.
367    /// This provides full control over which built-in and custom tools are available.
368    /// Note: Only effective in Nano mode. In Standard mode, use NanoConfig fields.
369    pub fn with_tool_assembly(mut self, assembly: ToolAssembly) -> Self {
370        self.tool_assembly = Some(assembly);
371        self.mode = AgentLoopMode::Standard;
372        self
373    }
374
375    /// Configure built-in tools using BuiltInToolsConfig.
376    /// Shortcut for creating a ToolAssembly.
377    pub fn builtin_tools(mut self, config: BuiltInToolsConfig) -> Self {
378        let workspace = self.config.workspace.clone();
379        let assembly = ToolAssembly::new(workspace)
380            .builtin(config);
381        self.tool_assembly = Some(assembly);
382        self.mode = AgentLoopMode::Standard;
383        self
384    }
385
386    /// Set a custom system prompt (only effective in Nano mode).
387    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
388        self.system_prompt = Some(prompt.into());
389        self
390    }
391
392    /// Build the [`Agent`].
393    pub async fn build(self) -> Result<Agent, NanoError> {
394        let config = self.config;
395        if config.model.is_empty() {
396            return Err(NanoError::Other("model must be set".to_string()));
397        }
398
399        let bus = MessageBus::new();
400        let client = build_provider(
401            &config.model,
402            &config.api_key,
403            config.api_base.as_deref(),
404        )?;
405        let provider = Arc::new(DynamicProvider::new(Arc::new(client)));
406        let workspace = config.workspace.clone();
407        let model = config.model.clone();
408        let max_iterations = config.max_iterations;
409
410        // Initialize file manager (only with files feature)
411        #[cfg(feature = "files")]
412        let file_manager = {
413            let storage_path = workspace.join(".agent-diva/files");
414            let file_config = FileConfig::with_path(&storage_path);
415            Arc::new(FileManager::new(file_config).await.map_err(|e| NanoError::Other(e.to_string()))?)
416        };
417
418        match self.mode {
419            AgentLoopMode::Standard => {
420                let tool_config = build_tool_config(&config);
421                let mut assembly = if let Some(assembly) = self.tool_assembly {
422                    assembly
423                } else {
424                    let builtin_config = config
425                        .builtin_tools
426                        .clone()
427                        .unwrap_or_else(BuiltInToolsConfig::default);
428                    let mut assembly = ToolAssembly::new(workspace.clone())
429                        .builtin(builtin_config)
430                        .restrict_to_workspace(config.restrict_to_workspace);
431                    if let Some(ref search) = config.web_search {
432                        assembly = assembly.web_config(crate::tool_assembly::WebToolConfig {
433                            search_enabled: true,
434                            fetch_enabled: true,
435                            search_provider: search.provider.clone(),
436                            search_api_key: search.api_key.clone(),
437                            max_results: search.max_results,
438                        });
439                    }
440                    if !config.mcp_servers.is_empty() {
441                        assembly = assembly.mcp_servers(config.mcp_servers.clone());
442                    }
443                    assembly
444                };
445                for tool in self.custom_tools {
446                    assembly = assembly.with_tool(tool);
447                }
448                #[cfg(feature = "files")]
449                {
450                    assembly = assembly.with_file_manager(file_manager.clone());
451                }
452                let tool_registry = assembly.build();
453
454                Ok(Agent {
455                    bus,
456                    provider,
457                    mode: AgentLoopMode::Standard,
458                    tool_config: Some(tool_config),
459                    tool_registry: Some(tool_registry),
460                    nano_loop_config: None,
461                    workspace,
462                    model,
463                    max_iterations,
464                    #[cfg(feature = "files")]
465                    file_manager,
466                    runtime_control_tx: None,
467                    agent_handle: None,
468                    outbound_handle: None,
469                })
470            }
471            AgentLoopMode::Nano => {
472                let tool_registry = if let Some(mut assembly) = self.tool_assembly {
473                    for tool in self.custom_tools {
474                        assembly = assembly.with_tool(tool);
475                    }
476                    #[cfg(feature = "files")]
477                    {
478                        assembly = assembly.with_file_manager(file_manager.clone());
479                    }
480                    assembly.build()
481                } else {
482                    let builtin_config = config.builtin_tools.clone()
483                        .unwrap_or_else(|| {
484                            if config.restrict_to_workspace {
485                                BuiltInToolsConfig::default()
486                            } else {
487                                BuiltInToolsConfig::all()
488                            }
489                        });
490                    
491                    let mut assembly = ToolAssembly::new(workspace.clone())
492                        .builtin(builtin_config)
493                        .restrict_to_workspace(config.restrict_to_workspace);
494                    if let Some(ref search) = config.web_search {
495                        assembly = assembly.web_config(crate::tool_assembly::WebToolConfig {
496                            search_enabled: true,
497                            fetch_enabled: true,
498                            search_provider: search.provider.clone(),
499                            search_api_key: search.api_key.clone(),
500                            max_results: search.max_results,
501                        });
502                    }
503
504                    for tool in self.custom_tools {
505                        assembly = assembly.with_tool(tool);
506                    }
507
508                    if !config.mcp_servers.is_empty() {
509                        assembly = assembly.mcp_servers(config.mcp_servers.clone());
510                    }
511                    #[cfg(feature = "files")]
512                    {
513                        assembly = assembly.with_file_manager(file_manager.clone());
514                    }
515
516                    assembly.build()
517                };
518
519                // Build NanoLoopConfig
520                let nano_loop_config = NanoLoopConfig {
521                    max_iterations,
522                    memory_window: 10,
523                    soul_settings: NanoSoulSettings {
524                        enabled: config.soul.enabled,
525                        max_chars: config.soul.max_chars,
526                        bootstrap_once: config.soul.bootstrap_once,
527                    },
528                    notify_on_soul_change: config.soul.notify_on_change,
529                };
530
531                Ok(Agent {
532                    bus,
533                    provider,
534                    mode: AgentLoopMode::Nano,
535                    tool_config: None,
536                    tool_registry: Some(tool_registry),
537                    nano_loop_config: Some(nano_loop_config),
538                    workspace,
539                    model,
540                    max_iterations,
541                    #[cfg(feature = "files")]
542                    file_manager,
543                    runtime_control_tx: None,
544                    agent_handle: None,
545                    outbound_handle: None,
546                })
547            }
548        }
549    }
550}
551
552fn clone_registry(registry: &ToolRegistry) -> ToolRegistry {
553    let mut cloned = ToolRegistry::new();
554    for name in registry.tool_names() {
555        if let Some(tool) = registry.get(&name) {
556            cloned.register(tool);
557        }
558    }
559    cloned
560}