1use std::collections::HashMap;
7use std::sync::Arc;
8
9use anyhow::anyhow;
10use chrono::Utc;
11use futures::stream::BoxStream;
12use futures::{FutureExt, StreamExt};
13use rs_utcp::plugins::codemode::{CodeModeUtcp, CodemodeOrchestrator};
14use rs_utcp::providers::base::Provider as UtcpProvider;
15use rs_utcp::providers::cli::CliProvider;
16use rs_utcp::tools::Tool as UtcpTool;
17use rs_utcp::tools::ToolInputOutputSchema;
18use rs_utcp::UtcpClientInterface;
19use serde_json::{json, Value};
20use toon_format::encode_default;
21use uuid::Uuid;
22
23use crate::agent_orchestrators::{build_orchestrator, format_codemode_value, CodeModeTool};
24use crate::agent_tool::{ensure_agent_cli_transport, InProcessTool};
25use crate::error::{AgentError, Result};
26use crate::memory::{MemoryRecord, SessionMemory};
27use crate::models::LLM;
28use crate::tools::ToolCatalog;
29use crate::types::{
30 AgentOptions, AgentState, File, GenerationChunk, GenerationResponse, Message, Role, ToolRequest,
31};
32
33const DEFAULT_SYSTEM_PROMPT: &str = "You are a helpful AI assistant. Provide concise, accurate answers and explain when you use tools.";
34
35pub struct Agent {
40 model: Arc<dyn LLM>,
41 memory: Arc<SessionMemory>,
42 system_prompt: String,
43 context_limit: usize,
44 tool_catalog: Arc<ToolCatalog>,
45 codemode: Option<Arc<CodeModeUtcp>>,
46 codemode_orchestrator: Option<Arc<CodemodeOrchestrator>>,
47}
48
49impl Agent {
50 pub fn new(model: Arc<dyn LLM>, memory: Arc<SessionMemory>, options: AgentOptions) -> Self {
52 Self {
53 model,
54 memory,
55 system_prompt: options
56 .system_prompt
57 .unwrap_or_else(|| DEFAULT_SYSTEM_PROMPT.to_string()),
58 context_limit: options.context_limit.unwrap_or(8192),
59 tool_catalog: Arc::new(ToolCatalog::new()),
60 codemode: None,
61 codemode_orchestrator: None,
62 }
63 }
64
65 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
67 self.system_prompt = prompt.into();
68 self
69 }
70
71 pub fn with_tools(mut self, catalog: Arc<ToolCatalog>) -> Self {
73 self.tool_catalog = catalog;
74 self
75 }
76
77 pub fn with_codemode(mut self, engine: Arc<CodeModeUtcp>) -> Self {
79 self.set_codemode(engine);
80 self
81 }
82
83 pub fn with_codemode_orchestrator(
86 mut self,
87 engine: Arc<CodeModeUtcp>,
88 orchestrator_model: Option<Arc<dyn LLM>>,
89 ) -> Self {
90 self.set_codemode(engine.clone());
91
92 let llm = orchestrator_model.unwrap_or_else(|| Arc::clone(&self.model));
93 let orchestrator = build_orchestrator(engine, llm);
94 self.codemode_orchestrator = Some(Arc::new(orchestrator));
95 self
96 }
97
98 pub async fn register_utcp_provider(
100 &self,
101 client: Arc<dyn UtcpClientInterface>,
102 provider: Arc<dyn UtcpProvider>,
103 ) -> Result<Vec<UtcpTool>> {
104 let tools = client
105 .register_tool_provider(provider)
106 .await
107 .map_err(|e| AgentError::UtcpError(e.to_string()))?;
108
109 crate::utcp::register_utcp_tools(self.tool_catalog.as_ref(), client, tools.clone())?;
110 Ok(tools)
111 }
112
113 pub async fn register_utcp_provider_with_tools(
115 &self,
116 client: Arc<dyn UtcpClientInterface>,
117 provider: Arc<dyn UtcpProvider>,
118 tools: Vec<UtcpTool>,
119 ) -> Result<Vec<UtcpTool>> {
120 let registered_tools = client
121 .register_tool_provider_with_tools(provider, tools)
122 .await
123 .map_err(|e| AgentError::UtcpError(e.to_string()))?;
124
125 crate::utcp::register_utcp_tools(
126 self.tool_catalog.as_ref(),
127 client,
128 registered_tools.clone(),
129 )?;
130
131 Ok(registered_tools)
132 }
133
134 pub fn register_utcp_tools(
136 &self,
137 client: Arc<dyn UtcpClientInterface>,
138 tools: Vec<UtcpTool>,
139 ) -> Result<()> {
140 crate::utcp::register_utcp_tools(self.tool_catalog.as_ref(), client, tools)
141 }
142
143 pub fn as_utcp_tool(
145 &self,
146 name: impl Into<String>,
147 description: impl Into<String>,
148 ) -> UtcpTool {
149 let name = name.into();
150 let description = description.into();
151 let provider_name = name
152 .split('.')
153 .next()
154 .map(str::trim)
155 .filter(|s| !s.is_empty())
156 .unwrap_or("agent")
157 .to_string();
158
159 let inputs = ToolInputOutputSchema {
160 type_: "object".to_string(),
161 properties: Some(HashMap::from([
162 (
163 "instruction".to_string(),
164 json!({
165 "type": "string",
166 "description": "The instruction or query for the agent."
167 }),
168 ),
169 (
170 "session_id".to_string(),
171 json!({
172 "type": "string",
173 "description": "Optional session id; defaults to the provider-derived session."
174 }),
175 ),
176 ])),
177 required: Some(vec!["instruction".to_string()]),
178 description: Some("Call the agent with an instruction".to_string()),
179 title: Some("AgentInvocation".to_string()),
180 items: None,
181 enum_: None,
182 minimum: None,
183 maximum: None,
184 format: None,
185 };
186
187 let outputs = ToolInputOutputSchema {
188 type_: "object".to_string(),
189 properties: Some(HashMap::from([
190 ("response".to_string(), json!({ "type": "string" })),
191 ("session_id".to_string(), json!({ "type": "string" })),
192 ])),
193 required: None,
194 description: Some("Agent response payload".to_string()),
195 title: Some("AgentResponse".to_string()),
196 items: None,
197 enum_: None,
198 minimum: None,
199 maximum: None,
200 format: None,
201 };
202
203 UtcpTool {
204 name,
205 description,
206 inputs,
207 outputs,
208 tags: vec![
209 "agent".to_string(),
210 "rs-agent".to_string(),
211 "inproc".to_string(),
212 ],
213 average_response_size: None,
214 provider: Some(json!({
215 "name": provider_name,
216 "provider_type": "cli",
217 })),
218 }
219 }
220
221 pub async fn register_as_utcp_provider(
223 self: Arc<Self>,
224 utcp_client: &dyn UtcpClientInterface,
225 name: impl Into<String>,
226 description: impl Into<String>,
227 ) -> Result<()> {
228 let name = name.into();
229 let description = description.into();
230
231 let provider_name = name
232 .split('.')
233 .next()
234 .map(str::trim)
235 .filter(|s| !s.is_empty())
236 .unwrap_or("agent")
237 .to_string();
238
239 let tool_spec = self.as_utcp_tool(&name, &description);
240 let default_session = format!("{}.session", provider_name);
241 let agent = Arc::clone(&self);
242 let handler = Arc::new(move |args: HashMap<String, Value>| {
243 let agent = Arc::clone(&agent);
244 let default_session = default_session.clone();
245 async move {
246 let instruction = args
247 .get("instruction")
248 .and_then(|v| v.as_str())
249 .map(str::to_string)
250 .filter(|s| !s.trim().is_empty())
251 .ok_or_else(|| anyhow!("missing or invalid 'instruction'"))?;
252
253 let session_id = args
254 .get("session_id")
255 .and_then(|v| v.as_str())
256 .map(str::to_string)
257 .filter(|s| !s.trim().is_empty())
258 .unwrap_or_else(|| default_session.clone());
259
260 let content = agent
261 .generate(session_id, instruction)
262 .await
263 .map_err(|e| anyhow!(e.to_string()))?;
264
265 Ok(Value::String(content))
266 }
267 .boxed()
268 });
269
270 let inproc_tool = InProcessTool {
271 spec: tool_spec.clone(),
272 handler,
273 };
274
275 let transport = ensure_agent_cli_transport();
276 transport.register(&provider_name, inproc_tool);
277
278 let provider = CliProvider::new(
279 provider_name.clone(),
280 format!("rs-agent-{}", provider_name),
281 None,
282 );
283
284 utcp_client
285 .register_tool_provider_with_tools(Arc::new(provider), vec![tool_spec])
286 .await
287 .map_err(|e| AgentError::UtcpError(e.to_string()))?;
288
289 Ok(())
290 }
291
292 pub async fn generate(
294 &self,
295 session_id: impl Into<String>,
296 user_input: impl Into<String>,
297 ) -> Result<String> {
298 let response = self
299 .generate_internal(session_id.into(), user_input.into(), None)
300 .await?;
301
302 encode_default(&response).map_err(|e| AgentError::ToonFormatError(e.to_string()))
303 }
304
305 pub async fn generate_with_files(
307 &self,
308 session_id: impl Into<String>,
309 user_input: impl Into<String>,
310 files: Vec<File>,
311 ) -> Result<String> {
312 let response = self
313 .generate_internal(session_id.into(), user_input.into(), Some(files))
314 .await?;
315
316 Ok(response.content)
317 }
318
319 pub async fn generate_stream(
321 &self,
322 session_id: impl Into<String>,
323 user_input: impl Into<String>,
324 ) -> Result<BoxStream<'static, Result<GenerationChunk>>> {
325 let session_id = session_id.into();
326 let user_input = user_input.into();
327
328 self.store_memory(&session_id, "user", &user_input, None)
330 .await?;
331
332 if let Some((content, metadata)) = self
336 .try_codemode_orchestration(&session_id, &user_input)
337 .await?
338 {
339 self.store_memory(&session_id, "assistant", &content, metadata.clone())
340 .await?;
341
342 let chunk = GenerationChunk { content, metadata };
343 return Ok(futures::stream::once(async move { Ok(chunk) }).boxed());
344 }
345
346 let messages = self.build_prompt(&session_id, &user_input).await?;
348
349 let stream = self.model.stream_generate(messages, None).await?;
351 let memory = self.memory.clone();
352 let session_id_clone = session_id.clone();
353
354 let wrapped = futures::stream::unfold(
356 (stream, memory, session_id_clone, String::new(), false),
357 |(mut stream, memory, session_id, mut accumulated, finished)| async move {
358 if finished {
359 return None;
360 }
361
362 match stream.next().await {
363 Some(Ok(chunk)) => {
364 accumulated.push_str(&chunk.content);
365 Some((
366 Ok(chunk),
367 (stream, memory, session_id, accumulated, false),
368 ))
369 }
370 Some(Err(e)) => Some((Err(e), (stream, memory, session_id, accumulated, true))),
371 None => {
372 let record = MemoryRecord {
377 id: Uuid::new_v4(),
378 session_id: session_id.clone(),
379 role: "assistant".to_string(),
380 content: accumulated,
381 importance: 0.5,
382 timestamp: Utc::now(),
383 metadata: None,
384 embedding: None,
385 };
386
387 if let Err(e) = memory.store(record).await {
388 return Some((Err(AgentError::MemoryError(e.to_string())), (stream, memory, session_id, String::new(), true)));
389 }
390
391 None
392 }
393 }
394 },
395 );
396
397 Ok(wrapped.boxed())
398 }
399
400 pub async fn invoke_tool(
402 &self,
403 session_id: impl Into<String>,
404 tool_name: &str,
405 arguments: HashMap<String, serde_json::Value>,
406 ) -> Result<String> {
407 let session_id = session_id.into();
408
409 let request = ToolRequest {
410 session_id: session_id.clone(),
411 arguments,
412 };
413
414 let response = self.tool_catalog.invoke(tool_name, request).await?;
415
416 self.store_memory(
418 &session_id,
419 "tool",
420 &format!("Called {}: {}", tool_name, response.content),
421 response.metadata,
422 )
423 .await?;
424
425 Ok(response.content)
426 }
427
428 async fn build_prompt(&self, session_id: &str, user_input: &str) -> Result<Vec<Message>> {
430 let mut messages = Vec::new();
431
432 if !self.system_prompt.is_empty() {
434 messages.push(Message {
435 role: Role::System,
436 content: self.system_prompt.clone(),
437 metadata: None,
438 });
439 }
440
441 let recent_memories = self.memory.retrieve_recent(session_id).await?;
443
444 let mut token_count = 0;
446 for record in recent_memories.iter().rev() {
447 let estimated_tokens = record.content.len() / 4;
449 if token_count + estimated_tokens > self.context_limit {
450 break;
451 }
452
453 messages.push(Message {
454 role: match record.role.as_str() {
455 "user" => Role::User,
456 "assistant" => Role::Assistant,
457 "tool" => Role::Tool,
458 _ => Role::User,
459 },
460 content: record.content.clone(),
461 metadata: record.metadata.clone(),
462 });
463
464 token_count += estimated_tokens;
465 }
466
467 messages.push(Message {
469 role: Role::User,
470 content: user_input.to_string(),
471 metadata: None,
472 });
473
474 Ok(messages)
475 }
476
477 async fn generate_internal(
478 &self,
479 session_id: String,
480 user_input: String,
481 files: Option<Vec<File>>,
482 ) -> Result<GenerationResponse> {
483 self.store_memory(&session_id, "user", &user_input, None)
485 .await?;
486
487 let has_files = files.as_ref().map(|f| !f.is_empty()).unwrap_or(false);
489 if !has_files {
490 if let Some((content, metadata)) = self
491 .try_codemode_orchestration(&session_id, &user_input)
492 .await?
493 {
494 self.store_memory(&session_id, "assistant", &content, metadata.clone())
495 .await?;
496
497 return Ok(GenerationResponse { content, metadata });
498 }
499 }
500
501 let messages = self.build_prompt(&session_id, &user_input).await?;
503
504 let response = self.model.generate(messages, files).await?;
506
507 self.store_memory(&session_id, "assistant", &response.content, None)
509 .await?;
510
511 Ok(response)
512 }
513
514 fn set_codemode(&mut self, engine: Arc<CodeModeUtcp>) {
515 self.codemode = Some(engine.clone());
516 let _ = self
518 .tool_catalog
519 .register(Box::new(CodeModeTool::new(engine)));
520 }
521
522 async fn try_codemode_orchestration(
523 &self,
524 _session_id: &str,
525 user_input: &str,
526 ) -> Result<Option<(String, Option<HashMap<String, String>>)>> {
527 let orchestrator = match self.codemode_orchestrator.as_ref() {
528 Some(o) => o,
529 None => return Ok(None),
530 };
531
532 let value = orchestrator
533 .call_prompt(user_input)
534 .await
535 .map_err(|e| AgentError::Other(e.to_string()))?;
536
537 if let Some(v) = value {
538 let content = format_codemode_value(&v);
539 let metadata = Some(HashMap::from([(
540 "source".to_string(),
541 "codemode_orchestrator".to_string(),
542 )]));
543 return Ok(Some((content, metadata)));
544 }
545
546 Ok(None)
547 }
548
549 async fn store_memory(
551 &self,
552 session_id: &str,
553 role: &str,
554 content: &str,
555 metadata: Option<HashMap<String, String>>,
556 ) -> Result<()> {
557 let record = MemoryRecord {
558 id: Uuid::new_v4(),
559 session_id: session_id.to_string(),
560 role: role.to_string(),
561 content: content.to_string(),
562 importance: 0.5, timestamp: Utc::now(),
564 metadata,
565 embedding: None,
566 };
567
568 self.memory.store(record).await
569 }
570
571 pub async fn flush(&self, _session_id: &str) -> Result<()> {
573 self.memory.flush().await
574 }
575
576 pub fn tools(&self) -> Arc<ToolCatalog> {
578 Arc::clone(&self.tool_catalog)
579 }
580
581 pub async fn checkpoint(&self, session_id: &str) -> Result<Vec<u8>> {
583 let recent = self.memory.retrieve_recent(session_id).await?;
584
585 let state = AgentState {
586 system_prompt: self.system_prompt.clone(),
587 short_term: recent,
588 joined_spaces: None,
589 timestamp: Utc::now(),
590 };
591
592 serde_json::to_vec(&state).map_err(|e| AgentError::SerializationError(e))
593 }
594
595 pub async fn restore(&self, _session_id: &str, data: &[u8]) -> Result<()> {
597 let state: AgentState =
598 serde_json::from_slice(data).map_err(|e| AgentError::SerializationError(e))?;
599
600 for record in state.short_term {
602 self.memory.store(record).await?;
603 }
604
605 Ok(())
606 }
607}