1use crate::chat::{Completion, Document, Message, Request};
2use crate::executor::Executor;
3use crate::knowledge::Knowledge;
4use crate::mcp::{MCPClient, MCPError, setup_mcp_clients, sse_client, stdio_client};
5use crate::memory::Memory;
6use crate::store::{Storage, VectorStoreError};
7use crate::task::TaskError;
8use crate::tool::Tool;
9use crate::{Ref, make_ref};
10use futures::{StreamExt, TryStreamExt, stream};
11use std::collections::HashMap;
12use std::path::Path;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use uuid::Uuid;
16
17pub struct Agent<M: Completion> {
18 pub model: Ref<M>,
20 pub store_indices: Vec<(usize, Box<dyn Storage>)>,
22 pub tools: Ref<Vec<Box<dyn Tool>>>,
24 pub knowledges: Arc<Vec<Box<dyn Knowledge>>>,
26 pub memory: Option<Ref<dyn Memory>>,
28 pub id: Uuid,
30 pub name: String,
32 pub preamble: String,
34 pub temperature: Option<f32>,
36 pub max_tokens: Option<usize>,
38 pub max_execution_time: Option<usize>,
40 pub respect_context_window: bool,
42 pub allow_code_execution: bool,
44 mcp_clients: Ref<Vec<MCPClient>>,
46}
47
48impl<M: Completion> Agent<M>
49where
50 M: Completion,
51{
52 pub fn new(name: impl ToString, model: M) -> Agent<M> {
54 Agent {
55 model: Arc::new(RwLock::new(model)),
56 tools: make_ref(vec![]),
57 store_indices: vec![],
58 id: Uuid::new_v4(),
59 name: name.to_string(),
60 preamble: String::new(),
61 temperature: None,
62 max_tokens: None,
63 max_execution_time: None,
64 knowledges: Arc::new(Vec::new()),
65 memory: None,
66 mcp_clients: make_ref(vec![]),
67 respect_context_window: false,
68 allow_code_execution: false,
69 }
70 }
71
72 pub fn new_with_tools<I>(name: impl ToString, model: M, tools: I) -> Agent<M>
74 where
75 I: IntoIterator<Item = Box<dyn Tool>>,
76 {
77 Agent {
78 model: Arc::new(RwLock::new(model)),
79 tools: make_ref(tools.into_iter().collect()),
80 store_indices: vec![],
81 id: Uuid::new_v4(),
82 name: name.to_string(),
83 preamble: String::new(),
84 temperature: None,
85 max_tokens: None,
86 max_execution_time: None,
87 knowledges: Arc::new(Vec::new()),
88 memory: None,
89 mcp_clients: make_ref(vec![]),
90 respect_context_window: false,
91 allow_code_execution: false,
92 }
93 }
94
95 pub async fn tool(self, tool: impl Tool + 'static) -> Self {
97 let mut self_tools = self.tools.write().await;
98 self_tools.push(Box::new(tool));
99 drop(self_tools);
100 self
101 }
102
103 pub async fn tools<I>(self, tools: I) -> Self
105 where
106 I: IntoIterator<Item = Box<dyn Tool>>,
107 {
108 let mut self_tools = self.tools.write().await;
109 for tool in tools.into_iter() {
110 self_tools.push(tool);
111 }
112 drop(self_tools);
113 self
114 }
115
116 pub fn memory(mut self, memory: impl Memory + 'static) -> Self {
118 self.memory = Some(Arc::new(RwLock::new(memory)));
119 self
120 }
121
122 pub fn store_index(mut self, sample: usize, store: impl Storage + 'static) -> Self {
124 self.store_indices.push((sample, Box::new(store)));
125 self
126 }
127
128 pub fn preamble(mut self, preamble: impl ToString) -> Self {
130 self.preamble = preamble.to_string();
131 self
132 }
133
134 pub async fn mcp_client(self, mcp_client: MCPClient) -> Self {
136 let mut mcp_clients = self.mcp_clients.write().await;
137 mcp_clients.push(mcp_client);
138 drop(mcp_clients);
139 self
140 }
141
142 pub async fn mcp_config_path<P: AsRef<Path>>(self, path: P) -> anyhow::Result<Self, MCPError> {
144 let clients = setup_mcp_clients(path).await?;
145 let mut mcp_clients = self.mcp_clients.write().await;
146 for (_, client) in clients {
147 mcp_clients.push(client);
148 }
149 drop(mcp_clients);
150 Ok(self)
151 }
152
153 #[inline]
155 pub async fn mcp_sse_client<S: AsRef<str> + 'static>(
156 self,
157 sse_url: S,
158 env: HashMap<String, String>,
159 ) -> anyhow::Result<Self> {
160 Ok(self.mcp_client(sse_client(sse_url, env).await?).await)
161 }
162
163 #[inline]
165 pub async fn mcp_stdio_client<S: AsRef<str> + 'static>(
166 self,
167 command: S,
168 args: Vec<S>,
169 env: HashMap<String, String>,
170 ) -> anyhow::Result<Self> {
171 Ok(self
172 .mcp_client(stdio_client(command, args, env).await?)
173 .await)
174 }
175
176 pub async fn prompt(&self, prompt: &str) -> Result<String, TaskError> {
178 let history = if let Some(memory) = &self.memory {
180 let memory = memory.read().await;
181 memory
182 .messages()
183 .iter()
184 .map(|m| Message {
185 role: m.message_type.type_string(),
186 content: m.content.clone(),
187 })
188 .collect()
189 } else {
190 vec![]
191 };
192 self.chat(prompt, history).await
193 }
194
195 pub async fn chat(&self, prompt: &str, history: Vec<Message>) -> Result<String, TaskError> {
197 let mut executor = Executor::new(
198 self.model.clone(),
199 self.knowledges.clone(),
200 self.tools.clone(),
201 self.memory.clone(),
202 self.mcp_clients.clone(),
203 );
204 let mut req = Request::new(prompt.to_string(), self.preamble.clone());
205 req.history = history;
206 req.max_tokens = self.max_tokens;
207 req.temperature = self.temperature;
208 let tools = self.tools.read().await;
209 req.tools = tools
210 .iter()
211 .map(|tool| tool.definition())
212 .collect::<Vec<_>>();
213 let mcp_clients = self.mcp_clients.read().await;
214 for client in mcp_clients.iter() {
215 for tool in client.tools.values() {
216 req.tools.push(tool.clone());
217 }
218 }
219 req.documents = stream::iter(self.store_indices.iter())
220 .then(|(num_sample, storage)| async {
221 Ok::<_, VectorStoreError>(
222 storage
223 .search(prompt, *num_sample, 0.5)
224 .await?
225 .into_iter()
226 .map(|(id, text, _)| Document {
227 id,
228 text,
229 additional_props: HashMap::new(),
230 })
231 .collect::<Vec<_>>(),
232 )
233 })
234 .try_fold(vec![], |mut acc, docs| async {
235 acc.extend(docs);
236 Ok(acc)
237 })
238 .await
239 .map_err(|err| TaskError::ExecutionError(err.to_string()))?;
240
241 let response = executor
242 .invoke(req)
243 .await
244 .map_err(|err| TaskError::ExecutionError(err.to_string()))?;
245
246 Ok(response)
247 }
248}