mini_langchain/
agent.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use crate::llm::traits::LLM;
4use crate::message::Message;
5use crate::tools::{
6    traits::Tool,
7    schema::ToolSchema,
8};
9use serde_json::json;
10
11
12pub mod types;
13pub mod error;
14pub mod traits;
15
16use traits::AgentRunner;
17use types::{Agent,AgentResult,AgentExecuteResult};
18use error::AgentError;
19
20
21impl Agent {
22    /// Create a new Agent with the provided name and LLM. Tools start empty.
23    pub fn new(name: impl Into<String>, llm: Arc<dyn LLM>,max_iterations:Option<usize>) -> Self {
24        Self {
25            name: name.into(),
26            llm,
27            tools: HashMap::new(),
28            memory: Vec::new(),
29            system_prompt: None,
30            max_iterations: max_iterations.unwrap_or(100) ,
31        }
32    }
33
34    /// Register a tool under the given name. Replaces any existing tool with the same name. Returns &mut Self for chaining.
35    pub fn register_tool(&mut self, name: Option<&str>, tool: Arc<dyn Tool>) -> &mut Self {
36        // If no name is provided, use the tool's own name.
37        let name = name.unwrap_or_else(|| tool.name());
38        self.tools.insert(name.into(), tool);
39        self
40    }
41
42    /// Change the maximum iterations for the agent's decision process.
43    pub fn change_max_iterations(&mut self, max_iterations: usize) {
44        self.max_iterations = max_iterations;
45    }   
46
47    /// Look up a tool by name.
48    pub fn get_tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
49        self.tools.get(name).cloned()
50    }
51
52    /// Set or replace the agent's system prompt.
53    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
54        self.system_prompt = Some(prompt.into());
55    }
56
57    // generate system prompt
58    pub fn generate_system_prompt(&self) -> Vec<Message> {
59        let mut msgs = Vec::new();
60        if let Some(prompt) = self.system_prompt.as_ref() {
61            msgs.push(Message::system(prompt.clone()));
62        }
63        if !self.tools.is_empty() {
64            msgs.push(Message::developer(
65                format!("I also provide some tools for you to choose from. If you want to call a tool, please include the following JSON format in your response: {}", 
66                json!({
67                    "tool_calls": [
68                        {
69                            "name": "tool_name",
70                            "args": {
71                                "param1": "value1",
72                                "param2": "value2"
73                            }
74                        }
75                    ]
76                }).to_string())
77            ));
78        }
79        msgs
80    }
81
82    // 生成工具提示
83    pub fn generate_tools_prompt(&self) -> Vec<Message> {
84        self.tools.iter().map(|(name, tool)| {
85            let schema = ToolSchema {
86                name: name.clone(),
87                description: tool.description().to_string(),
88                args: tool.args(),
89            };
90            
91            Message::system(serde_json::to_string(&schema).unwrap())
92        }).collect()
93    }
94}
95
96
97
98#[async_trait::async_trait]
99impl AgentRunner for Agent {
100    async fn call_llm(&self, prompt: &str) -> AgentExecuteResult {
101        // Build a sequence of messages so LLM implementations that support
102        // system/user roles can consume them properly.
103        let mut msgs: Vec<Message> = self.generate_system_prompt();
104        let tool_msgs = self.generate_tools_prompt();
105        msgs.extend(tool_msgs);
106        msgs.push(Message::user(prompt.to_string()));
107        let mut result = AgentResult::default();
108        let mut  counter:usize = 0;
109
110        while counter < self.max_iterations {
111            // Call the LLM to get a response.
112            let res = self.llm.generate(&msgs).await?;
113            // get assistant message
114            let msg = Message::assistant(res.generation.clone());
115            result.tokens.prompt_tokens += res.tokens.prompt_tokens;
116            result.tokens.completion_tokens += res.tokens.completion_tokens;
117            result.tokens.total_tokens += res.tokens.total_tokens;
118            // update generation
119            result.generation = res.generation.clone();
120
121            counter += 1;
122            // check if there are tool calls
123            if !res.call_tools.is_empty() {
124                // add assistant message
125                msgs.push(msg);
126                // process tool calls
127                for call_info in res.call_tools.into_iter(){
128                    let name = call_info.name.clone();
129                    if let Some(tool_impl) = self.tools.get(&name){
130                        let tool_result = tool_impl.run(call_info.args).await?;
131                        let tool_res_msg = Message::tool_res(
132                            &call_info.name,
133                            format!("Tool {} returned: {}", &name, tool_result));
134                        msgs.push(tool_res_msg);
135                    }else{
136                        return Err(AgentError::ToolNotFound(call_info.name.clone()));
137                    }
138                }
139            } else {
140                return Ok(result);
141            }
142        }
143        Err(AgentError::MaxIterationsExceeded(self.max_iterations))
144    }
145
146        
147    
148}