1use crate::core::action::{ActionInvoker, ActionMetadata, ToolInvoker};
2use crate::core::agent::{Agent, AgentContext};
3use crate::core::error::Result;
4use crate::core::memory::{Memory, MemoryEntry, MemoryQuery};
5use crate::core::message::Message as CeylonMessage;
6use crate::llm::react::{ReActConfig, ReActEngine, ReActResult};
7use crate::llm::types::{Message as LlmMessage, ToolSpec};
8use crate::llm::{LLMClient, LLMConfig, LLMResponse, UniversalLLMClient};
9use async_trait::async_trait;
10use serde_json::{json, Value};
11use std::sync::Arc;
12use tokio::sync::Mutex;
13
14pub struct LlmAgent {
43 name: String,
44 llm_client: UniversalLLMClient,
45 llm_config: LLMConfig,
46 system_prompt: String,
47 conversation_history: Vec<LlmMessage>,
48 tool_invoker: ToolInvoker,
49 memory: Option<Arc<dyn Memory>>,
50 react_config: Option<ReActConfig>,
51}
52
53impl LlmAgent {
54 pub fn builder(name: impl Into<String>, model: impl Into<String>) -> LlmAgentBuilder {
71 LlmAgentBuilder::new(name, model)
72 }
73
74 pub fn new_with_config(
76 name: impl Into<String>,
77 config: LLMConfig,
78 system_prompt: impl Into<String>,
79 memory: Option<Arc<dyn Memory>>,
80 ) -> Result<Self> {
81 let client = UniversalLLMClient::new_with_config(config.clone())
82 .map_err(|e| crate::core::error::Error::MeshError(e))?;
83
84 let mut agent = Self {
85 name: name.into(),
86 llm_client: client,
87 llm_config: config,
88 system_prompt: system_prompt.into(),
89 conversation_history: Vec::new(),
90 tool_invoker: ToolInvoker::default(),
91 memory,
92 react_config: None,
93 };
94
95 if agent.memory.is_some() {
97 agent.register_memory_tools();
98 }
99
100 Ok(agent)
101 }
102
103 fn register_memory_tools(&mut self) {
105 if let Some(memory) = &self.memory {
106 self.tool_invoker
108 .register(Box::new(SaveMemoryAction::new(memory.clone())));
109
110 self.tool_invoker
112 .register(Box::new(SearchMemoryAction::new(memory.clone())));
113 }
114 }
115
116 fn action_to_tool_spec(action: &ActionMetadata) -> ToolSpec {
118 ToolSpec {
119 name: action.name.clone(),
120 description: action.description.clone(),
121 input_schema: action.input_schema.clone(),
122 }
123 }
124
125 pub fn with_react(&mut self, config: ReActConfig) {
127 self.react_config = Some(config);
128 }
129
130 pub async fn send_message_react(
132 &mut self,
133 message: impl Into<String>,
134 ctx: &mut AgentContext,
135 ) -> Result<ReActResult> {
136 let content = message.into();
137
138 let react_config = self.react_config.clone().ok_or_else(|| {
140 crate::core::error::Error::MeshError(
141 "ReAct mode not enabled. Call with_react() first".to_string(),
142 )
143 })?;
144
145 let engine = ReActEngine::new(react_config, None);
148
149 let result = engine
151 .execute(
152 content,
153 &self.llm_client,
154 &self.llm_config,
155 self.memory.as_ref(),
156 ctx,
157 )
158 .await?;
159
160 Ok(result)
161 }
162
163 pub async fn send_message_and_get_response(
167 &mut self,
168 message: impl Into<String>,
169 ctx: &mut AgentContext,
170 ) -> Result<String> {
171 let content = message.into();
172
173 self.conversation_history.push(LlmMessage {
175 role: "user".to_string(),
176 content,
177 });
178
179 self.process_with_llm(ctx).await
181 }
182
183 pub fn last_response(&self) -> Option<String> {
185 self.conversation_history
186 .iter()
187 .rev()
188 .find(|m| m.role == "assistant")
189 .map(|m| m.content.clone())
190 }
191
192 async fn process_with_llm(&mut self, ctx: &mut AgentContext) -> Result<String> {
194 if self.conversation_history.len() == 1 {
196 self.conversation_history.insert(
197 0,
198 LlmMessage {
199 role: "system".to_string(),
200 content: self.system_prompt.clone(),
201 },
202 );
203 }
204
205 let actions = self.tool_invoker.list_actions();
207 let tools: Vec<ToolSpec> = actions.iter().map(Self::action_to_tool_spec).collect();
208
209 let response: LLMResponse<String> = self
211 .llm_client
212 .complete::<LLMResponse<String>, String>(&self.conversation_history, &tools)
213 .await
214 .map_err(|e| crate::core::error::Error::MeshError(e))?;
215
216 if !response.is_complete && !response.tool_calls.is_empty() {
218 let mut tool_results = Vec::new();
219
220 for tool_call in response.tool_calls {
221 let result = self
222 .tool_invoker
223 .invoke(&tool_call.name, ctx, tool_call.input)
224 .await?;
225
226 tool_results.push(format!("Tool {}: {}", tool_call.name, result));
227 }
228
229 let tool_result_message = LlmMessage {
231 role: "tool".to_string(),
232 content: tool_results.join("\n"),
233 };
234 self.conversation_history.push(tool_result_message);
235
236 return Box::pin(self.process_with_llm(ctx)).await;
238 }
239
240 self.conversation_history.push(LlmMessage {
242 role: "assistant".to_string(),
243 content: response.content.clone(),
244 });
245
246 Ok(response.content)
247 }
248}
249
250#[async_trait]
251impl Agent for LlmAgent {
252 fn name(&self) -> String {
253 self.name.clone()
254 }
255
256 async fn on_message(&mut self, msg: CeylonMessage, ctx: &mut AgentContext) -> Result<()> {
257 let content = String::from_utf8(msg.payload.clone()).map_err(|e| {
259 crate::core::error::Error::MeshError(format!("Invalid UTF-8 in message payload: {}", e))
260 })?;
261
262 self.conversation_history.push(LlmMessage {
264 role: "user".to_string(),
265 content,
266 });
267
268 let _response = self.process_with_llm(ctx).await?;
270 Ok(())
271 }
272
273 async fn on_generic_message(
274 &mut self,
275 msg: crate::core::message::GenericMessage,
276 ctx: &mut AgentContext,
277 ) -> Result<crate::core::message::GenericResponse> {
278 let response_text = self.send_message_and_get_response(msg.content, ctx).await?;
280 Ok(crate::core::message::GenericResponse::new(response_text))
281 }
282
283 fn tool_invoker(&self) -> Option<&ToolInvoker> {
284 Some(&self.tool_invoker)
285 }
286
287 fn tool_invoker_mut(&mut self) -> Option<&mut ToolInvoker> {
288 Some(&mut self.tool_invoker)
289 }
290}
291
292pub struct LlmAgentBuilder {
294 name: String,
295 model: String,
296 api_key: Option<String>,
297 system_prompt: String,
298 temperature: Option<f32>,
299 max_tokens: Option<u32>,
300 memory: Option<Arc<dyn Memory>>,
301}
302
303impl LlmAgentBuilder {
304 fn new(name: impl Into<String>, model: impl Into<String>) -> Self {
305 Self {
306 name: name.into(),
307 model: model.into(),
308 api_key: None,
309 system_prompt: "You are a helpful AI assistant.".to_string(),
310 temperature: None,
311 max_tokens: None,
312 memory: None,
313 }
314 }
315
316 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
318 self.api_key = Some(api_key.into());
319 self
320 }
321
322 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
324 self.system_prompt = prompt.into();
325 self
326 }
327
328 pub fn with_temperature(mut self, temperature: f32) -> Self {
330 self.temperature = Some(temperature);
331 self
332 }
333
334 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
336 self.max_tokens = Some(max_tokens);
337 self
338 }
339
340 pub fn with_memory(mut self, memory: Arc<dyn Memory>) -> Self {
342 self.memory = Some(memory);
343 self
344 }
345
346 pub fn build(self) -> Result<LlmAgent> {
348 let mut config = LLMConfig::new(self.model);
350
351 if let Some(api_key) = self.api_key {
352 config = config.with_api_key(api_key);
353 }
354
355 if let Some(temperature) = self.temperature {
356 config = config.with_temperature(temperature);
357 }
358
359 if let Some(max_tokens) = self.max_tokens {
360 config = config.with_max_tokens(max_tokens);
361 }
362
363 LlmAgent::new_with_config(self.name, config, self.system_prompt, self.memory)
364 }
365}
366
367struct SaveMemoryAction {
370 memory: Arc<dyn Memory>,
371 metadata: ActionMetadata,
372}
373
374impl SaveMemoryAction {
375 fn new(memory: Arc<dyn Memory>) -> Self {
376 Self {
377 memory,
378 metadata: ActionMetadata {
379 name: "save_memory".to_string(),
380 description: "Save information to memory for later retrieval.".to_string(),
381 input_schema: json!({
382 "type": "object",
383 "properties": {
384 "content": {
385 "type": "string",
386 "description": "The information to save."
387 }
388 },
389 "required": ["content"]
390 }),
391 output_schema: Some(json!({
392 "type": "object",
393 "properties": {
394 "status": { "type": "string" },
395 "id": { "type": "string" }
396 }
397 })),
398 },
399 }
400 }
401}
402
403#[async_trait]
404impl ActionInvoker for SaveMemoryAction {
405 async fn execute(&self, _ctx: &mut AgentContext, inputs: Value) -> Result<Value> {
406 let content = inputs
407 .get("content")
408 .and_then(|v| v.as_str())
409 .ok_or_else(|| {
410 crate::core::error::Error::ActionExecutionError(
411 "Missing 'content' in inputs".to_string(),
412 )
413 })?;
414
415 let entry = MemoryEntry::new(content);
416 let id = self.memory.store(entry).await?;
417
418 Ok(json!({ "status": "success", "id": id }))
419 }
420
421 fn metadata(&self) -> &ActionMetadata {
422 &self.metadata
423 }
424}
425
426struct SearchMemoryAction {
427 memory: Arc<dyn Memory>,
428 metadata: ActionMetadata,
429}
430
431impl SearchMemoryAction {
432 fn new(memory: Arc<dyn Memory>) -> Self {
433 Self {
434 memory,
435 metadata: ActionMetadata {
436 name: "search_memory".to_string(),
437 description: "Search memory for relevant information.".to_string(),
438 input_schema: json!({
439 "type": "object",
440 "properties": {
441 "query": {
442 "type": "string",
443 "description": "The query to search for."
444 },
445 "limit": {
446 "type": "integer",
447 "description": "Max number of results (default 5)."
448 }
449 },
450 "required": ["query"]
451 }),
452 output_schema: Some(json!({
453 "type": "object",
454 "properties": {
455 "results": {
456 "type": "array",
457 "items": { "type": "string" }
458 }
459 }
460 })),
461 },
462 }
463 }
464}
465
466#[async_trait]
467impl ActionInvoker for SearchMemoryAction {
468 async fn execute(&self, _ctx: &mut AgentContext, inputs: Value) -> Result<Value> {
469 let query_str = inputs
470 .get("query")
471 .and_then(|v| v.as_str())
472 .ok_or_else(|| {
473 crate::core::error::Error::ActionExecutionError(
474 "Missing 'query' in inputs".to_string(),
475 )
476 })?;
477
478 let limit = inputs
479 .get("limit")
480 .and_then(|v| v.as_u64())
481 .map(|v| v as usize)
482 .unwrap_or(5);
483
484 let mut query = MemoryQuery::new().with_limit(limit);
485 query.semantic_query = Some(query_str.to_string());
486
487 let results = self.memory.search(query).await?;
488
489 let result_strings: Vec<String> = results.into_iter().map(|e| e.content).collect();
490
491 Ok(json!({ "results": result_strings }))
492 }
493
494 fn metadata(&self) -> &ActionMetadata {
495 &self.metadata
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 #[tokio::test]
508 async fn test_llm_agent_builder() {
509 let builder = LlmAgent::builder("test", "ollama::llama2")
512 .with_system_prompt("Custom prompt")
513 .with_temperature(0.7)
514 .with_max_tokens(1000);
515
516 assert_eq!(builder.name, "test");
518 assert_eq!(builder.model, "ollama::llama2");
519 assert_eq!(builder.system_prompt, "Custom prompt");
520 assert_eq!(builder.temperature, Some(0.7));
521 assert_eq!(builder.max_tokens, Some(1000));
522 }
523}