Skip to main content

rig_core/tool/
server.rs

1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6    completion::{CompletionError, ToolDefinition},
7    tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11/// Append `name` to the advertised static-tool list unless already present.
12/// Registration is last-wins on the toolset, so the name list only needs
13/// first-occurrence order: a re-registered name keeps its original position
14/// while the toolset swaps in the new implementation. Providers reject
15/// duplicate function declarations, so the list must stay unique.
16fn push_unique_name(names: &mut Vec<String>, name: String) {
17    if !names.contains(&name) {
18        names.push(name);
19    }
20}
21
22/// Shared state behind a `ToolServerHandle`.
23struct ToolServerState {
24    /// Static tool names that persist until explicitly removed.
25    static_tool_names: Vec<String>,
26    /// Dynamic tools fetched from vector stores on each prompt.
27    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28    /// The toolset where tools are registered and executed.
29    toolset: ToolSet,
30}
31
32/// Builder for constructing a [`ToolServerHandle`].
33///
34/// Accumulates tools and configuration, then produces a shared handle via
35/// [`run()`](ToolServer::run).
36pub struct ToolServer {
37    static_tool_names: Vec<String>,
38    dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
39    toolset: ToolSet,
40}
41
42impl Default for ToolServer {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl ToolServer {
49    pub fn new() -> Self {
50        Self {
51            static_tool_names: Vec::new(),
52            dynamic_tools: Vec::new(),
53            toolset: ToolSet::default(),
54        }
55    }
56
57    pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
58        // Last-wins registration replaces the implementation but keeps the
59        // original position, so the advertised list dedupes to first
60        // occurrence (duplicate declarations are rejected by providers).
61        self.static_tool_names = Vec::with_capacity(names.len());
62        for name in names {
63            push_unique_name(&mut self.static_tool_names, name);
64        }
65        self
66    }
67
68    pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
69        self.toolset = tools;
70        self
71    }
72
73    pub(crate) fn add_dynamic_tools(
74        mut self,
75        dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
76    ) -> Self {
77        self.dynamic_tools = dyn_tools;
78        self
79    }
80
81    /// Add a static tool to the agent. Re-registering an existing name
82    /// replaces the implementation (last wins) and keeps its position.
83    pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
84        let toolname = tool.name();
85        self.toolset.add_tool(tool);
86        push_unique_name(&mut self.static_tool_names, toolname);
87        self
88    }
89
90    /// Add an MCP tool (from `rmcp`) to the agent, bounded by
91    /// [`DEFAULT_MCP_TOOL_TIMEOUT`](crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT)
92    /// (see issue #1914). Use [`rmcp_tool_with_timeout`](Self::rmcp_tool_with_timeout)
93    /// to change or disable it.
94    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
95    #[cfg(feature = "rmcp")]
96    pub fn rmcp_tool(self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
97        self.rmcp_tool_with_timeout(tool, client, crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT)
98    }
99
100    /// Add an MCP tool (from `rmcp`) with a per-call timeout (see issue #1914).
101    ///
102    /// Pass a [`Duration`](std::time::Duration) to bound the call, or `None` to
103    /// disable the timeout (unbounded).
104    #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
105    #[cfg(feature = "rmcp")]
106    pub fn rmcp_tool_with_timeout(
107        mut self,
108        tool: rmcp::model::Tool,
109        client: rmcp::service::ServerSink,
110        timeout: impl Into<Option<std::time::Duration>>,
111    ) -> Self {
112        use crate::tool::rmcp::McpTool;
113        let toolname = tool.name.to_string();
114        self.toolset
115            .add_tool(McpTool::from_mcp_server(tool, client).with_timeout(timeout));
116        push_unique_name(&mut self.static_tool_names, toolname);
117        self
118    }
119
120    /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the
121    /// dynamic toolset will be inserted in the request.
122    pub fn dynamic_tools(
123        mut self,
124        sample: usize,
125        dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
126        toolset: ToolSet,
127    ) -> Self {
128        self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
129        self.toolset.add_tools(toolset);
130        self
131    }
132
133    /// Consume the builder and return a shared [`ToolServerHandle`].
134    pub fn run(self) -> ToolServerHandle {
135        ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
136            static_tool_names: self.static_tool_names,
137            dynamic_tools: self.dynamic_tools,
138            toolset: self.toolset,
139        })))
140    }
141}
142
143/// A cheaply-cloneable handle to the shared tool server state.
144///
145/// All operations acquire locks directly on the underlying state.
146/// Multiple handles (e.g. across agents) can share the same state
147/// without channel-based message routing.
148#[derive(Clone)]
149pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
150
151impl ToolServerHandle {
152    /// Register a new static tool. Re-registering an existing name replaces
153    /// the implementation (last wins) and keeps its position.
154    pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
155        let mut state = self.0.write().await;
156        let toolname = tool.name();
157        push_unique_name(&mut state.static_tool_names, toolname);
158        state.toolset.add_tool_boxed(Box::new(tool));
159        Ok(())
160    }
161
162    /// Merge an entire toolset into the server. Tool names from `toolset`
163    /// are appended to the static-tool list in `toolset`'s registration
164    /// order, so the tools become visible to the LLM via
165    /// [`Self::get_tool_defs`]. Existing names are replaced (last wins) and
166    /// keep their position.
167    pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
168        let mut state = self.0.write().await;
169        for name in toolset.ordered_names() {
170            push_unique_name(&mut state.static_tool_names, name.clone());
171        }
172        state.toolset.add_tools(toolset);
173        Ok(())
174    }
175
176    /// Remove a tool by name from both the toolset and the static list.
177    pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
178        let mut state = self.0.write().await;
179        state.static_tool_names.retain(|x| *x != tool_name);
180        state.toolset.delete_tool(tool_name);
181        Ok(())
182    }
183
184    /// Look up and execute a tool by name.
185    ///
186    /// The tool handle is cloned under a brief read lock so that
187    /// long-running tool executions never block writers.
188    pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
189        let tool = {
190            let state = self.0.read().await;
191            state.toolset.get(tool_name).cloned()
192        };
193
194        match tool {
195            Some(tool) => {
196                tracing::debug!(target: "rig",
197                    "Calling tool {tool_name} with args:\n{}",
198                    serde_json::to_string_pretty(&args).unwrap_or_default()
199                );
200                tool.call(args.to_string())
201                    .await
202                    .map_err(|e| ToolSetError::ToolCallError(e).into())
203            }
204            None => Err(ToolServerError::ToolsetError(
205                ToolSetError::ToolNotFoundError(tool_name.to_string()),
206            )),
207        }
208    }
209
210    /// Retrieve tool definitions, optionally using a prompt to select
211    /// dynamic tools from configured vector stores.
212    pub async fn get_tool_defs(
213        &self,
214        prompt: Option<String>,
215    ) -> Result<Vec<ToolDefinition>, ToolServerError> {
216        // Snapshot the metadata we need under a brief read lock
217        let (static_tool_names, dynamic_tools) = {
218            let state = self.0.read().await;
219            (state.static_tool_names.clone(), state.dynamic_tools.clone())
220        };
221
222        let mut tools = if let Some(ref text) = prompt {
223            // Create a future for each dynamic tool index
224            let search_futures = dynamic_tools.iter().map(|(num_sample, index)| {
225                let text = text.clone();
226                let num_sample = *num_sample;
227                let index = index.clone();
228
229                async move {
230                    let req = VectorSearchRequest::builder()
231                        .query(text)
232                        .samples(num_sample as u64)
233                        .build();
234
235                    let ids = index
236                        .as_ref()
237                        .top_n_ids(req.map_filter(Filter::interpret))
238                        .await?
239                        .into_iter()
240                        .map(|(_, id)| id)
241                        .collect::<Vec<String>>();
242
243                    Ok::<_, VectorStoreError>(ids)
244                }
245            });
246
247            // Execute searches concurrently and collect/flatten the IDs
248            let dynamic_tool_ids: Vec<String> = futures::future::try_join_all(search_futures)
249                .await
250                .map_err(|e| {
251                    ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
252                })?
253                .into_iter()
254                .flatten()
255                .collect();
256
257            let dynamic_tool_handles: Vec<_> = {
258                let state = self.0.read().await;
259                dynamic_tool_ids
260                    .iter()
261                    .filter_map(|doc| {
262                        let handle = state.toolset.get(doc).cloned();
263                        if handle.is_none() {
264                            tracing::warn!("Tool implementation not found in toolset: {}", doc);
265                        }
266                        handle
267                    })
268                    .collect()
269            };
270
271            let mut tools = Vec::new();
272            for tool in dynamic_tool_handles {
273                tools.push(tool.definition(text.clone()).await);
274            }
275            tools
276        } else {
277            Vec::new()
278        };
279
280        let static_tool_handles: Vec<_> = {
281            let state = self.0.read().await;
282            static_tool_names
283                .iter()
284                .filter_map(|toolname| {
285                    let handle = state.toolset.get(toolname).cloned();
286                    if handle.is_none() {
287                        tracing::warn!("Tool implementation not found in toolset: {}", toolname);
288                    }
289                    handle
290                })
291                .collect()
292        };
293
294        for tool in static_tool_handles {
295            tools.push(tool.definition(String::new()).await);
296        }
297
298        // One shared toolset backs both lists, so a name appearing in the
299        // dynamic AND static lists (or retrieved by two indexes) refers to
300        // the same tool. Keep the first definition and drop exact-name
301        // repeats: providers reject duplicate function declarations.
302        let mut seen = std::collections::HashSet::new();
303        tools.retain(|def| {
304            let fresh = seen.insert(def.name.clone());
305            if !fresh {
306                tracing::debug!(
307                    tool_name = %def.name,
308                    "dropping duplicate tool definition from the request"
309                );
310            }
311            fresh
312        });
313
314        Ok(tools)
315    }
316}
317
318#[derive(Debug, thiserror::Error)]
319pub enum ToolServerError {
320    #[error("Toolset error: {0}")]
321    ToolsetError(#[from] ToolSetError),
322    #[error("Failed to retrieve tool definitions: {0}")]
323    DefinitionError(CompletionError),
324}
325
326#[cfg(test)]
327mod tests {
328    use std::{sync::Arc, time::Duration};
329
330    use crate::{
331        test_utils::{
332            BarrierMockToolIndex, MockAddTool, MockBarrierTool, MockControlledTool,
333            MockSubtractTool, MockToolIndex,
334        },
335        tool::{ToolSet, server::ToolServer},
336    };
337
338    #[tokio::test]
339    pub async fn test_toolserver() {
340        let server = ToolServer::new();
341
342        let handle = server.run();
343
344        handle.add_tool(MockAddTool).await.unwrap();
345        let res = handle.get_tool_defs(None).await.unwrap();
346
347        assert_eq!(res.len(), 1);
348
349        let json_args_as_string =
350            serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
351        let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
352        assert_eq!(res, "7");
353
354        handle.remove_tool("add").await.unwrap();
355        let res = handle.get_tool_defs(None).await.unwrap();
356
357        assert_eq!(res.len(), 0);
358    }
359
360    #[tokio::test]
361    pub async fn test_toolserver_append_toolset_matches_add_tool() {
362        let mut via_add_tool = {
363            let handle = ToolServer::new().run();
364            handle.add_tool(MockAddTool).await.unwrap();
365            handle.add_tool(MockSubtractTool).await.unwrap();
366            handle.get_tool_defs(None).await.unwrap()
367        };
368        via_add_tool.sort_by(|a, b| a.name.cmp(&b.name));
369
370        let mut via_append_toolset = {
371            let handle = ToolServer::new().run();
372            let mut toolset = ToolSet::default();
373            toolset.add_tool(MockAddTool);
374            toolset.add_tool(MockSubtractTool);
375            handle.append_toolset(toolset).await.unwrap();
376            handle.get_tool_defs(None).await.unwrap()
377        };
378        via_append_toolset.sort_by(|a, b| a.name.cmp(&b.name));
379
380        assert_eq!(via_add_tool.len(), via_append_toolset.len());
381        assert!(
382            via_add_tool
383                .iter()
384                .zip(via_append_toolset.iter())
385                .all(|(a, b)| a.name == b.name),
386            "append_toolset must surface the same LLM-visible tools as add_tool",
387        );
388    }
389
390    #[tokio::test]
391    pub async fn get_tool_defs_dedupes_dynamic_and_static_overlap() {
392        // One shared toolset backs both lists, so a dynamically retrieved
393        // name that is also static must yield a single definition.
394        let handle = ToolServer::new()
395            .tool(MockAddTool)
396            .dynamic_tools(1, MockToolIndex::new(["add"]), ToolSet::default())
397            .run();
398
399        let defs = handle
400            .get_tool_defs(Some("add two numbers".to_string()))
401            .await
402            .unwrap();
403        assert_eq!(
404            defs.len(),
405            1,
406            "dynamic/static name overlap must not produce duplicate declarations: {:?}",
407            defs.iter().map(|def| def.name.as_str()).collect::<Vec<_>>()
408        );
409        assert_eq!(defs[0].name, "add");
410    }
411
412    #[tokio::test]
413    pub async fn duplicate_registration_advertises_one_definition() {
414        let handle = ToolServer::new().tool(MockAddTool).run();
415        handle.add_tool(MockAddTool).await.unwrap();
416
417        let mut toolset = ToolSet::default();
418        toolset.add_tool(MockAddTool);
419        handle.append_toolset(toolset).await.unwrap();
420
421        let defs = handle.get_tool_defs(None).await.unwrap();
422        assert_eq!(
423            defs.len(),
424            1,
425            "re-registering a name must not advertise duplicate declarations"
426        );
427        assert_eq!(defs[0].name, "add");
428    }
429
430    #[tokio::test]
431    pub async fn test_toolserver_dynamic_tools() {
432        // Create a toolset with both tools
433        let mut toolset = ToolSet::default();
434        toolset.add_tool(MockAddTool);
435        toolset.add_tool(MockSubtractTool);
436
437        // Create a mock index that will return "subtract" as the dynamic tool
438        let mock_index = MockToolIndex::new(["subtract"]);
439
440        // Build server with static tool "add" and dynamic tools from the mock index
441        let server = ToolServer::new().tool(MockAddTool).dynamic_tools(
442            1,
443            mock_index,
444            ToolSet::from_tools(vec![MockSubtractTool]),
445        );
446
447        let handle = server.run();
448
449        // Test with None prompt - should only return static tools
450        let res = handle.get_tool_defs(None).await.unwrap();
451        assert_eq!(res.len(), 1);
452        assert_eq!(res[0].name, "add");
453
454        // Test with Some prompt - should return both static and dynamic tools
455        let res = handle
456            .get_tool_defs(Some("calculate difference".to_string()))
457            .await
458            .unwrap();
459        assert_eq!(res.len(), 2);
460
461        // Check that both tools are present (order may vary)
462        let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
463        assert!(tool_names.contains(&"add"));
464        assert!(tool_names.contains(&"subtract"));
465    }
466
467    #[tokio::test]
468    pub async fn test_toolserver_dynamic_tools_missing_implementation() {
469        // Create a mock index that returns a tool ID that doesn't exist in the toolset
470        let mock_index = MockToolIndex::new(["nonexistent_tool"]);
471
472        // Build server with only static tool, but dynamic index references missing tool
473        let server =
474            ToolServer::new()
475                .tool(MockAddTool)
476                .dynamic_tools(1, mock_index, ToolSet::default());
477
478        let handle = server.run();
479
480        // Test with Some prompt - should only return static tool since dynamic tool is missing
481        let res = handle
482            .get_tool_defs(Some("some query".to_string()))
483            .await
484            .unwrap();
485        assert_eq!(res.len(), 1);
486        assert_eq!(res[0].name, "add");
487    }
488
489    #[tokio::test]
490    pub async fn test_toolserver_concurrent_tool_execution() {
491        let num_calls = 3;
492        let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
493
494        let server = ToolServer::new().tool(MockBarrierTool::new(barrier.clone()));
495        let handle = server.run();
496
497        // Make concurrent calls
498        let futures: Vec<_> = (0..num_calls)
499            .map(|_| handle.call_tool("barrier_tool", "{}"))
500            .collect();
501
502        // If execution is sequential, the first call will block at the barrier forever.
503        // We use a 1-second timeout to fail fast instead of hanging the test runner.
504        let result =
505            tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
506
507        assert!(
508            result.is_ok(),
509            "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
510        );
511
512        // All calls should succeed
513        for res in result.unwrap() {
514            assert!(res.is_ok(), "Tool call failed: {:?}", res);
515            assert_eq!(res.unwrap(), "done");
516        }
517    }
518
519    #[tokio::test]
520    pub async fn test_toolserver_write_while_tool_running() {
521        let started = Arc::new(tokio::sync::Notify::new());
522        let allow_finish = Arc::new(tokio::sync::Notify::new());
523
524        // Build server with the controlled tool that waits at a barrier during execution
525        let tool = MockControlledTool::new(started.clone(), allow_finish.clone());
526
527        let server = ToolServer::new().tool(tool);
528        let handle = server.run();
529
530        // Start tool call in background
531        let handle_clone = handle.clone();
532        let call_task =
533            tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
534
535        // Wait until we are strictly inside `call()`
536        started.notified().await;
537
538        // Try to write to the state (add a tool) while the tool call is mid-execution.
539        // If the read lock is incorrectly held across tool execution, this will deadlock.
540        let add_result =
541            tokio::time::timeout(Duration::from_secs(1), handle.add_tool(MockAddTool)).await;
542
543        assert!(
544            add_result.is_ok(),
545            "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
546        );
547        assert!(add_result.unwrap().is_ok());
548
549        // Allow the background tool to finish and clean up
550        allow_finish.notify_one();
551        let call_result = call_task.await.unwrap();
552        assert_eq!(call_result.unwrap(), "42");
553    }
554
555    #[tokio::test]
556    pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
557        // We expect exactly 2 parallel searches to hit the barrier at the same time
558        let barrier = Arc::new(tokio::sync::Barrier::new(2));
559
560        let index1 = BarrierMockToolIndex::new(barrier.clone(), "add");
561        let index2 = BarrierMockToolIndex::new(barrier.clone(), "subtract");
562
563        // Put both tools in the toolset so they resolve correctly
564        let mut toolset = ToolSet::default();
565        toolset.add_tool(MockAddTool);
566        toolset.add_tool(MockSubtractTool);
567
568        let server = ToolServer::new()
569            .dynamic_tools(1, index1, ToolSet::default())
570            .dynamic_tools(1, index2, toolset);
571
572        let handle = server.run();
573
574        // This will trigger a search across both indices.
575        // If fetched sequentially, the first index will wait at the barrier forever.
576        let get_defs = tokio::time::timeout(
577            std::time::Duration::from_secs(1),
578            handle.get_tool_defs(Some("do math".to_string())),
579        )
580        .await;
581
582        assert!(
583            get_defs.is_ok(),
584            "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
585        );
586
587        let defs = get_defs.unwrap().unwrap();
588        assert_eq!(defs.len(), 2);
589
590        let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
591        assert!(tool_names.contains(&"add"));
592        assert!(tool_names.contains(&"subtract"));
593    }
594}