Skip to main content

rig/tool/
server.rs

1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6    completion::{CompletionError, ToolDefinition},
7    tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11/// Shared state behind a `ToolServerHandle`.
12struct ToolServerState {
13    /// Static tool names that persist until explicitly removed.
14    static_tool_names: Vec<String>,
15    /// Dynamic tools fetched from vector stores on each prompt.
16    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
17    /// The toolset where tools are registered and executed.
18    toolset: ToolSet,
19}
20
21/// Builder for constructing a [`ToolServerHandle`].
22///
23/// Accumulates tools and configuration, then produces a shared handle via
24/// [`run()`](ToolServer::run).
25pub struct ToolServer {
26    static_tool_names: Vec<String>,
27    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28    toolset: ToolSet,
29}
30
31impl Default for ToolServer {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl ToolServer {
38    pub fn new() -> Self {
39        Self {
40            static_tool_names: Vec::new(),
41            dynamic_tools: Vec::new(),
42            toolset: ToolSet::default(),
43        }
44    }
45
46    pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
47        self.static_tool_names = names;
48        self
49    }
50
51    pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
52        self.toolset = tools;
53        self
54    }
55
56    pub(crate) fn add_dynamic_tools(
57        mut self,
58        dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
59    ) -> Self {
60        self.dynamic_tools = dyn_tools;
61        self
62    }
63
64    /// Add a static tool to the agent
65    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
66        let toolname = tool.name();
67        self.toolset.add_tool(tool);
68        self.static_tool_names.push(toolname);
69        self
70    }
71
72    /// Add an MCP tool (from `rmcp`) to the agent
73    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
74    #[cfg(feature = "rmcp")]
75    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
76        use crate::tool::rmcp::McpTool;
77        let toolname = tool.name.clone();
78        self.toolset
79            .add_tool(McpTool::from_mcp_server(tool, client));
80        self.static_tool_names.push(toolname.to_string());
81        self
82    }
83
84    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
85    /// dynamic toolset will be inserted in the request.
86    pub fn dynamic_tools(
87        mut self,
88        sample: usize,
89        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
90        toolset: ToolSet,
91    ) -> Self {
92        self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
93        self.toolset.add_tools(toolset);
94        self
95    }
96
97    /// Consume the builder and return a shared [`ToolServerHandle`].
98    pub fn run(self) -> ToolServerHandle {
99        ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
100            static_tool_names: self.static_tool_names,
101            dynamic_tools: self.dynamic_tools,
102            toolset: self.toolset,
103        })))
104    }
105}
106
107/// A cheaply-cloneable handle to the shared tool server state.
108///
109/// All operations acquire locks directly on the underlying state.
110/// Multiple handles (e.g. across agents) can share the same state
111/// without channel-based message routing.
112#[derive(Clone)]
113pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
114
115impl ToolServerHandle {
116    /// Register a new static tool.
117    pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
118        let mut state = self.0.write().await;
119        state.static_tool_names.push(tool.name());
120        state.toolset.add_tool_boxed(Box::new(tool));
121        Ok(())
122    }
123
124    /// Merge an entire toolset into the server.
125    pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
126        let mut state = self.0.write().await;
127        state.toolset.add_tools(toolset);
128        Ok(())
129    }
130
131    /// Remove a tool by name from both the toolset and the static list.
132    pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
133        let mut state = self.0.write().await;
134        state.static_tool_names.retain(|x| *x != tool_name);
135        state.toolset.delete_tool(tool_name);
136        Ok(())
137    }
138
139    /// Look up and execute a tool by name.
140    ///
141    /// The tool handle is cloned under a brief read lock so that
142    /// long-running tool executions never block writers.
143    pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
144        let tool = {
145            let state = self.0.read().await;
146            state.toolset.get(tool_name).cloned()
147        };
148
149        match tool {
150            Some(tool) => {
151                tracing::debug!(target: "rig",
152                    "Calling tool {tool_name} with args:\n{}",
153                    serde_json::to_string_pretty(&args).unwrap_or_default()
154                );
155                tool.call(args.to_string())
156                    .await
157                    .map_err(|e| ToolSetError::ToolCallError(e).into())
158            }
159            None => Err(ToolServerError::ToolsetError(
160                ToolSetError::ToolNotFoundError(tool_name.to_string()),
161            )),
162        }
163    }
164
165    /// Retrieve tool definitions, optionally using a prompt to select
166    /// dynamic tools from configured vector stores.
167    pub async fn get_tool_defs(
168        &self,
169        prompt: Option<String>,
170    ) -> Result<Vec<ToolDefinition>, ToolServerError> {
171        // Snapshot the metadata we need under a brief read lock
172        let (static_tool_names, dynamic_tools) = {
173            let state = self.0.read().await;
174            (state.static_tool_names.clone(), state.dynamic_tools.clone())
175        };
176
177        let mut tools = if let Some(ref text) = prompt {
178            // Create a future for each dynamic tool index
179            let search_futures = dynamic_tools.iter().map(|(num_sample, index)| {
180                let text = text.clone();
181                let num_sample = *num_sample;
182                let index = index.clone();
183
184                async move {
185                    let req = VectorSearchRequest::builder()
186                        .query(text)
187                        .samples(num_sample as u64)
188                        .build();
189
190                    let ids = index
191                        .as_ref()
192                        .top_n_ids(req.map_filter(Filter::interpret))
193                        .await?
194                        .into_iter()
195                        .map(|(_, id)| id)
196                        .collect::<Vec<String>>();
197
198                    Ok::<_, VectorStoreError>(ids)
199                }
200            });
201
202            // Execute searches concurrently and collect/flatten the IDs
203            let dynamic_tool_ids: Vec<String> = futures::future::try_join_all(search_futures)
204                .await
205                .map_err(|e| {
206                    ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
207                })?
208                .into_iter()
209                .flatten()
210                .collect();
211
212            let dynamic_tool_handles: Vec<_> = {
213                let state = self.0.read().await;
214                dynamic_tool_ids
215                    .iter()
216                    .filter_map(|doc| {
217                        let handle = state.toolset.get(doc).cloned();
218                        if handle.is_none() {
219                            tracing::warn!("Tool implementation not found in toolset: {}", doc);
220                        }
221                        handle
222                    })
223                    .collect()
224            };
225
226            let mut tools = Vec::new();
227            for tool in dynamic_tool_handles {
228                tools.push(tool.definition(text.clone()).await);
229            }
230            tools
231        } else {
232            Vec::new()
233        };
234
235        let static_tool_handles: Vec<_> = {
236            let state = self.0.read().await;
237            static_tool_names
238                .iter()
239                .filter_map(|toolname| {
240                    let handle = state.toolset.get(toolname).cloned();
241                    if handle.is_none() {
242                        tracing::warn!("Tool implementation not found in toolset: {}", toolname);
243                    }
244                    handle
245                })
246                .collect()
247        };
248
249        for tool in static_tool_handles {
250            tools.push(tool.definition(String::new()).await);
251        }
252
253        Ok(tools)
254    }
255}
256
257#[derive(Debug, thiserror::Error)]
258pub enum ToolServerError {
259    #[error("Toolset error: {0}")]
260    ToolsetError(#[from] ToolSetError),
261    #[error("Failed to retrieve tool definitions: {0}")]
262    DefinitionError(CompletionError),
263}
264
265#[cfg(test)]
266mod tests {
267    use std::{sync::Arc, time::Duration};
268
269    use serde::{Deserialize, Serialize};
270    use serde_json::json;
271
272    use crate::{
273        completion::ToolDefinition,
274        tool::{Tool, ToolSet, server::ToolServer},
275        vector_store::{
276            VectorStoreError, VectorStoreIndex,
277            request::{Filter, VectorSearchRequest},
278        },
279        wasm_compat::WasmCompatSend,
280    };
281
282    #[derive(Deserialize)]
283    struct OperationArgs {
284        x: i32,
285        y: i32,
286    }
287
288    #[derive(Debug, thiserror::Error)]
289    #[error("Math error")]
290    struct MathError;
291
292    #[derive(Deserialize, Serialize)]
293    struct Adder;
294    impl Tool for Adder {
295        const NAME: &'static str = "add";
296        type Error = MathError;
297        type Args = OperationArgs;
298        type Output = i32;
299
300        async fn definition(&self, _prompt: String) -> ToolDefinition {
301            ToolDefinition {
302                name: "add".to_string(),
303                description: "Add x and y together".to_string(),
304                parameters: json!({
305                    "type": "object",
306                    "properties": {
307                        "x": {
308                            "type": "number",
309                            "description": "The first number to add"
310                        },
311                        "y": {
312                            "type": "number",
313                            "description": "The second number to add"
314                        }
315                    },
316                    "required": ["x", "y"],
317                }),
318            }
319        }
320
321        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
322            println!("[tool-call] Adding {} and {}", args.x, args.y);
323            let result = args.x + args.y;
324            Ok(result)
325        }
326    }
327
328    #[derive(Deserialize, Serialize)]
329    struct Subtractor;
330    impl Tool for Subtractor {
331        const NAME: &'static str = "subtract";
332        type Error = MathError;
333        type Args = OperationArgs;
334        type Output = i32;
335
336        async fn definition(&self, _prompt: String) -> ToolDefinition {
337            ToolDefinition {
338                name: "subtract".to_string(),
339                description: "Subtract y from x".to_string(),
340                parameters: json!({
341                    "type": "object",
342                    "properties": {
343                        "x": {
344                            "type": "number",
345                            "description": "The number to subtract from"
346                        },
347                        "y": {
348                            "type": "number",
349                            "description": "The number to subtract"
350                        }
351                    },
352                    "required": ["x", "y"],
353                }),
354            }
355        }
356
357        async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
358            let result = args.x - args.y;
359            Ok(result)
360        }
361    }
362
363    /// A mock vector store index that returns a predefined list of tool IDs.
364    struct MockToolIndex {
365        tool_ids: Vec<String>,
366    }
367
368    impl VectorStoreIndex for MockToolIndex {
369        type Filter = Filter<serde_json::Value>;
370
371        async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
372            &self,
373            _req: VectorSearchRequest,
374        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
375            // Not used by get_tool_definitions, but required by trait
376            Ok(vec![])
377        }
378
379        async fn top_n_ids(
380            &self,
381            _req: VectorSearchRequest,
382        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
383            Ok(self
384                .tool_ids
385                .iter()
386                .enumerate()
387                .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
388                .collect())
389        }
390    }
391
392    #[tokio::test]
393    pub async fn test_toolserver() {
394        let server = ToolServer::new();
395
396        let handle = server.run();
397
398        handle.add_tool(Adder).await.unwrap();
399        let res = handle.get_tool_defs(None).await.unwrap();
400
401        assert_eq!(res.len(), 1);
402
403        let json_args_as_string =
404            serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
405        let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
406        assert_eq!(res, "7");
407
408        handle.remove_tool("add").await.unwrap();
409        let res = handle.get_tool_defs(None).await.unwrap();
410
411        assert_eq!(res.len(), 0);
412    }
413
414    #[tokio::test]
415    pub async fn test_toolserver_dynamic_tools() {
416        // Create a toolset with both tools
417        let mut toolset = ToolSet::default();
418        toolset.add_tool(Adder);
419        toolset.add_tool(Subtractor);
420
421        // Create a mock index that will return "subtract" as the dynamic tool
422        let mock_index = MockToolIndex {
423            tool_ids: vec!["subtract".to_string()],
424        };
425
426        // Build server with static tool "add" and dynamic tools from the mock index
427        let server = ToolServer::new().tool(Adder).dynamic_tools(
428            1,
429            mock_index,
430            ToolSet::from_tools(vec![Subtractor]),
431        );
432
433        let handle = server.run();
434
435        // Test with None prompt - should only return static tools
436        let res = handle.get_tool_defs(None).await.unwrap();
437        assert_eq!(res.len(), 1);
438        assert_eq!(res[0].name, "add");
439
440        // Test with Some prompt - should return both static and dynamic tools
441        let res = handle
442            .get_tool_defs(Some("calculate difference".to_string()))
443            .await
444            .unwrap();
445        assert_eq!(res.len(), 2);
446
447        // Check that both tools are present (order may vary)
448        let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
449        assert!(tool_names.contains(&"add"));
450        assert!(tool_names.contains(&"subtract"));
451    }
452
453    #[tokio::test]
454    pub async fn test_toolserver_dynamic_tools_missing_implementation() {
455        // Create a mock index that returns a tool ID that doesn't exist in the toolset
456        let mock_index = MockToolIndex {
457            tool_ids: vec!["nonexistent_tool".to_string()],
458        };
459
460        // Build server with only static tool, but dynamic index references missing tool
461        let server = ToolServer::new()
462            .tool(Adder)
463            .dynamic_tools(1, mock_index, ToolSet::default());
464
465        let handle = server.run();
466
467        // Test with Some prompt - should only return static tool since dynamic tool is missing
468        let res = handle
469            .get_tool_defs(Some("some query".to_string()))
470            .await
471            .unwrap();
472        assert_eq!(res.len(), 1);
473        assert_eq!(res[0].name, "add");
474    }
475
476    /// A tool that waits at a barrier to test concurrency of tool execution.
477    #[derive(Clone)]
478    struct BarrierTool {
479        barrier: Arc<tokio::sync::Barrier>,
480    }
481
482    #[derive(Debug, thiserror::Error)]
483    #[error("Barrier error")]
484    struct BarrierError;
485
486    impl Tool for BarrierTool {
487        const NAME: &'static str = "barrier_tool";
488        type Error = BarrierError;
489        type Args = serde_json::Value;
490        type Output = String;
491
492        async fn definition(&self, _prompt: String) -> ToolDefinition {
493            ToolDefinition {
494                name: "barrier_tool".to_string(),
495                description: "Waits at a barrier to test concurrency".to_string(),
496                parameters: serde_json::json!({"type": "object", "properties": {}}),
497            }
498        }
499
500        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
501            // Wait for all concurrent invocations to reach this point
502            self.barrier.wait().await;
503            Ok("done".to_string())
504        }
505    }
506
507    #[tokio::test]
508    pub async fn test_toolserver_concurrent_tool_execution() {
509        let num_calls = 3;
510        let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
511
512        let server = ToolServer::new().tool(BarrierTool {
513            barrier: barrier.clone(),
514        });
515        let handle = server.run();
516
517        // Make concurrent calls
518        let futures: Vec<_> = (0..num_calls)
519            .map(|_| handle.call_tool("barrier_tool", "{}"))
520            .collect();
521
522        // If execution is sequential, the first call will block at the barrier forever.
523        // We use a 1-second timeout to fail fast instead of hanging the test runner.
524        let result =
525            tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
526
527        assert!(
528            result.is_ok(),
529            "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
530        );
531
532        // All calls should succeed
533        for res in result.unwrap() {
534            assert!(res.is_ok(), "Tool call failed: {:?}", res);
535            assert_eq!(res.unwrap(), "done");
536        }
537    }
538
539    /// A tool that can be controlled to test concurrent writes to the ToolServer.
540    #[derive(Clone)]
541    struct ControlledTool {
542        started: Arc<tokio::sync::Notify>,
543        allow_finish: Arc<tokio::sync::Notify>,
544    }
545
546    #[derive(Debug, thiserror::Error)]
547    #[error("Controlled error")]
548    struct ControlledError;
549
550    impl Tool for ControlledTool {
551        const NAME: &'static str = "controlled";
552        type Error = ControlledError;
553        type Args = serde_json::Value;
554        type Output = i32;
555
556        async fn definition(&self, _prompt: String) -> ToolDefinition {
557            ToolDefinition {
558                name: "controlled".to_string(),
559                description: "Test tool".to_string(),
560                parameters: serde_json::json!({"type": "object", "properties": {}}),
561            }
562        }
563
564        async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
565            // 1. Signal that we are inside the call (lock should be dropped by now)
566            self.started.notify_one();
567            // 2. Wait indefinitely until the test allows us to finish
568            self.allow_finish.notified().await;
569            Ok(42)
570        }
571    }
572
573    #[tokio::test]
574    pub async fn test_toolserver_write_while_tool_running() {
575        let started = Arc::new(tokio::sync::Notify::new());
576        let allow_finish = Arc::new(tokio::sync::Notify::new());
577
578        // Build server with the controlled tool that waits at a barrier during execution
579        let tool = ControlledTool {
580            started: started.clone(),
581            allow_finish: allow_finish.clone(),
582        };
583
584        let server = ToolServer::new().tool(tool);
585        let handle = server.run();
586
587        // Start tool call in background
588        let handle_clone = handle.clone();
589        let call_task =
590            tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
591
592        // Wait until we are strictly inside `call()`
593        started.notified().await;
594
595        // Try to write to the state (add a tool) while the tool call is mid-execution.
596        // If the read lock is incorrectly held across tool execution, this will deadlock.
597        let add_result = tokio::time::timeout(Duration::from_secs(1), handle.add_tool(Adder)).await;
598
599        assert!(
600            add_result.is_ok(),
601            "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
602        );
603        assert!(add_result.unwrap().is_ok());
604
605        // Allow the background tool to finish and clean up
606        allow_finish.notify_one();
607        let call_result = call_task.await.unwrap();
608        assert_eq!(call_result.unwrap(), "42");
609    }
610
611    /// A mock vector store index that waits at a barrier to enforce parallel execution
612    struct BarrierMockIndex {
613        barrier: Arc<tokio::sync::Barrier>,
614        tool_id: String,
615    }
616
617    impl VectorStoreIndex for BarrierMockIndex {
618        type Filter = Filter<serde_json::Value>;
619
620        async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
621            &self,
622            _req: VectorSearchRequest,
623        ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
624            Ok(vec![])
625        }
626
627        async fn top_n_ids(
628            &self,
629            _req: VectorSearchRequest,
630        ) -> Result<Vec<(f64, String)>, VectorStoreError> {
631            // Wait for all indices to reach this point simultaneously
632            self.barrier.wait().await;
633            Ok(vec![(1.0, self.tool_id.clone())])
634        }
635    }
636
637    #[tokio::test]
638    pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
639        // We expect exactly 2 parallel searches to hit the barrier at the same time
640        let barrier = Arc::new(tokio::sync::Barrier::new(2));
641
642        let index1 = BarrierMockIndex {
643            barrier: barrier.clone(),
644            tool_id: "add".to_string(),
645        };
646
647        let index2 = BarrierMockIndex {
648            barrier: barrier.clone(),
649            tool_id: "subtract".to_string(),
650        };
651
652        // Put both tools in the toolset so they resolve correctly
653        let mut toolset = ToolSet::default();
654        toolset.add_tool(Adder);
655        toolset.add_tool(Subtractor);
656
657        let server = ToolServer::new()
658            .dynamic_tools(1, index1, ToolSet::default())
659            .dynamic_tools(1, index2, toolset);
660
661        let handle = server.run();
662
663        // This will trigger a search across both indices.
664        // If fetched sequentially, the first index will wait at the barrier forever.
665        let get_defs = tokio::time::timeout(
666            std::time::Duration::from_secs(1),
667            handle.get_tool_defs(Some("do math".to_string())),
668        )
669        .await;
670
671        assert!(
672            get_defs.is_ok(),
673            "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
674        );
675
676        let defs = get_defs.unwrap().unwrap();
677        assert_eq!(defs.len(), 2);
678
679        let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
680        assert!(tool_names.contains(&"add"));
681        assert!(tool_names.contains(&"subtract"));
682    }
683}