Skip to main content

rs_agent/
agent.rs

1//! Core Agent orchestrator
2//!
3//! This module provides the main Agent struct that coordinates LLM calls, memory,
4//! tool invocations, and UTCP integration. Matches the structure from go-agent's agent.go.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::anyhow;
10use chrono::Utc;
11use futures::stream::BoxStream;
12use futures::{FutureExt, StreamExt};
13use rs_utcp::plugins::codemode::{CodeModeUtcp, CodemodeOrchestrator};
14use rs_utcp::providers::base::Provider as UtcpProvider;
15use rs_utcp::providers::cli::CliProvider;
16use rs_utcp::tools::Tool as UtcpTool;
17use rs_utcp::tools::ToolInputOutputSchema;
18use rs_utcp::UtcpClientInterface;
19use serde_json::{json, Value};
20use toon_format::encode_default;
21use uuid::Uuid;
22
23use crate::agent_orchestrators::{build_orchestrator, format_codemode_value, CodeModeTool};
24use crate::agent_tool::{ensure_agent_cli_transport, InProcessTool};
25use crate::error::{AgentError, Result};
26use crate::memory::{mmr_rerank_records, MemoryRecord, SessionMemory};
27use crate::models::LLM;
28use crate::query::{classify_query, QueryType};
29use crate::tools::ToolCatalog;
30use crate::types::{
31    AgentOptions, AgentState, File, GenerationChunk, GenerationResponse, Message, Role, ToolRequest,
32};
33
34const DEFAULT_SYSTEM_PROMPT: &str = "You are a helpful AI assistant. Provide concise, accurate answers and explain when you use tools.";
35
36/// Main Agent orchestrator
37///
38/// The Agent coordinates model calls, memory, tools, and sub-agents. It matches
39/// the structure from go-agent's Agent struct.
40pub struct Agent {
41    model: Arc<dyn LLM>,
42    memory: Arc<SessionMemory>,
43    system_prompt: String,
44    context_limit: usize,
45    tool_catalog: Arc<ToolCatalog>,
46    codemode: Option<Arc<CodeModeUtcp>>,
47    codemode_orchestrator: Option<Arc<CodemodeOrchestrator>>,
48}
49
50impl Agent {
51    /// Creates a new Agent with the given configuration
52    pub fn new(model: Arc<dyn LLM>, memory: Arc<SessionMemory>, options: AgentOptions) -> Self {
53        Self {
54            model,
55            memory,
56            system_prompt: options
57                .system_prompt
58                .unwrap_or_else(|| DEFAULT_SYSTEM_PROMPT.to_string()),
59            context_limit: options.context_limit.unwrap_or(8192),
60            tool_catalog: Arc::new(ToolCatalog::new()),
61            codemode: None,
62            codemode_orchestrator: None,
63        }
64    }
65
66    /// Sets the system prompt
67    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
68        self.system_prompt = prompt.into();
69        self
70    }
71
72    /// Sets the tool catalog
73    pub fn with_tools(mut self, catalog: Arc<ToolCatalog>) -> Self {
74        self.tool_catalog = catalog;
75        self
76    }
77
78    /// Enables CodeMode execution as a first-class tool (`codemode.run_code`).
79    pub fn with_codemode(mut self, engine: Arc<CodeModeUtcp>) -> Self {
80        self.set_codemode(engine);
81        self
82    }
83
84    /// Enables CodeMode plus the Codemode orchestrator for automatic tool routing.
85    /// If `orchestrator_model` is None, the primary agent model is reused.
86    pub fn with_codemode_orchestrator(
87        mut self,
88        engine: Arc<CodeModeUtcp>,
89        orchestrator_model: Option<Arc<dyn LLM>>,
90    ) -> Self {
91        self.set_codemode(engine.clone());
92
93        let llm = orchestrator_model.unwrap_or_else(|| Arc::clone(&self.model));
94        let orchestrator = build_orchestrator(engine, llm);
95        self.codemode_orchestrator = Some(Arc::new(orchestrator));
96        self
97    }
98
99    /// Registers a UTCP provider and loads its tools into the agent's catalog.
100    pub async fn register_utcp_provider(
101        &self,
102        client: Arc<dyn UtcpClientInterface>,
103        provider: Arc<dyn UtcpProvider>,
104    ) -> Result<Vec<UtcpTool>> {
105        let tools = client
106            .register_tool_provider(provider)
107            .await
108            .map_err(|e| AgentError::UtcpError(e.to_string()))?;
109
110        crate::utcp::register_utcp_tools(self.tool_catalog.as_ref(), client, tools.clone())?;
111        Ok(tools)
112    }
113
114    /// Registers a UTCP provider using a predefined set of tools and adds them to the catalog.
115    pub async fn register_utcp_provider_with_tools(
116        &self,
117        client: Arc<dyn UtcpClientInterface>,
118        provider: Arc<dyn UtcpProvider>,
119        tools: Vec<UtcpTool>,
120    ) -> Result<Vec<UtcpTool>> {
121        let registered_tools = client
122            .register_tool_provider_with_tools(provider, tools)
123            .await
124            .map_err(|e| AgentError::UtcpError(e.to_string()))?;
125
126        crate::utcp::register_utcp_tools(
127            self.tool_catalog.as_ref(),
128            client,
129            registered_tools.clone(),
130        )?;
131
132        Ok(registered_tools)
133    }
134
135    /// Registers UTCP tools into the agent's catalog without re-registering the provider.
136    pub fn register_utcp_tools(
137        &self,
138        client: Arc<dyn UtcpClientInterface>,
139        tools: Vec<UtcpTool>,
140    ) -> Result<()> {
141        crate::utcp::register_utcp_tools(self.tool_catalog.as_ref(), client, tools)
142    }
143
144    /// Returns a UTCP tool specification representing this agent as an in-process tool.
145    pub fn as_utcp_tool(
146        &self,
147        name: impl Into<String>,
148        description: impl Into<String>,
149    ) -> UtcpTool {
150        let name = name.into();
151        let description = description.into();
152        let provider_name = name
153            .split('.')
154            .next()
155            .map(str::trim)
156            .filter(|s| !s.is_empty())
157            .unwrap_or("agent")
158            .to_string();
159
160        let inputs = ToolInputOutputSchema {
161            type_: "object".to_string(),
162            properties: Some(HashMap::from([
163                (
164                    "instruction".to_string(),
165                    json!({
166                        "type": "string",
167                        "description": "The instruction or query for the agent."
168                    }),
169                ),
170                (
171                    "session_id".to_string(),
172                    json!({
173                        "type": "string",
174                        "description": "Optional session id; defaults to the provider-derived session."
175                    }),
176                ),
177            ])),
178            required: Some(vec!["instruction".to_string()]),
179            description: Some("Call the agent with an instruction".to_string()),
180            title: Some("AgentInvocation".to_string()),
181            items: None,
182            enum_: None,
183            minimum: None,
184            maximum: None,
185            format: None,
186        };
187
188        let outputs = ToolInputOutputSchema {
189            type_: "object".to_string(),
190            properties: Some(HashMap::from([
191                ("response".to_string(), json!({ "type": "string" })),
192                ("session_id".to_string(), json!({ "type": "string" })),
193            ])),
194            required: None,
195            description: Some("Agent response payload".to_string()),
196            title: Some("AgentResponse".to_string()),
197            items: None,
198            enum_: None,
199            minimum: None,
200            maximum: None,
201            format: None,
202        };
203
204        UtcpTool {
205            name,
206            description,
207            inputs,
208            outputs,
209            tags: vec![
210                "agent".to_string(),
211                "rs-agent".to_string(),
212                "inproc".to_string(),
213            ],
214            average_response_size: None,
215            provider: Some(json!({
216                "name": provider_name,
217                "provider_type": "cli",
218            })),
219        }
220    }
221
222    /// Registers this agent as a UTCP provider using an in-process CLI shim.
223    pub async fn register_as_utcp_provider(
224        self: Arc<Self>,
225        utcp_client: &dyn UtcpClientInterface,
226        name: impl Into<String>,
227        description: impl Into<String>,
228    ) -> Result<()> {
229        let name = name.into();
230        let description = description.into();
231
232        let provider_name = name
233            .split('.')
234            .next()
235            .map(str::trim)
236            .filter(|s| !s.is_empty())
237            .unwrap_or("agent")
238            .to_string();
239
240        let tool_spec = self.as_utcp_tool(&name, &description);
241        let default_session = format!("{}.session", provider_name);
242        let agent = Arc::clone(&self);
243        let handler = Arc::new(move |args: HashMap<String, Value>| {
244            let agent = Arc::clone(&agent);
245            let default_session = default_session.clone();
246            async move {
247                let instruction = args
248                    .get("instruction")
249                    .and_then(|v| v.as_str())
250                    .map(str::to_string)
251                    .filter(|s| !s.trim().is_empty())
252                    .ok_or_else(|| anyhow!("missing or invalid 'instruction'"))?;
253
254                let session_id = args
255                    .get("session_id")
256                    .and_then(|v| v.as_str())
257                    .map(str::to_string)
258                    .filter(|s| !s.trim().is_empty())
259                    .unwrap_or_else(|| default_session.clone());
260
261                let content = agent
262                    .generate(session_id, instruction)
263                    .await
264                    .map_err(|e| anyhow!(e.to_string()))?;
265
266                Ok(Value::String(content))
267            }
268            .boxed()
269        });
270
271        let inproc_tool = InProcessTool {
272            spec: tool_spec.clone(),
273            handler,
274        };
275
276        let transport = ensure_agent_cli_transport();
277        transport.register(&provider_name, inproc_tool);
278
279        let provider = CliProvider::new(
280            provider_name.clone(),
281            format!("rs-agent-{}", provider_name),
282            None,
283        );
284
285        utcp_client
286            .register_tool_provider_with_tools(Arc::new(provider), vec![tool_spec])
287            .await
288            .map_err(|e| AgentError::UtcpError(e.to_string()))?;
289
290        Ok(())
291    }
292
293    /// Generates a response for the given user input, encoded as TOON
294    pub async fn generate(
295        &self,
296        session_id: impl Into<String>,
297        user_input: impl Into<String>,
298    ) -> Result<String> {
299        let response = self
300            .generate_internal(session_id.into(), user_input.into(), None)
301            .await?;
302
303        encode_default(&response).map_err(|e| AgentError::ToonFormatError(e.to_string()))
304    }
305
306    /// Generates a response with file attachments
307    pub async fn generate_with_files(
308        &self,
309        session_id: impl Into<String>,
310        user_input: impl Into<String>,
311        files: Vec<File>,
312    ) -> Result<String> {
313        let response = self
314            .generate_internal(session_id.into(), user_input.into(), Some(files))
315            .await?;
316
317        Ok(response.content)
318    }
319
320    /// Generates a streaming response
321    pub async fn generate_stream(
322        &self,
323        session_id: impl Into<String>,
324        user_input: impl Into<String>,
325    ) -> Result<BoxStream<'static, Result<GenerationChunk>>> {
326        let session_id = session_id.into();
327        let user_input = user_input.into();
328
329        // Store user message
330        self.store_memory(&session_id, "user", &user_input, None)
331            .await?;
332
333        // Try codemode (basic version: just run it and return as a single chunk if it triggers)
334        // Note: For true codemode streaming, we'd need a different architecture.
335        // Here we just check: if codemode triggers, we yield it as one chunk.
336        if let Some((content, metadata)) = self
337            .try_codemode_orchestration(&session_id, &user_input)
338            .await?
339        {
340            self.store_memory(&session_id, "assistant", &content, metadata.clone())
341                .await?;
342
343            let chunk = GenerationChunk { content, metadata };
344            return Ok(futures::stream::once(async move { Ok(chunk) }).boxed());
345        }
346
347        // Build prompt
348        let messages = self.build_prompt(&session_id, &user_input).await?;
349
350        // Start streaming
351        let stream = self.model.stream_generate(messages, None).await?;
352        let memory = self.memory.clone();
353        let session_id_clone = session_id.clone();
354        
355        // Wrap stream to capture content
356        let wrapped = futures::stream::unfold(
357            (stream, memory, session_id_clone, String::new(), false),
358            |(mut stream, memory, session_id, mut accumulated, finished)| async move {
359                if finished {
360                    return None;
361                }
362
363                match stream.next().await {
364                    Some(Ok(chunk)) => {
365                        accumulated.push_str(&chunk.content);
366                        Some((
367                            Ok(chunk),
368                            (stream, memory, session_id, accumulated, false),
369                        ))
370                    }
371                    Some(Err(e)) => Some((Err(e), (stream, memory, session_id, accumulated, true))),
372                    None => {
373                        // Stream finished, save memory
374                        // We spawn this to avoid blocking the stream end, or we could return a final "empty" chunk?
375                        // Better to spawn a detached task for saving, as we can't easily yield an error here after None without changing stream type.
376                        // Or we could execute the save here and wait for it.
377                        let record = MemoryRecord {
378                            id: Uuid::new_v4(),
379                            session_id: session_id.clone(),
380                            role: "assistant".to_string(),
381                            content: accumulated,
382                            importance: 0.5,
383                            timestamp: Utc::now(),
384                            metadata: None,
385                            embedding: None,
386                        };
387
388                        if let Err(e) = memory.store(record).await {
389                             return Some((Err(AgentError::MemoryError(e.to_string())), (stream, memory, session_id, String::new(), true)));
390                        }
391                        
392                        None
393                    }
394                }
395            },
396        );
397
398        Ok(wrapped.boxed())
399    }
400
401    /// Invokes a tool by name
402    pub async fn invoke_tool(
403        &self,
404        session_id: impl Into<String>,
405        tool_name: &str,
406        arguments: HashMap<String, serde_json::Value>,
407    ) -> Result<String> {
408        let session_id = session_id.into();
409
410        let request = ToolRequest {
411            session_id: session_id.clone(),
412            arguments,
413        };
414
415        let response = self.tool_catalog.invoke(tool_name, request).await?;
416
417        // Store tool invocation in memory
418        self.store_memory(
419            &session_id,
420            "tool",
421            &format!("Called {}: {}", tool_name, response.content),
422            response.metadata,
423        )
424        .await?;
425
426        Ok(response.content)
427    }
428
429    /// Builds the prompt with system message and context, using RAG
430    async fn build_prompt(&self, session_id: &str, user_input: &str) -> Result<Vec<Message>> {
431        let mut messages = Vec::new();
432
433        // Add system prompt if set
434        if !self.system_prompt.is_empty() {
435            messages.push(Message {
436                role: Role::System,
437                content: self.system_prompt.clone(),
438                metadata: None,
439            });
440        }
441
442        let mut available_tokens = self.context_limit;
443
444        // Reserve space for user input (approx)
445        available_tokens = available_tokens.saturating_sub(user_input.len() / 4);
446
447        // Classify query to determine retrieval strategy
448        let query_type = classify_query(user_input);
449
450        // Strategy 1: Recent Conversation History (Short-term memory)
451        // Always retrieve recent context first
452        let recent_memories = self.memory.retrieve_recent(session_id).await?;
453        let mut context_messages = Vec::new();
454        let mut recent_ids = std::collections::HashSet::new();
455
456        // Take up to 60% of available tokens for recent history
457        let recent_token_limit = (available_tokens as f32 * 0.6) as usize;
458        let mut current_tokens = 0;
459
460        for record in recent_memories.iter().rev() {
461            let estimated_tokens = record.content.len() / 4;
462            if current_tokens + estimated_tokens > recent_token_limit {
463                break;
464            }
465
466            recent_ids.insert(record.id);
467            context_messages.push(record.clone());
468            current_tokens += estimated_tokens;
469        }
470
471        // Strategy 2: Semantic Search (RAG) for Complex queries or if we have space
472        if matches!(query_type, QueryType::Complex | QueryType::Math) || context_messages.len() < 5 {
473            let search_limit = 20; // Fetch more candidates for re-ranking
474            let embeddings = self.memory.embed(user_input).await.unwrap_or_default();
475            
476            if !embeddings.is_empty() {
477                // Perform semantic search
478                let search_results = self
479                    .memory
480                    .search(session_id, user_input, search_limit)
481                    .await?;
482
483                // Filter out memories already in recent history
484                let candidates: Vec<MemoryRecord> = search_results
485                    .into_iter()
486                    .filter(|r| !recent_ids.contains(&r.id))
487                    .collect();
488
489                // Apply MMR Re-ranking to diversify context
490                // Lambda 0.7 balances relevance (70%) with diversity (30%)
491                let reranked = mmr_rerank_records(&embeddings, candidates, 5, 0.7);
492
493                for record in reranked {
494                    let estimated_tokens = record.content.len() / 4;
495                    if current_tokens + estimated_tokens > available_tokens {
496                        break;
497                    }
498                    
499                    // Add as a 'system' or 'user' message with context note? 
500                    // Usually RAG context is inserted as a system message or implied context.
501                    // Here we'll insert them before correct chronological messages, but since 
502                    // we're rebuilding the whole list, we need to be careful with ordering.
503                    // Ideally, RAG context goes into the System Prompt or a specific Context block.
504                    // For now, we'll append to context_messages but marks them.
505                    
506                    // Actually, let's just add them to the list we're building.
507                    // To distinguish RAG context from conversation flow, we might wrap it.
508                    // specific for this implementation, we just mix them in as past messages.
509                    context_messages.push(record);
510                    current_tokens += estimated_tokens;
511                }
512            }
513        }
514
515        // Sort context messages by timestamp to maintain basic coherence where possible,
516        // though RAG injections might be out of order historically, they are relevant context.
517        context_messages.sort_by_key(|r| r.timestamp);
518
519        for record in context_messages {
520             messages.push(Message {
521                role: match record.role.as_str() {
522                    "user" => Role::User,
523                    "assistant" => Role::Assistant,
524                    "tool" => Role::Tool,
525                    _ => Role::User,
526                },
527                content: record.content.clone(),
528                metadata: record.metadata.clone(),
529            });
530        }
531
532        // Add current user input
533        messages.push(Message {
534            role: Role::User,
535            content: user_input.to_string(),
536            metadata: None,
537        });
538
539        Ok(messages)
540    }
541
542    async fn generate_internal(
543        &self,
544        session_id: String,
545        user_input: String,
546        files: Option<Vec<File>>,
547    ) -> Result<GenerationResponse> {
548        // Store user message in memory
549        self.store_memory(&session_id, "user", &user_input, None)
550            .await?;
551
552        // Try CodeMode orchestration before invoking the primary model
553        let has_files = files.as_ref().map(|f| !f.is_empty()).unwrap_or(false);
554        if !has_files {
555            if let Some((content, metadata)) = self
556                .try_codemode_orchestration(&session_id, &user_input)
557                .await?
558            {
559                self.store_memory(&session_id, "assistant", &content, metadata.clone())
560                    .await?;
561
562                return Ok(GenerationResponse { content, metadata });
563            }
564        }
565
566        // Build prompt with context
567        let messages = self.build_prompt(&session_id, &user_input).await?;
568
569        // Generate response
570        let response = self.model.generate(messages, files).await?;
571
572        // Store assistant response in memory
573        self.store_memory(&session_id, "assistant", &response.content, None)
574            .await?;
575
576        Ok(response)
577    }
578
579    fn set_codemode(&mut self, engine: Arc<CodeModeUtcp>) {
580        self.codemode = Some(engine.clone());
581        // Expose codemode.run_code as a tool; ignore duplicate registrations
582        let _ = self
583            .tool_catalog
584            .register(Box::new(CodeModeTool::new(engine)));
585    }
586
587    async fn try_codemode_orchestration(
588        &self,
589        _session_id: &str,
590        user_input: &str,
591    ) -> Result<Option<(String, Option<HashMap<String, String>>)>> {
592        let orchestrator = match self.codemode_orchestrator.as_ref() {
593            Some(o) => o,
594            None => return Ok(None),
595        };
596
597        let value = orchestrator
598            .call_prompt(user_input)
599            .await
600            .map_err(|e| AgentError::Other(e.to_string()))?;
601
602        if let Some(v) = value {
603            let content = format_codemode_value(&v);
604            let metadata = Some(HashMap::from([(
605                "source".to_string(),
606                "codemode_orchestrator".to_string(),
607            )]));
608            return Ok(Some((content, metadata)));
609        }
610
611        Ok(None)
612    }
613
614    /// Stores a memory record
615    async fn store_memory(
616        &self,
617        session_id: &str,
618        role: &str,
619        content: &str,
620        metadata: Option<HashMap<String, String>>,
621    ) -> Result<()> {
622        let record = MemoryRecord {
623            id: Uuid::new_v4(),
624            session_id: session_id.to_string(),
625            role: role.to_string(),
626            content: content.to_string(),
627            importance: 0.5, // Default importance
628            timestamp: Utc::now(),
629            metadata,
630            embedding: None,
631        };
632
633        self.memory.store(record).await
634    }
635
636    /// Flushes memory to persistent store
637    pub async fn flush(&self, _session_id: &str) -> Result<()> {
638        self.memory.flush().await
639    }
640
641    /// Returns the tool catalog
642    pub fn tools(&self) -> Arc<ToolCatalog> {
643        Arc::clone(&self.tool_catalog)
644    }
645
646    /// Checkpoints the agent state for persistence
647    pub async fn checkpoint(&self, session_id: &str) -> Result<Vec<u8>> {
648        let recent = self.memory.retrieve_recent(session_id).await?;
649
650        let state = AgentState {
651            system_prompt: self.system_prompt.clone(),
652            short_term: recent,
653            joined_spaces: None,
654            timestamp: Utc::now(),
655        };
656
657        serde_json::to_vec(&state).map_err(|e| AgentError::SerializationError(e))
658    }
659
660    /// Restores agent state from checkpoint
661    pub async fn restore(&self, _session_id: &str, data: &[u8]) -> Result<()> {
662        let state: AgentState =
663            serde_json::from_slice(data).map_err(|e| AgentError::SerializationError(e))?;
664
665        // Restore memories
666        for record in state.short_term {
667            self.memory.store(record).await?;
668        }
669
670        Ok(())
671    }
672}