Skip to main content

rs_agent/
agent.rs

1//! Core Agent orchestrator
2//!
3//! This module provides the main Agent struct that coordinates LLM calls, memory,
4//! tool invocations, and UTCP integration. Matches the structure from go-agent's agent.go.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::anyhow;
10use chrono::Utc;
11use futures::stream::BoxStream;
12use futures::{FutureExt, StreamExt};
13use rs_utcp::plugins::codemode::{CodeModeUtcp, CodemodeOrchestrator};
14use rs_utcp::providers::base::Provider as UtcpProvider;
15use rs_utcp::providers::cli::CliProvider;
16use rs_utcp::tools::Tool as UtcpTool;
17use rs_utcp::tools::ToolInputOutputSchema;
18use rs_utcp::UtcpClientInterface;
19use serde_json::{json, Value};
20use toon_format::encode_default;
21use uuid::Uuid;
22
23use crate::agent_orchestrators::{build_orchestrator, format_codemode_value, CodeModeTool};
24use crate::agent_tool::{ensure_agent_cli_transport, InProcessTool};
25use crate::error::{AgentError, Result};
26use crate::memory::{MemoryRecord, SessionMemory};
27use crate::models::LLM;
28use crate::tools::ToolCatalog;
29use crate::types::{
30    AgentOptions, AgentState, File, GenerationChunk, GenerationResponse, Message, Role, ToolRequest,
31};
32
33const DEFAULT_SYSTEM_PROMPT: &str = "You are a helpful AI assistant. Provide concise, accurate answers and explain when you use tools.";
34
35/// Main Agent orchestrator
36///
37/// The Agent coordinates model calls, memory, tools, and sub-agents. It matches
38/// the structure from go-agent's Agent struct.
39pub struct Agent {
40    model: Arc<dyn LLM>,
41    memory: Arc<SessionMemory>,
42    system_prompt: String,
43    context_limit: usize,
44    tool_catalog: Arc<ToolCatalog>,
45    codemode: Option<Arc<CodeModeUtcp>>,
46    codemode_orchestrator: Option<Arc<CodemodeOrchestrator>>,
47}
48
49impl Agent {
50    /// Creates a new Agent with the given configuration
51    pub fn new(model: Arc<dyn LLM>, memory: Arc<SessionMemory>, options: AgentOptions) -> Self {
52        Self {
53            model,
54            memory,
55            system_prompt: options
56                .system_prompt
57                .unwrap_or_else(|| DEFAULT_SYSTEM_PROMPT.to_string()),
58            context_limit: options.context_limit.unwrap_or(8192),
59            tool_catalog: Arc::new(ToolCatalog::new()),
60            codemode: None,
61            codemode_orchestrator: None,
62        }
63    }
64
65    /// Sets the system prompt
66    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
67        self.system_prompt = prompt.into();
68        self
69    }
70
71    /// Sets the tool catalog
72    pub fn with_tools(mut self, catalog: Arc<ToolCatalog>) -> Self {
73        self.tool_catalog = catalog;
74        self
75    }
76
77    /// Enables CodeMode execution as a first-class tool (`codemode.run_code`).
78    pub fn with_codemode(mut self, engine: Arc<CodeModeUtcp>) -> Self {
79        self.set_codemode(engine);
80        self
81    }
82
83    /// Enables CodeMode plus the Codemode orchestrator for automatic tool routing.
84    /// If `orchestrator_model` is None, the primary agent model is reused.
85    pub fn with_codemode_orchestrator(
86        mut self,
87        engine: Arc<CodeModeUtcp>,
88        orchestrator_model: Option<Arc<dyn LLM>>,
89    ) -> Self {
90        self.set_codemode(engine.clone());
91
92        let llm = orchestrator_model.unwrap_or_else(|| Arc::clone(&self.model));
93        let orchestrator = build_orchestrator(engine, llm);
94        self.codemode_orchestrator = Some(Arc::new(orchestrator));
95        self
96    }
97
98    /// Registers a UTCP provider and loads its tools into the agent's catalog.
99    pub async fn register_utcp_provider(
100        &self,
101        client: Arc<dyn UtcpClientInterface>,
102        provider: Arc<dyn UtcpProvider>,
103    ) -> Result<Vec<UtcpTool>> {
104        let tools = client
105            .register_tool_provider(provider)
106            .await
107            .map_err(|e| AgentError::UtcpError(e.to_string()))?;
108
109        crate::utcp::register_utcp_tools(self.tool_catalog.as_ref(), client, tools.clone())?;
110        Ok(tools)
111    }
112
113    /// Registers a UTCP provider using a predefined set of tools and adds them to the catalog.
114    pub async fn register_utcp_provider_with_tools(
115        &self,
116        client: Arc<dyn UtcpClientInterface>,
117        provider: Arc<dyn UtcpProvider>,
118        tools: Vec<UtcpTool>,
119    ) -> Result<Vec<UtcpTool>> {
120        let registered_tools = client
121            .register_tool_provider_with_tools(provider, tools)
122            .await
123            .map_err(|e| AgentError::UtcpError(e.to_string()))?;
124
125        crate::utcp::register_utcp_tools(
126            self.tool_catalog.as_ref(),
127            client,
128            registered_tools.clone(),
129        )?;
130
131        Ok(registered_tools)
132    }
133
134    /// Registers UTCP tools into the agent's catalog without re-registering the provider.
135    pub fn register_utcp_tools(
136        &self,
137        client: Arc<dyn UtcpClientInterface>,
138        tools: Vec<UtcpTool>,
139    ) -> Result<()> {
140        crate::utcp::register_utcp_tools(self.tool_catalog.as_ref(), client, tools)
141    }
142
143    /// Returns a UTCP tool specification representing this agent as an in-process tool.
144    pub fn as_utcp_tool(
145        &self,
146        name: impl Into<String>,
147        description: impl Into<String>,
148    ) -> UtcpTool {
149        let name = name.into();
150        let description = description.into();
151        let provider_name = name
152            .split('.')
153            .next()
154            .map(str::trim)
155            .filter(|s| !s.is_empty())
156            .unwrap_or("agent")
157            .to_string();
158
159        let inputs = ToolInputOutputSchema {
160            type_: "object".to_string(),
161            properties: Some(HashMap::from([
162                (
163                    "instruction".to_string(),
164                    json!({
165                        "type": "string",
166                        "description": "The instruction or query for the agent."
167                    }),
168                ),
169                (
170                    "session_id".to_string(),
171                    json!({
172                        "type": "string",
173                        "description": "Optional session id; defaults to the provider-derived session."
174                    }),
175                ),
176            ])),
177            required: Some(vec!["instruction".to_string()]),
178            description: Some("Call the agent with an instruction".to_string()),
179            title: Some("AgentInvocation".to_string()),
180            items: None,
181            enum_: None,
182            minimum: None,
183            maximum: None,
184            format: None,
185        };
186
187        let outputs = ToolInputOutputSchema {
188            type_: "object".to_string(),
189            properties: Some(HashMap::from([
190                ("response".to_string(), json!({ "type": "string" })),
191                ("session_id".to_string(), json!({ "type": "string" })),
192            ])),
193            required: None,
194            description: Some("Agent response payload".to_string()),
195            title: Some("AgentResponse".to_string()),
196            items: None,
197            enum_: None,
198            minimum: None,
199            maximum: None,
200            format: None,
201        };
202
203        UtcpTool {
204            name,
205            description,
206            inputs,
207            outputs,
208            tags: vec![
209                "agent".to_string(),
210                "rs-agent".to_string(),
211                "inproc".to_string(),
212            ],
213            average_response_size: None,
214            provider: Some(json!({
215                "name": provider_name,
216                "provider_type": "cli",
217            })),
218        }
219    }
220
221    /// Registers this agent as a UTCP provider using an in-process CLI shim.
222    pub async fn register_as_utcp_provider(
223        self: Arc<Self>,
224        utcp_client: &dyn UtcpClientInterface,
225        name: impl Into<String>,
226        description: impl Into<String>,
227    ) -> Result<()> {
228        let name = name.into();
229        let description = description.into();
230
231        let provider_name = name
232            .split('.')
233            .next()
234            .map(str::trim)
235            .filter(|s| !s.is_empty())
236            .unwrap_or("agent")
237            .to_string();
238
239        let tool_spec = self.as_utcp_tool(&name, &description);
240        let default_session = format!("{}.session", provider_name);
241        let agent = Arc::clone(&self);
242        let handler = Arc::new(move |args: HashMap<String, Value>| {
243            let agent = Arc::clone(&agent);
244            let default_session = default_session.clone();
245            async move {
246                let instruction = args
247                    .get("instruction")
248                    .and_then(|v| v.as_str())
249                    .map(str::to_string)
250                    .filter(|s| !s.trim().is_empty())
251                    .ok_or_else(|| anyhow!("missing or invalid 'instruction'"))?;
252
253                let session_id = args
254                    .get("session_id")
255                    .and_then(|v| v.as_str())
256                    .map(str::to_string)
257                    .filter(|s| !s.trim().is_empty())
258                    .unwrap_or_else(|| default_session.clone());
259
260                let content = agent
261                    .generate(session_id, instruction)
262                    .await
263                    .map_err(|e| anyhow!(e.to_string()))?;
264
265                Ok(Value::String(content))
266            }
267            .boxed()
268        });
269
270        let inproc_tool = InProcessTool {
271            spec: tool_spec.clone(),
272            handler,
273        };
274
275        let transport = ensure_agent_cli_transport();
276        transport.register(&provider_name, inproc_tool);
277
278        let provider = CliProvider::new(
279            provider_name.clone(),
280            format!("rs-agent-{}", provider_name),
281            None,
282        );
283
284        utcp_client
285            .register_tool_provider_with_tools(Arc::new(provider), vec![tool_spec])
286            .await
287            .map_err(|e| AgentError::UtcpError(e.to_string()))?;
288
289        Ok(())
290    }
291
292    /// Generates a response for the given user input, encoded as TOON
293    pub async fn generate(
294        &self,
295        session_id: impl Into<String>,
296        user_input: impl Into<String>,
297    ) -> Result<String> {
298        let response = self
299            .generate_internal(session_id.into(), user_input.into(), None)
300            .await?;
301
302        encode_default(&response).map_err(|e| AgentError::ToonFormatError(e.to_string()))
303    }
304
305    /// Generates a response with file attachments
306    pub async fn generate_with_files(
307        &self,
308        session_id: impl Into<String>,
309        user_input: impl Into<String>,
310        files: Vec<File>,
311    ) -> Result<String> {
312        let response = self
313            .generate_internal(session_id.into(), user_input.into(), Some(files))
314            .await?;
315
316        Ok(response.content)
317    }
318
319    /// Generates a streaming response
320    pub async fn generate_stream(
321        &self,
322        session_id: impl Into<String>,
323        user_input: impl Into<String>,
324    ) -> Result<BoxStream<'static, Result<GenerationChunk>>> {
325        let session_id = session_id.into();
326        let user_input = user_input.into();
327
328        // Store user message
329        self.store_memory(&session_id, "user", &user_input, None)
330            .await?;
331
332        // Try codemode (basic version: just run it and return as a single chunk if it triggers)
333        // Note: For true codemode streaming, we'd need a different architecture.
334        // Here we just check: if codemode triggers, we yield it as one chunk.
335        if let Some((content, metadata)) = self
336            .try_codemode_orchestration(&session_id, &user_input)
337            .await?
338        {
339            self.store_memory(&session_id, "assistant", &content, metadata.clone())
340                .await?;
341
342            let chunk = GenerationChunk { content, metadata };
343            return Ok(futures::stream::once(async move { Ok(chunk) }).boxed());
344        }
345
346        // Build prompt
347        let messages = self.build_prompt(&session_id, &user_input).await?;
348
349        // Start streaming
350        let stream = self.model.stream_generate(messages, None).await?;
351        let memory = self.memory.clone();
352        let session_id_clone = session_id.clone();
353        
354        // Wrap stream to capture content
355        let wrapped = futures::stream::unfold(
356            (stream, memory, session_id_clone, String::new(), false),
357            |(mut stream, memory, session_id, mut accumulated, finished)| async move {
358                if finished {
359                    return None;
360                }
361
362                match stream.next().await {
363                    Some(Ok(chunk)) => {
364                        accumulated.push_str(&chunk.content);
365                        Some((
366                            Ok(chunk),
367                            (stream, memory, session_id, accumulated, false),
368                        ))
369                    }
370                    Some(Err(e)) => Some((Err(e), (stream, memory, session_id, accumulated, true))),
371                    None => {
372                        // Stream finished, save memory
373                        // We spawn this to avoid blocking the stream end, or we could return a final "empty" chunk?
374                        // Better to spawn a detached task for saving, as we can't easily yield an error here after None without changing stream type.
375                        // Or we could execute the save here and wait for it.
376                        let record = MemoryRecord {
377                            id: Uuid::new_v4(),
378                            session_id: session_id.clone(),
379                            role: "assistant".to_string(),
380                            content: accumulated,
381                            importance: 0.5,
382                            timestamp: Utc::now(),
383                            metadata: None,
384                            embedding: None,
385                        };
386
387                        if let Err(e) = memory.store(record).await {
388                             return Some((Err(AgentError::MemoryError(e.to_string())), (stream, memory, session_id, String::new(), true)));
389                        }
390                        
391                        None
392                    }
393                }
394            },
395        );
396
397        Ok(wrapped.boxed())
398    }
399
400    /// Invokes a tool by name
401    pub async fn invoke_tool(
402        &self,
403        session_id: impl Into<String>,
404        tool_name: &str,
405        arguments: HashMap<String, serde_json::Value>,
406    ) -> Result<String> {
407        let session_id = session_id.into();
408
409        let request = ToolRequest {
410            session_id: session_id.clone(),
411            arguments,
412        };
413
414        let response = self.tool_catalog.invoke(tool_name, request).await?;
415
416        // Store tool invocation in memory
417        self.store_memory(
418            &session_id,
419            "tool",
420            &format!("Called {}: {}", tool_name, response.content),
421            response.metadata,
422        )
423        .await?;
424
425        Ok(response.content)
426    }
427
428    /// Builds the prompt with system message and context
429    async fn build_prompt(&self, session_id: &str, user_input: &str) -> Result<Vec<Message>> {
430        let mut messages = Vec::new();
431
432        // Add system prompt if set
433        if !self.system_prompt.is_empty() {
434            messages.push(Message {
435                role: Role::System,
436                content: self.system_prompt.clone(),
437                metadata: None,
438            });
439        }
440
441        // Retrieve recent conversation history
442        let recent_memories = self.memory.retrieve_recent(session_id).await?;
443
444        // Add context from memory (limited by context_limit)
445        let mut token_count = 0;
446        for record in recent_memories.iter().rev() {
447            // Simple token estimation (4 chars ≈ 1 token)
448            let estimated_tokens = record.content.len() / 4;
449            if token_count + estimated_tokens > self.context_limit {
450                break;
451            }
452
453            messages.push(Message {
454                role: match record.role.as_str() {
455                    "user" => Role::User,
456                    "assistant" => Role::Assistant,
457                    "tool" => Role::Tool,
458                    _ => Role::User,
459                },
460                content: record.content.clone(),
461                metadata: record.metadata.clone(),
462            });
463
464            token_count += estimated_tokens;
465        }
466
467        // Add current user input
468        messages.push(Message {
469            role: Role::User,
470            content: user_input.to_string(),
471            metadata: None,
472        });
473
474        Ok(messages)
475    }
476
477    async fn generate_internal(
478        &self,
479        session_id: String,
480        user_input: String,
481        files: Option<Vec<File>>,
482    ) -> Result<GenerationResponse> {
483        // Store user message in memory
484        self.store_memory(&session_id, "user", &user_input, None)
485            .await?;
486
487        // Try CodeMode orchestration before invoking the primary model
488        let has_files = files.as_ref().map(|f| !f.is_empty()).unwrap_or(false);
489        if !has_files {
490            if let Some((content, metadata)) = self
491                .try_codemode_orchestration(&session_id, &user_input)
492                .await?
493            {
494                self.store_memory(&session_id, "assistant", &content, metadata.clone())
495                    .await?;
496
497                return Ok(GenerationResponse { content, metadata });
498            }
499        }
500
501        // Build prompt with context
502        let messages = self.build_prompt(&session_id, &user_input).await?;
503
504        // Generate response
505        let response = self.model.generate(messages, files).await?;
506
507        // Store assistant response in memory
508        self.store_memory(&session_id, "assistant", &response.content, None)
509            .await?;
510
511        Ok(response)
512    }
513
514    fn set_codemode(&mut self, engine: Arc<CodeModeUtcp>) {
515        self.codemode = Some(engine.clone());
516        // Expose codemode.run_code as a tool; ignore duplicate registrations
517        let _ = self
518            .tool_catalog
519            .register(Box::new(CodeModeTool::new(engine)));
520    }
521
522    async fn try_codemode_orchestration(
523        &self,
524        _session_id: &str,
525        user_input: &str,
526    ) -> Result<Option<(String, Option<HashMap<String, String>>)>> {
527        let orchestrator = match self.codemode_orchestrator.as_ref() {
528            Some(o) => o,
529            None => return Ok(None),
530        };
531
532        let value = orchestrator
533            .call_prompt(user_input)
534            .await
535            .map_err(|e| AgentError::Other(e.to_string()))?;
536
537        if let Some(v) = value {
538            let content = format_codemode_value(&v);
539            let metadata = Some(HashMap::from([(
540                "source".to_string(),
541                "codemode_orchestrator".to_string(),
542            )]));
543            return Ok(Some((content, metadata)));
544        }
545
546        Ok(None)
547    }
548
549    /// Stores a memory record
550    async fn store_memory(
551        &self,
552        session_id: &str,
553        role: &str,
554        content: &str,
555        metadata: Option<HashMap<String, String>>,
556    ) -> Result<()> {
557        let record = MemoryRecord {
558            id: Uuid::new_v4(),
559            session_id: session_id.to_string(),
560            role: role.to_string(),
561            content: content.to_string(),
562            importance: 0.5, // Default importance
563            timestamp: Utc::now(),
564            metadata,
565            embedding: None,
566        };
567
568        self.memory.store(record).await
569    }
570
571    /// Flushes memory to persistent store
572    pub async fn flush(&self, _session_id: &str) -> Result<()> {
573        self.memory.flush().await
574    }
575
576    /// Returns the tool catalog
577    pub fn tools(&self) -> Arc<ToolCatalog> {
578        Arc::clone(&self.tool_catalog)
579    }
580
581    /// Checkpoints the agent state for persistence
582    pub async fn checkpoint(&self, session_id: &str) -> Result<Vec<u8>> {
583        let recent = self.memory.retrieve_recent(session_id).await?;
584
585        let state = AgentState {
586            system_prompt: self.system_prompt.clone(),
587            short_term: recent,
588            joined_spaces: None,
589            timestamp: Utc::now(),
590        };
591
592        serde_json::to_vec(&state).map_err(|e| AgentError::SerializationError(e))
593    }
594
595    /// Restores agent state from checkpoint
596    pub async fn restore(&self, _session_id: &str, data: &[u8]) -> Result<()> {
597        let state: AgentState =
598            serde_json::from_slice(data).map_err(|e| AgentError::SerializationError(e))?;
599
600        // Restore memories
601        for record in state.short_term {
602            self.memory.store(record).await?;
603        }
604
605        Ok(())
606    }
607}