rig/tool/
server.rs

1use futures::{StreamExt, TryStreamExt, channel::oneshot::Canceled, stream};
2use tokio::sync::mpsc::{Sender, error::SendError};
3
4use crate::{
5    completion::{CompletionError, ToolDefinition},
6    tool::{Tool, ToolDyn, ToolError, ToolSet, ToolSetError},
7    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
8};
9
10pub struct ToolServer {
11    /// A list of static tool names.
12    /// These tools will always exist on the tool server for as long as they are not deleted.
13    static_tool_names: Vec<String>,
14    /// Dynamic tools. These tools will be dynamically fetched from a given vector store.
15    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
16    /// The toolset where tools are called (to be executed).
17    toolset: ToolSet,
18}
19
20impl Default for ToolServer {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl ToolServer {
27    pub fn new() -> Self {
28        Self {
29            static_tool_names: Vec::new(),
30            dynamic_tools: Vec::new(),
31            toolset: ToolSet::default(),
32        }
33    }
34
35    pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
36        self.static_tool_names = names;
37        self
38    }
39
40    pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
41        self.toolset = tools;
42        self
43    }
44
45    pub(crate) fn add_dynamic_tools(
46        mut self,
47        dyn_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
48    ) -> Self {
49        self.dynamic_tools = dyn_tools;
50        self
51    }
52
53    /// Add a static tool to the agent
54    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
55        let toolname = tool.name();
56        self.toolset.add_tool(tool);
57        self.static_tool_names.push(toolname);
58        self
59    }
60
61    // Add an MCP tool (from `rmcp`) to the agent
62    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
63    #[cfg(feature = "rmcp")]
64    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
65        use crate::tool::rmcp::McpTool;
66        let toolname = tool.name.clone();
67        self.toolset
68            .add_tool(McpTool::from_mcp_server(tool, client));
69        self.static_tool_names.push(toolname.to_string());
70        self
71    }
72
73    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
74    /// dynamic toolset will be inserted in the request.
75    pub fn dynamic_tools(
76        mut self,
77        sample: usize,
78        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
79        toolset: ToolSet,
80    ) -> Self {
81        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
82        self.toolset.add_tools(toolset);
83        self
84    }
85
86    pub fn run(mut self) -> ToolServerHandle {
87        let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
88
89        #[cfg(not(target_family = "wasm"))]
90        tokio::spawn(async move {
91            while let Some(message) = rx.recv().await {
92                self.handle_message(message).await;
93            }
94        });
95
96        // SAFETY: `rig` currently doesn't compile to WASM without the `worker` feature.
97        // Therefore, we can safely assume that the user won't try to compile to wasm without the worker feature.
98        #[cfg(all(feature = "worker", target_family = "wasm"))]
99        wasm_bindgen_futures::spawn_local(async move {
100            while let Some(message) = rx.recv().await {
101                self.handle_message(message).await;
102            }
103        });
104
105        ToolServerHandle(tx)
106    }
107
108    pub async fn handle_message(&mut self, message: ToolServerRequest) {
109        let ToolServerRequest {
110            callback_channel,
111            data,
112        } = message;
113
114        match data {
115            ToolServerRequestMessageKind::AddTool(tool) => {
116                self.static_tool_names.push(tool.name());
117                self.toolset.add_tool_boxed(tool);
118                callback_channel
119                    .send(ToolServerResponse::ToolAdded)
120                    .unwrap();
121            }
122            ToolServerRequestMessageKind::AppendToolset(tools) => {
123                self.toolset.add_tools(tools);
124                callback_channel
125                    .send(ToolServerResponse::ToolAdded)
126                    .unwrap();
127            }
128            ToolServerRequestMessageKind::RemoveTool { tool_name } => {
129                self.static_tool_names.retain(|x| *x != tool_name);
130                self.toolset.delete_tool(&tool_name);
131                callback_channel
132                    .send(ToolServerResponse::ToolDeleted)
133                    .unwrap();
134            }
135            ToolServerRequestMessageKind::CallTool { name, args } => {
136                match self.toolset.call(&name, args.clone()).await {
137                    Ok(result) => {
138                        let _ = callback_channel.send(ToolServerResponse::ToolExecuted { result });
139                    }
140                    Err(err) => {
141                        let _ = callback_channel.send(ToolServerResponse::ToolError {
142                            error: err.to_string(),
143                        });
144                    }
145                }
146            }
147            ToolServerRequestMessageKind::GetToolDefs { prompt } => {
148                let res = self.get_tool_definitions(prompt).await.unwrap();
149                callback_channel
150                    .send(ToolServerResponse::ToolDefinitions(res))
151                    .unwrap();
152            }
153        }
154    }
155
156    pub async fn get_tool_definitions(
157        &mut self,
158        text: Option<String>,
159    ) -> Result<Vec<ToolDefinition>, CompletionError> {
160        let static_tool_names = self.static_tool_names.clone();
161        let mut tools = if let Some(text) = text {
162            stream::iter(self.dynamic_tools.iter())
163                        .then(|(num_sample, index)| async {
164                            let req =
165                                VectorSearchRequest::builder()
166                                    .query(text.clone())
167                                    .samples(*num_sample as u64)
168                                    .build()
169                                    .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
170                            Ok::<_, VectorStoreError>(
171                        index.as_ref()
172                                    .top_n_ids(req.map_filter(Filter::interpret))
173                                    .await?
174                                    .into_iter()
175                                    .map(|(_, id)| id)
176                                    .collect::<Vec<String>>(),
177                            )
178                        })
179                        .try_fold(vec![], |mut acc, docs| async {
180                            for doc in docs {
181                                if let Some(tool) = self.toolset.get(&doc) {
182                                    acc.push(tool.definition(text.clone()).await)
183                                } else {
184                                    tracing::warn!("Tool implementation not found in toolset: {}", doc);
185                                }
186                            }
187                            Ok(acc)
188                        })
189                        .await
190                        .map_err(|e| CompletionError::RequestError(Box::new(e)))?
191        } else {
192            Vec::new()
193        };
194
195        for toolname in static_tool_names {
196            if let Some(tool) = self.toolset.get(&toolname) {
197                tools.push(tool.definition(String::new()).await)
198            } else {
199                tracing::warn!("Tool implementation not found in toolset: {}", toolname);
200            }
201        }
202
203        Ok(tools)
204    }
205}
206
207#[derive(Clone)]
208pub struct ToolServerHandle(Sender<ToolServerRequest>);
209
210impl ToolServerHandle {
211    pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
212        let tool = Box::new(tool);
213
214        let (tx, rx) = futures::channel::oneshot::channel();
215
216        self.0
217            .send(ToolServerRequest {
218                callback_channel: tx,
219                data: ToolServerRequestMessageKind::AddTool(tool),
220            })
221            .await?;
222
223        let res = rx.await?;
224
225        let ToolServerResponse::ToolAdded = res else {
226            return Err(ToolServerError::InvalidMessage(res));
227        };
228
229        Ok(())
230    }
231
232    pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
233        let (tx, rx) = futures::channel::oneshot::channel();
234
235        self.0
236            .send(ToolServerRequest {
237                callback_channel: tx,
238                data: ToolServerRequestMessageKind::AppendToolset(toolset),
239            })
240            .await?;
241
242        let res = rx.await?;
243
244        let ToolServerResponse::ToolAdded = res else {
245            return Err(ToolServerError::InvalidMessage(res));
246        };
247
248        Ok(())
249    }
250
251    pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
252        let (tx, rx) = futures::channel::oneshot::channel();
253
254        self.0
255            .send(ToolServerRequest {
256                callback_channel: tx,
257                data: ToolServerRequestMessageKind::RemoveTool {
258                    tool_name: tool_name.to_string(),
259                },
260            })
261            .await?;
262
263        let res = rx.await?;
264
265        let ToolServerResponse::ToolDeleted = res else {
266            return Err(ToolServerError::InvalidMessage(res));
267        };
268
269        Ok(())
270    }
271
272    pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
273        let (tx, rx) = futures::channel::oneshot::channel();
274
275        self.0
276            .send(ToolServerRequest {
277                callback_channel: tx,
278                data: ToolServerRequestMessageKind::CallTool {
279                    name: tool_name.to_string(),
280                    args: args.to_string(),
281                },
282            })
283            .await?;
284
285        let res = rx.await?;
286
287        match res {
288            ToolServerResponse::ToolExecuted { result, .. } => Ok(result),
289            ToolServerResponse::ToolError { error } => Err(ToolServerError::ToolsetError(
290                ToolSetError::ToolCallError(ToolError::ToolCallError(error.into())),
291            )),
292            invalid => Err(ToolServerError::InvalidMessage(invalid)),
293        }
294    }
295
296    pub async fn get_tool_defs(
297        &self,
298        prompt: Option<String>,
299    ) -> Result<Vec<ToolDefinition>, ToolServerError> {
300        let (tx, rx) = futures::channel::oneshot::channel();
301
302        self.0
303            .send(ToolServerRequest {
304                callback_channel: tx,
305                data: ToolServerRequestMessageKind::GetToolDefs { prompt },
306            })
307            .await?;
308
309        let res = rx.await?;
310
311        let ToolServerResponse::ToolDefinitions(tooldefs) = res else {
312            return Err(ToolServerError::InvalidMessage(res));
313        };
314
315        Ok(tooldefs)
316    }
317}
318
319pub struct ToolServerRequest {
320    callback_channel: futures::channel::oneshot::Sender<ToolServerResponse>,
321    data: ToolServerRequestMessageKind,
322}
323
324pub enum ToolServerRequestMessageKind {
325    AddTool(Box<dyn ToolDyn>),
326    AppendToolset(ToolSet),
327    RemoveTool { tool_name: String },
328    CallTool { name: String, args: String },
329    GetToolDefs { prompt: Option<String> },
330}
331
332#[derive(PartialEq, Debug)]
333pub enum ToolServerResponse {
334    ToolAdded,
335    ToolDeleted,
336    ToolExecuted { result: String },
337    ToolError { error: String },
338    ToolDefinitions(Vec<ToolDefinition>),
339}
340
341#[derive(Debug, thiserror::Error)]
342pub enum ToolServerError {
343    #[error("Sending message was cancelled")]
344    Canceled(#[from] Canceled),
345    #[error("Toolset error: {0}")]
346    ToolsetError(#[from] ToolSetError),
347    #[error("Error while sending message: {0}")]
348    SendError(#[from] SendError<ToolServerRequest>),
349    #[error("An invalid message type was returned")]
350    InvalidMessage(ToolServerResponse),
351}
352
353#[cfg(test)]
354mod tests {
355    use serde::{Deserialize, Serialize};
356    use serde_json::json;
357
358    use crate::{
359        completion::ToolDefinition,
360        tool::{Tool, server::ToolServer},
361    };
362
363    #[derive(Deserialize)]
364    struct OperationArgs {
365        x: i32,
366        y: i32,
367    }
368
369    #[derive(Debug, thiserror::Error)]
370    #[error("Math error")]
371    struct MathError;
372
373    #[derive(Deserialize, Serialize)]
374    struct Adder;
375    impl Tool for Adder {
376        const NAME: &'static str = "add";
377        type Error = MathError;
378        type Args = OperationArgs;
379        type Output = i32;
380
381        async fn definition(&self, _prompt: String) -> ToolDefinition {
382            ToolDefinition {
383                name: "add".to_string(),
384                description: "Add x and y together".to_string(),
385                parameters: json!({
386                    "type": "object",
387                    "properties": {
388                        "x": {
389                            "type": "number",
390                            "description": "The first number to add"
391                        },
392                        "y": {
393                            "type": "number",
394                            "description": "The second number to add"
395                        }
396                    },
397                    "required": ["x", "y"],
398                }),
399            }
400        }
401
402        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
403            println!("[tool-call] Adding {} and {}", args.x, args.y);
404            let result = args.x + args.y;
405            Ok(result)
406        }
407    }
408
409    #[tokio::test]
410    pub async fn test_toolserver() {
411        let server = ToolServer::new();
412
413        let handle = server.run();
414
415        handle.add_tool(Adder).await.unwrap();
416        let res = handle.get_tool_defs(None).await.unwrap();
417
418        assert_eq!(res.len(), 1);
419
420        let json_args_as_string =
421            serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
422        let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
423        assert_eq!(res, "7");
424
425        handle.remove_tool("add").await.unwrap();
426        let res = handle.get_tool_defs(None).await.unwrap();
427
428        assert_eq!(res.len(), 0);
429    }
430}