1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6 completion::{CompletionError, ToolDefinition},
7 tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11struct ToolServerState {
13 static_tool_names: Vec<String>,
15 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
17 toolset: ToolSet,
19}
20
21pub struct ToolServer {
26 static_tool_names: Vec<String>,
27 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28 toolset: ToolSet,
29}
30
31impl Default for ToolServer {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl ToolServer {
38 pub fn new() -> Self {
39 Self {
40 static_tool_names: Vec::new(),
41 dynamic_tools: Vec::new(),
42 toolset: ToolSet::default(),
43 }
44 }
45
46 pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
47 self.static_tool_names = names;
48 self
49 }
50
51 pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
52 self.toolset = tools;
53 self
54 }
55
56 pub(crate) fn add_dynamic_tools(
57 mut self,
58 dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
59 ) -> Self {
60 self.dynamic_tools = dyn_tools;
61 self
62 }
63
64 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
66 let toolname = tool.name();
67 self.toolset.add_tool(tool);
68 self.static_tool_names.push(toolname);
69 self
70 }
71
72 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
74 #[cfg(feature = "rmcp")]
75 pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
76 use crate::tool::rmcp::McpTool;
77 let toolname = tool.name.clone();
78 self.toolset
79 .add_tool(McpTool::from_mcp_server(tool, client));
80 self.static_tool_names.push(toolname.to_string());
81 self
82 }
83
84 pub fn dynamic_tools(
87 mut self,
88 sample: usize,
89 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
90 toolset: ToolSet,
91 ) -> Self {
92 self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
93 self.toolset.add_tools(toolset);
94 self
95 }
96
97 pub fn run(self) -> ToolServerHandle {
99 ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
100 static_tool_names: self.static_tool_names,
101 dynamic_tools: self.dynamic_tools,
102 toolset: self.toolset,
103 })))
104 }
105}
106
107#[derive(Clone)]
113pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
114
115impl ToolServerHandle {
116 pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
118 let mut state = self.0.write().await;
119 state.static_tool_names.push(tool.name());
120 state.toolset.add_tool_boxed(Box::new(tool));
121 Ok(())
122 }
123
124 pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
128 let mut state = self.0.write().await;
129 state
130 .static_tool_names
131 .extend(toolset.tools.keys().cloned());
132 state.toolset.add_tools(toolset);
133 Ok(())
134 }
135
136 pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
138 let mut state = self.0.write().await;
139 state.static_tool_names.retain(|x| *x != tool_name);
140 state.toolset.delete_tool(tool_name);
141 Ok(())
142 }
143
144 pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
149 let tool = {
150 let state = self.0.read().await;
151 state.toolset.get(tool_name).cloned()
152 };
153
154 match tool {
155 Some(tool) => {
156 tracing::debug!(target: "rig",
157 "Calling tool {tool_name} with args:\n{}",
158 serde_json::to_string_pretty(&args).unwrap_or_default()
159 );
160 tool.call(args.to_string())
161 .await
162 .map_err(|e| ToolSetError::ToolCallError(e).into())
163 }
164 None => Err(ToolServerError::ToolsetError(
165 ToolSetError::ToolNotFoundError(tool_name.to_string()),
166 )),
167 }
168 }
169
170 pub async fn get_tool_defs(
173 &self,
174 prompt: Option<String>,
175 ) -> Result<Vec<ToolDefinition>, ToolServerError> {
176 let (static_tool_names, dynamic_tools) = {
178 let state = self.0.read().await;
179 (state.static_tool_names.clone(), state.dynamic_tools.clone())
180 };
181
182 let mut tools = if let Some(ref text) = prompt {
183 let search_futures = dynamic_tools.iter().map(|(num_sample, index)| {
185 let text = text.clone();
186 let num_sample = *num_sample;
187 let index = index.clone();
188
189 async move {
190 let req = VectorSearchRequest::builder()
191 .query(text)
192 .samples(num_sample as u64)
193 .build();
194
195 let ids = index
196 .as_ref()
197 .top_n_ids(req.map_filter(Filter::interpret))
198 .await?
199 .into_iter()
200 .map(|(_, id)| id)
201 .collect::<Vec<String>>();
202
203 Ok::<_, VectorStoreError>(ids)
204 }
205 });
206
207 let dynamic_tool_ids: Vec<String> = futures::future::try_join_all(search_futures)
209 .await
210 .map_err(|e| {
211 ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
212 })?
213 .into_iter()
214 .flatten()
215 .collect();
216
217 let dynamic_tool_handles: Vec<_> = {
218 let state = self.0.read().await;
219 dynamic_tool_ids
220 .iter()
221 .filter_map(|doc| {
222 let handle = state.toolset.get(doc).cloned();
223 if handle.is_none() {
224 tracing::warn!("Tool implementation not found in toolset: {}", doc);
225 }
226 handle
227 })
228 .collect()
229 };
230
231 let mut tools = Vec::new();
232 for tool in dynamic_tool_handles {
233 tools.push(tool.definition(text.clone()).await);
234 }
235 tools
236 } else {
237 Vec::new()
238 };
239
240 let static_tool_handles: Vec<_> = {
241 let state = self.0.read().await;
242 static_tool_names
243 .iter()
244 .filter_map(|toolname| {
245 let handle = state.toolset.get(toolname).cloned();
246 if handle.is_none() {
247 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
248 }
249 handle
250 })
251 .collect()
252 };
253
254 for tool in static_tool_handles {
255 tools.push(tool.definition(String::new()).await);
256 }
257
258 Ok(tools)
259 }
260}
261
262#[derive(Debug, thiserror::Error)]
263pub enum ToolServerError {
264 #[error("Toolset error: {0}")]
265 ToolsetError(#[from] ToolSetError),
266 #[error("Failed to retrieve tool definitions: {0}")]
267 DefinitionError(CompletionError),
268}
269
270#[cfg(test)]
271mod tests {
272 use std::{sync::Arc, time::Duration};
273
274 use crate::{
275 test_utils::{
276 BarrierMockToolIndex, MockAddTool, MockBarrierTool, MockControlledTool,
277 MockSubtractTool, MockToolIndex,
278 },
279 tool::{ToolSet, server::ToolServer},
280 };
281
282 #[tokio::test]
283 pub async fn test_toolserver() {
284 let server = ToolServer::new();
285
286 let handle = server.run();
287
288 handle.add_tool(MockAddTool).await.unwrap();
289 let res = handle.get_tool_defs(None).await.unwrap();
290
291 assert_eq!(res.len(), 1);
292
293 let json_args_as_string =
294 serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
295 let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
296 assert_eq!(res, "7");
297
298 handle.remove_tool("add").await.unwrap();
299 let res = handle.get_tool_defs(None).await.unwrap();
300
301 assert_eq!(res.len(), 0);
302 }
303
304 #[tokio::test]
305 pub async fn test_toolserver_append_toolset_matches_add_tool() {
306 let mut via_add_tool = {
307 let handle = ToolServer::new().run();
308 handle.add_tool(MockAddTool).await.unwrap();
309 handle.add_tool(MockSubtractTool).await.unwrap();
310 handle.get_tool_defs(None).await.unwrap()
311 };
312 via_add_tool.sort_by(|a, b| a.name.cmp(&b.name));
313
314 let mut via_append_toolset = {
315 let handle = ToolServer::new().run();
316 let mut toolset = ToolSet::default();
317 toolset.add_tool(MockAddTool);
318 toolset.add_tool(MockSubtractTool);
319 handle.append_toolset(toolset).await.unwrap();
320 handle.get_tool_defs(None).await.unwrap()
321 };
322 via_append_toolset.sort_by(|a, b| a.name.cmp(&b.name));
323
324 assert_eq!(via_add_tool.len(), via_append_toolset.len());
325 assert!(
326 via_add_tool
327 .iter()
328 .zip(via_append_toolset.iter())
329 .all(|(a, b)| a.name == b.name),
330 "append_toolset must surface the same LLM-visible tools as add_tool",
331 );
332 }
333
334 #[tokio::test]
335 pub async fn test_toolserver_dynamic_tools() {
336 let mut toolset = ToolSet::default();
338 toolset.add_tool(MockAddTool);
339 toolset.add_tool(MockSubtractTool);
340
341 let mock_index = MockToolIndex::new(["subtract"]);
343
344 let server = ToolServer::new().tool(MockAddTool).dynamic_tools(
346 1,
347 mock_index,
348 ToolSet::from_tools(vec![MockSubtractTool]),
349 );
350
351 let handle = server.run();
352
353 let res = handle.get_tool_defs(None).await.unwrap();
355 assert_eq!(res.len(), 1);
356 assert_eq!(res[0].name, "add");
357
358 let res = handle
360 .get_tool_defs(Some("calculate difference".to_string()))
361 .await
362 .unwrap();
363 assert_eq!(res.len(), 2);
364
365 let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
367 assert!(tool_names.contains(&"add"));
368 assert!(tool_names.contains(&"subtract"));
369 }
370
371 #[tokio::test]
372 pub async fn test_toolserver_dynamic_tools_missing_implementation() {
373 let mock_index = MockToolIndex::new(["nonexistent_tool"]);
375
376 let server =
378 ToolServer::new()
379 .tool(MockAddTool)
380 .dynamic_tools(1, mock_index, ToolSet::default());
381
382 let handle = server.run();
383
384 let res = handle
386 .get_tool_defs(Some("some query".to_string()))
387 .await
388 .unwrap();
389 assert_eq!(res.len(), 1);
390 assert_eq!(res[0].name, "add");
391 }
392
393 #[tokio::test]
394 pub async fn test_toolserver_concurrent_tool_execution() {
395 let num_calls = 3;
396 let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
397
398 let server = ToolServer::new().tool(MockBarrierTool::new(barrier.clone()));
399 let handle = server.run();
400
401 let futures: Vec<_> = (0..num_calls)
403 .map(|_| handle.call_tool("barrier_tool", "{}"))
404 .collect();
405
406 let result =
409 tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
410
411 assert!(
412 result.is_ok(),
413 "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
414 );
415
416 for res in result.unwrap() {
418 assert!(res.is_ok(), "Tool call failed: {:?}", res);
419 assert_eq!(res.unwrap(), "done");
420 }
421 }
422
423 #[tokio::test]
424 pub async fn test_toolserver_write_while_tool_running() {
425 let started = Arc::new(tokio::sync::Notify::new());
426 let allow_finish = Arc::new(tokio::sync::Notify::new());
427
428 let tool = MockControlledTool::new(started.clone(), allow_finish.clone());
430
431 let server = ToolServer::new().tool(tool);
432 let handle = server.run();
433
434 let handle_clone = handle.clone();
436 let call_task =
437 tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
438
439 started.notified().await;
441
442 let add_result =
445 tokio::time::timeout(Duration::from_secs(1), handle.add_tool(MockAddTool)).await;
446
447 assert!(
448 add_result.is_ok(),
449 "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
450 );
451 assert!(add_result.unwrap().is_ok());
452
453 allow_finish.notify_one();
455 let call_result = call_task.await.unwrap();
456 assert_eq!(call_result.unwrap(), "42");
457 }
458
459 #[tokio::test]
460 pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
461 let barrier = Arc::new(tokio::sync::Barrier::new(2));
463
464 let index1 = BarrierMockToolIndex::new(barrier.clone(), "add");
465 let index2 = BarrierMockToolIndex::new(barrier.clone(), "subtract");
466
467 let mut toolset = ToolSet::default();
469 toolset.add_tool(MockAddTool);
470 toolset.add_tool(MockSubtractTool);
471
472 let server = ToolServer::new()
473 .dynamic_tools(1, index1, ToolSet::default())
474 .dynamic_tools(1, index2, toolset);
475
476 let handle = server.run();
477
478 let get_defs = tokio::time::timeout(
481 std::time::Duration::from_secs(1),
482 handle.get_tool_defs(Some("do math".to_string())),
483 )
484 .await;
485
486 assert!(
487 get_defs.is_ok(),
488 "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
489 );
490
491 let defs = get_defs.unwrap().unwrap();
492 assert_eq!(defs.len(), 2);
493
494 let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
495 assert!(tool_names.contains(&"add"));
496 assert!(tool_names.contains(&"subtract"));
497 }
498}