1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6 completion::{CompletionError, ToolDefinition},
7 tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11struct ToolServerState {
13 static_tool_names: Vec<String>,
15 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
17 toolset: ToolSet,
19}
20
21pub struct ToolServer {
26 static_tool_names: Vec<String>,
27 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28 toolset: ToolSet,
29}
30
31impl Default for ToolServer {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl ToolServer {
38 pub fn new() -> Self {
39 Self {
40 static_tool_names: Vec::new(),
41 dynamic_tools: Vec::new(),
42 toolset: ToolSet::default(),
43 }
44 }
45
46 pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
47 self.static_tool_names = names;
48 self
49 }
50
51 pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
52 self.toolset = tools;
53 self
54 }
55
56 pub(crate) fn add_dynamic_tools(
57 mut self,
58 dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
59 ) -> Self {
60 self.dynamic_tools = dyn_tools;
61 self
62 }
63
64 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
66 let toolname = tool.name();
67 self.toolset.add_tool(tool);
68 self.static_tool_names.push(toolname);
69 self
70 }
71
72 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
74 #[cfg(feature = "rmcp")]
75 pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
76 use crate::tool::rmcp::McpTool;
77 let toolname = tool.name.clone();
78 self.toolset
79 .add_tool(McpTool::from_mcp_server(tool, client));
80 self.static_tool_names.push(toolname.to_string());
81 self
82 }
83
84 pub fn dynamic_tools(
87 mut self,
88 sample: usize,
89 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
90 toolset: ToolSet,
91 ) -> Self {
92 self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
93 self.toolset.add_tools(toolset);
94 self
95 }
96
97 pub fn run(self) -> ToolServerHandle {
99 ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
100 static_tool_names: self.static_tool_names,
101 dynamic_tools: self.dynamic_tools,
102 toolset: self.toolset,
103 })))
104 }
105}
106
107#[derive(Clone)]
113pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
114
115impl ToolServerHandle {
116 pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
118 let mut state = self.0.write().await;
119 state.static_tool_names.push(tool.name());
120 state.toolset.add_tool_boxed(Box::new(tool));
121 Ok(())
122 }
123
124 pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
126 let mut state = self.0.write().await;
127 state.toolset.add_tools(toolset);
128 Ok(())
129 }
130
131 pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
133 let mut state = self.0.write().await;
134 state.static_tool_names.retain(|x| *x != tool_name);
135 state.toolset.delete_tool(tool_name);
136 Ok(())
137 }
138
139 pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
144 let tool = {
145 let state = self.0.read().await;
146 state.toolset.get(tool_name).cloned()
147 };
148
149 match tool {
150 Some(tool) => {
151 tracing::debug!(target: "rig",
152 "Calling tool {tool_name} with args:\n{}",
153 serde_json::to_string_pretty(&args).unwrap_or_default()
154 );
155 tool.call(args.to_string())
156 .await
157 .map_err(|e| ToolSetError::ToolCallError(e).into())
158 }
159 None => Err(ToolServerError::ToolsetError(
160 ToolSetError::ToolNotFoundError(tool_name.to_string()),
161 )),
162 }
163 }
164
165 pub async fn get_tool_defs(
168 &self,
169 prompt: Option<String>,
170 ) -> Result<Vec<ToolDefinition>, ToolServerError> {
171 let (static_tool_names, dynamic_tools) = {
173 let state = self.0.read().await;
174 (state.static_tool_names.clone(), state.dynamic_tools.clone())
175 };
176
177 let mut tools = if let Some(ref text) = prompt {
178 let futs: Vec<_> = dynamic_tools
179 .into_iter()
180 .map(|(num_sample, index)| {
181 let text = text.clone();
182 async move {
183 let req = VectorSearchRequest::builder()
184 .query(text)
185 .samples(num_sample as u64)
186 .build()
187 .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
188 Ok::<_, VectorStoreError>(
189 index
190 .top_n_ids(req.map_filter(Filter::interpret))
191 .await?
192 .into_iter()
193 .map(|(_, id)| id)
194 .collect::<Vec<String>>(),
195 )
196 }
197 })
198 .collect();
199
200 let results = futures::future::try_join_all(futs).await.map_err(|e| {
201 ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
202 })?;
203
204 let dynamic_tool_ids: Vec<String> = results.into_iter().flatten().collect();
205
206 let dynamic_tool_handles: Vec<_> = {
207 let state = self.0.read().await;
208 dynamic_tool_ids
209 .iter()
210 .filter_map(|doc| {
211 let handle = state.toolset.get(doc).cloned();
212 if handle.is_none() {
213 tracing::warn!("Tool implementation not found in toolset: {}", doc);
214 }
215 handle
216 })
217 .collect()
218 };
219
220 let mut tools = Vec::new();
221 for tool in dynamic_tool_handles {
222 tools.push(tool.definition(text.clone()).await);
223 }
224 tools
225 } else {
226 Vec::new()
227 };
228
229 let static_tool_handles: Vec<_> = {
230 let state = self.0.read().await;
231 static_tool_names
232 .iter()
233 .filter_map(|toolname| {
234 let handle = state.toolset.get(toolname).cloned();
235 if handle.is_none() {
236 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
237 }
238 handle
239 })
240 .collect()
241 };
242
243 for tool in static_tool_handles {
244 tools.push(tool.definition(String::new()).await);
245 }
246
247 Ok(tools)
248 }
249}
250
251#[derive(Debug, thiserror::Error)]
252pub enum ToolServerError {
253 #[error("Toolset error: {0}")]
254 ToolsetError(#[from] ToolSetError),
255 #[error("Failed to retrieve tool definitions: {0}")]
256 DefinitionError(CompletionError),
257}
258
259#[cfg(test)]
260mod tests {
261 use std::{sync::Arc, time::Duration};
262
263 use serde::{Deserialize, Serialize};
264 use serde_json::json;
265
266 use crate::{
267 completion::ToolDefinition,
268 tool::{Tool, ToolSet, server::ToolServer},
269 vector_store::{
270 VectorStoreError, VectorStoreIndex,
271 request::{Filter, VectorSearchRequest},
272 },
273 wasm_compat::WasmCompatSend,
274 };
275
276 #[derive(Deserialize)]
277 struct OperationArgs {
278 x: i32,
279 y: i32,
280 }
281
282 #[derive(Debug, thiserror::Error)]
283 #[error("Math error")]
284 struct MathError;
285
286 #[derive(Deserialize, Serialize)]
287 struct Adder;
288 impl Tool for Adder {
289 const NAME: &'static str = "add";
290 type Error = MathError;
291 type Args = OperationArgs;
292 type Output = i32;
293
294 async fn definition(&self, _prompt: String) -> ToolDefinition {
295 ToolDefinition {
296 name: "add".to_string(),
297 description: "Add x and y together".to_string(),
298 parameters: json!({
299 "type": "object",
300 "properties": {
301 "x": {
302 "type": "number",
303 "description": "The first number to add"
304 },
305 "y": {
306 "type": "number",
307 "description": "The second number to add"
308 }
309 },
310 "required": ["x", "y"],
311 }),
312 }
313 }
314
315 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
316 println!("[tool-call] Adding {} and {}", args.x, args.y);
317 let result = args.x + args.y;
318 Ok(result)
319 }
320 }
321
322 #[derive(Deserialize, Serialize)]
323 struct Subtractor;
324 impl Tool for Subtractor {
325 const NAME: &'static str = "subtract";
326 type Error = MathError;
327 type Args = OperationArgs;
328 type Output = i32;
329
330 async fn definition(&self, _prompt: String) -> ToolDefinition {
331 ToolDefinition {
332 name: "subtract".to_string(),
333 description: "Subtract y from x".to_string(),
334 parameters: json!({
335 "type": "object",
336 "properties": {
337 "x": {
338 "type": "number",
339 "description": "The number to subtract from"
340 },
341 "y": {
342 "type": "number",
343 "description": "The number to subtract"
344 }
345 },
346 "required": ["x", "y"],
347 }),
348 }
349 }
350
351 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
352 let result = args.x - args.y;
353 Ok(result)
354 }
355 }
356
357 struct MockToolIndex {
359 tool_ids: Vec<String>,
360 }
361
362 impl VectorStoreIndex for MockToolIndex {
363 type Filter = Filter<serde_json::Value>;
364
365 async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
366 &self,
367 _req: VectorSearchRequest,
368 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
369 Ok(vec![])
371 }
372
373 async fn top_n_ids(
374 &self,
375 _req: VectorSearchRequest,
376 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
377 Ok(self
378 .tool_ids
379 .iter()
380 .enumerate()
381 .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
382 .collect())
383 }
384 }
385
386 #[tokio::test]
387 pub async fn test_toolserver() {
388 let server = ToolServer::new();
389
390 let handle = server.run();
391
392 handle.add_tool(Adder).await.unwrap();
393 let res = handle.get_tool_defs(None).await.unwrap();
394
395 assert_eq!(res.len(), 1);
396
397 let json_args_as_string =
398 serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
399 let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
400 assert_eq!(res, "7");
401
402 handle.remove_tool("add").await.unwrap();
403 let res = handle.get_tool_defs(None).await.unwrap();
404
405 assert_eq!(res.len(), 0);
406 }
407
408 #[tokio::test]
409 pub async fn test_toolserver_dynamic_tools() {
410 let mut toolset = ToolSet::default();
412 toolset.add_tool(Adder);
413 toolset.add_tool(Subtractor);
414
415 let mock_index = MockToolIndex {
417 tool_ids: vec!["subtract".to_string()],
418 };
419
420 let server = ToolServer::new().tool(Adder).dynamic_tools(
422 1,
423 mock_index,
424 ToolSet::from_tools(vec![Subtractor]),
425 );
426
427 let handle = server.run();
428
429 let res = handle.get_tool_defs(None).await.unwrap();
431 assert_eq!(res.len(), 1);
432 assert_eq!(res[0].name, "add");
433
434 let res = handle
436 .get_tool_defs(Some("calculate difference".to_string()))
437 .await
438 .unwrap();
439 assert_eq!(res.len(), 2);
440
441 let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
443 assert!(tool_names.contains(&"add"));
444 assert!(tool_names.contains(&"subtract"));
445 }
446
447 #[tokio::test]
448 pub async fn test_toolserver_dynamic_tools_missing_implementation() {
449 let mock_index = MockToolIndex {
451 tool_ids: vec!["nonexistent_tool".to_string()],
452 };
453
454 let server = ToolServer::new()
456 .tool(Adder)
457 .dynamic_tools(1, mock_index, ToolSet::default());
458
459 let handle = server.run();
460
461 let res = handle
463 .get_tool_defs(Some("some query".to_string()))
464 .await
465 .unwrap();
466 assert_eq!(res.len(), 1);
467 assert_eq!(res[0].name, "add");
468 }
469
470 #[derive(Clone)]
472 struct BarrierTool {
473 barrier: Arc<tokio::sync::Barrier>,
474 }
475
476 #[derive(Debug, thiserror::Error)]
477 #[error("Barrier error")]
478 struct BarrierError;
479
480 impl Tool for BarrierTool {
481 const NAME: &'static str = "barrier_tool";
482 type Error = BarrierError;
483 type Args = serde_json::Value;
484 type Output = String;
485
486 async fn definition(&self, _prompt: String) -> ToolDefinition {
487 ToolDefinition {
488 name: "barrier_tool".to_string(),
489 description: "Waits at a barrier to test concurrency".to_string(),
490 parameters: serde_json::json!({"type": "object", "properties": {}}),
491 }
492 }
493
494 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
495 self.barrier.wait().await;
497 Ok("done".to_string())
498 }
499 }
500
501 #[tokio::test]
502 pub async fn test_toolserver_concurrent_tool_execution() {
503 let num_calls = 3;
504 let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
505
506 let server = ToolServer::new().tool(BarrierTool {
507 barrier: barrier.clone(),
508 });
509 let handle = server.run();
510
511 let futures: Vec<_> = (0..num_calls)
513 .map(|_| handle.call_tool("barrier_tool", "{}"))
514 .collect();
515
516 let result =
519 tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
520
521 assert!(
522 result.is_ok(),
523 "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
524 );
525
526 for res in result.unwrap() {
528 assert!(res.is_ok(), "Tool call failed: {:?}", res);
529 assert_eq!(res.unwrap(), "done");
530 }
531 }
532
533 #[derive(Clone)]
535 struct ControlledTool {
536 started: Arc<tokio::sync::Notify>,
537 allow_finish: Arc<tokio::sync::Notify>,
538 }
539
540 #[derive(Debug, thiserror::Error)]
541 #[error("Controlled error")]
542 struct ControlledError;
543
544 impl Tool for ControlledTool {
545 const NAME: &'static str = "controlled";
546 type Error = ControlledError;
547 type Args = serde_json::Value;
548 type Output = i32;
549
550 async fn definition(&self, _prompt: String) -> ToolDefinition {
551 ToolDefinition {
552 name: "controlled".to_string(),
553 description: "Test tool".to_string(),
554 parameters: serde_json::json!({"type": "object", "properties": {}}),
555 }
556 }
557
558 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
559 self.started.notify_one();
561 self.allow_finish.notified().await;
563 Ok(42)
564 }
565 }
566
567 #[tokio::test]
568 pub async fn test_toolserver_write_while_tool_running() {
569 let started = Arc::new(tokio::sync::Notify::new());
570 let allow_finish = Arc::new(tokio::sync::Notify::new());
571
572 let tool = ControlledTool {
574 started: started.clone(),
575 allow_finish: allow_finish.clone(),
576 };
577
578 let server = ToolServer::new().tool(tool);
579 let handle = server.run();
580
581 let handle_clone = handle.clone();
583 let call_task =
584 tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
585
586 started.notified().await;
588
589 let add_result = tokio::time::timeout(Duration::from_secs(1), handle.add_tool(Adder)).await;
592
593 assert!(
594 add_result.is_ok(),
595 "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
596 );
597 assert!(add_result.unwrap().is_ok());
598
599 allow_finish.notify_one();
601 let call_result = call_task.await.unwrap();
602 assert_eq!(call_result.unwrap(), "42");
603 }
604
605 struct BarrierMockIndex {
607 barrier: Arc<tokio::sync::Barrier>,
608 tool_id: String,
609 }
610
611 impl VectorStoreIndex for BarrierMockIndex {
612 type Filter = Filter<serde_json::Value>;
613
614 async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
615 &self,
616 _req: VectorSearchRequest,
617 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
618 Ok(vec![])
619 }
620
621 async fn top_n_ids(
622 &self,
623 _req: VectorSearchRequest,
624 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
625 self.barrier.wait().await;
627 Ok(vec![(1.0, self.tool_id.clone())])
628 }
629 }
630
631 #[tokio::test]
632 pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
633 let barrier = Arc::new(tokio::sync::Barrier::new(2));
635
636 let index1 = BarrierMockIndex {
637 barrier: barrier.clone(),
638 tool_id: "add".to_string(),
639 };
640
641 let index2 = BarrierMockIndex {
642 barrier: barrier.clone(),
643 tool_id: "subtract".to_string(),
644 };
645
646 let mut toolset = ToolSet::default();
648 toolset.add_tool(Adder);
649 toolset.add_tool(Subtractor);
650
651 let server = ToolServer::new()
652 .dynamic_tools(1, index1, ToolSet::default())
653 .dynamic_tools(1, index2, toolset);
654
655 let handle = server.run();
656
657 let get_defs = tokio::time::timeout(
660 std::time::Duration::from_secs(1),
661 handle.get_tool_defs(Some("do math".to_string())),
662 )
663 .await;
664
665 assert!(
666 get_defs.is_ok(),
667 "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
668 );
669
670 let defs = get_defs.unwrap().unwrap();
671 assert_eq!(defs.len(), 2);
672
673 let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
674 assert!(tool_names.contains(&"add"));
675 assert!(tool_names.contains(&"subtract"));
676 }
677}