1use std::collections::HashMap;
2use std::sync::Arc;
3use crate::llm::traits::LLM;
4use crate::message::Message;
5use crate::tools::{
6 traits::Tool,
7 schema::ToolSchema,
8};
9use serde_json::json;
10
11
12pub mod types;
13pub mod error;
14pub mod traits;
15
16use traits::AgentRunner;
17use types::{Agent,AgentResult,AgentExecuteResult};
18use error::AgentError;
19
20
21impl Agent {
22 pub fn new(name: impl Into<String>, llm: Arc<dyn LLM>,max_iterations:Option<usize>) -> Self {
24 Self {
25 name: name.into(),
26 llm,
27 tools: HashMap::new(),
28 memory: Vec::new(),
29 system_prompt: None,
30 max_iterations: max_iterations.unwrap_or(100) ,
31 }
32 }
33
34 pub fn register_tool(&mut self, name: Option<&str>, tool: Arc<dyn Tool>) -> &mut Self {
36 let name = name.unwrap_or_else(|| tool.name());
38 self.tools.insert(name.into(), tool);
39 self
40 }
41
42 pub fn change_max_iterations(&mut self, max_iterations: usize) {
44 self.max_iterations = max_iterations;
45 }
46
47 pub fn get_tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
49 self.tools.get(name).cloned()
50 }
51
52 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
54 self.system_prompt = Some(prompt.into());
55 }
56
57 pub fn generate_system_prompt(&self) -> Vec<Message> {
59 let mut msgs = Vec::new();
60 if let Some(prompt) = self.system_prompt.as_ref() {
61 msgs.push(Message::system(prompt.clone()));
62 }
63 if !self.tools.is_empty() {
64 msgs.push(Message::developer(
65 format!("I also provide some tools for you to choose from. If you want to call a tool, please include the following JSON format in your response: {}",
66 json!({
67 "tool_calls": [
68 {
69 "name": "tool_name",
70 "args": {
71 "param1": "value1",
72 "param2": "value2"
73 }
74 }
75 ]
76 }).to_string())
77 ));
78 }
79 msgs
80 }
81
82 pub fn generate_tools_prompt(&self) -> Vec<Message> {
84 self.tools.iter().map(|(name, tool)| {
85 let schema = ToolSchema {
86 name: name.clone(),
87 description: tool.description().to_string(),
88 args: tool.args(),
89 };
90
91 Message::system(serde_json::to_string(&schema).unwrap())
92 }).collect()
93 }
94}
95
96
97
98#[async_trait::async_trait]
99impl AgentRunner for Agent {
100 async fn call_llm(&self, prompt: &str) -> AgentExecuteResult {
101 let mut msgs: Vec<Message> = self.generate_system_prompt();
104 let tool_msgs = self.generate_tools_prompt();
105 msgs.extend(tool_msgs);
106 msgs.push(Message::user(prompt.to_string()));
107 let mut result = AgentResult::default();
108 let mut counter:usize = 0;
109
110 while counter < self.max_iterations {
111 let res = self.llm.generate(&msgs).await?;
113 let msg = Message::assistant(res.generation.clone());
115 result.tokens.prompt_tokens += res.tokens.prompt_tokens;
116 result.tokens.completion_tokens += res.tokens.completion_tokens;
117 result.tokens.total_tokens += res.tokens.total_tokens;
118 result.generation = res.generation.clone();
120
121 counter += 1;
122 if !res.call_tools.is_empty() {
124 msgs.push(msg);
126 for call_info in res.call_tools.into_iter(){
128 let name = call_info.name.clone();
129 if let Some(tool_impl) = self.tools.get(&name){
130 let tool_result = tool_impl.run(call_info.args).await?;
131 let tool_res_msg = Message::tool_res(
132 &call_info.name,
133 format!("Tool {} returned: {}", &name, tool_result));
134 msgs.push(tool_res_msg);
135 }else{
136 return Err(AgentError::ToolNotFound(call_info.name.clone()));
137 }
138 }
139 } else {
140 return Ok(result);
141 }
142 }
143 Err(AgentError::MaxIterationsExceeded(self.max_iterations))
144 }
145
146
147
148}