Skip to main content

rig_core/tool/
server.rs

1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6    completion::{CompletionError, ToolDefinition},
7    tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11/// Shared state behind a `ToolServerHandle`.
12struct ToolServerState {
13    /// Static tool names that persist until explicitly removed.
14    static_tool_names: Vec<String>,
15    /// Dynamic tools fetched from vector stores on each prompt.
16    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
17    /// The toolset where tools are registered and executed.
18    toolset: ToolSet,
19}
20
21/// Builder for constructing a [`ToolServerHandle`].
22///
23/// Accumulates tools and configuration, then produces a shared handle via
24/// [`run()`](ToolServer::run).
25pub struct ToolServer {
26    static_tool_names: Vec<String>,
27    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28    toolset: ToolSet,
29}
30
31impl Default for ToolServer {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl ToolServer {
38    pub fn new() -> Self {
39        Self {
40            static_tool_names: Vec::new(),
41            dynamic_tools: Vec::new(),
42            toolset: ToolSet::default(),
43        }
44    }
45
46    pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
47        self.static_tool_names = names;
48        self
49    }
50
51    pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
52        self.toolset = tools;
53        self
54    }
55
56    pub(crate) fn add_dynamic_tools(
57        mut self,
58        dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
59    ) -> Self {
60        self.dynamic_tools = dyn_tools;
61        self
62    }
63
64    /// Add a static tool to the agent
65    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
66        let toolname = tool.name();
67        self.toolset.add_tool(tool);
68        self.static_tool_names.push(toolname);
69        self
70    }
71
72    /// Add an MCP tool (from `rmcp`) to the agent
73    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
74    #[cfg(feature = "rmcp")]
75    pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
76        use crate::tool::rmcp::McpTool;
77        let toolname = tool.name.clone();
78        self.toolset
79            .add_tool(McpTool::from_mcp_server(tool, client));
80        self.static_tool_names.push(toolname.to_string());
81        self
82    }
83
84    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
85    /// dynamic toolset will be inserted in the request.
86    pub fn dynamic_tools(
87        mut self,
88        sample: usize,
89        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
90        toolset: ToolSet,
91    ) -> Self {
92        self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
93        self.toolset.add_tools(toolset);
94        self
95    }
96
97    /// Consume the builder and return a shared [`ToolServerHandle`].
98    pub fn run(self) -> ToolServerHandle {
99        ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
100            static_tool_names: self.static_tool_names,
101            dynamic_tools: self.dynamic_tools,
102            toolset: self.toolset,
103        })))
104    }
105}
106
107/// A cheaply-cloneable handle to the shared tool server state.
108///
109/// All operations acquire locks directly on the underlying state.
110/// Multiple handles (e.g. across agents) can share the same state
111/// without channel-based message routing.
112#[derive(Clone)]
113pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
114
115impl ToolServerHandle {
116    /// Register a new static tool.
117    pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
118        let mut state = self.0.write().await;
119        state.static_tool_names.push(tool.name());
120        state.toolset.add_tool_boxed(Box::new(tool));
121        Ok(())
122    }
123
124    /// Merge an entire toolset into the server. Tool names from `toolset`
125    /// are appended to the static-tool list, so the tools become visible
126    /// to the LLM via [`Self::get_tool_defs`].
127    pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
128        let mut state = self.0.write().await;
129        state
130            .static_tool_names
131            .extend(toolset.tools.keys().cloned());
132        state.toolset.add_tools(toolset);
133        Ok(())
134    }
135
136    /// Remove a tool by name from both the toolset and the static list.
137    pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
138        let mut state = self.0.write().await;
139        state.static_tool_names.retain(|x| *x != tool_name);
140        state.toolset.delete_tool(tool_name);
141        Ok(())
142    }
143
144    /// Look up and execute a tool by name.
145    ///
146    /// The tool handle is cloned under a brief read lock so that
147    /// long-running tool executions never block writers.
148    pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
149        let tool = {
150            let state = self.0.read().await;
151            state.toolset.get(tool_name).cloned()
152        };
153
154        match tool {
155            Some(tool) => {
156                tracing::debug!(target: "rig",
157                    "Calling tool {tool_name} with args:\n{}",
158                    serde_json::to_string_pretty(&args).unwrap_or_default()
159                );
160                tool.call(args.to_string())
161                    .await
162                    .map_err(|e| ToolSetError::ToolCallError(e).into())
163            }
164            None => Err(ToolServerError::ToolsetError(
165                ToolSetError::ToolNotFoundError(tool_name.to_string()),
166            )),
167        }
168    }
169
170    /// Retrieve tool definitions, optionally using a prompt to select
171    /// dynamic tools from configured vector stores.
172    pub async fn get_tool_defs(
173        &self,
174        prompt: Option<String>,
175    ) -> Result<Vec<ToolDefinition>, ToolServerError> {
176        // Snapshot the metadata we need under a brief read lock
177        let (static_tool_names, dynamic_tools) = {
178            let state = self.0.read().await;
179            (state.static_tool_names.clone(), state.dynamic_tools.clone())
180        };
181
182        let mut tools = if let Some(ref text) = prompt {
183            // Create a future for each dynamic tool index
184            let search_futures = dynamic_tools.iter().map(|(num_sample, index)| {
185                let text = text.clone();
186                let num_sample = *num_sample;
187                let index = index.clone();
188
189                async move {
190                    let req = VectorSearchRequest::builder()
191                        .query(text)
192                        .samples(num_sample as u64)
193                        .build();
194
195                    let ids = index
196                        .as_ref()
197                        .top_n_ids(req.map_filter(Filter::interpret))
198                        .await?
199                        .into_iter()
200                        .map(|(_, id)| id)
201                        .collect::<Vec<String>>();
202
203                    Ok::<_, VectorStoreError>(ids)
204                }
205            });
206
207            // Execute searches concurrently and collect/flatten the IDs
208            let dynamic_tool_ids: Vec<String> = futures::future::try_join_all(search_futures)
209                .await
210                .map_err(|e| {
211                    ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
212                })?
213                .into_iter()
214                .flatten()
215                .collect();
216
217            let dynamic_tool_handles: Vec<_> = {
218                let state = self.0.read().await;
219                dynamic_tool_ids
220                    .iter()
221                    .filter_map(|doc| {
222                        let handle = state.toolset.get(doc).cloned();
223                        if handle.is_none() {
224                            tracing::warn!("Tool implementation not found in toolset: {}", doc);
225                        }
226                        handle
227                    })
228                    .collect()
229            };
230
231            let mut tools = Vec::new();
232            for tool in dynamic_tool_handles {
233                tools.push(tool.definition(text.clone()).await);
234            }
235            tools
236        } else {
237            Vec::new()
238        };
239
240        let static_tool_handles: Vec<_> = {
241            let state = self.0.read().await;
242            static_tool_names
243                .iter()
244                .filter_map(|toolname| {
245                    let handle = state.toolset.get(toolname).cloned();
246                    if handle.is_none() {
247                        tracing::warn!("Tool implementation not found in toolset: {}", toolname);
248                    }
249                    handle
250                })
251                .collect()
252        };
253
254        for tool in static_tool_handles {
255            tools.push(tool.definition(String::new()).await);
256        }
257
258        Ok(tools)
259    }
260}
261
262#[derive(Debug, thiserror::Error)]
263pub enum ToolServerError {
264    #[error("Toolset error: {0}")]
265    ToolsetError(#[from] ToolSetError),
266    #[error("Failed to retrieve tool definitions: {0}")]
267    DefinitionError(CompletionError),
268}
269
270#[cfg(test)]
271mod tests {
272    use std::{sync::Arc, time::Duration};
273
274    use crate::{
275        test_utils::{
276            BarrierMockToolIndex, MockAddTool, MockBarrierTool, MockControlledTool,
277            MockSubtractTool, MockToolIndex,
278        },
279        tool::{ToolSet, server::ToolServer},
280    };
281
282    #[tokio::test]
283    pub async fn test_toolserver() {
284        let server = ToolServer::new();
285
286        let handle = server.run();
287
288        handle.add_tool(MockAddTool).await.unwrap();
289        let res = handle.get_tool_defs(None).await.unwrap();
290
291        assert_eq!(res.len(), 1);
292
293        let json_args_as_string =
294            serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
295        let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
296        assert_eq!(res, "7");
297
298        handle.remove_tool("add").await.unwrap();
299        let res = handle.get_tool_defs(None).await.unwrap();
300
301        assert_eq!(res.len(), 0);
302    }
303
304    #[tokio::test]
305    pub async fn test_toolserver_append_toolset_matches_add_tool() {
306        let mut via_add_tool = {
307            let handle = ToolServer::new().run();
308            handle.add_tool(MockAddTool).await.unwrap();
309            handle.add_tool(MockSubtractTool).await.unwrap();
310            handle.get_tool_defs(None).await.unwrap()
311        };
312        via_add_tool.sort_by(|a, b| a.name.cmp(&b.name));
313
314        let mut via_append_toolset = {
315            let handle = ToolServer::new().run();
316            let mut toolset = ToolSet::default();
317            toolset.add_tool(MockAddTool);
318            toolset.add_tool(MockSubtractTool);
319            handle.append_toolset(toolset).await.unwrap();
320            handle.get_tool_defs(None).await.unwrap()
321        };
322        via_append_toolset.sort_by(|a, b| a.name.cmp(&b.name));
323
324        assert_eq!(via_add_tool.len(), via_append_toolset.len());
325        assert!(
326            via_add_tool
327                .iter()
328                .zip(via_append_toolset.iter())
329                .all(|(a, b)| a.name == b.name),
330            "append_toolset must surface the same LLM-visible tools as add_tool",
331        );
332    }
333
334    #[tokio::test]
335    pub async fn test_toolserver_dynamic_tools() {
336        // Create a toolset with both tools
337        let mut toolset = ToolSet::default();
338        toolset.add_tool(MockAddTool);
339        toolset.add_tool(MockSubtractTool);
340
341        // Create a mock index that will return "subtract" as the dynamic tool
342        let mock_index = MockToolIndex::new(["subtract"]);
343
344        // Build server with static tool "add" and dynamic tools from the mock index
345        let server = ToolServer::new().tool(MockAddTool).dynamic_tools(
346            1,
347            mock_index,
348            ToolSet::from_tools(vec![MockSubtractTool]),
349        );
350
351        let handle = server.run();
352
353        // Test with None prompt - should only return static tools
354        let res = handle.get_tool_defs(None).await.unwrap();
355        assert_eq!(res.len(), 1);
356        assert_eq!(res[0].name, "add");
357
358        // Test with Some prompt - should return both static and dynamic tools
359        let res = handle
360            .get_tool_defs(Some("calculate difference".to_string()))
361            .await
362            .unwrap();
363        assert_eq!(res.len(), 2);
364
365        // Check that both tools are present (order may vary)
366        let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
367        assert!(tool_names.contains(&"add"));
368        assert!(tool_names.contains(&"subtract"));
369    }
370
371    #[tokio::test]
372    pub async fn test_toolserver_dynamic_tools_missing_implementation() {
373        // Create a mock index that returns a tool ID that doesn't exist in the toolset
374        let mock_index = MockToolIndex::new(["nonexistent_tool"]);
375
376        // Build server with only static tool, but dynamic index references missing tool
377        let server =
378            ToolServer::new()
379                .tool(MockAddTool)
380                .dynamic_tools(1, mock_index, ToolSet::default());
381
382        let handle = server.run();
383
384        // Test with Some prompt - should only return static tool since dynamic tool is missing
385        let res = handle
386            .get_tool_defs(Some("some query".to_string()))
387            .await
388            .unwrap();
389        assert_eq!(res.len(), 1);
390        assert_eq!(res[0].name, "add");
391    }
392
393    #[tokio::test]
394    pub async fn test_toolserver_concurrent_tool_execution() {
395        let num_calls = 3;
396        let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
397
398        let server = ToolServer::new().tool(MockBarrierTool::new(barrier.clone()));
399        let handle = server.run();
400
401        // Make concurrent calls
402        let futures: Vec<_> = (0..num_calls)
403            .map(|_| handle.call_tool("barrier_tool", "{}"))
404            .collect();
405
406        // If execution is sequential, the first call will block at the barrier forever.
407        // We use a 1-second timeout to fail fast instead of hanging the test runner.
408        let result =
409            tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
410
411        assert!(
412            result.is_ok(),
413            "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
414        );
415
416        // All calls should succeed
417        for res in result.unwrap() {
418            assert!(res.is_ok(), "Tool call failed: {:?}", res);
419            assert_eq!(res.unwrap(), "done");
420        }
421    }
422
423    #[tokio::test]
424    pub async fn test_toolserver_write_while_tool_running() {
425        let started = Arc::new(tokio::sync::Notify::new());
426        let allow_finish = Arc::new(tokio::sync::Notify::new());
427
428        // Build server with the controlled tool that waits at a barrier during execution
429        let tool = MockControlledTool::new(started.clone(), allow_finish.clone());
430
431        let server = ToolServer::new().tool(tool);
432        let handle = server.run();
433
434        // Start tool call in background
435        let handle_clone = handle.clone();
436        let call_task =
437            tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
438
439        // Wait until we are strictly inside `call()`
440        started.notified().await;
441
442        // Try to write to the state (add a tool) while the tool call is mid-execution.
443        // If the read lock is incorrectly held across tool execution, this will deadlock.
444        let add_result =
445            tokio::time::timeout(Duration::from_secs(1), handle.add_tool(MockAddTool)).await;
446
447        assert!(
448            add_result.is_ok(),
449            "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
450        );
451        assert!(add_result.unwrap().is_ok());
452
453        // Allow the background tool to finish and clean up
454        allow_finish.notify_one();
455        let call_result = call_task.await.unwrap();
456        assert_eq!(call_result.unwrap(), "42");
457    }
458
459    #[tokio::test]
460    pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
461        // We expect exactly 2 parallel searches to hit the barrier at the same time
462        let barrier = Arc::new(tokio::sync::Barrier::new(2));
463
464        let index1 = BarrierMockToolIndex::new(barrier.clone(), "add");
465        let index2 = BarrierMockToolIndex::new(barrier.clone(), "subtract");
466
467        // Put both tools in the toolset so they resolve correctly
468        let mut toolset = ToolSet::default();
469        toolset.add_tool(MockAddTool);
470        toolset.add_tool(MockSubtractTool);
471
472        let server = ToolServer::new()
473            .dynamic_tools(1, index1, ToolSet::default())
474            .dynamic_tools(1, index2, toolset);
475
476        let handle = server.run();
477
478        // This will trigger a search across both indices.
479        // If fetched sequentially, the first index will wait at the barrier forever.
480        let get_defs = tokio::time::timeout(
481            std::time::Duration::from_secs(1),
482            handle.get_tool_defs(Some("do math".to_string())),
483        )
484        .await;
485
486        assert!(
487            get_defs.is_ok(),
488            "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
489        );
490
491        let defs = get_defs.unwrap().unwrap();
492        assert_eq!(defs.len(), 2);
493
494        let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
495        assert!(tool_names.contains(&"add"));
496        assert!(tool_names.contains(&"subtract"));
497    }
498}