Skip to main content

rig/tool/
server.rs

1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6    completion::{CompletionError, ToolDefinition},
7    tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11/// Shared state behind a `ToolServerHandle`.
12struct ToolServerState {
13    /// Static tool names that persist until explicitly removed.
14    static_tool_names: Vec<String>,
15    /// Dynamic tools fetched from vector stores on each prompt.
16    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
17    /// The toolset where tools are registered and executed.
18    toolset: ToolSet,
19}
20
21/// Builder for constructing a [`ToolServerHandle`].
22///
23/// Accumulates tools and configuration, then produces a shared handle via
24/// [`run()`](ToolServer::run).
25pub struct ToolServer {
26    static_tool_names: Vec<String>,
27    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28    toolset: ToolSet,
29}
30
31impl Default for ToolServer {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl ToolServer {
38    pub fn new() -> Self {
39        Self {
40            static_tool_names: Vec::new(),
41            dynamic_tools: Vec::new(),
42            toolset: ToolSet::default(),
43        }
44    }
45
46    pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
47        self.static_tool_names = names;
48        self
49    }
50
51    pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
52        self.toolset = tools;
53        self
54    }
55
56    pub(crate) fn add_dynamic_tools(
57        mut self,
58        dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
59    ) -> Self {
60        self.dynamic_tools = dyn_tools;
61        self
62    }
63
64    /// Add a static tool to the agent
65    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
66        let toolname = tool.name();
67        self.toolset.add_tool(tool);
68        self.static_tool_names.push(toolname);
69        self
70    }
71
72    /// Add an MCP tool (from `rmcp`) to the agent
73    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
74    #[cfg(feature = "rmcp")]
75    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
76        use crate::tool::rmcp::McpTool;
77        let toolname = tool.name.clone();
78        self.toolset
79            .add_tool(McpTool::from_mcp_server(tool, client));
80        self.static_tool_names.push(toolname.to_string());
81        self
82    }
83
84    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
85    /// dynamic toolset will be inserted in the request.
86    pub fn dynamic_tools(
87        mut self,
88        sample: usize,
89        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
90        toolset: ToolSet,
91    ) -> Self {
92        self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
93        self.toolset.add_tools(toolset);
94        self
95    }
96
97    /// Consume the builder and return a shared [`ToolServerHandle`].
98    pub fn run(self) -> ToolServerHandle {
99        ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
100            static_tool_names: self.static_tool_names,
101            dynamic_tools: self.dynamic_tools,
102            toolset: self.toolset,
103        })))
104    }
105}
106
107/// A cheaply-cloneable handle to the shared tool server state.
108///
109/// All operations acquire locks directly on the underlying state.
110/// Multiple handles (e.g. across agents) can share the same state
111/// without channel-based message routing.
112#[derive(Clone)]
113pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
114
115impl ToolServerHandle {
116    /// Register a new static tool.
117    pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
118        let mut state = self.0.write().await;
119        state.static_tool_names.push(tool.name());
120        state.toolset.add_tool_boxed(Box::new(tool));
121        Ok(())
122    }
123
124    /// Merge an entire toolset into the server.
125    pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
126        let mut state = self.0.write().await;
127        state.toolset.add_tools(toolset);
128        Ok(())
129    }
130
131    /// Remove a tool by name from both the toolset and the static list.
132    pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
133        let mut state = self.0.write().await;
134        state.static_tool_names.retain(|x| *x != tool_name);
135        state.toolset.delete_tool(tool_name);
136        Ok(())
137    }
138
139    /// Look up and execute a tool by name.
140    ///
141    /// The tool handle is cloned under a brief read lock so that
142    /// long-running tool executions never block writers.
143    pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
144        let tool = {
145            let state = self.0.read().await;
146            state.toolset.get(tool_name).cloned()
147        };
148
149        match tool {
150            Some(tool) => {
151                tracing::debug!(target: "rig",
152                    "Calling tool {tool_name} with args:\n{}",
153                    serde_json::to_string_pretty(&args).unwrap_or_default()
154                );
155                tool.call(args.to_string())
156                    .await
157                    .map_err(|e| ToolSetError::ToolCallError(e).into())
158            }
159            None => Err(ToolServerError::ToolsetError(
160                ToolSetError::ToolNotFoundError(tool_name.to_string()),
161            )),
162        }
163    }
164
165    /// Retrieve tool definitions, optionally using a prompt to select
166    /// dynamic tools from configured vector stores.
167    pub async fn get_tool_defs(
168        &self,
169        prompt: Option<String>,
170    ) -> Result<Vec<ToolDefinition>, ToolServerError> {
171        // Snapshot the metadata we need under a brief read lock
172        let (static_tool_names, dynamic_tools) = {
173            let state = self.0.read().await;
174            (state.static_tool_names.clone(), state.dynamic_tools.clone())
175        };
176
177        let mut tools = if let Some(ref text) = prompt {
178            let futs: Vec<_> = dynamic_tools
179                .into_iter()
180                .map(|(num_sample, index)| {
181                    let text = text.clone();
182                    async move {
183                        let req = VectorSearchRequest::builder()
184                            .query(text)
185                            .samples(num_sample as u64)
186                            .build()
187                            .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
188                        Ok::<_, VectorStoreError>(
189                            index
190                                .top_n_ids(req.map_filter(Filter::interpret))
191                                .await?
192                                .into_iter()
193                                .map(|(_, id)| id)
194                                .collect::<Vec<String>>(),
195                        )
196                    }
197                })
198                .collect();
199
200            let results = futures::future::try_join_all(futs).await.map_err(|e| {
201                ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
202            })?;
203
204            let dynamic_tool_ids: Vec<String> = results.into_iter().flatten().collect();
205
206            let dynamic_tool_handles: Vec<_> = {
207                let state = self.0.read().await;
208                dynamic_tool_ids
209                    .iter()
210                    .filter_map(|doc| {
211                        let handle = state.toolset.get(doc).cloned();
212                        if handle.is_none() {
213                            tracing::warn!("Tool implementation not found in toolset: {}", doc);
214                        }
215                        handle
216                    })
217                    .collect()
218            };
219
220            let mut tools = Vec::new();
221            for tool in dynamic_tool_handles {
222                tools.push(tool.definition(text.clone()).await);
223            }
224            tools
225        } else {
226            Vec::new()
227        };
228
229        let static_tool_handles: Vec<_> = {
230            let state = self.0.read().await;
231            static_tool_names
232                .iter()
233                .filter_map(|toolname| {
234                    let handle = state.toolset.get(toolname).cloned();
235                    if handle.is_none() {
236                        tracing::warn!("Tool implementation not found in toolset: {}", toolname);
237                    }
238                    handle
239                })
240                .collect()
241        };
242
243        for tool in static_tool_handles {
244            tools.push(tool.definition(String::new()).await);
245        }
246
247        Ok(tools)
248    }
249}
250
251#[derive(Debug, thiserror::Error)]
252pub enum ToolServerError {
253    #[error("Toolset error: {0}")]
254    ToolsetError(#[from] ToolSetError),
255    #[error("Failed to retrieve tool definitions: {0}")]
256    DefinitionError(CompletionError),
257}
258
259#[cfg(test)]
260mod tests {
261    use std::{sync::Arc, time::Duration};
262
263    use serde::{Deserialize, Serialize};
264    use serde_json::json;
265
266    use crate::{
267        completion::ToolDefinition,
268        tool::{Tool, ToolSet, server::ToolServer},
269        vector_store::{
270            VectorStoreError, VectorStoreIndex,
271            request::{Filter, VectorSearchRequest},
272        },
273        wasm_compat::WasmCompatSend,
274    };
275
276    #[derive(Deserialize)]
277    struct OperationArgs {
278        x: i32,
279        y: i32,
280    }
281
282    #[derive(Debug, thiserror::Error)]
283    #[error("Math error")]
284    struct MathError;
285
286    #[derive(Deserialize, Serialize)]
287    struct Adder;
288    impl Tool for Adder {
289        const NAME: &'static str = "add";
290        type Error = MathError;
291        type Args = OperationArgs;
292        type Output = i32;
293
294        async fn definition(&self, _prompt: String) -> ToolDefinition {
295            ToolDefinition {
296                name: "add".to_string(),
297                description: "Add x and y together".to_string(),
298                parameters: json!({
299                    "type": "object",
300                    "properties": {
301                        "x": {
302                            "type": "number",
303                            "description": "The first number to add"
304                        },
305                        "y": {
306                            "type": "number",
307                            "description": "The second number to add"
308                        }
309                    },
310                    "required": ["x", "y"],
311                }),
312            }
313        }
314
315        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
316            println!("[tool-call] Adding {} and {}", args.x, args.y);
317            let result = args.x + args.y;
318            Ok(result)
319        }
320    }
321
322    #[derive(Deserialize, Serialize)]
323    struct Subtractor;
324    impl Tool for Subtractor {
325        const NAME: &'static str = "subtract";
326        type Error = MathError;
327        type Args = OperationArgs;
328        type Output = i32;
329
330        async fn definition(&self, _prompt: String) -> ToolDefinition {
331            ToolDefinition {
332                name: "subtract".to_string(),
333                description: "Subtract y from x".to_string(),
334                parameters: json!({
335                    "type": "object",
336                    "properties": {
337                        "x": {
338                            "type": "number",
339                            "description": "The number to subtract from"
340                        },
341                        "y": {
342                            "type": "number",
343                            "description": "The number to subtract"
344                        }
345                    },
346                    "required": ["x", "y"],
347                }),
348            }
349        }
350
351        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
352            let result = args.x - args.y;
353            Ok(result)
354        }
355    }
356
357    /// A mock vector store index that returns a predefined list of tool IDs.
358    struct MockToolIndex {
359        tool_ids: Vec<String>,
360    }
361
362    impl VectorStoreIndex for MockToolIndex {
363        type Filter = Filter<serde_json::Value>;
364
365        async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
366            &self,
367            _req: VectorSearchRequest,
368        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
369            // Not used by get_tool_definitions, but required by trait
370            Ok(vec![])
371        }
372
373        async fn top_n_ids(
374            &self,
375            _req: VectorSearchRequest,
376        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
377            Ok(self
378                .tool_ids
379                .iter()
380                .enumerate()
381                .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
382                .collect())
383        }
384    }
385
386    #[tokio::test]
387    pub async fn test_toolserver() {
388        let server = ToolServer::new();
389
390        let handle = server.run();
391
392        handle.add_tool(Adder).await.unwrap();
393        let res = handle.get_tool_defs(None).await.unwrap();
394
395        assert_eq!(res.len(), 1);
396
397        let json_args_as_string =
398            serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
399        let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
400        assert_eq!(res, "7");
401
402        handle.remove_tool("add").await.unwrap();
403        let res = handle.get_tool_defs(None).await.unwrap();
404
405        assert_eq!(res.len(), 0);
406    }
407
408    #[tokio::test]
409    pub async fn test_toolserver_dynamic_tools() {
410        // Create a toolset with both tools
411        let mut toolset = ToolSet::default();
412        toolset.add_tool(Adder);
413        toolset.add_tool(Subtractor);
414
415        // Create a mock index that will return "subtract" as the dynamic tool
416        let mock_index = MockToolIndex {
417            tool_ids: vec!["subtract".to_string()],
418        };
419
420        // Build server with static tool "add" and dynamic tools from the mock index
421        let server = ToolServer::new().tool(Adder).dynamic_tools(
422            1,
423            mock_index,
424            ToolSet::from_tools(vec![Subtractor]),
425        );
426
427        let handle = server.run();
428
429        // Test with None prompt - should only return static tools
430        let res = handle.get_tool_defs(None).await.unwrap();
431        assert_eq!(res.len(), 1);
432        assert_eq!(res[0].name, "add");
433
434        // Test with Some prompt - should return both static and dynamic tools
435        let res = handle
436            .get_tool_defs(Some("calculate difference".to_string()))
437            .await
438            .unwrap();
439        assert_eq!(res.len(), 2);
440
441        // Check that both tools are present (order may vary)
442        let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
443        assert!(tool_names.contains(&"add"));
444        assert!(tool_names.contains(&"subtract"));
445    }
446
447    #[tokio::test]
448    pub async fn test_toolserver_dynamic_tools_missing_implementation() {
449        // Create a mock index that returns a tool ID that doesn't exist in the toolset
450        let mock_index = MockToolIndex {
451            tool_ids: vec!["nonexistent_tool".to_string()],
452        };
453
454        // Build server with only static tool, but dynamic index references missing tool
455        let server = ToolServer::new()
456            .tool(Adder)
457            .dynamic_tools(1, mock_index, ToolSet::default());
458
459        let handle = server.run();
460
461        // Test with Some prompt - should only return static tool since dynamic tool is missing
462        let res = handle
463            .get_tool_defs(Some("some query".to_string()))
464            .await
465            .unwrap();
466        assert_eq!(res.len(), 1);
467        assert_eq!(res[0].name, "add");
468    }
469
470    /// A tool that waits at a barrier to test concurrency of tool execution.
471    #[derive(Clone)]
472    struct BarrierTool {
473        barrier: Arc<tokio::sync::Barrier>,
474    }
475
476    #[derive(Debug, thiserror::Error)]
477    #[error("Barrier error")]
478    struct BarrierError;
479
480    impl Tool for BarrierTool {
481        const NAME: &'static str = "barrier_tool";
482        type Error = BarrierError;
483        type Args = serde_json::Value;
484        type Output = String;
485
486        async fn definition(&self, _prompt: String) -> ToolDefinition {
487            ToolDefinition {
488                name: "barrier_tool".to_string(),
489                description: "Waits at a barrier to test concurrency".to_string(),
490                parameters: serde_json::json!({"type": "object", "properties": {}}),
491            }
492        }
493
494        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
495            // Wait for all concurrent invocations to reach this point
496            self.barrier.wait().await;
497            Ok("done".to_string())
498        }
499    }
500
501    #[tokio::test]
502    pub async fn test_toolserver_concurrent_tool_execution() {
503        let num_calls = 3;
504        let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
505
506        let server = ToolServer::new().tool(BarrierTool {
507            barrier: barrier.clone(),
508        });
509        let handle = server.run();
510
511        // Make concurrent calls
512        let futures: Vec<_> = (0..num_calls)
513            .map(|_| handle.call_tool("barrier_tool", "{}"))
514            .collect();
515
516        // If execution is sequential, the first call will block at the barrier forever.
517        // We use a 1-second timeout to fail fast instead of hanging the test runner.
518        let result =
519            tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
520
521        assert!(
522            result.is_ok(),
523            "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
524        );
525
526        // All calls should succeed
527        for res in result.unwrap() {
528            assert!(res.is_ok(), "Tool call failed: {:?}", res);
529            assert_eq!(res.unwrap(), "done");
530        }
531    }
532
533    /// A tool that can be controlled to test concurrent writes to the ToolServer.
534    #[derive(Clone)]
535    struct ControlledTool {
536        started: Arc<tokio::sync::Notify>,
537        allow_finish: Arc<tokio::sync::Notify>,
538    }
539
540    #[derive(Debug, thiserror::Error)]
541    #[error("Controlled error")]
542    struct ControlledError;
543
544    impl Tool for ControlledTool {
545        const NAME: &'static str = "controlled";
546        type Error = ControlledError;
547        type Args = serde_json::Value;
548        type Output = i32;
549
550        async fn definition(&self, _prompt: String) -> ToolDefinition {
551            ToolDefinition {
552                name: "controlled".to_string(),
553                description: "Test tool".to_string(),
554                parameters: serde_json::json!({"type": "object", "properties": {}}),
555            }
556        }
557
558        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
559            // 1. Signal that we are inside the call (lock should be dropped by now)
560            self.started.notify_one();
561            // 2. Wait indefinitely until the test allows us to finish
562            self.allow_finish.notified().await;
563            Ok(42)
564        }
565    }
566
567    #[tokio::test]
568    pub async fn test_toolserver_write_while_tool_running() {
569        let started = Arc::new(tokio::sync::Notify::new());
570        let allow_finish = Arc::new(tokio::sync::Notify::new());
571
572        // Build server with the controlled tool that waits at a barrier during execution
573        let tool = ControlledTool {
574            started: started.clone(),
575            allow_finish: allow_finish.clone(),
576        };
577
578        let server = ToolServer::new().tool(tool);
579        let handle = server.run();
580
581        // Start tool call in background
582        let handle_clone = handle.clone();
583        let call_task =
584            tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
585
586        // Wait until we are strictly inside `call()`
587        started.notified().await;
588
589        // Try to write to the state (add a tool) while the tool call is mid-execution.
590        // If the read lock is incorrectly held across tool execution, this will deadlock.
591        let add_result = tokio::time::timeout(Duration::from_secs(1), handle.add_tool(Adder)).await;
592
593        assert!(
594            add_result.is_ok(),
595            "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
596        );
597        assert!(add_result.unwrap().is_ok());
598
599        // Allow the background tool to finish and clean up
600        allow_finish.notify_one();
601        let call_result = call_task.await.unwrap();
602        assert_eq!(call_result.unwrap(), "42");
603    }
604
605    /// A mock vector store index that waits at a barrier to enforce parallel execution
606    struct BarrierMockIndex {
607        barrier: Arc<tokio::sync::Barrier>,
608        tool_id: String,
609    }
610
611    impl VectorStoreIndex for BarrierMockIndex {
612        type Filter = Filter<serde_json::Value>;
613
614        async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
615            &self,
616            _req: VectorSearchRequest,
617        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
618            Ok(vec![])
619        }
620
621        async fn top_n_ids(
622            &self,
623            _req: VectorSearchRequest,
624        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
625            // Wait for all indices to reach this point simultaneously
626            self.barrier.wait().await;
627            Ok(vec![(1.0, self.tool_id.clone())])
628        }
629    }
630
631    #[tokio::test]
632    pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
633        // We expect exactly 2 parallel searches to hit the barrier at the same time
634        let barrier = Arc::new(tokio::sync::Barrier::new(2));
635
636        let index1 = BarrierMockIndex {
637            barrier: barrier.clone(),
638            tool_id: "add".to_string(),
639        };
640
641        let index2 = BarrierMockIndex {
642            barrier: barrier.clone(),
643            tool_id: "subtract".to_string(),
644        };
645
646        // Put both tools in the toolset so they resolve correctly
647        let mut toolset = ToolSet::default();
648        toolset.add_tool(Adder);
649        toolset.add_tool(Subtractor);
650
651        let server = ToolServer::new()
652            .dynamic_tools(1, index1, ToolSet::default())
653            .dynamic_tools(1, index2, toolset);
654
655        let handle = server.run();
656
657        // This will trigger a search across both indices.
658        // If fetched sequentially, the first index will wait at the barrier forever.
659        let get_defs = tokio::time::timeout(
660            std::time::Duration::from_secs(1),
661            handle.get_tool_defs(Some("do math".to_string())),
662        )
663        .await;
664
665        assert!(
666            get_defs.is_ok(),
667            "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
668        );
669
670        let defs = get_defs.unwrap().unwrap();
671        assert_eq!(defs.len(), 2);
672
673        let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
674        assert!(tool_names.contains(&"add"));
675        assert!(tool_names.contains(&"subtract"));
676    }
677}