Skip to main content

rig/tool/
server.rs

1use std::sync::Arc;
2
3use futures::{StreamExt, TryStreamExt, channel::oneshot::Canceled, stream};
4use tokio::sync::{
5    RwLock,
6    mpsc::{Sender, error::SendError},
7};
8
9use crate::{
10    completion::{CompletionError, ToolDefinition},
11    tool::{Tool, ToolDyn, ToolError, ToolSet, ToolSetError},
12    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
13};
14
15pub struct ToolServer {
16    /// A list of static tool names.
17    /// These tools will always exist on the tool server for as long as they are not deleted.
18    static_tool_names: Vec<String>,
19    /// Dynamic tools. These tools will be dynamically fetched from a given vector store.
20    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
21    /// The toolset where tools are called (to be executed).
22    /// Wrapped in Arc<RwLock<...>> to allow concurrent tool execution.
23    toolset: Arc<RwLock<ToolSet>>,
24}
25
26impl Default for ToolServer {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl ToolServer {
33    pub fn new() -> Self {
34        Self {
35            static_tool_names: Vec::new(),
36            dynamic_tools: Vec::new(),
37            toolset: Arc::new(RwLock::new(ToolSet::default())),
38        }
39    }
40
41    pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
42        self.static_tool_names = names;
43        self
44    }
45
46    pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
47        self.toolset = Arc::new(RwLock::new(tools));
48        self
49    }
50
51    pub(crate) fn add_dynamic_tools(
52        mut self,
53        dyn_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
54    ) -> Self {
55        self.dynamic_tools = dyn_tools;
56        self
57    }
58
59    /// Add a static tool to the agent
60    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
61        let toolname = tool.name();
62        // This should be practically impossible to fail: cloning the Arc before calling
63        // .tool() is impossible since the toolset field is private, and the server cannot
64        // be running prior to run(), which consumes self.
65        Arc::get_mut(&mut self.toolset)
66            .expect("ToolServer::tool() called after run()")
67            .get_mut()
68            .add_tool(tool);
69        self.static_tool_names.push(toolname);
70        self
71    }
72
73    // Add an MCP tool (from `rmcp`) to the agent
74    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
75    #[cfg(feature = "rmcp")]
76    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
77        use crate::tool::rmcp::McpTool;
78        let toolname = tool.name.clone();
79        // This should be practically impossible to fail: cloning the Arc before calling
80        // .rmcp_tool() is impossible since the toolset field is private, and the server cannot
81        // be running prior to run(), which consumes self.
82        Arc::get_mut(&mut self.toolset)
83            .expect("ToolServer::rmcp_tool() called after run()")
84            .get_mut()
85            .add_tool(McpTool::from_mcp_server(tool, client));
86        self.static_tool_names.push(toolname.to_string());
87        self
88    }
89
90    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
91    /// dynamic toolset will be inserted in the request.
92    pub fn dynamic_tools(
93        mut self,
94        sample: usize,
95        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
96        toolset: ToolSet,
97    ) -> Self {
98        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
99        // This should be practically impossible to fail: cloning the Arc before calling
100        // .dynamic_tools() is impossible since the toolset field is private, and the server cannot
101        // be running prior to run(), which consumes self.
102        Arc::get_mut(&mut self.toolset)
103            .expect("ToolServer::dynamic_tools() called after run()")
104            .get_mut()
105            .add_tools(toolset);
106        self
107    }
108
109    pub fn run(mut self) -> ToolServerHandle {
110        let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
111
112        #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
113        tokio::spawn(async move {
114            while let Some(message) = rx.recv().await {
115                self.handle_message(message).await;
116            }
117        });
118
119        // SAFETY: `rig` currently doesn't compile to WASM without the `worker` feature.
120        // Therefore, we can safely assume that the user won't try to compile to wasm without the worker feature.
121        #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
122        wasm_bindgen_futures::spawn_local(async move {
123            while let Some(message) = rx.recv().await {
124                self.handle_message(message).await;
125            }
126        });
127
128        ToolServerHandle(tx)
129    }
130
131    pub async fn handle_message(&mut self, message: ToolServerRequest) {
132        let ToolServerRequest {
133            callback_channel,
134            data,
135        } = message;
136
137        match data {
138            ToolServerRequestMessageKind::AddTool(tool) => {
139                self.static_tool_names.push(tool.name());
140                self.toolset.write().await.add_tool_boxed(tool);
141                callback_channel
142                    .send(ToolServerResponse::ToolAdded)
143                    .unwrap();
144            }
145            ToolServerRequestMessageKind::AppendToolset(tools) => {
146                self.toolset.write().await.add_tools(tools);
147                callback_channel
148                    .send(ToolServerResponse::ToolAdded)
149                    .unwrap();
150            }
151            ToolServerRequestMessageKind::RemoveTool { tool_name } => {
152                self.static_tool_names.retain(|x| *x != tool_name);
153                self.toolset.write().await.delete_tool(&tool_name);
154                callback_channel
155                    .send(ToolServerResponse::ToolDeleted)
156                    .unwrap();
157            }
158            ToolServerRequestMessageKind::CallTool { name, args } => {
159                let toolset = Arc::clone(&self.toolset);
160
161                #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
162                tokio::spawn(async move {
163                    match toolset.read().await.call(&name, args.clone()).await {
164                        Ok(result) => {
165                            let _ =
166                                callback_channel.send(ToolServerResponse::ToolExecuted { result });
167                        }
168                        Err(err) => {
169                            let _ = callback_channel.send(ToolServerResponse::ToolError {
170                                error: err.to_string(),
171                            });
172                        }
173                    }
174                });
175
176                #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
177                wasm_bindgen_futures::spawn_local(async move {
178                    match toolset.read().await.call(&name, args.clone()).await {
179                        Ok(result) => {
180                            let _ =
181                                callback_channel.send(ToolServerResponse::ToolExecuted { result });
182                        }
183                        Err(err) => {
184                            let _ = callback_channel.send(ToolServerResponse::ToolError {
185                                error: err.to_string(),
186                            });
187                        }
188                    }
189                });
190            }
191            ToolServerRequestMessageKind::GetToolDefs { prompt } => {
192                let res = self.get_tool_definitions(prompt).await.unwrap();
193                callback_channel
194                    .send(ToolServerResponse::ToolDefinitions(res))
195                    .unwrap();
196            }
197        }
198    }
199
200    pub async fn get_tool_definitions(
201        &mut self,
202        text: Option<String>,
203    ) -> Result<Vec<ToolDefinition>, CompletionError> {
204        let static_tool_names = self.static_tool_names.clone();
205        let toolset = self.toolset.read().await;
206
207        let mut tools = if let Some(text) = text {
208            // First, collect all dynamic tool IDs from vector stores
209            let dynamic_tool_ids: Vec<String> = stream::iter(self.dynamic_tools.iter())
210                .then(|(num_sample, index)| async {
211                    let req = VectorSearchRequest::builder()
212                        .query(text.clone())
213                        .samples(*num_sample as u64)
214                        .build()
215                        .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
216                    Ok::<_, VectorStoreError>(
217                        index
218                            .as_ref()
219                            .top_n_ids(req.map_filter(Filter::interpret))
220                            .await?
221                            .into_iter()
222                            .map(|(_, id)| id)
223                            .collect::<Vec<String>>(),
224                    )
225                })
226                .try_fold(vec![], |mut acc, docs| async {
227                    acc.extend(docs);
228                    Ok(acc)
229                })
230                .await
231                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
232
233            // Then, get tool definitions for each ID
234            let mut tools = Vec::new();
235            for doc in dynamic_tool_ids {
236                if let Some(tool) = toolset.get(&doc) {
237                    tools.push(tool.definition(text.clone()).await)
238                } else {
239                    tracing::warn!("Tool implementation not found in toolset: {}", doc);
240                }
241            }
242            tools
243        } else {
244            Vec::new()
245        };
246
247        for toolname in static_tool_names {
248            if let Some(tool) = toolset.get(&toolname) {
249                tools.push(tool.definition(String::new()).await)
250            } else {
251                tracing::warn!("Tool implementation not found in toolset: {}", toolname);
252            }
253        }
254
255        Ok(tools)
256    }
257}
258
259#[derive(Clone)]
260pub struct ToolServerHandle(Sender<ToolServerRequest>);
261
262impl ToolServerHandle {
263    pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
264        let tool = Box::new(tool);
265
266        let (tx, rx) = futures::channel::oneshot::channel();
267
268        self.0
269            .send(ToolServerRequest {
270                callback_channel: tx,
271                data: ToolServerRequestMessageKind::AddTool(tool),
272            })
273            .await?;
274
275        let res = rx.await?;
276
277        let ToolServerResponse::ToolAdded = res else {
278            return Err(ToolServerError::InvalidMessage(res));
279        };
280
281        Ok(())
282    }
283
284    pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
285        let (tx, rx) = futures::channel::oneshot::channel();
286
287        self.0
288            .send(ToolServerRequest {
289                callback_channel: tx,
290                data: ToolServerRequestMessageKind::AppendToolset(toolset),
291            })
292            .await?;
293
294        let res = rx.await?;
295
296        let ToolServerResponse::ToolAdded = res else {
297            return Err(ToolServerError::InvalidMessage(res));
298        };
299
300        Ok(())
301    }
302
303    pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
304        let (tx, rx) = futures::channel::oneshot::channel();
305
306        self.0
307            .send(ToolServerRequest {
308                callback_channel: tx,
309                data: ToolServerRequestMessageKind::RemoveTool {
310                    tool_name: tool_name.to_string(),
311                },
312            })
313            .await?;
314
315        let res = rx.await?;
316
317        let ToolServerResponse::ToolDeleted = res else {
318            return Err(ToolServerError::InvalidMessage(res));
319        };
320
321        Ok(())
322    }
323
324    pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
325        let (tx, rx) = futures::channel::oneshot::channel();
326
327        self.0
328            .send(ToolServerRequest {
329                callback_channel: tx,
330                data: ToolServerRequestMessageKind::CallTool {
331                    name: tool_name.to_string(),
332                    args: args.to_string(),
333                },
334            })
335            .await?;
336
337        let res = rx.await?;
338
339        match res {
340            ToolServerResponse::ToolExecuted { result, .. } => Ok(result),
341            ToolServerResponse::ToolError { error } => Err(ToolServerError::ToolsetError(
342                ToolSetError::ToolCallError(ToolError::ToolCallError(error.into())),
343            )),
344            invalid => Err(ToolServerError::InvalidMessage(invalid)),
345        }
346    }
347
348    pub async fn get_tool_defs(
349        &self,
350        prompt: Option<String>,
351    ) -> Result<Vec<ToolDefinition>, ToolServerError> {
352        let (tx, rx) = futures::channel::oneshot::channel();
353
354        self.0
355            .send(ToolServerRequest {
356                callback_channel: tx,
357                data: ToolServerRequestMessageKind::GetToolDefs { prompt },
358            })
359            .await?;
360
361        let res = rx.await?;
362
363        let ToolServerResponse::ToolDefinitions(tooldefs) = res else {
364            return Err(ToolServerError::InvalidMessage(res));
365        };
366
367        Ok(tooldefs)
368    }
369}
370
371pub struct ToolServerRequest {
372    callback_channel: futures::channel::oneshot::Sender<ToolServerResponse>,
373    data: ToolServerRequestMessageKind,
374}
375
376pub enum ToolServerRequestMessageKind {
377    AddTool(Box<dyn ToolDyn>),
378    AppendToolset(ToolSet),
379    RemoveTool { tool_name: String },
380    CallTool { name: String, args: String },
381    GetToolDefs { prompt: Option<String> },
382}
383
384#[derive(PartialEq, Debug)]
385pub enum ToolServerResponse {
386    ToolAdded,
387    ToolDeleted,
388    ToolExecuted { result: String },
389    ToolError { error: String },
390    ToolDefinitions(Vec<ToolDefinition>),
391}
392
393#[derive(Debug, thiserror::Error)]
394pub enum ToolServerError {
395    #[error("Sending message was cancelled")]
396    Canceled(#[from] Canceled),
397    #[error("Toolset error: {0}")]
398    ToolsetError(#[from] ToolSetError),
399    #[error("Error while sending message: {0}")]
400    SendError(#[from] SendError<ToolServerRequest>),
401    #[error("An invalid message type was returned")]
402    InvalidMessage(ToolServerResponse),
403}
404
405#[cfg(test)]
406mod tests {
407    use std::time::Duration;
408
409    use serde::{Deserialize, Serialize};
410    use serde_json::json;
411
412    use crate::{
413        completion::ToolDefinition,
414        tool::{Tool, ToolSet, server::ToolServer},
415        vector_store::{
416            VectorStoreError, VectorStoreIndex,
417            request::{Filter, VectorSearchRequest},
418        },
419        wasm_compat::WasmCompatSend,
420    };
421
422    #[derive(Deserialize)]
423    struct OperationArgs {
424        x: i32,
425        y: i32,
426    }
427
428    #[derive(Debug, thiserror::Error)]
429    #[error("Math error")]
430    struct MathError;
431
432    #[derive(Deserialize, Serialize)]
433    struct Adder;
434    impl Tool for Adder {
435        const NAME: &'static str = "add";
436        type Error = MathError;
437        type Args = OperationArgs;
438        type Output = i32;
439
440        async fn definition(&self, _prompt: String) -> ToolDefinition {
441            ToolDefinition {
442                name: "add".to_string(),
443                description: "Add x and y together".to_string(),
444                parameters: json!({
445                    "type": "object",
446                    "properties": {
447                        "x": {
448                            "type": "number",
449                            "description": "The first number to add"
450                        },
451                        "y": {
452                            "type": "number",
453                            "description": "The second number to add"
454                        }
455                    },
456                    "required": ["x", "y"],
457                }),
458            }
459        }
460
461        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
462            println!("[tool-call] Adding {} and {}", args.x, args.y);
463            let result = args.x + args.y;
464            Ok(result)
465        }
466    }
467
468    #[derive(Deserialize, Serialize)]
469    struct Subtractor;
470    impl Tool for Subtractor {
471        const NAME: &'static str = "subtract";
472        type Error = MathError;
473        type Args = OperationArgs;
474        type Output = i32;
475
476        async fn definition(&self, _prompt: String) -> ToolDefinition {
477            ToolDefinition {
478                name: "subtract".to_string(),
479                description: "Subtract y from x".to_string(),
480                parameters: json!({
481                    "type": "object",
482                    "properties": {
483                        "x": {
484                            "type": "number",
485                            "description": "The number to subtract from"
486                        },
487                        "y": {
488                            "type": "number",
489                            "description": "The number to subtract"
490                        }
491                    },
492                    "required": ["x", "y"],
493                }),
494            }
495        }
496
497        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
498            let result = args.x - args.y;
499            Ok(result)
500        }
501    }
502
503    /// A mock vector store index that returns a predefined list of tool IDs.
504    struct MockToolIndex {
505        tool_ids: Vec<String>,
506    }
507
508    impl VectorStoreIndex for MockToolIndex {
509        type Filter = Filter<serde_json::Value>;
510
511        async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
512            &self,
513            _req: VectorSearchRequest,
514        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
515            // Not used by get_tool_definitions, but required by trait
516            Ok(vec![])
517        }
518
519        async fn top_n_ids(
520            &self,
521            _req: VectorSearchRequest,
522        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
523            Ok(self
524                .tool_ids
525                .iter()
526                .enumerate()
527                .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
528                .collect())
529        }
530    }
531
532    #[tokio::test]
533    pub async fn test_toolserver() {
534        let server = ToolServer::new();
535
536        let handle = server.run();
537
538        handle.add_tool(Adder).await.unwrap();
539        let res = handle.get_tool_defs(None).await.unwrap();
540
541        assert_eq!(res.len(), 1);
542
543        let json_args_as_string =
544            serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
545        let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
546        assert_eq!(res, "7");
547
548        handle.remove_tool("add").await.unwrap();
549        let res = handle.get_tool_defs(None).await.unwrap();
550
551        assert_eq!(res.len(), 0);
552    }
553
554    #[tokio::test]
555    pub async fn test_toolserver_dynamic_tools() {
556        // Create a toolset with both tools
557        let mut toolset = ToolSet::default();
558        toolset.add_tool(Adder);
559        toolset.add_tool(Subtractor);
560
561        // Create a mock index that will return "subtract" as the dynamic tool
562        let mock_index = MockToolIndex {
563            tool_ids: vec!["subtract".to_string()],
564        };
565
566        // Build server with static tool "add" and dynamic tools from the mock index
567        let server = ToolServer::new().tool(Adder).dynamic_tools(
568            1,
569            mock_index,
570            ToolSet::from_tools(vec![Subtractor]),
571        );
572
573        let handle = server.run();
574
575        // Test with None prompt - should only return static tools
576        let res = handle.get_tool_defs(None).await.unwrap();
577        assert_eq!(res.len(), 1);
578        assert_eq!(res[0].name, "add");
579
580        // Test with Some prompt - should return both static and dynamic tools
581        let res = handle
582            .get_tool_defs(Some("calculate difference".to_string()))
583            .await
584            .unwrap();
585        assert_eq!(res.len(), 2);
586
587        // Check that both tools are present (order may vary)
588        let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
589        assert!(tool_names.contains(&"add"));
590        assert!(tool_names.contains(&"subtract"));
591    }
592
593    #[tokio::test]
594    pub async fn test_toolserver_dynamic_tools_missing_implementation() {
595        // Create a mock index that returns a tool ID that doesn't exist in the toolset
596        let mock_index = MockToolIndex {
597            tool_ids: vec!["nonexistent_tool".to_string()],
598        };
599
600        // Build server with only static tool, but dynamic index references missing tool
601        let server = ToolServer::new()
602            .tool(Adder)
603            .dynamic_tools(1, mock_index, ToolSet::default());
604
605        let handle = server.run();
606
607        // Test with Some prompt - should only return static tool since dynamic tool is missing
608        let res = handle
609            .get_tool_defs(Some("some query".to_string()))
610            .await
611            .unwrap();
612        assert_eq!(res.len(), 1);
613        assert_eq!(res[0].name, "add");
614    }
615
616    #[derive(Debug, thiserror::Error)]
617    #[error("Sleeper error")]
618    struct SleeperError;
619
620    /// A tool that sleeps for a configurable duration, used to test concurrent execution.
621    #[derive(Deserialize, Serialize, Clone)]
622    struct SleeperTool {
623        sleep_duration_ms: u64,
624    }
625
626    impl SleeperTool {
627        fn new(sleep_duration_ms: u64) -> Self {
628            Self { sleep_duration_ms }
629        }
630    }
631
632    impl Tool for SleeperTool {
633        const NAME: &'static str = "sleeper";
634        type Error = SleeperError;
635        type Args = serde_json::Value;
636        type Output = u64;
637
638        async fn definition(&self, _prompt: String) -> ToolDefinition {
639            ToolDefinition {
640                name: "sleeper".to_string(),
641                description: "Sleeps for configured duration".to_string(),
642                parameters: json!({"type": "object", "properties": {}}),
643            }
644        }
645
646        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
647            tokio::time::sleep(Duration::from_millis(self.sleep_duration_ms)).await;
648            Ok(self.sleep_duration_ms)
649        }
650    }
651
652    #[tokio::test]
653    pub async fn test_toolserver_concurrent_tool_execution() {
654        let sleep_ms: u64 = 100;
655        let num_calls: u64 = 3;
656
657        let server = ToolServer::new().tool(SleeperTool::new(sleep_ms));
658        let handle = server.run();
659
660        let start = std::time::Instant::now();
661
662        // Make concurrent calls
663        let futures: Vec<_> = (0..num_calls)
664            .map(|_| handle.call_tool("sleeper", "{}"))
665            .collect();
666        let results = futures::future::join_all(futures).await;
667
668        let elapsed = start.elapsed();
669
670        // All calls should succeed
671        for result in &results {
672            assert!(result.is_ok(), "Tool call failed: {:?}", result);
673        }
674
675        // If concurrent: elapsed ≈ 100ms (plus overhead)
676        // If sequential: elapsed ≈ 300ms
677        // Threshold: less than 2x single sleep duration means concurrent execution
678        let max_concurrent_time = Duration::from_millis(sleep_ms * 2);
679        assert!(
680            elapsed < max_concurrent_time,
681            "Expected concurrent execution in < {:?}, but took {:?}. Tools may be running sequentially.",
682            max_concurrent_time,
683            elapsed
684        );
685    }
686}