Skip to main content

rig_core/tool/
server.rs

1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6    completion::{CompletionError, ToolDefinition},
7    tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11/// Shared state behind a `ToolServerHandle`.
12struct ToolServerState {
13    /// Static tool names that persist until explicitly removed.
14    static_tool_names: Vec<String>,
15    /// Dynamic tools fetched from vector stores on each prompt.
16    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
17    /// The toolset where tools are registered and executed.
18    toolset: ToolSet,
19}
20
21/// Builder for constructing a [`ToolServerHandle`].
22///
23/// Accumulates tools and configuration, then produces a shared handle via
24/// [`run()`](ToolServer::run).
25pub struct ToolServer {
26    static_tool_names: Vec<String>,
27    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28    toolset: ToolSet,
29}
30
31impl Default for ToolServer {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl ToolServer {
38    pub fn new() -> Self {
39        Self {
40            static_tool_names: Vec::new(),
41            dynamic_tools: Vec::new(),
42            toolset: ToolSet::default(),
43        }
44    }
45
46    pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
47        self.static_tool_names = names;
48        self
49    }
50
51    pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
52        self.toolset = tools;
53        self
54    }
55
56    pub(crate) fn add_dynamic_tools(
57        mut self,
58        dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
59    ) -> Self {
60        self.dynamic_tools = dyn_tools;
61        self
62    }
63
64    /// Add a static tool to the agent
65    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
66        let toolname = tool.name();
67        self.toolset.add_tool(tool);
68        self.static_tool_names.push(toolname);
69        self
70    }
71
72    /// Add an MCP tool (from `rmcp`) to the agent
73    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
74    #[cfg(feature = "rmcp")]
75    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
76        use crate::tool::rmcp::McpTool;
77        let toolname = tool.name.clone();
78        self.toolset
79            .add_tool(McpTool::from_mcp_server(tool, client));
80        self.static_tool_names.push(toolname.to_string());
81        self
82    }
83
84    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
85    /// dynamic toolset will be inserted in the request.
86    pub fn dynamic_tools(
87        mut self,
88        sample: usize,
89        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
90        toolset: ToolSet,
91    ) -> Self {
92        self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
93        self.toolset.add_tools(toolset);
94        self
95    }
96
97    /// Consume the builder and return a shared [`ToolServerHandle`].
98    pub fn run(self) -> ToolServerHandle {
99        ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
100            static_tool_names: self.static_tool_names,
101            dynamic_tools: self.dynamic_tools,
102            toolset: self.toolset,
103        })))
104    }
105}
106
107/// A cheaply-cloneable handle to the shared tool server state.
108///
109/// All operations acquire locks directly on the underlying state.
110/// Multiple handles (e.g. across agents) can share the same state
111/// without channel-based message routing.
112#[derive(Clone)]
113pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
114
115impl ToolServerHandle {
116    /// Register a new static tool.
117    pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
118        let mut state = self.0.write().await;
119        state.static_tool_names.push(tool.name());
120        state.toolset.add_tool_boxed(Box::new(tool));
121        Ok(())
122    }
123
124    /// Merge an entire toolset into the server.
125    pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
126        let mut state = self.0.write().await;
127        state.toolset.add_tools(toolset);
128        Ok(())
129    }
130
131    /// Remove a tool by name from both the toolset and the static list.
132    pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
133        let mut state = self.0.write().await;
134        state.static_tool_names.retain(|x| *x != tool_name);
135        state.toolset.delete_tool(tool_name);
136        Ok(())
137    }
138
139    /// Look up and execute a tool by name.
140    ///
141    /// The tool handle is cloned under a brief read lock so that
142    /// long-running tool executions never block writers.
143    pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
144        let tool = {
145            let state = self.0.read().await;
146            state.toolset.get(tool_name).cloned()
147        };
148
149        match tool {
150            Some(tool) => {
151                tracing::debug!(target: "rig",
152                    "Calling tool {tool_name} with args:\n{}",
153                    serde_json::to_string_pretty(&args).unwrap_or_default()
154                );
155                tool.call(args.to_string())
156                    .await
157                    .map_err(|e| ToolSetError::ToolCallError(e).into())
158            }
159            None => Err(ToolServerError::ToolsetError(
160                ToolSetError::ToolNotFoundError(tool_name.to_string()),
161            )),
162        }
163    }
164
165    /// Retrieve tool definitions, optionally using a prompt to select
166    /// dynamic tools from configured vector stores.
167    pub async fn get_tool_defs(
168        &self,
169        prompt: Option<String>,
170    ) -> Result<Vec<ToolDefinition>, ToolServerError> {
171        // Snapshot the metadata we need under a brief read lock
172        let (static_tool_names, dynamic_tools) = {
173            let state = self.0.read().await;
174            (state.static_tool_names.clone(), state.dynamic_tools.clone())
175        };
176
177        let mut tools = if let Some(ref text) = prompt {
178            // Create a future for each dynamic tool index
179            let search_futures = dynamic_tools.iter().map(|(num_sample, index)| {
180                let text = text.clone();
181                let num_sample = *num_sample;
182                let index = index.clone();
183
184                async move {
185                    let req = VectorSearchRequest::builder()
186                        .query(text)
187                        .samples(num_sample as u64)
188                        .build();
189
190                    let ids = index
191                        .as_ref()
192                        .top_n_ids(req.map_filter(Filter::interpret))
193                        .await?
194                        .into_iter()
195                        .map(|(_, id)| id)
196                        .collect::<Vec<String>>();
197
198                    Ok::<_, VectorStoreError>(ids)
199                }
200            });
201
202            // Execute searches concurrently and collect/flatten the IDs
203            let dynamic_tool_ids: Vec<String> = futures::future::try_join_all(search_futures)
204                .await
205                .map_err(|e| {
206                    ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
207                })?
208                .into_iter()
209                .flatten()
210                .collect();
211
212            let dynamic_tool_handles: Vec<_> = {
213                let state = self.0.read().await;
214                dynamic_tool_ids
215                    .iter()
216                    .filter_map(|doc| {
217                        let handle = state.toolset.get(doc).cloned();
218                        if handle.is_none() {
219                            tracing::warn!("Tool implementation not found in toolset: {}", doc);
220                        }
221                        handle
222                    })
223                    .collect()
224            };
225
226            let mut tools = Vec::new();
227            for tool in dynamic_tool_handles {
228                tools.push(tool.definition(text.clone()).await);
229            }
230            tools
231        } else {
232            Vec::new()
233        };
234
235        let static_tool_handles: Vec<_> = {
236            let state = self.0.read().await;
237            static_tool_names
238                .iter()
239                .filter_map(|toolname| {
240                    let handle = state.toolset.get(toolname).cloned();
241                    if handle.is_none() {
242                        tracing::warn!("Tool implementation not found in toolset: {}", toolname);
243                    }
244                    handle
245                })
246                .collect()
247        };
248
249        for tool in static_tool_handles {
250            tools.push(tool.definition(String::new()).await);
251        }
252
253        Ok(tools)
254    }
255}
256
257#[derive(Debug, thiserror::Error)]
258pub enum ToolServerError {
259    #[error("Toolset error: {0}")]
260    ToolsetError(#[from] ToolSetError),
261    #[error("Failed to retrieve tool definitions: {0}")]
262    DefinitionError(CompletionError),
263}
264
265#[cfg(test)]
266mod tests {
267    use std::{sync::Arc, time::Duration};
268
269    use crate::{
270        test_utils::{
271            BarrierMockToolIndex, MockAddTool, MockBarrierTool, MockControlledTool,
272            MockSubtractTool, MockToolIndex,
273        },
274        tool::{ToolSet, server::ToolServer},
275    };
276
277    #[tokio::test]
278    pub async fn test_toolserver() {
279        let server = ToolServer::new();
280
281        let handle = server.run();
282
283        handle.add_tool(MockAddTool).await.unwrap();
284        let res = handle.get_tool_defs(None).await.unwrap();
285
286        assert_eq!(res.len(), 1);
287
288        let json_args_as_string =
289            serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
290        let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
291        assert_eq!(res, "7");
292
293        handle.remove_tool("add").await.unwrap();
294        let res = handle.get_tool_defs(None).await.unwrap();
295
296        assert_eq!(res.len(), 0);
297    }
298
299    #[tokio::test]
300    pub async fn test_toolserver_dynamic_tools() {
301        // Create a toolset with both tools
302        let mut toolset = ToolSet::default();
303        toolset.add_tool(MockAddTool);
304        toolset.add_tool(MockSubtractTool);
305
306        // Create a mock index that will return "subtract" as the dynamic tool
307        let mock_index = MockToolIndex::new(["subtract"]);
308
309        // Build server with static tool "add" and dynamic tools from the mock index
310        let server = ToolServer::new().tool(MockAddTool).dynamic_tools(
311            1,
312            mock_index,
313            ToolSet::from_tools(vec![MockSubtractTool]),
314        );
315
316        let handle = server.run();
317
318        // Test with None prompt - should only return static tools
319        let res = handle.get_tool_defs(None).await.unwrap();
320        assert_eq!(res.len(), 1);
321        assert_eq!(res[0].name, "add");
322
323        // Test with Some prompt - should return both static and dynamic tools
324        let res = handle
325            .get_tool_defs(Some("calculate difference".to_string()))
326            .await
327            .unwrap();
328        assert_eq!(res.len(), 2);
329
330        // Check that both tools are present (order may vary)
331        let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
332        assert!(tool_names.contains(&"add"));
333        assert!(tool_names.contains(&"subtract"));
334    }
335
336    #[tokio::test]
337    pub async fn test_toolserver_dynamic_tools_missing_implementation() {
338        // Create a mock index that returns a tool ID that doesn't exist in the toolset
339        let mock_index = MockToolIndex::new(["nonexistent_tool"]);
340
341        // Build server with only static tool, but dynamic index references missing tool
342        let server =
343            ToolServer::new()
344                .tool(MockAddTool)
345                .dynamic_tools(1, mock_index, ToolSet::default());
346
347        let handle = server.run();
348
349        // Test with Some prompt - should only return static tool since dynamic tool is missing
350        let res = handle
351            .get_tool_defs(Some("some query".to_string()))
352            .await
353            .unwrap();
354        assert_eq!(res.len(), 1);
355        assert_eq!(res[0].name, "add");
356    }
357
358    #[tokio::test]
359    pub async fn test_toolserver_concurrent_tool_execution() {
360        let num_calls = 3;
361        let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
362
363        let server = ToolServer::new().tool(MockBarrierTool::new(barrier.clone()));
364        let handle = server.run();
365
366        // Make concurrent calls
367        let futures: Vec<_> = (0..num_calls)
368            .map(|_| handle.call_tool("barrier_tool", "{}"))
369            .collect();
370
371        // If execution is sequential, the first call will block at the barrier forever.
372        // We use a 1-second timeout to fail fast instead of hanging the test runner.
373        let result =
374            tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
375
376        assert!(
377            result.is_ok(),
378            "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
379        );
380
381        // All calls should succeed
382        for res in result.unwrap() {
383            assert!(res.is_ok(), "Tool call failed: {:?}", res);
384            assert_eq!(res.unwrap(), "done");
385        }
386    }
387
388    #[tokio::test]
389    pub async fn test_toolserver_write_while_tool_running() {
390        let started = Arc::new(tokio::sync::Notify::new());
391        let allow_finish = Arc::new(tokio::sync::Notify::new());
392
393        // Build server with the controlled tool that waits at a barrier during execution
394        let tool = MockControlledTool::new(started.clone(), allow_finish.clone());
395
396        let server = ToolServer::new().tool(tool);
397        let handle = server.run();
398
399        // Start tool call in background
400        let handle_clone = handle.clone();
401        let call_task =
402            tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
403
404        // Wait until we are strictly inside `call()`
405        started.notified().await;
406
407        // Try to write to the state (add a tool) while the tool call is mid-execution.
408        // If the read lock is incorrectly held across tool execution, this will deadlock.
409        let add_result =
410            tokio::time::timeout(Duration::from_secs(1), handle.add_tool(MockAddTool)).await;
411
412        assert!(
413            add_result.is_ok(),
414            "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
415        );
416        assert!(add_result.unwrap().is_ok());
417
418        // Allow the background tool to finish and clean up
419        allow_finish.notify_one();
420        let call_result = call_task.await.unwrap();
421        assert_eq!(call_result.unwrap(), "42");
422    }
423
424    #[tokio::test]
425    pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
426        // We expect exactly 2 parallel searches to hit the barrier at the same time
427        let barrier = Arc::new(tokio::sync::Barrier::new(2));
428
429        let index1 = BarrierMockToolIndex::new(barrier.clone(), "add");
430        let index2 = BarrierMockToolIndex::new(barrier.clone(), "subtract");
431
432        // Put both tools in the toolset so they resolve correctly
433        let mut toolset = ToolSet::default();
434        toolset.add_tool(MockAddTool);
435        toolset.add_tool(MockSubtractTool);
436
437        let server = ToolServer::new()
438            .dynamic_tools(1, index1, ToolSet::default())
439            .dynamic_tools(1, index2, toolset);
440
441        let handle = server.run();
442
443        // This will trigger a search across both indices.
444        // If fetched sequentially, the first index will wait at the barrier forever.
445        let get_defs = tokio::time::timeout(
446            std::time::Duration::from_secs(1),
447            handle.get_tool_defs(Some("do math".to_string())),
448        )
449        .await;
450
451        assert!(
452            get_defs.is_ok(),
453            "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
454        );
455
456        let defs = get_defs.unwrap().unwrap();
457        assert_eq!(defs.len(), 2);
458
459        let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
460        assert!(tool_names.contains(&"add"));
461        assert!(tool_names.contains(&"subtract"));
462    }
463}