1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6 completion::{CompletionError, ToolDefinition},
7 tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11struct ToolServerState {
13 static_tool_names: Vec<String>,
15 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
17 toolset: ToolSet,
19}
20
21pub struct ToolServer {
26 static_tool_names: Vec<String>,
27 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28 toolset: ToolSet,
29}
30
31impl Default for ToolServer {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl ToolServer {
38 pub fn new() -> Self {
39 Self {
40 static_tool_names: Vec::new(),
41 dynamic_tools: Vec::new(),
42 toolset: ToolSet::default(),
43 }
44 }
45
46 pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
47 self.static_tool_names = names;
48 self
49 }
50
51 pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
52 self.toolset = tools;
53 self
54 }
55
56 pub(crate) fn add_dynamic_tools(
57 mut self,
58 dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
59 ) -> Self {
60 self.dynamic_tools = dyn_tools;
61 self
62 }
63
64 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
66 let toolname = tool.name();
67 self.toolset.add_tool(tool);
68 self.static_tool_names.push(toolname);
69 self
70 }
71
72 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
74 #[cfg(feature = "rmcp")]
75 pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
76 use crate::tool::rmcp::McpTool;
77 let toolname = tool.name.clone();
78 self.toolset
79 .add_tool(McpTool::from_mcp_server(tool, client));
80 self.static_tool_names.push(toolname.to_string());
81 self
82 }
83
84 pub fn dynamic_tools(
87 mut self,
88 sample: usize,
89 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
90 toolset: ToolSet,
91 ) -> Self {
92 self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
93 self.toolset.add_tools(toolset);
94 self
95 }
96
97 pub fn run(self) -> ToolServerHandle {
99 ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
100 static_tool_names: self.static_tool_names,
101 dynamic_tools: self.dynamic_tools,
102 toolset: self.toolset,
103 })))
104 }
105}
106
107#[derive(Clone)]
113pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
114
115impl ToolServerHandle {
116 pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
118 let mut state = self.0.write().await;
119 state.static_tool_names.push(tool.name());
120 state.toolset.add_tool_boxed(Box::new(tool));
121 Ok(())
122 }
123
124 pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
126 let mut state = self.0.write().await;
127 state.toolset.add_tools(toolset);
128 Ok(())
129 }
130
131 pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
133 let mut state = self.0.write().await;
134 state.static_tool_names.retain(|x| *x != tool_name);
135 state.toolset.delete_tool(tool_name);
136 Ok(())
137 }
138
139 pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
144 let tool = {
145 let state = self.0.read().await;
146 state.toolset.get(tool_name).cloned()
147 };
148
149 match tool {
150 Some(tool) => {
151 tracing::debug!(target: "rig",
152 "Calling tool {tool_name} with args:\n{}",
153 serde_json::to_string_pretty(&args).unwrap_or_default()
154 );
155 tool.call(args.to_string())
156 .await
157 .map_err(|e| ToolSetError::ToolCallError(e).into())
158 }
159 None => Err(ToolServerError::ToolsetError(
160 ToolSetError::ToolNotFoundError(tool_name.to_string()),
161 )),
162 }
163 }
164
165 pub async fn get_tool_defs(
168 &self,
169 prompt: Option<String>,
170 ) -> Result<Vec<ToolDefinition>, ToolServerError> {
171 let (static_tool_names, dynamic_tools) = {
173 let state = self.0.read().await;
174 (state.static_tool_names.clone(), state.dynamic_tools.clone())
175 };
176
177 let mut tools = if let Some(ref text) = prompt {
178 let search_futures = dynamic_tools.iter().map(|(num_sample, index)| {
180 let text = text.clone();
181 let num_sample = *num_sample;
182 let index = index.clone();
183
184 async move {
185 let req = VectorSearchRequest::builder()
186 .query(text)
187 .samples(num_sample as u64)
188 .build();
189
190 let ids = index
191 .as_ref()
192 .top_n_ids(req.map_filter(Filter::interpret))
193 .await?
194 .into_iter()
195 .map(|(_, id)| id)
196 .collect::<Vec<String>>();
197
198 Ok::<_, VectorStoreError>(ids)
199 }
200 });
201
202 let dynamic_tool_ids: Vec<String> = futures::future::try_join_all(search_futures)
204 .await
205 .map_err(|e| {
206 ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
207 })?
208 .into_iter()
209 .flatten()
210 .collect();
211
212 let dynamic_tool_handles: Vec<_> = {
213 let state = self.0.read().await;
214 dynamic_tool_ids
215 .iter()
216 .filter_map(|doc| {
217 let handle = state.toolset.get(doc).cloned();
218 if handle.is_none() {
219 tracing::warn!("Tool implementation not found in toolset: {}", doc);
220 }
221 handle
222 })
223 .collect()
224 };
225
226 let mut tools = Vec::new();
227 for tool in dynamic_tool_handles {
228 tools.push(tool.definition(text.clone()).await);
229 }
230 tools
231 } else {
232 Vec::new()
233 };
234
235 let static_tool_handles: Vec<_> = {
236 let state = self.0.read().await;
237 static_tool_names
238 .iter()
239 .filter_map(|toolname| {
240 let handle = state.toolset.get(toolname).cloned();
241 if handle.is_none() {
242 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
243 }
244 handle
245 })
246 .collect()
247 };
248
249 for tool in static_tool_handles {
250 tools.push(tool.definition(String::new()).await);
251 }
252
253 Ok(tools)
254 }
255}
256
257#[derive(Debug, thiserror::Error)]
258pub enum ToolServerError {
259 #[error("Toolset error: {0}")]
260 ToolsetError(#[from] ToolSetError),
261 #[error("Failed to retrieve tool definitions: {0}")]
262 DefinitionError(CompletionError),
263}
264
265#[cfg(test)]
266mod tests {
267 use std::{sync::Arc, time::Duration};
268
269 use serde::{Deserialize, Serialize};
270 use serde_json::json;
271
272 use crate::{
273 completion::ToolDefinition,
274 tool::{Tool, ToolSet, server::ToolServer},
275 vector_store::{
276 VectorStoreError, VectorStoreIndex,
277 request::{Filter, VectorSearchRequest},
278 },
279 wasm_compat::WasmCompatSend,
280 };
281
282 #[derive(Deserialize)]
283 struct OperationArgs {
284 x: i32,
285 y: i32,
286 }
287
288 #[derive(Debug, thiserror::Error)]
289 #[error("Math error")]
290 struct MathError;
291
292 #[derive(Deserialize, Serialize)]
293 struct Adder;
294 impl Tool for Adder {
295 const NAME: &'static str = "add";
296 type Error = MathError;
297 type Args = OperationArgs;
298 type Output = i32;
299
300 async fn definition(&self, _prompt: String) -> ToolDefinition {
301 ToolDefinition {
302 name: "add".to_string(),
303 description: "Add x and y together".to_string(),
304 parameters: json!({
305 "type": "object",
306 "properties": {
307 "x": {
308 "type": "number",
309 "description": "The first number to add"
310 },
311 "y": {
312 "type": "number",
313 "description": "The second number to add"
314 }
315 },
316 "required": ["x", "y"],
317 }),
318 }
319 }
320
321 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
322 println!("[tool-call] Adding {} and {}", args.x, args.y);
323 let result = args.x + args.y;
324 Ok(result)
325 }
326 }
327
328 #[derive(Deserialize, Serialize)]
329 struct Subtractor;
330 impl Tool for Subtractor {
331 const NAME: &'static str = "subtract";
332 type Error = MathError;
333 type Args = OperationArgs;
334 type Output = i32;
335
336 async fn definition(&self, _prompt: String) -> ToolDefinition {
337 ToolDefinition {
338 name: "subtract".to_string(),
339 description: "Subtract y from x".to_string(),
340 parameters: json!({
341 "type": "object",
342 "properties": {
343 "x": {
344 "type": "number",
345 "description": "The number to subtract from"
346 },
347 "y": {
348 "type": "number",
349 "description": "The number to subtract"
350 }
351 },
352 "required": ["x", "y"],
353 }),
354 }
355 }
356
357 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
358 let result = args.x - args.y;
359 Ok(result)
360 }
361 }
362
363 struct MockToolIndex {
365 tool_ids: Vec<String>,
366 }
367
368 impl VectorStoreIndex for MockToolIndex {
369 type Filter = Filter<serde_json::Value>;
370
371 async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
372 &self,
373 _req: VectorSearchRequest,
374 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
375 Ok(vec![])
377 }
378
379 async fn top_n_ids(
380 &self,
381 _req: VectorSearchRequest,
382 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
383 Ok(self
384 .tool_ids
385 .iter()
386 .enumerate()
387 .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
388 .collect())
389 }
390 }
391
392 #[tokio::test]
393 pub async fn test_toolserver() {
394 let server = ToolServer::new();
395
396 let handle = server.run();
397
398 handle.add_tool(Adder).await.unwrap();
399 let res = handle.get_tool_defs(None).await.unwrap();
400
401 assert_eq!(res.len(), 1);
402
403 let json_args_as_string =
404 serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
405 let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
406 assert_eq!(res, "7");
407
408 handle.remove_tool("add").await.unwrap();
409 let res = handle.get_tool_defs(None).await.unwrap();
410
411 assert_eq!(res.len(), 0);
412 }
413
414 #[tokio::test]
415 pub async fn test_toolserver_dynamic_tools() {
416 let mut toolset = ToolSet::default();
418 toolset.add_tool(Adder);
419 toolset.add_tool(Subtractor);
420
421 let mock_index = MockToolIndex {
423 tool_ids: vec!["subtract".to_string()],
424 };
425
426 let server = ToolServer::new().tool(Adder).dynamic_tools(
428 1,
429 mock_index,
430 ToolSet::from_tools(vec![Subtractor]),
431 );
432
433 let handle = server.run();
434
435 let res = handle.get_tool_defs(None).await.unwrap();
437 assert_eq!(res.len(), 1);
438 assert_eq!(res[0].name, "add");
439
440 let res = handle
442 .get_tool_defs(Some("calculate difference".to_string()))
443 .await
444 .unwrap();
445 assert_eq!(res.len(), 2);
446
447 let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
449 assert!(tool_names.contains(&"add"));
450 assert!(tool_names.contains(&"subtract"));
451 }
452
453 #[tokio::test]
454 pub async fn test_toolserver_dynamic_tools_missing_implementation() {
455 let mock_index = MockToolIndex {
457 tool_ids: vec!["nonexistent_tool".to_string()],
458 };
459
460 let server = ToolServer::new()
462 .tool(Adder)
463 .dynamic_tools(1, mock_index, ToolSet::default());
464
465 let handle = server.run();
466
467 let res = handle
469 .get_tool_defs(Some("some query".to_string()))
470 .await
471 .unwrap();
472 assert_eq!(res.len(), 1);
473 assert_eq!(res[0].name, "add");
474 }
475
476 #[derive(Clone)]
478 struct BarrierTool {
479 barrier: Arc<tokio::sync::Barrier>,
480 }
481
482 #[derive(Debug, thiserror::Error)]
483 #[error("Barrier error")]
484 struct BarrierError;
485
486 impl Tool for BarrierTool {
487 const NAME: &'static str = "barrier_tool";
488 type Error = BarrierError;
489 type Args = serde_json::Value;
490 type Output = String;
491
492 async fn definition(&self, _prompt: String) -> ToolDefinition {
493 ToolDefinition {
494 name: "barrier_tool".to_string(),
495 description: "Waits at a barrier to test concurrency".to_string(),
496 parameters: serde_json::json!({"type": "object", "properties": {}}),
497 }
498 }
499
500 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
501 self.barrier.wait().await;
503 Ok("done".to_string())
504 }
505 }
506
507 #[tokio::test]
508 pub async fn test_toolserver_concurrent_tool_execution() {
509 let num_calls = 3;
510 let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
511
512 let server = ToolServer::new().tool(BarrierTool {
513 barrier: barrier.clone(),
514 });
515 let handle = server.run();
516
517 let futures: Vec<_> = (0..num_calls)
519 .map(|_| handle.call_tool("barrier_tool", "{}"))
520 .collect();
521
522 let result =
525 tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
526
527 assert!(
528 result.is_ok(),
529 "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
530 );
531
532 for res in result.unwrap() {
534 assert!(res.is_ok(), "Tool call failed: {:?}", res);
535 assert_eq!(res.unwrap(), "done");
536 }
537 }
538
539 #[derive(Clone)]
541 struct ControlledTool {
542 started: Arc<tokio::sync::Notify>,
543 allow_finish: Arc<tokio::sync::Notify>,
544 }
545
546 #[derive(Debug, thiserror::Error)]
547 #[error("Controlled error")]
548 struct ControlledError;
549
550 impl Tool for ControlledTool {
551 const NAME: &'static str = "controlled";
552 type Error = ControlledError;
553 type Args = serde_json::Value;
554 type Output = i32;
555
556 async fn definition(&self, _prompt: String) -> ToolDefinition {
557 ToolDefinition {
558 name: "controlled".to_string(),
559 description: "Test tool".to_string(),
560 parameters: serde_json::json!({"type": "object", "properties": {}}),
561 }
562 }
563
564 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
565 self.started.notify_one();
567 self.allow_finish.notified().await;
569 Ok(42)
570 }
571 }
572
573 #[tokio::test]
574 pub async fn test_toolserver_write_while_tool_running() {
575 let started = Arc::new(tokio::sync::Notify::new());
576 let allow_finish = Arc::new(tokio::sync::Notify::new());
577
578 let tool = ControlledTool {
580 started: started.clone(),
581 allow_finish: allow_finish.clone(),
582 };
583
584 let server = ToolServer::new().tool(tool);
585 let handle = server.run();
586
587 let handle_clone = handle.clone();
589 let call_task =
590 tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
591
592 started.notified().await;
594
595 let add_result = tokio::time::timeout(Duration::from_secs(1), handle.add_tool(Adder)).await;
598
599 assert!(
600 add_result.is_ok(),
601 "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
602 );
603 assert!(add_result.unwrap().is_ok());
604
605 allow_finish.notify_one();
607 let call_result = call_task.await.unwrap();
608 assert_eq!(call_result.unwrap(), "42");
609 }
610
611 struct BarrierMockIndex {
613 barrier: Arc<tokio::sync::Barrier>,
614 tool_id: String,
615 }
616
617 impl VectorStoreIndex for BarrierMockIndex {
618 type Filter = Filter<serde_json::Value>;
619
620 async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
621 &self,
622 _req: VectorSearchRequest,
623 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
624 Ok(vec![])
625 }
626
627 async fn top_n_ids(
628 &self,
629 _req: VectorSearchRequest,
630 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
631 self.barrier.wait().await;
633 Ok(vec![(1.0, self.tool_id.clone())])
634 }
635 }
636
637 #[tokio::test]
638 pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
639 let barrier = Arc::new(tokio::sync::Barrier::new(2));
641
642 let index1 = BarrierMockIndex {
643 barrier: barrier.clone(),
644 tool_id: "add".to_string(),
645 };
646
647 let index2 = BarrierMockIndex {
648 barrier: barrier.clone(),
649 tool_id: "subtract".to_string(),
650 };
651
652 let mut toolset = ToolSet::default();
654 toolset.add_tool(Adder);
655 toolset.add_tool(Subtractor);
656
657 let server = ToolServer::new()
658 .dynamic_tools(1, index1, ToolSet::default())
659 .dynamic_tools(1, index2, toolset);
660
661 let handle = server.run();
662
663 let get_defs = tokio::time::timeout(
666 std::time::Duration::from_secs(1),
667 handle.get_tool_defs(Some("do math".to_string())),
668 )
669 .await;
670
671 assert!(
672 get_defs.is_ok(),
673 "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
674 );
675
676 let defs = get_defs.unwrap().unwrap();
677 assert_eq!(defs.len(), 2);
678
679 let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
680 assert!(tool_names.contains(&"add"));
681 assert!(tool_names.contains(&"subtract"));
682 }
683}