1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6 completion::{CompletionError, ToolDefinition},
7 tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11fn push_unique_name(names: &mut Vec<String>, name: String) {
17 if !names.contains(&name) {
18 names.push(name);
19 }
20}
21
22struct ToolServerState {
24 static_tool_names: Vec<String>,
26 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28 toolset: ToolSet,
30}
31
32pub struct ToolServer {
37 static_tool_names: Vec<String>,
38 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
39 toolset: ToolSet,
40}
41
42impl Default for ToolServer {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl ToolServer {
49 pub fn new() -> Self {
50 Self {
51 static_tool_names: Vec::new(),
52 dynamic_tools: Vec::new(),
53 toolset: ToolSet::default(),
54 }
55 }
56
57 pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
58 self.static_tool_names = Vec::with_capacity(names.len());
62 for name in names {
63 push_unique_name(&mut self.static_tool_names, name);
64 }
65 self
66 }
67
68 pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
69 self.toolset = tools;
70 self
71 }
72
73 pub(crate) fn add_dynamic_tools(
74 mut self,
75 dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
76 ) -> Self {
77 self.dynamic_tools = dyn_tools;
78 self
79 }
80
81 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
84 let toolname = tool.name();
85 self.toolset.add_tool(tool);
86 push_unique_name(&mut self.static_tool_names, toolname);
87 self
88 }
89
90 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
95 #[cfg(feature = "rmcp")]
96 pub fn rmcp_tool(self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
97 self.rmcp_tool_with_timeout(tool, client, crate::tool::rmcp::DEFAULT_MCP_TOOL_TIMEOUT)
98 }
99
100 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
105 #[cfg(feature = "rmcp")]
106 pub fn rmcp_tool_with_timeout(
107 mut self,
108 tool: rmcp::model::Tool,
109 client: rmcp::service::ServerSink,
110 timeout: impl Into<Option<std::time::Duration>>,
111 ) -> Self {
112 use crate::tool::rmcp::McpTool;
113 let toolname = tool.name.to_string();
114 self.toolset
115 .add_tool(McpTool::from_mcp_server(tool, client).with_timeout(timeout));
116 push_unique_name(&mut self.static_tool_names, toolname);
117 self
118 }
119
120 pub fn dynamic_tools(
123 mut self,
124 sample: usize,
125 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
126 toolset: ToolSet,
127 ) -> Self {
128 self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
129 self.toolset.add_tools(toolset);
130 self
131 }
132
133 pub fn run(self) -> ToolServerHandle {
135 ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
136 static_tool_names: self.static_tool_names,
137 dynamic_tools: self.dynamic_tools,
138 toolset: self.toolset,
139 })))
140 }
141}
142
143#[derive(Clone)]
149pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
150
151impl ToolServerHandle {
152 pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
155 let mut state = self.0.write().await;
156 let toolname = tool.name();
157 push_unique_name(&mut state.static_tool_names, toolname);
158 state.toolset.add_tool_boxed(Box::new(tool));
159 Ok(())
160 }
161
162 pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
168 let mut state = self.0.write().await;
169 for name in toolset.ordered_names() {
170 push_unique_name(&mut state.static_tool_names, name.clone());
171 }
172 state.toolset.add_tools(toolset);
173 Ok(())
174 }
175
176 pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
178 let mut state = self.0.write().await;
179 state.static_tool_names.retain(|x| *x != tool_name);
180 state.toolset.delete_tool(tool_name);
181 Ok(())
182 }
183
184 pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
189 let tool = {
190 let state = self.0.read().await;
191 state.toolset.get(tool_name).cloned()
192 };
193
194 match tool {
195 Some(tool) => {
196 tracing::debug!(target: "rig",
197 "Calling tool {tool_name} with args:\n{}",
198 serde_json::to_string_pretty(&args).unwrap_or_default()
199 );
200 tool.call(args.to_string())
201 .await
202 .map_err(|e| ToolSetError::ToolCallError(e).into())
203 }
204 None => Err(ToolServerError::ToolsetError(
205 ToolSetError::ToolNotFoundError(tool_name.to_string()),
206 )),
207 }
208 }
209
210 pub async fn get_tool_defs(
213 &self,
214 prompt: Option<String>,
215 ) -> Result<Vec<ToolDefinition>, ToolServerError> {
216 let (static_tool_names, dynamic_tools) = {
218 let state = self.0.read().await;
219 (state.static_tool_names.clone(), state.dynamic_tools.clone())
220 };
221
222 let mut tools = if let Some(ref text) = prompt {
223 let search_futures = dynamic_tools.iter().map(|(num_sample, index)| {
225 let text = text.clone();
226 let num_sample = *num_sample;
227 let index = index.clone();
228
229 async move {
230 let req = VectorSearchRequest::builder()
231 .query(text)
232 .samples(num_sample as u64)
233 .build();
234
235 let ids = index
236 .as_ref()
237 .top_n_ids(req.map_filter(Filter::interpret))
238 .await?
239 .into_iter()
240 .map(|(_, id)| id)
241 .collect::<Vec<String>>();
242
243 Ok::<_, VectorStoreError>(ids)
244 }
245 });
246
247 let dynamic_tool_ids: Vec<String> = futures::future::try_join_all(search_futures)
249 .await
250 .map_err(|e| {
251 ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
252 })?
253 .into_iter()
254 .flatten()
255 .collect();
256
257 let dynamic_tool_handles: Vec<_> = {
258 let state = self.0.read().await;
259 dynamic_tool_ids
260 .iter()
261 .filter_map(|doc| {
262 let handle = state.toolset.get(doc).cloned();
263 if handle.is_none() {
264 tracing::warn!("Tool implementation not found in toolset: {}", doc);
265 }
266 handle
267 })
268 .collect()
269 };
270
271 let mut tools = Vec::new();
272 for tool in dynamic_tool_handles {
273 tools.push(tool.definition(text.clone()).await);
274 }
275 tools
276 } else {
277 Vec::new()
278 };
279
280 let static_tool_handles: Vec<_> = {
281 let state = self.0.read().await;
282 static_tool_names
283 .iter()
284 .filter_map(|toolname| {
285 let handle = state.toolset.get(toolname).cloned();
286 if handle.is_none() {
287 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
288 }
289 handle
290 })
291 .collect()
292 };
293
294 for tool in static_tool_handles {
295 tools.push(tool.definition(String::new()).await);
296 }
297
298 let mut seen = std::collections::HashSet::new();
303 tools.retain(|def| {
304 let fresh = seen.insert(def.name.clone());
305 if !fresh {
306 tracing::debug!(
307 tool_name = %def.name,
308 "dropping duplicate tool definition from the request"
309 );
310 }
311 fresh
312 });
313
314 Ok(tools)
315 }
316}
317
318#[derive(Debug, thiserror::Error)]
319pub enum ToolServerError {
320 #[error("Toolset error: {0}")]
321 ToolsetError(#[from] ToolSetError),
322 #[error("Failed to retrieve tool definitions: {0}")]
323 DefinitionError(CompletionError),
324}
325
326#[cfg(test)]
327mod tests {
328 use std::{sync::Arc, time::Duration};
329
330 use crate::{
331 test_utils::{
332 BarrierMockToolIndex, MockAddTool, MockBarrierTool, MockControlledTool,
333 MockSubtractTool, MockToolIndex,
334 },
335 tool::{ToolSet, server::ToolServer},
336 };
337
338 #[tokio::test]
339 pub async fn test_toolserver() {
340 let server = ToolServer::new();
341
342 let handle = server.run();
343
344 handle.add_tool(MockAddTool).await.unwrap();
345 let res = handle.get_tool_defs(None).await.unwrap();
346
347 assert_eq!(res.len(), 1);
348
349 let json_args_as_string =
350 serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
351 let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
352 assert_eq!(res, "7");
353
354 handle.remove_tool("add").await.unwrap();
355 let res = handle.get_tool_defs(None).await.unwrap();
356
357 assert_eq!(res.len(), 0);
358 }
359
360 #[tokio::test]
361 pub async fn test_toolserver_append_toolset_matches_add_tool() {
362 let mut via_add_tool = {
363 let handle = ToolServer::new().run();
364 handle.add_tool(MockAddTool).await.unwrap();
365 handle.add_tool(MockSubtractTool).await.unwrap();
366 handle.get_tool_defs(None).await.unwrap()
367 };
368 via_add_tool.sort_by(|a, b| a.name.cmp(&b.name));
369
370 let mut via_append_toolset = {
371 let handle = ToolServer::new().run();
372 let mut toolset = ToolSet::default();
373 toolset.add_tool(MockAddTool);
374 toolset.add_tool(MockSubtractTool);
375 handle.append_toolset(toolset).await.unwrap();
376 handle.get_tool_defs(None).await.unwrap()
377 };
378 via_append_toolset.sort_by(|a, b| a.name.cmp(&b.name));
379
380 assert_eq!(via_add_tool.len(), via_append_toolset.len());
381 assert!(
382 via_add_tool
383 .iter()
384 .zip(via_append_toolset.iter())
385 .all(|(a, b)| a.name == b.name),
386 "append_toolset must surface the same LLM-visible tools as add_tool",
387 );
388 }
389
390 #[tokio::test]
391 pub async fn get_tool_defs_dedupes_dynamic_and_static_overlap() {
392 let handle = ToolServer::new()
395 .tool(MockAddTool)
396 .dynamic_tools(1, MockToolIndex::new(["add"]), ToolSet::default())
397 .run();
398
399 let defs = handle
400 .get_tool_defs(Some("add two numbers".to_string()))
401 .await
402 .unwrap();
403 assert_eq!(
404 defs.len(),
405 1,
406 "dynamic/static name overlap must not produce duplicate declarations: {:?}",
407 defs.iter().map(|def| def.name.as_str()).collect::<Vec<_>>()
408 );
409 assert_eq!(defs[0].name, "add");
410 }
411
412 #[tokio::test]
413 pub async fn duplicate_registration_advertises_one_definition() {
414 let handle = ToolServer::new().tool(MockAddTool).run();
415 handle.add_tool(MockAddTool).await.unwrap();
416
417 let mut toolset = ToolSet::default();
418 toolset.add_tool(MockAddTool);
419 handle.append_toolset(toolset).await.unwrap();
420
421 let defs = handle.get_tool_defs(None).await.unwrap();
422 assert_eq!(
423 defs.len(),
424 1,
425 "re-registering a name must not advertise duplicate declarations"
426 );
427 assert_eq!(defs[0].name, "add");
428 }
429
430 #[tokio::test]
431 pub async fn test_toolserver_dynamic_tools() {
432 let mut toolset = ToolSet::default();
434 toolset.add_tool(MockAddTool);
435 toolset.add_tool(MockSubtractTool);
436
437 let mock_index = MockToolIndex::new(["subtract"]);
439
440 let server = ToolServer::new().tool(MockAddTool).dynamic_tools(
442 1,
443 mock_index,
444 ToolSet::from_tools(vec![MockSubtractTool]),
445 );
446
447 let handle = server.run();
448
449 let res = handle.get_tool_defs(None).await.unwrap();
451 assert_eq!(res.len(), 1);
452 assert_eq!(res[0].name, "add");
453
454 let res = handle
456 .get_tool_defs(Some("calculate difference".to_string()))
457 .await
458 .unwrap();
459 assert_eq!(res.len(), 2);
460
461 let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
463 assert!(tool_names.contains(&"add"));
464 assert!(tool_names.contains(&"subtract"));
465 }
466
467 #[tokio::test]
468 pub async fn test_toolserver_dynamic_tools_missing_implementation() {
469 let mock_index = MockToolIndex::new(["nonexistent_tool"]);
471
472 let server =
474 ToolServer::new()
475 .tool(MockAddTool)
476 .dynamic_tools(1, mock_index, ToolSet::default());
477
478 let handle = server.run();
479
480 let res = handle
482 .get_tool_defs(Some("some query".to_string()))
483 .await
484 .unwrap();
485 assert_eq!(res.len(), 1);
486 assert_eq!(res[0].name, "add");
487 }
488
489 #[tokio::test]
490 pub async fn test_toolserver_concurrent_tool_execution() {
491 let num_calls = 3;
492 let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
493
494 let server = ToolServer::new().tool(MockBarrierTool::new(barrier.clone()));
495 let handle = server.run();
496
497 let futures: Vec<_> = (0..num_calls)
499 .map(|_| handle.call_tool("barrier_tool", "{}"))
500 .collect();
501
502 let result =
505 tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
506
507 assert!(
508 result.is_ok(),
509 "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
510 );
511
512 for res in result.unwrap() {
514 assert!(res.is_ok(), "Tool call failed: {:?}", res);
515 assert_eq!(res.unwrap(), "done");
516 }
517 }
518
519 #[tokio::test]
520 pub async fn test_toolserver_write_while_tool_running() {
521 let started = Arc::new(tokio::sync::Notify::new());
522 let allow_finish = Arc::new(tokio::sync::Notify::new());
523
524 let tool = MockControlledTool::new(started.clone(), allow_finish.clone());
526
527 let server = ToolServer::new().tool(tool);
528 let handle = server.run();
529
530 let handle_clone = handle.clone();
532 let call_task =
533 tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
534
535 started.notified().await;
537
538 let add_result =
541 tokio::time::timeout(Duration::from_secs(1), handle.add_tool(MockAddTool)).await;
542
543 assert!(
544 add_result.is_ok(),
545 "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
546 );
547 assert!(add_result.unwrap().is_ok());
548
549 allow_finish.notify_one();
551 let call_result = call_task.await.unwrap();
552 assert_eq!(call_result.unwrap(), "42");
553 }
554
555 #[tokio::test]
556 pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
557 let barrier = Arc::new(tokio::sync::Barrier::new(2));
559
560 let index1 = BarrierMockToolIndex::new(barrier.clone(), "add");
561 let index2 = BarrierMockToolIndex::new(barrier.clone(), "subtract");
562
563 let mut toolset = ToolSet::default();
565 toolset.add_tool(MockAddTool);
566 toolset.add_tool(MockSubtractTool);
567
568 let server = ToolServer::new()
569 .dynamic_tools(1, index1, ToolSet::default())
570 .dynamic_tools(1, index2, toolset);
571
572 let handle = server.run();
573
574 let get_defs = tokio::time::timeout(
577 std::time::Duration::from_secs(1),
578 handle.get_tool_defs(Some("do math".to_string())),
579 )
580 .await;
581
582 assert!(
583 get_defs.is_ok(),
584 "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
585 );
586
587 let defs = get_defs.unwrap().unwrap();
588 assert_eq!(defs.len(), 2);
589
590 let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
591 assert!(tool_names.contains(&"add"));
592 assert!(tool_names.contains(&"subtract"));
593 }
594}