Skip to main content

rig/tool/
server.rs

1use std::sync::Arc;
2
3use futures::{StreamExt, TryStreamExt, channel::oneshot::Canceled, stream};
4use tokio::sync::{
5    RwLock,
6    mpsc::{Sender, error::SendError},
7};
8use tracing::Instrument;
9
10use crate::{
11    completion::{CompletionError, ToolDefinition},
12    tool::{Tool, ToolDyn, ToolError, ToolSet, ToolSetError},
13    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
14};
15
16pub struct ToolServer {
17    /// A list of static tool names.
18    /// These tools will always exist on the tool server for as long as they are not deleted.
19    static_tool_names: Vec<String>,
20    /// Dynamic tools. These tools will be dynamically fetched from a given vector store.
21    dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
22    /// The toolset where tools are called (to be executed).
23    /// Wrapped in Arc<RwLock<...>> to allow concurrent tool execution.
24    toolset: Arc<RwLock<ToolSet>>,
25}
26
27impl Default for ToolServer {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl ToolServer {
34    pub fn new() -> Self {
35        Self {
36            static_tool_names: Vec::new(),
37            dynamic_tools: Vec::new(),
38            toolset: Arc::new(RwLock::new(ToolSet::default())),
39        }
40    }
41
42    pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
43        self.static_tool_names = names;
44        self
45    }
46
47    pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
48        self.toolset = Arc::new(RwLock::new(tools));
49        self
50    }
51
52    pub(crate) fn add_dynamic_tools(
53        mut self,
54        dyn_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
55    ) -> Self {
56        self.dynamic_tools = dyn_tools;
57        self
58    }
59
60    /// Add a static tool to the agent
61    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
62        let toolname = tool.name();
63        // This should be practically impossible to fail: cloning the Arc before calling
64        // .tool() is impossible since the toolset field is private, and the server cannot
65        // be running prior to run(), which consumes self.
66        Arc::get_mut(&mut self.toolset)
67            .expect("ToolServer::tool() called after run()")
68            .get_mut()
69            .add_tool(tool);
70        self.static_tool_names.push(toolname);
71        self
72    }
73
74    // Add an MCP tool (from `rmcp`) to the agent
75    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
76    #[cfg(feature = "rmcp")]
77    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
78        use crate::tool::rmcp::McpTool;
79        let toolname = tool.name.clone();
80        // This should be practically impossible to fail: cloning the Arc before calling
81        // .rmcp_tool() is impossible since the toolset field is private, and the server cannot
82        // be running prior to run(), which consumes self.
83        Arc::get_mut(&mut self.toolset)
84            .expect("ToolServer::rmcp_tool() called after run()")
85            .get_mut()
86            .add_tool(McpTool::from_mcp_server(tool, client));
87        self.static_tool_names.push(toolname.to_string());
88        self
89    }
90
91    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
92    /// dynamic toolset will be inserted in the request.
93    pub fn dynamic_tools(
94        mut self,
95        sample: usize,
96        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
97        toolset: ToolSet,
98    ) -> Self {
99        self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
100        // This should be practically impossible to fail: cloning the Arc before calling
101        // .dynamic_tools() is impossible since the toolset field is private, and the server cannot
102        // be running prior to run(), which consumes self.
103        Arc::get_mut(&mut self.toolset)
104            .expect("ToolServer::dynamic_tools() called after run()")
105            .get_mut()
106            .add_tools(toolset);
107        self
108    }
109
110    pub fn run(mut self) -> ToolServerHandle {
111        let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
112
113        #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
114        tokio::spawn(async move {
115            while let Some(message) = rx.recv().await {
116                self.handle_message(message).await;
117            }
118        });
119
120        // SAFETY: `rig` currently doesn't compile to WASM without the `worker` feature.
121        // Therefore, we can safely assume that the user won't try to compile to wasm without the worker feature.
122        #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
123        wasm_bindgen_futures::spawn_local(async move {
124            while let Some(message) = rx.recv().await {
125                self.handle_message(message).await;
126            }
127        });
128
129        ToolServerHandle(tx)
130    }
131
132    pub async fn handle_message(&mut self, message: ToolServerRequest) {
133        let ToolServerRequest {
134            callback_channel,
135            data,
136        } = message;
137
138        match data {
139            ToolServerRequestMessageKind::AddTool(tool) => {
140                self.static_tool_names.push(tool.name());
141                self.toolset.write().await.add_tool_boxed(tool);
142                callback_channel
143                    .send(ToolServerResponse::ToolAdded)
144                    .unwrap();
145            }
146            ToolServerRequestMessageKind::AppendToolset(tools) => {
147                self.toolset.write().await.add_tools(tools);
148                callback_channel
149                    .send(ToolServerResponse::ToolAdded)
150                    .unwrap();
151            }
152            ToolServerRequestMessageKind::RemoveTool { tool_name } => {
153                self.static_tool_names.retain(|x| *x != tool_name);
154                self.toolset.write().await.delete_tool(&tool_name);
155                callback_channel
156                    .send(ToolServerResponse::ToolDeleted)
157                    .unwrap();
158            }
159            ToolServerRequestMessageKind::CallTool { name, args, span } => {
160                let toolset = Arc::clone(&self.toolset);
161
162                #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
163                tokio::spawn(
164                    async move {
165                        match toolset.read().await.call(&name, args.clone()).await {
166                            Ok(result) => {
167                                let _ = callback_channel
168                                    .send(ToolServerResponse::ToolExecuted { result });
169                            }
170                            Err(err) => {
171                                let _ = callback_channel.send(ToolServerResponse::ToolError {
172                                    error: err.to_string(),
173                                });
174                            }
175                        }
176                    }
177                    .instrument(span),
178                );
179
180                #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
181                wasm_bindgen_futures::spawn_local(
182                    async move {
183                        match toolset.read().await.call(&name, args.clone()).await {
184                            Ok(result) => {
185                                let _ = callback_channel
186                                    .send(ToolServerResponse::ToolExecuted { result });
187                            }
188                            Err(err) => {
189                                let _ = callback_channel.send(ToolServerResponse::ToolError {
190                                    error: err.to_string(),
191                                });
192                            }
193                        }
194                    }
195                    .instrument(span),
196                );
197            }
198            ToolServerRequestMessageKind::GetToolDefs { prompt } => {
199                let res = self.get_tool_definitions(prompt).await.unwrap();
200                callback_channel
201                    .send(ToolServerResponse::ToolDefinitions(res))
202                    .unwrap();
203            }
204        }
205    }
206
207    pub async fn get_tool_definitions(
208        &mut self,
209        text: Option<String>,
210    ) -> Result<Vec<ToolDefinition>, CompletionError> {
211        let static_tool_names = self.static_tool_names.clone();
212        let toolset = self.toolset.read().await;
213
214        let mut tools = if let Some(text) = text {
215            // First, collect all dynamic tool IDs from vector stores
216            let dynamic_tool_ids: Vec<String> = stream::iter(self.dynamic_tools.iter())
217                .then(|(num_sample, index)| async {
218                    let req = VectorSearchRequest::builder()
219                        .query(text.clone())
220                        .samples(*num_sample as u64)
221                        .build()
222                        .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
223                    Ok::<_, VectorStoreError>(
224                        index
225                            .as_ref()
226                            .top_n_ids(req.map_filter(Filter::interpret))
227                            .await?
228                            .into_iter()
229                            .map(|(_, id)| id)
230                            .collect::<Vec<String>>(),
231                    )
232                })
233                .try_fold(vec![], |mut acc, docs| async {
234                    acc.extend(docs);
235                    Ok(acc)
236                })
237                .await
238                .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
239
240            // Then, get tool definitions for each ID
241            let mut tools = Vec::new();
242            for doc in dynamic_tool_ids {
243                if let Some(tool) = toolset.get(&doc) {
244                    tools.push(tool.definition(text.clone()).await)
245                } else {
246                    tracing::warn!("Tool implementation not found in toolset: {}", doc);
247                }
248            }
249            tools
250        } else {
251            Vec::new()
252        };
253
254        for toolname in static_tool_names {
255            if let Some(tool) = toolset.get(&toolname) {
256                tools.push(tool.definition(String::new()).await)
257            } else {
258                tracing::warn!("Tool implementation not found in toolset: {}", toolname);
259            }
260        }
261
262        Ok(tools)
263    }
264}
265
266#[derive(Clone)]
267pub struct ToolServerHandle(Sender<ToolServerRequest>);
268
269impl ToolServerHandle {
270    pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
271        let tool = Box::new(tool);
272
273        let (tx, rx) = futures::channel::oneshot::channel();
274
275        self.0
276            .send(ToolServerRequest {
277                callback_channel: tx,
278                data: ToolServerRequestMessageKind::AddTool(tool),
279            })
280            .await?;
281
282        let res = rx.await?;
283
284        let ToolServerResponse::ToolAdded = res else {
285            return Err(ToolServerError::InvalidMessage(res));
286        };
287
288        Ok(())
289    }
290
291    pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
292        let (tx, rx) = futures::channel::oneshot::channel();
293
294        self.0
295            .send(ToolServerRequest {
296                callback_channel: tx,
297                data: ToolServerRequestMessageKind::AppendToolset(toolset),
298            })
299            .await?;
300
301        let res = rx.await?;
302
303        let ToolServerResponse::ToolAdded = res else {
304            return Err(ToolServerError::InvalidMessage(res));
305        };
306
307        Ok(())
308    }
309
310    pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
311        let (tx, rx) = futures::channel::oneshot::channel();
312
313        self.0
314            .send(ToolServerRequest {
315                callback_channel: tx,
316                data: ToolServerRequestMessageKind::RemoveTool {
317                    tool_name: tool_name.to_string(),
318                },
319            })
320            .await?;
321
322        let res = rx.await?;
323
324        let ToolServerResponse::ToolDeleted = res else {
325            return Err(ToolServerError::InvalidMessage(res));
326        };
327
328        Ok(())
329    }
330
331    pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
332        let (tx, rx) = futures::channel::oneshot::channel();
333
334        self.0
335            .send(ToolServerRequest {
336                callback_channel: tx,
337                data: ToolServerRequestMessageKind::CallTool {
338                    name: tool_name.to_string(),
339                    args: args.to_string(),
340                    span: tracing::Span::current(),
341                },
342            })
343            .await?;
344
345        let res = rx.await?;
346
347        match res {
348            ToolServerResponse::ToolExecuted { result, .. } => Ok(result),
349            ToolServerResponse::ToolError { error } => Err(ToolServerError::ToolsetError(
350                ToolSetError::ToolCallError(ToolError::ToolCallError(error.into())),
351            )),
352            invalid => Err(ToolServerError::InvalidMessage(invalid)),
353        }
354    }
355
356    pub async fn get_tool_defs(
357        &self,
358        prompt: Option<String>,
359    ) -> Result<Vec<ToolDefinition>, ToolServerError> {
360        let (tx, rx) = futures::channel::oneshot::channel();
361
362        self.0
363            .send(ToolServerRequest {
364                callback_channel: tx,
365                data: ToolServerRequestMessageKind::GetToolDefs { prompt },
366            })
367            .await?;
368
369        let res = rx.await?;
370
371        let ToolServerResponse::ToolDefinitions(tooldefs) = res else {
372            return Err(ToolServerError::InvalidMessage(res));
373        };
374
375        Ok(tooldefs)
376    }
377}
378
379pub struct ToolServerRequest {
380    callback_channel: futures::channel::oneshot::Sender<ToolServerResponse>,
381    data: ToolServerRequestMessageKind,
382}
383
384pub enum ToolServerRequestMessageKind {
385    AddTool(Box<dyn ToolDyn>),
386    AppendToolset(ToolSet),
387    RemoveTool {
388        tool_name: String,
389    },
390    CallTool {
391        name: String,
392        args: String,
393        span: tracing::Span,
394    },
395    GetToolDefs {
396        prompt: Option<String>,
397    },
398}
399
400#[derive(PartialEq, Debug)]
401pub enum ToolServerResponse {
402    ToolAdded,
403    ToolDeleted,
404    ToolExecuted { result: String },
405    ToolError { error: String },
406    ToolDefinitions(Vec<ToolDefinition>),
407}
408
409#[derive(Debug, thiserror::Error)]
410pub enum ToolServerError {
411    #[error("Sending message was cancelled")]
412    Canceled(#[from] Canceled),
413    #[error("Toolset error: {0}")]
414    ToolsetError(#[from] ToolSetError),
415    #[error("Error while sending message: {0}")]
416    SendError(#[from] SendError<ToolServerRequest>),
417    #[error("An invalid message type was returned")]
418    InvalidMessage(ToolServerResponse),
419}
420
421#[cfg(test)]
422mod tests {
423    use std::time::Duration;
424
425    use serde::{Deserialize, Serialize};
426    use serde_json::json;
427
428    use crate::{
429        completion::ToolDefinition,
430        tool::{Tool, ToolSet, server::ToolServer},
431        vector_store::{
432            VectorStoreError, VectorStoreIndex,
433            request::{Filter, VectorSearchRequest},
434        },
435        wasm_compat::WasmCompatSend,
436    };
437
438    #[derive(Deserialize)]
439    struct OperationArgs {
440        x: i32,
441        y: i32,
442    }
443
444    #[derive(Debug, thiserror::Error)]
445    #[error("Math error")]
446    struct MathError;
447
448    #[derive(Deserialize, Serialize)]
449    struct Adder;
450    impl Tool for Adder {
451        const NAME: &'static str = "add";
452        type Error = MathError;
453        type Args = OperationArgs;
454        type Output = i32;
455
456        async fn definition(&self, _prompt: String) -> ToolDefinition {
457            ToolDefinition {
458                name: "add".to_string(),
459                description: "Add x and y together".to_string(),
460                parameters: json!({
461                    "type": "object",
462                    "properties": {
463                        "x": {
464                            "type": "number",
465                            "description": "The first number to add"
466                        },
467                        "y": {
468                            "type": "number",
469                            "description": "The second number to add"
470                        }
471                    },
472                    "required": ["x", "y"],
473                }),
474            }
475        }
476
477        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
478            println!("[tool-call] Adding {} and {}", args.x, args.y);
479            let result = args.x + args.y;
480            Ok(result)
481        }
482    }
483
484    #[derive(Deserialize, Serialize)]
485    struct Subtractor;
486    impl Tool for Subtractor {
487        const NAME: &'static str = "subtract";
488        type Error = MathError;
489        type Args = OperationArgs;
490        type Output = i32;
491
492        async fn definition(&self, _prompt: String) -> ToolDefinition {
493            ToolDefinition {
494                name: "subtract".to_string(),
495                description: "Subtract y from x".to_string(),
496                parameters: json!({
497                    "type": "object",
498                    "properties": {
499                        "x": {
500                            "type": "number",
501                            "description": "The number to subtract from"
502                        },
503                        "y": {
504                            "type": "number",
505                            "description": "The number to subtract"
506                        }
507                    },
508                    "required": ["x", "y"],
509                }),
510            }
511        }
512
513        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
514            let result = args.x - args.y;
515            Ok(result)
516        }
517    }
518
519    /// A mock vector store index that returns a predefined list of tool IDs.
520    struct MockToolIndex {
521        tool_ids: Vec<String>,
522    }
523
524    impl VectorStoreIndex for MockToolIndex {
525        type Filter = Filter<serde_json::Value>;
526
527        async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
528            &self,
529            _req: VectorSearchRequest,
530        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
531            // Not used by get_tool_definitions, but required by trait
532            Ok(vec![])
533        }
534
535        async fn top_n_ids(
536            &self,
537            _req: VectorSearchRequest,
538        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
539            Ok(self
540                .tool_ids
541                .iter()
542                .enumerate()
543                .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
544                .collect())
545        }
546    }
547
548    #[tokio::test]
549    pub async fn test_toolserver() {
550        let server = ToolServer::new();
551
552        let handle = server.run();
553
554        handle.add_tool(Adder).await.unwrap();
555        let res = handle.get_tool_defs(None).await.unwrap();
556
557        assert_eq!(res.len(), 1);
558
559        let json_args_as_string =
560            serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
561        let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
562        assert_eq!(res, "7");
563
564        handle.remove_tool("add").await.unwrap();
565        let res = handle.get_tool_defs(None).await.unwrap();
566
567        assert_eq!(res.len(), 0);
568    }
569
570    #[tokio::test]
571    pub async fn test_toolserver_dynamic_tools() {
572        // Create a toolset with both tools
573        let mut toolset = ToolSet::default();
574        toolset.add_tool(Adder);
575        toolset.add_tool(Subtractor);
576
577        // Create a mock index that will return "subtract" as the dynamic tool
578        let mock_index = MockToolIndex {
579            tool_ids: vec!["subtract".to_string()],
580        };
581
582        // Build server with static tool "add" and dynamic tools from the mock index
583        let server = ToolServer::new().tool(Adder).dynamic_tools(
584            1,
585            mock_index,
586            ToolSet::from_tools(vec![Subtractor]),
587        );
588
589        let handle = server.run();
590
591        // Test with None prompt - should only return static tools
592        let res = handle.get_tool_defs(None).await.unwrap();
593        assert_eq!(res.len(), 1);
594        assert_eq!(res[0].name, "add");
595
596        // Test with Some prompt - should return both static and dynamic tools
597        let res = handle
598            .get_tool_defs(Some("calculate difference".to_string()))
599            .await
600            .unwrap();
601        assert_eq!(res.len(), 2);
602
603        // Check that both tools are present (order may vary)
604        let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
605        assert!(tool_names.contains(&"add"));
606        assert!(tool_names.contains(&"subtract"));
607    }
608
609    #[tokio::test]
610    pub async fn test_toolserver_dynamic_tools_missing_implementation() {
611        // Create a mock index that returns a tool ID that doesn't exist in the toolset
612        let mock_index = MockToolIndex {
613            tool_ids: vec!["nonexistent_tool".to_string()],
614        };
615
616        // Build server with only static tool, but dynamic index references missing tool
617        let server = ToolServer::new()
618            .tool(Adder)
619            .dynamic_tools(1, mock_index, ToolSet::default());
620
621        let handle = server.run();
622
623        // Test with Some prompt - should only return static tool since dynamic tool is missing
624        let res = handle
625            .get_tool_defs(Some("some query".to_string()))
626            .await
627            .unwrap();
628        assert_eq!(res.len(), 1);
629        assert_eq!(res[0].name, "add");
630    }
631
632    #[derive(Debug, thiserror::Error)]
633    #[error("Sleeper error")]
634    struct SleeperError;
635
636    /// A tool that sleeps for a configurable duration, used to test concurrent execution.
637    #[derive(Deserialize, Serialize, Clone)]
638    struct SleeperTool {
639        sleep_duration_ms: u64,
640    }
641
642    impl SleeperTool {
643        fn new(sleep_duration_ms: u64) -> Self {
644            Self { sleep_duration_ms }
645        }
646    }
647
648    impl Tool for SleeperTool {
649        const NAME: &'static str = "sleeper";
650        type Error = SleeperError;
651        type Args = serde_json::Value;
652        type Output = u64;
653
654        async fn definition(&self, _prompt: String) -> ToolDefinition {
655            ToolDefinition {
656                name: "sleeper".to_string(),
657                description: "Sleeps for configured duration".to_string(),
658                parameters: json!({"type": "object", "properties": {}}),
659            }
660        }
661
662        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
663            tokio::time::sleep(Duration::from_millis(self.sleep_duration_ms)).await;
664            Ok(self.sleep_duration_ms)
665        }
666    }
667
668    #[tokio::test]
669    pub async fn test_toolserver_concurrent_tool_execution() {
670        let sleep_ms: u64 = 100;
671        let num_calls: u64 = 3;
672
673        let server = ToolServer::new().tool(SleeperTool::new(sleep_ms));
674        let handle = server.run();
675
676        let start = std::time::Instant::now();
677
678        // Make concurrent calls
679        let futures: Vec<_> = (0..num_calls)
680            .map(|_| handle.call_tool("sleeper", "{}"))
681            .collect();
682        let results = futures::future::join_all(futures).await;
683
684        let elapsed = start.elapsed();
685
686        // All calls should succeed
687        for result in &results {
688            assert!(result.is_ok(), "Tool call failed: {:?}", result);
689        }
690
691        // If concurrent: elapsed ≈ 100ms (plus overhead)
692        // If sequential: elapsed ≈ 300ms
693        // Threshold: less than 2x single sleep duration means concurrent execution
694        let max_concurrent_time = Duration::from_millis(sleep_ms * 2);
695        assert!(
696            elapsed < max_concurrent_time,
697            "Expected concurrent execution in < {:?}, but took {:?}. Tools may be running sequentially.",
698            max_concurrent_time,
699            elapsed
700        );
701    }
702}