1use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::anyhow;
10use chrono::Utc;
11use futures::stream::BoxStream;
12use futures::{FutureExt, StreamExt};
13use rs_utcp::plugins::codemode::{CodeModeUtcp, CodemodeOrchestrator};
14use rs_utcp::providers::base::Provider as UtcpProvider;
15use rs_utcp::providers::cli::CliProvider;
16use rs_utcp::tools::Tool as UtcpTool;
17use rs_utcp::tools::ToolInputOutputSchema;
18use rs_utcp::UtcpClientInterface;
19use serde_json::{json, Value};
20use toon_format::encode_default;
21use uuid::Uuid;
22
23use crate::agent_orchestrators::{build_orchestrator, format_codemode_value, CodeModeTool};
24use crate::agent_tool::{ensure_agent_cli_transport, InProcessTool};
25use crate::error::{AgentError, Result};
26use crate::memory::{mmr_rerank_records, MemoryRecord, SessionMemory};
27use crate::models::LLM;
28use crate::query::{classify_query, QueryType};
29use crate::tools::ToolCatalog;
30use crate::types::{
31 AgentOptions, AgentState, File, GenerationChunk, GenerationResponse, Message, Role, ToolRequest,
32};
33
34const DEFAULT_SYSTEM_PROMPT: &str = "You are a helpful AI assistant. Provide concise, accurate answers and explain when you use tools.";
35
36pub struct Agent {
41 model: Arc<dyn LLM>,
42 memory: Arc<SessionMemory>,
43 system_prompt: String,
44 context_limit: usize,
45 tool_catalog: Arc<ToolCatalog>,
46 codemode: Option<Arc<CodeModeUtcp>>,
47 codemode_orchestrator: Option<Arc<CodemodeOrchestrator>>,
48}
49
50impl Agent {
51 pub fn new(model: Arc<dyn LLM>, memory: Arc<SessionMemory>, options: AgentOptions) -> Self {
53 Self {
54 model,
55 memory,
56 system_prompt: options
57 .system_prompt
58 .unwrap_or_else(|| DEFAULT_SYSTEM_PROMPT.to_string()),
59 context_limit: options.context_limit.unwrap_or(8192),
60 tool_catalog: Arc::new(ToolCatalog::new()),
61 codemode: None,
62 codemode_orchestrator: None,
63 }
64 }
65
66 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
68 self.system_prompt = prompt.into();
69 self
70 }
71
72 pub fn with_tools(mut self, catalog: Arc<ToolCatalog>) -> Self {
74 self.tool_catalog = catalog;
75 self
76 }
77
78 pub fn with_codemode(mut self, engine: Arc<CodeModeUtcp>) -> Self {
80 self.set_codemode(engine);
81 self
82 }
83
84 pub fn with_codemode_orchestrator(
87 mut self,
88 engine: Arc<CodeModeUtcp>,
89 orchestrator_model: Option<Arc<dyn LLM>>,
90 ) -> Self {
91 self.set_codemode(engine.clone());
92
93 let llm = orchestrator_model.unwrap_or_else(|| Arc::clone(&self.model));
94 let orchestrator = build_orchestrator(engine, llm);
95 self.codemode_orchestrator = Some(Arc::new(orchestrator));
96 self
97 }
98
99 pub async fn register_utcp_provider(
101 &self,
102 client: Arc<dyn UtcpClientInterface>,
103 provider: Arc<dyn UtcpProvider>,
104 ) -> Result<Vec<UtcpTool>> {
105 let tools = client
106 .register_tool_provider(provider)
107 .await
108 .map_err(|e| AgentError::UtcpError(e.to_string()))?;
109
110 crate::utcp::register_utcp_tools(self.tool_catalog.as_ref(), client, tools.clone())?;
111 Ok(tools)
112 }
113
114 pub async fn register_utcp_provider_with_tools(
116 &self,
117 client: Arc<dyn UtcpClientInterface>,
118 provider: Arc<dyn UtcpProvider>,
119 tools: Vec<UtcpTool>,
120 ) -> Result<Vec<UtcpTool>> {
121 let registered_tools = client
122 .register_tool_provider_with_tools(provider, tools)
123 .await
124 .map_err(|e| AgentError::UtcpError(e.to_string()))?;
125
126 crate::utcp::register_utcp_tools(
127 self.tool_catalog.as_ref(),
128 client,
129 registered_tools.clone(),
130 )?;
131
132 Ok(registered_tools)
133 }
134
135 pub fn register_utcp_tools(
137 &self,
138 client: Arc<dyn UtcpClientInterface>,
139 tools: Vec<UtcpTool>,
140 ) -> Result<()> {
141 crate::utcp::register_utcp_tools(self.tool_catalog.as_ref(), client, tools)
142 }
143
144 pub fn as_utcp_tool(
146 &self,
147 name: impl Into<String>,
148 description: impl Into<String>,
149 ) -> UtcpTool {
150 let name = name.into();
151 let description = description.into();
152 let provider_name = name
153 .split('.')
154 .next()
155 .map(str::trim)
156 .filter(|s| !s.is_empty())
157 .unwrap_or("agent")
158 .to_string();
159
160 let inputs = ToolInputOutputSchema {
161 type_: "object".to_string(),
162 properties: Some(HashMap::from([
163 (
164 "instruction".to_string(),
165 json!({
166 "type": "string",
167 "description": "The instruction or query for the agent."
168 }),
169 ),
170 (
171 "session_id".to_string(),
172 json!({
173 "type": "string",
174 "description": "Optional session id; defaults to the provider-derived session."
175 }),
176 ),
177 ])),
178 required: Some(vec!["instruction".to_string()]),
179 description: Some("Call the agent with an instruction".to_string()),
180 title: Some("AgentInvocation".to_string()),
181 items: None,
182 enum_: None,
183 minimum: None,
184 maximum: None,
185 format: None,
186 };
187
188 let outputs = ToolInputOutputSchema {
189 type_: "object".to_string(),
190 properties: Some(HashMap::from([
191 ("response".to_string(), json!({ "type": "string" })),
192 ("session_id".to_string(), json!({ "type": "string" })),
193 ])),
194 required: None,
195 description: Some("Agent response payload".to_string()),
196 title: Some("AgentResponse".to_string()),
197 items: None,
198 enum_: None,
199 minimum: None,
200 maximum: None,
201 format: None,
202 };
203
204 UtcpTool {
205 name,
206 description,
207 inputs,
208 outputs,
209 tags: vec![
210 "agent".to_string(),
211 "rs-agent".to_string(),
212 "inproc".to_string(),
213 ],
214 average_response_size: None,
215 provider: Some(json!({
216 "name": provider_name,
217 "provider_type": "cli",
218 })),
219 }
220 }
221
222 pub async fn register_as_utcp_provider(
224 self: Arc<Self>,
225 utcp_client: &dyn UtcpClientInterface,
226 name: impl Into<String>,
227 description: impl Into<String>,
228 ) -> Result<()> {
229 let name = name.into();
230 let description = description.into();
231
232 let provider_name = name
233 .split('.')
234 .next()
235 .map(str::trim)
236 .filter(|s| !s.is_empty())
237 .unwrap_or("agent")
238 .to_string();
239
240 let tool_spec = self.as_utcp_tool(&name, &description);
241 let default_session = format!("{}.session", provider_name);
242 let agent = Arc::clone(&self);
243 let handler = Arc::new(move |args: HashMap<String, Value>| {
244 let agent = Arc::clone(&agent);
245 let default_session = default_session.clone();
246 async move {
247 let instruction = args
248 .get("instruction")
249 .and_then(|v| v.as_str())
250 .map(str::to_string)
251 .filter(|s| !s.trim().is_empty())
252 .ok_or_else(|| anyhow!("missing or invalid 'instruction'"))?;
253
254 let session_id = args
255 .get("session_id")
256 .and_then(|v| v.as_str())
257 .map(str::to_string)
258 .filter(|s| !s.trim().is_empty())
259 .unwrap_or_else(|| default_session.clone());
260
261 let content = agent
262 .generate(session_id, instruction)
263 .await
264 .map_err(|e| anyhow!(e.to_string()))?;
265
266 Ok(Value::String(content))
267 }
268 .boxed()
269 });
270
271 let inproc_tool = InProcessTool {
272 spec: tool_spec.clone(),
273 handler,
274 };
275
276 let transport = ensure_agent_cli_transport();
277 transport.register(&provider_name, inproc_tool);
278
279 let provider = CliProvider::new(
280 provider_name.clone(),
281 format!("rs-agent-{}", provider_name),
282 None,
283 );
284
285 utcp_client
286 .register_tool_provider_with_tools(Arc::new(provider), vec![tool_spec])
287 .await
288 .map_err(|e| AgentError::UtcpError(e.to_string()))?;
289
290 Ok(())
291 }
292
293 pub async fn generate(
295 &self,
296 session_id: impl Into<String>,
297 user_input: impl Into<String>,
298 ) -> Result<String> {
299 let response = self
300 .generate_internal(session_id.into(), user_input.into(), None)
301 .await?;
302
303 encode_default(&response).map_err(|e| AgentError::ToonFormatError(e.to_string()))
304 }
305
306 pub async fn generate_with_files(
308 &self,
309 session_id: impl Into<String>,
310 user_input: impl Into<String>,
311 files: Vec<File>,
312 ) -> Result<String> {
313 let response = self
314 .generate_internal(session_id.into(), user_input.into(), Some(files))
315 .await?;
316
317 Ok(response.content)
318 }
319
320 pub async fn generate_stream(
322 &self,
323 session_id: impl Into<String>,
324 user_input: impl Into<String>,
325 ) -> Result<BoxStream<'static, Result<GenerationChunk>>> {
326 let session_id = session_id.into();
327 let user_input = user_input.into();
328
329 self.store_memory(&session_id, "user", &user_input, None)
331 .await?;
332
333 if let Some((content, metadata)) = self
337 .try_codemode_orchestration(&session_id, &user_input)
338 .await?
339 {
340 self.store_memory(&session_id, "assistant", &content, metadata.clone())
341 .await?;
342
343 let chunk = GenerationChunk { content, metadata };
344 return Ok(futures::stream::once(async move { Ok(chunk) }).boxed());
345 }
346
347 let messages = self.build_prompt(&session_id, &user_input).await?;
349
350 let stream = self.model.stream_generate(messages, None).await?;
352 let memory = self.memory.clone();
353 let session_id_clone = session_id.clone();
354
355 let wrapped = futures::stream::unfold(
357 (stream, memory, session_id_clone, String::new(), false),
358 |(mut stream, memory, session_id, mut accumulated, finished)| async move {
359 if finished {
360 return None;
361 }
362
363 match stream.next().await {
364 Some(Ok(chunk)) => {
365 accumulated.push_str(&chunk.content);
366 Some((
367 Ok(chunk),
368 (stream, memory, session_id, accumulated, false),
369 ))
370 }
371 Some(Err(e)) => Some((Err(e), (stream, memory, session_id, accumulated, true))),
372 None => {
373 let record = MemoryRecord {
378 id: Uuid::new_v4(),
379 session_id: session_id.clone(),
380 role: "assistant".to_string(),
381 content: accumulated,
382 importance: 0.5,
383 timestamp: Utc::now(),
384 metadata: None,
385 embedding: None,
386 };
387
388 if let Err(e) = memory.store(record).await {
389 return Some((Err(AgentError::MemoryError(e.to_string())), (stream, memory, session_id, String::new(), true)));
390 }
391
392 None
393 }
394 }
395 },
396 );
397
398 Ok(wrapped.boxed())
399 }
400
401 pub async fn invoke_tool(
403 &self,
404 session_id: impl Into<String>,
405 tool_name: &str,
406 arguments: HashMap<String, serde_json::Value>,
407 ) -> Result<String> {
408 let session_id = session_id.into();
409
410 let request = ToolRequest {
411 session_id: session_id.clone(),
412 arguments,
413 };
414
415 let response = self.tool_catalog.invoke(tool_name, request).await?;
416
417 self.store_memory(
419 &session_id,
420 "tool",
421 &format!("Called {}: {}", tool_name, response.content),
422 response.metadata,
423 )
424 .await?;
425
426 Ok(response.content)
427 }
428
429 async fn build_prompt(&self, session_id: &str, user_input: &str) -> Result<Vec<Message>> {
431 let mut messages = Vec::new();
432
433 if !self.system_prompt.is_empty() {
435 messages.push(Message {
436 role: Role::System,
437 content: self.system_prompt.clone(),
438 metadata: None,
439 });
440 }
441
442 let mut available_tokens = self.context_limit;
443
444 available_tokens = available_tokens.saturating_sub(user_input.len() / 4);
446
447 let query_type = classify_query(user_input);
449
450 let recent_memories = self.memory.retrieve_recent(session_id).await?;
453 let mut context_messages = Vec::new();
454 let mut recent_ids = std::collections::HashSet::new();
455
456 let recent_token_limit = (available_tokens as f32 * 0.6) as usize;
458 let mut current_tokens = 0;
459
460 for record in recent_memories.iter().rev() {
461 let estimated_tokens = record.content.len() / 4;
462 if current_tokens + estimated_tokens > recent_token_limit {
463 break;
464 }
465
466 recent_ids.insert(record.id);
467 context_messages.push(record.clone());
468 current_tokens += estimated_tokens;
469 }
470
471 if matches!(query_type, QueryType::Complex | QueryType::Math) || context_messages.len() < 5 {
473 let search_limit = 20; let embeddings = self.memory.embed(user_input).await.unwrap_or_default();
475
476 if !embeddings.is_empty() {
477 let search_results = self
479 .memory
480 .search(session_id, user_input, search_limit)
481 .await?;
482
483 let candidates: Vec<MemoryRecord> = search_results
485 .into_iter()
486 .filter(|r| !recent_ids.contains(&r.id))
487 .collect();
488
489 let reranked = mmr_rerank_records(&embeddings, candidates, 5, 0.7);
492
493 for record in reranked {
494 let estimated_tokens = record.content.len() / 4;
495 if current_tokens + estimated_tokens > available_tokens {
496 break;
497 }
498
499 context_messages.push(record);
510 current_tokens += estimated_tokens;
511 }
512 }
513 }
514
515 context_messages.sort_by_key(|r| r.timestamp);
518
519 for record in context_messages {
520 messages.push(Message {
521 role: match record.role.as_str() {
522 "user" => Role::User,
523 "assistant" => Role::Assistant,
524 "tool" => Role::Tool,
525 _ => Role::User,
526 },
527 content: record.content.clone(),
528 metadata: record.metadata.clone(),
529 });
530 }
531
532 messages.push(Message {
534 role: Role::User,
535 content: user_input.to_string(),
536 metadata: None,
537 });
538
539 Ok(messages)
540 }
541
542 async fn generate_internal(
543 &self,
544 session_id: String,
545 user_input: String,
546 files: Option<Vec<File>>,
547 ) -> Result<GenerationResponse> {
548 self.store_memory(&session_id, "user", &user_input, None)
550 .await?;
551
552 let has_files = files.as_ref().map(|f| !f.is_empty()).unwrap_or(false);
554 if !has_files {
555 if let Some((content, metadata)) = self
556 .try_codemode_orchestration(&session_id, &user_input)
557 .await?
558 {
559 self.store_memory(&session_id, "assistant", &content, metadata.clone())
560 .await?;
561
562 return Ok(GenerationResponse { content, metadata });
563 }
564 }
565
566 let messages = self.build_prompt(&session_id, &user_input).await?;
568
569 let response = self.model.generate(messages, files).await?;
571
572 self.store_memory(&session_id, "assistant", &response.content, None)
574 .await?;
575
576 Ok(response)
577 }
578
579 fn set_codemode(&mut self, engine: Arc<CodeModeUtcp>) {
580 self.codemode = Some(engine.clone());
581 let _ = self
583 .tool_catalog
584 .register(Box::new(CodeModeTool::new(engine)));
585 }
586
587 async fn try_codemode_orchestration(
588 &self,
589 _session_id: &str,
590 user_input: &str,
591 ) -> Result<Option<(String, Option<HashMap<String, String>>)>> {
592 let orchestrator = match self.codemode_orchestrator.as_ref() {
593 Some(o) => o,
594 None => return Ok(None),
595 };
596
597 let value = orchestrator
598 .call_prompt(user_input)
599 .await
600 .map_err(|e| AgentError::Other(e.to_string()))?;
601
602 if let Some(v) = value {
603 let content = format_codemode_value(&v);
604 let metadata = Some(HashMap::from([(
605 "source".to_string(),
606 "codemode_orchestrator".to_string(),
607 )]));
608 return Ok(Some((content, metadata)));
609 }
610
611 Ok(None)
612 }
613
614 async fn store_memory(
616 &self,
617 session_id: &str,
618 role: &str,
619 content: &str,
620 metadata: Option<HashMap<String, String>>,
621 ) -> Result<()> {
622 let record = MemoryRecord {
623 id: Uuid::new_v4(),
624 session_id: session_id.to_string(),
625 role: role.to_string(),
626 content: content.to_string(),
627 importance: 0.5, timestamp: Utc::now(),
629 metadata,
630 embedding: None,
631 };
632
633 self.memory.store(record).await
634 }
635
636 pub async fn flush(&self, _session_id: &str) -> Result<()> {
638 self.memory.flush().await
639 }
640
641 pub fn tools(&self) -> Arc<ToolCatalog> {
643 Arc::clone(&self.tool_catalog)
644 }
645
646 pub async fn checkpoint(&self, session_id: &str) -> Result<Vec<u8>> {
648 let recent = self.memory.retrieve_recent(session_id).await?;
649
650 let state = AgentState {
651 system_prompt: self.system_prompt.clone(),
652 short_term: recent,
653 joined_spaces: None,
654 timestamp: Utc::now(),
655 };
656
657 serde_json::to_vec(&state).map_err(|e| AgentError::SerializationError(e))
658 }
659
660 pub async fn restore(&self, _session_id: &str, data: &[u8]) -> Result<()> {
662 let state: AgentState =
663 serde_json::from_slice(data).map_err(|e| AgentError::SerializationError(e))?;
664
665 for record in state.short_term {
667 self.memory.store(record).await?;
668 }
669
670 Ok(())
671 }
672}