alith_core/
agent.rs

1use crate::chat::{Completion, Document, Message, Request};
2use crate::executor::Executor;
3use crate::knowledge::Knowledge;
4use crate::mcp::{MCPClient, MCPError, setup_mcp_clients, sse_client, stdio_client};
5use crate::memory::Memory;
6use crate::store::{Storage, VectorStoreError};
7use crate::task::TaskError;
8use crate::tool::Tool;
9use crate::{Ref, make_ref};
10use futures::{StreamExt, TryStreamExt, stream};
11use std::collections::HashMap;
12use std::path::Path;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use uuid::Uuid;
16
17pub struct Agent<M: Completion> {
18    /// The model to use.
19    pub model: Ref<M>,
20    /// Indexed storage for the agent.
21    pub store_indices: Vec<(usize, Box<dyn Storage>)>,
22    /// The tools to use.
23    pub tools: Ref<Vec<Box<dyn Tool>>>,
24    /// Knowledge sources for the agent.
25    pub knowledges: Arc<Vec<Box<dyn Knowledge>>>,
26    /// Agent memory.
27    pub memory: Option<Ref<dyn Memory>>,
28    /// The unique ID of the agent.
29    pub id: Uuid,
30    /// The name of the agent.
31    pub name: String,
32    /// System prompt for the agent.
33    pub preamble: String,
34    /// Temperature of the model.
35    pub temperature: Option<f32>,
36    /// Maximum number of tokens for the completion.
37    pub max_tokens: Option<usize>,
38    /// Maximum execution time for the agent to complete a task.
39    pub max_execution_time: Option<usize>,
40    /// Whether to respect the context window.
41    pub respect_context_window: bool,
42    /// Whether code execution is allowed.
43    pub allow_code_execution: bool,
44    /// The MCP client used to communicate with the MCP server
45    mcp_clients: Ref<Vec<MCPClient>>,
46}
47
48impl<M: Completion> Agent<M>
49where
50    M: Completion,
51{
52    /// Creates a new agent.
53    pub fn new(name: impl ToString, model: M) -> Agent<M> {
54        Agent {
55            model: Arc::new(RwLock::new(model)),
56            tools: make_ref(vec![]),
57            store_indices: vec![],
58            id: Uuid::new_v4(),
59            name: name.to_string(),
60            preamble: String::new(),
61            temperature: None,
62            max_tokens: None,
63            max_execution_time: None,
64            knowledges: Arc::new(Vec::new()),
65            memory: None,
66            mcp_clients: make_ref(vec![]),
67            respect_context_window: false,
68            allow_code_execution: false,
69        }
70    }
71
72    /// Creates a new agent with some tools
73    pub fn new_with_tools<I>(name: impl ToString, model: M, tools: I) -> Agent<M>
74    where
75        I: IntoIterator<Item = Box<dyn Tool>>,
76    {
77        Agent {
78            model: Arc::new(RwLock::new(model)),
79            tools: make_ref(tools.into_iter().collect()),
80            store_indices: vec![],
81            id: Uuid::new_v4(),
82            name: name.to_string(),
83            preamble: String::new(),
84            temperature: None,
85            max_tokens: None,
86            max_execution_time: None,
87            knowledges: Arc::new(Vec::new()),
88            memory: None,
89            mcp_clients: make_ref(vec![]),
90            respect_context_window: false,
91            allow_code_execution: false,
92        }
93    }
94
95    /// Add a tool into the agent
96    pub async fn tool(self, tool: impl Tool + 'static) -> Self {
97        let mut self_tools = self.tools.write().await;
98        self_tools.push(Box::new(tool));
99        drop(self_tools);
100        self
101    }
102
103    /// Add some tools into the agent
104    pub async fn tools<I>(self, tools: I) -> Self
105    where
106        I: IntoIterator<Item = Box<dyn Tool>>,
107    {
108        let mut self_tools = self.tools.write().await;
109        for tool in tools.into_iter() {
110            self_tools.push(tool);
111        }
112        drop(self_tools);
113        self
114    }
115
116    /// Adds a memory to the agent.
117    pub fn memory(mut self, memory: impl Memory + 'static) -> Self {
118        self.memory = Some(Arc::new(RwLock::new(memory)));
119        self
120    }
121
122    /// Adds a storage index to the agent.
123    pub fn store_index(mut self, sample: usize, store: impl Storage + 'static) -> Self {
124        self.store_indices.push((sample, Box::new(store)));
125        self
126    }
127
128    /// System prompt for the agent.
129    pub fn preamble(mut self, preamble: impl ToString) -> Self {
130        self.preamble = preamble.to_string();
131        self
132    }
133
134    /// Set the MCP client.
135    pub async fn mcp_client(self, mcp_client: MCPClient) -> Self {
136        let mut mcp_clients = self.mcp_clients.write().await;
137        mcp_clients.push(mcp_client);
138        drop(mcp_clients);
139        self
140    }
141
142    /// Set the MCP server config path.
143    pub async fn mcp_config_path<P: AsRef<Path>>(self, path: P) -> anyhow::Result<Self, MCPError> {
144        let clients = setup_mcp_clients(path).await?;
145        let mut mcp_clients = self.mcp_clients.write().await;
146        for (_, client) in clients {
147            mcp_clients.push(client);
148        }
149        drop(mcp_clients);
150        Ok(self)
151    }
152
153    /// Set the MCP sse client.
154    #[inline]
155    pub async fn mcp_sse_client<S: AsRef<str> + 'static>(
156        self,
157        sse_url: S,
158        env: HashMap<String, String>,
159    ) -> anyhow::Result<Self> {
160        Ok(self.mcp_client(sse_client(sse_url, env).await?).await)
161    }
162
163    /// Set the MCP sse client.
164    #[inline]
165    pub async fn mcp_stdio_client<S: AsRef<str> + 'static>(
166        self,
167        command: S,
168        args: Vec<S>,
169        env: HashMap<String, String>,
170    ) -> anyhow::Result<Self> {
171        Ok(self
172            .mcp_client(stdio_client(command, args, env).await?)
173            .await)
174    }
175
176    /// Processes a prompt using the agent.
177    pub async fn prompt(&self, prompt: &str) -> Result<String, TaskError> {
178        // Add chat conversion history.
179        let history = if let Some(memory) = &self.memory {
180            let memory = memory.read().await;
181            memory
182                .messages()
183                .iter()
184                .map(|m| Message {
185                    role: m.message_type.type_string(),
186                    content: m.content.clone(),
187                })
188                .collect()
189        } else {
190            vec![]
191        };
192        self.chat(prompt, history).await
193    }
194
195    /// Processes a prompt using the agent.
196    pub async fn chat(&self, prompt: &str, history: Vec<Message>) -> Result<String, TaskError> {
197        let mut executor = Executor::new(
198            self.model.clone(),
199            self.knowledges.clone(),
200            self.tools.clone(),
201            self.memory.clone(),
202            self.mcp_clients.clone(),
203        );
204        let mut req = Request::new(prompt.to_string(), self.preamble.clone());
205        req.history = history;
206        req.max_tokens = self.max_tokens;
207        req.temperature = self.temperature;
208        let tools = self.tools.read().await;
209        req.tools = tools
210            .iter()
211            .map(|tool| tool.definition())
212            .collect::<Vec<_>>();
213        let mcp_clients = self.mcp_clients.read().await;
214        for client in mcp_clients.iter() {
215            for tool in client.tools.values() {
216                req.tools.push(tool.clone());
217            }
218        }
219        req.documents = stream::iter(self.store_indices.iter())
220            .then(|(num_sample, storage)| async {
221                Ok::<_, VectorStoreError>(
222                    storage
223                        .search(prompt, *num_sample, 0.5)
224                        .await?
225                        .into_iter()
226                        .map(|(id, text, _)| Document {
227                            id,
228                            text,
229                            additional_props: HashMap::new(),
230                        })
231                        .collect::<Vec<_>>(),
232                )
233            })
234            .try_fold(vec![], |mut acc, docs| async {
235                acc.extend(docs);
236                Ok(acc)
237            })
238            .await
239            .map_err(|err| TaskError::ExecutionError(err.to_string()))?;
240
241        let response = executor
242            .invoke(req)
243            .await
244            .map_err(|err| TaskError::ExecutionError(err.to_string()))?;
245
246        Ok(response)
247    }
248}