rig/tool/
server.rs

1use futures::{StreamExt, TryStreamExt, channel::oneshot::Canceled, stream};
2use tokio::sync::mpsc::{Sender, error::SendError};
3
4use crate::{
5    completion::{CompletionError, ToolDefinition},
6    tool::{Tool, ToolDyn, ToolError, ToolSet, ToolSetError},
7    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn},
8};
9
10pub struct ToolServer {
11    /// A list of static tool names.
12    /// These tools will always exist on the tool server for as long as they are not deleted.
13    static_tool_names: Vec<String>,
14    /// Dynamic tools. These tools will be dynamically fetched from a given vector store.
15    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
16    /// The toolset where tools are called (to be executed).
17    toolset: ToolSet,
18}
19
20impl Default for ToolServer {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl ToolServer {
27    pub fn new() -> Self {
28        Self {
29            static_tool_names: Vec::new(),
30            dynamic_tools: Vec::new(),
31            toolset: ToolSet::default(),
32        }
33    }
34
35    pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
36        self.static_tool_names = names;
37        self
38    }
39
40    pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
41        self.toolset = tools;
42        self
43    }
44
45    pub(crate) fn add_dynamic_tools(
46        mut self,
47        dyn_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
48    ) -> Self {
49        self.dynamic_tools = dyn_tools;
50        self
51    }
52
53    /// Add a static tool to the agent
54    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
55        let toolname = tool.name();
56        self.toolset.add_tool(tool);
57        self.static_tool_names.push(toolname);
58        self
59    }
60
61    // Add an MCP tool (from `rmcp`) to the agent
62    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
63    #[cfg(feature = "rmcp")]
64    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
65        use crate::tool::rmcp::McpTool;
66        let toolname = tool.name.clone();
67        self.toolset
68            .add_tool(McpTool::from_mcp_server(tool, client));
69        self.static_tool_names.push(toolname.to_string());
70        self
71    }
72
73    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
74    /// dynamic toolset will be inserted in the request.
75    pub fn dynamic_tools(
76        mut self,
77        sample: usize,
78        dynamic_tools: impl VectorStoreIndexDyn + 'static,
79        toolset: ToolSet,
80    ) -> Self {
81        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
82        self.toolset.add_tools(toolset);
83        self
84    }
85
86    pub fn run(mut self) -> ToolServerHandle {
87        let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
88
89        #[cfg(not(target_family = "wasm"))]
90        tokio::spawn(async move {
91            while let Some(message) = rx.recv().await {
92                self.handle_message(message).await;
93            }
94        });
95
96        #[cfg(target_family = "wasm")]
97        tokio::task::spawn_local(async move {
98            while let Some(message) = rx.recv().await {
99                self.handle_message(message).await;
100            }
101        });
102
103        ToolServerHandle(tx)
104    }
105
106    pub async fn handle_message(&mut self, message: ToolServerRequest) {
107        let ToolServerRequest {
108            callback_channel,
109            data,
110        } = message;
111
112        match data {
113            ToolServerRequestMessageKind::AddTool(tool) => {
114                self.static_tool_names.push(tool.name());
115                self.toolset.add_tool_boxed(tool);
116                callback_channel
117                    .send(ToolServerResponse::ToolAdded)
118                    .unwrap();
119            }
120            ToolServerRequestMessageKind::AppendToolset(tools) => {
121                self.toolset.add_tools(tools);
122                callback_channel
123                    .send(ToolServerResponse::ToolAdded)
124                    .unwrap();
125            }
126            ToolServerRequestMessageKind::RemoveTool { tool_name } => {
127                self.static_tool_names.retain(|x| *x != tool_name);
128                self.toolset.delete_tool(&tool_name);
129                callback_channel
130                    .send(ToolServerResponse::ToolDeleted)
131                    .unwrap();
132            }
133            ToolServerRequestMessageKind::CallTool { name, args } => {
134                match self.toolset.call(&name, args.clone()).await {
135                    Ok(result) => {
136                        let _ = callback_channel.send(ToolServerResponse::ToolExecuted { result });
137                    }
138                    Err(err) => {
139                        let _ = callback_channel.send(ToolServerResponse::ToolError {
140                            error: err.to_string(),
141                        });
142                    }
143                }
144            }
145            ToolServerRequestMessageKind::GetToolDefs { prompt } => {
146                let res = self.get_tool_definitions(prompt).await.unwrap();
147                callback_channel
148                    .send(ToolServerResponse::ToolDefinitions(res))
149                    .unwrap();
150            }
151        }
152    }
153
154    pub async fn get_tool_definitions(
155        &mut self,
156        text: Option<String>,
157    ) -> Result<Vec<ToolDefinition>, CompletionError> {
158        let static_tool_names = self.static_tool_names.clone();
159        let mut tools = if let Some(text) = text {
160            stream::iter(self.dynamic_tools.iter())
161                        .then(|(num_sample, index)| async {
162                            let req = VectorSearchRequest::builder().query(text.clone()).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
163                            Ok::<_, VectorStoreError>(
164                                index
165                                    .top_n_ids(req)
166                                    .await?
167                                    .into_iter()
168                                    .map(|(_, id)| id)
169                                    .collect::<Vec<String>>(),
170                            )
171                        })
172                        .try_fold(vec![], |mut acc, docs| async {
173                            for doc in docs {
174                                if let Some(tool) = self.toolset.get(&doc) {
175                                    acc.push(tool.definition(text.clone()).await)
176                                } else {
177                                    tracing::warn!("Tool implementation not found in toolset: {}", doc);
178                                }
179                            }
180                            Ok(acc)
181                        })
182                        .await
183                        .map_err(|e| CompletionError::RequestError(Box::new(e)))?
184        } else {
185            Vec::new()
186        };
187
188        for toolname in static_tool_names {
189            if let Some(tool) = self.toolset.get(&toolname) {
190                tools.push(tool.definition(String::new()).await)
191            } else {
192                tracing::warn!("Tool implementation not found in toolset: {}", toolname);
193            }
194        }
195
196        Ok(tools)
197    }
198}
199
200#[derive(Clone)]
201pub struct ToolServerHandle(Sender<ToolServerRequest>);
202
203impl ToolServerHandle {
204    pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
205        let tool = Box::new(tool);
206
207        let (tx, rx) = futures::channel::oneshot::channel();
208
209        self.0
210            .send(ToolServerRequest {
211                callback_channel: tx,
212                data: ToolServerRequestMessageKind::AddTool(tool),
213            })
214            .await?;
215
216        let res = rx.await?;
217
218        let ToolServerResponse::ToolAdded = res else {
219            return Err(ToolServerError::InvalidMessage(res));
220        };
221
222        Ok(())
223    }
224
225    pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
226        let (tx, rx) = futures::channel::oneshot::channel();
227
228        self.0
229            .send(ToolServerRequest {
230                callback_channel: tx,
231                data: ToolServerRequestMessageKind::AppendToolset(toolset),
232            })
233            .await?;
234
235        let res = rx.await?;
236
237        let ToolServerResponse::ToolAdded = res else {
238            return Err(ToolServerError::InvalidMessage(res));
239        };
240
241        Ok(())
242    }
243
244    pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
245        let (tx, rx) = futures::channel::oneshot::channel();
246
247        self.0
248            .send(ToolServerRequest {
249                callback_channel: tx,
250                data: ToolServerRequestMessageKind::RemoveTool {
251                    tool_name: tool_name.to_string(),
252                },
253            })
254            .await?;
255
256        let res = rx.await?;
257
258        let ToolServerResponse::ToolDeleted = res else {
259            return Err(ToolServerError::InvalidMessage(res));
260        };
261
262        Ok(())
263    }
264
265    pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
266        let (tx, rx) = futures::channel::oneshot::channel();
267
268        self.0
269            .send(ToolServerRequest {
270                callback_channel: tx,
271                data: ToolServerRequestMessageKind::CallTool {
272                    name: tool_name.to_string(),
273                    args: args.to_string(),
274                },
275            })
276            .await?;
277
278        let res = rx.await?;
279
280        match res {
281            ToolServerResponse::ToolExecuted { result, .. } => Ok(result),
282            ToolServerResponse::ToolError { error } => Err(ToolServerError::ToolsetError(
283                ToolSetError::ToolCallError(ToolError::ToolCallError(error.into())),
284            )),
285            invalid => Err(ToolServerError::InvalidMessage(invalid)),
286        }
287    }
288
289    pub async fn get_tool_defs(
290        &self,
291        prompt: Option<String>,
292    ) -> Result<Vec<ToolDefinition>, ToolServerError> {
293        let (tx, rx) = futures::channel::oneshot::channel();
294
295        self.0
296            .send(ToolServerRequest {
297                callback_channel: tx,
298                data: ToolServerRequestMessageKind::GetToolDefs { prompt },
299            })
300            .await?;
301
302        let res = rx.await?;
303
304        let ToolServerResponse::ToolDefinitions(tooldefs) = res else {
305            return Err(ToolServerError::InvalidMessage(res));
306        };
307
308        Ok(tooldefs)
309    }
310}
311
312pub struct ToolServerRequest {
313    callback_channel: futures::channel::oneshot::Sender<ToolServerResponse>,
314    data: ToolServerRequestMessageKind,
315}
316
317pub enum ToolServerRequestMessageKind {
318    AddTool(Box<dyn ToolDyn>),
319    AppendToolset(ToolSet),
320    RemoveTool { tool_name: String },
321    CallTool { name: String, args: String },
322    GetToolDefs { prompt: Option<String> },
323}
324
325#[derive(PartialEq, Debug)]
326pub enum ToolServerResponse {
327    ToolAdded,
328    ToolDeleted,
329    ToolExecuted { result: String },
330    ToolError { error: String },
331    ToolDefinitions(Vec<ToolDefinition>),
332}
333
334#[derive(Debug, thiserror::Error)]
335pub enum ToolServerError {
336    #[error("Sending message was cancelled")]
337    Canceled(#[from] Canceled),
338    #[error("Toolset error: {0}")]
339    ToolsetError(#[from] ToolSetError),
340    #[error("Error while sending message: {0}")]
341    SendError(#[from] SendError<ToolServerRequest>),
342    #[error("An invalid message type was returned")]
343    InvalidMessage(ToolServerResponse),
344}
345
346#[cfg(test)]
347mod tests {
348    use serde::{Deserialize, Serialize};
349    use serde_json::json;
350
351    use crate::{
352        completion::ToolDefinition,
353        tool::{Tool, server::ToolServer},
354    };
355
356    #[derive(Deserialize)]
357    struct OperationArgs {
358        x: i32,
359        y: i32,
360    }
361
362    #[derive(Debug, thiserror::Error)]
363    #[error("Math error")]
364    struct MathError;
365
366    #[derive(Deserialize, Serialize)]
367    struct Adder;
368    impl Tool for Adder {
369        const NAME: &'static str = "add";
370        type Error = MathError;
371        type Args = OperationArgs;
372        type Output = i32;
373
374        async fn definition(&self, _prompt: String) -> ToolDefinition {
375            ToolDefinition {
376                name: "add".to_string(),
377                description: "Add x and y together".to_string(),
378                parameters: json!({
379                    "type": "object",
380                    "properties": {
381                        "x": {
382                            "type": "number",
383                            "description": "The first number to add"
384                        },
385                        "y": {
386                            "type": "number",
387                            "description": "The second number to add"
388                        }
389                    },
390                    "required": ["x", "y"],
391                }),
392            }
393        }
394
395        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
396            println!("[tool-call] Adding {} and {}", args.x, args.y);
397            let result = args.x + args.y;
398            Ok(result)
399        }
400    }
401
402    #[tokio::test]
403    pub async fn test_toolserver() {
404        let server = ToolServer::new();
405
406        let handle = server.run();
407
408        handle.add_tool(Adder).await.unwrap();
409        let res = handle.get_tool_defs(None).await.unwrap();
410
411        assert_eq!(res.len(), 1);
412
413        let json_args_as_string =
414            serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
415        let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
416        assert_eq!(res, "7");
417
418        handle.remove_tool("add").await.unwrap();
419        let res = handle.get_tool_defs(None).await.unwrap();
420
421        assert_eq!(res.len(), 0);
422    }
423}