Skip to main content

abu_agent/
lib.rs

1pub mod error;
2use abu_provider::ChatProvide;
3use abu_tool::ToolCallResult;
4use context::ContextBuilder;
5pub use error::*;
6use memory::Memory;
7
8pub mod kit;
9pub mod memory;
10pub mod context;
11pub mod prompt;
12pub mod build;
13
14pub use build::AgentBuilder;
15use abu_base::chat::{AssistantMessage, ChatMessage, ChatRequest, ChatRequestBuilder, ToolCall, ToolDefinition};
16use thiserrorctx::Context;
17use crate::kit::AgentKit;
18use tracing::{debug, info, warn};
19
20#[derive(Clone)]
21pub struct AgentConfig {
22    pub max_iteration: usize,
23    pub temperature: f64,
24}
25
26pub struct Agent<C: ChatProvide, M: Memory> {
27    pub config: AgentConfig,
28    pub llm: C,
29    pub model: String,
30    pub memory: M,
31    pub context_builder: ContextBuilder,
32    pub kit: AgentKit,
33}
34
35impl<C: ChatProvide, M: Memory> Agent<C, M> {
36    // pub async fn tool_list(&self) -> RwLockReadGuard<'_, [ToolDefinition]> {
37    //     let gurad = self.kit.read().await;
38    //     RwLockReadGuard::map(gurad, |kit| kit.tool_definitions())
39    // }
40
41    pub fn tool_list(&self) -> &[ToolDefinition] {
42        self.kit.tool_definitions()
43    }
44
45    pub fn system_prompt(&self) -> &str {
46        &self.context_builder.system_prompt
47    }
48
49    pub async fn run(&mut self, query: &str) -> AgentResult<String> {
50        info!(query = %query, "🤖 Agent started with user query");
51        
52        let mut request = self.init_chat_request(query, true).await?;
53
54        // agent loop
55        let mut final_result = None; 
56        for step in 0..self.config.max_iteration {
57            info!(step, "🔄 Agent step begin");
58            let ai_message = self.send_chat_request(&request).await?;
59
60            // insert ai response
61            request.messages.push(ai_message.clone().into());
62
63            info!(step, role = "AI", content = ai_message.content, "🗣️ LLM Text Response");
64            if !ai_message.tool_calls.is_empty() {
65                info!(step, count = ai_message.tool_calls.len(), "🛠️ LLM requested tool calls");
66            } else {
67                final_result = Some(ai_message.content);
68                break;
69            }
70
71            // tool calls
72            for tool_call in ai_message.tool_calls.into_iter() {
73                info!(step, tool = %tool_call.name, id = %tool_call.id, args = %tool_call.arguments, "🚀 Executing tool");
74
75                let (id, result) = self.execute_tool(tool_call).await.context("execute tool")?;
76                let tool_content = if result.is_error {
77                    info!(step, result = %result.context, "Tool execute failed!");
78                    format!("Tool execute failed for {}", result.context)
79                } else {
80                    info!(step, result = %result.context, "✅ Tool execution finished");
81                    format!("Tool execute success with output {}", result.context)
82                };
83
84                // insert tool response
85                request.messages.push(ChatMessage::tool(tool_content, id));
86            }
87        }
88
89        match final_result {
90            Some(final_result) => {
91                info!(output = final_result, "🛑 Finsh task with final output");
92                self.add_memory(query, &final_result).await?;
93                Ok(final_result)
94            }
95            None => {
96                warn!("Agent reached max steps without termination");
97                Ok("Task do not finish yet".to_string())
98            }
99        }
100    }
101
102    pub async fn chat(&mut self, query: &str) -> AgentResult<String> {
103        info!(query = %query, "🤖 Agent started with user query");
104
105        let request = self.init_chat_request(query, false).await?;
106        let ai_message = self.send_chat_request(&request).await?;
107
108        info!(role = "AI", content = ai_message.content, "🗣️ LLM Text Response");
109        self.add_memory(query, &ai_message.content).await?;
110
111        Ok(ai_message.content)
112    }
113
114    async fn execute_tool(&mut self, tool_call: ToolCall) -> AgentResult<(String, ToolCallResult)> {
115        let id = tool_call.id;
116        let name = tool_call.name;
117        let arguments = serde_json::from_str(&tool_call.arguments)?;
118        let result = self.kit.execute_tool(name, arguments).await.context("execute tool")?;
119        Ok((id, result))
120    }
121
122    async fn add_memory(&mut self, user_input: &str, ai_response: &str) -> AgentResult<()> {
123        debug!(user_input = user_input, ai_response = ai_response, "add memory");
124        self.memory
125            .add(user_input, ai_response).await
126            .map_err(|e| AgentError::Memory(Box::new(e)))
127            .context("add new memory")?;
128        Ok(())
129    }
130
131    async fn send_chat_request(&self, request: &ChatRequest) -> AgentResult<AssistantMessage> {
132        let response = self.llm
133            .chat(&request).await
134            .map_err(|e| AgentError::ChatProvider(Box::new(e)))?;
135        Ok(response.message)
136    }
137
138    async fn init_chat_request(&mut self, query: &str, with_tool: bool) -> AgentResult<ChatRequest> {
139        let memorys: Vec<ChatMessage> = self.memory.search(query).await
140            .map_err(|e| AgentError::Memory(Box::new(e)))
141            .context("search memory")?;
142
143        let messages: Vec<ChatMessage> = self.context_builder.build(query, memorys);
144
145        let mut builder = ChatRequestBuilder::default();
146        builder
147            .model(&self.model)
148            .messages(messages)
149            .temperature(self.config.temperature);
150        
151        if with_tool {
152            builder.tools(self.kit.tool_definitions());
153        } 
154
155        let req = builder.build()?;
156        Ok(req)
157    }
158}