1use std::sync::Arc;
2
3use tokio::sync::RwLock;
4
5use crate::{
6 completion::{CompletionError, ToolDefinition},
7 tool::{Tool, ToolDyn, ToolSet, ToolSetError},
8 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
9};
10
11struct ToolServerState {
13 static_tool_names: Vec<String>,
15 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
17 toolset: ToolSet,
19}
20
21pub struct ToolServer {
26 static_tool_names: Vec<String>,
27 dynamic_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
28 toolset: ToolSet,
29}
30
31impl Default for ToolServer {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl ToolServer {
38 pub fn new() -> Self {
39 Self {
40 static_tool_names: Vec::new(),
41 dynamic_tools: Vec::new(),
42 toolset: ToolSet::default(),
43 }
44 }
45
46 pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
47 self.static_tool_names = names;
48 self
49 }
50
51 pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
52 self.toolset = tools;
53 self
54 }
55
56 pub(crate) fn add_dynamic_tools(
57 mut self,
58 dyn_tools: Vec<(usize, Arc<dyn VectorStoreIndexDyn + Send + Sync>)>,
59 ) -> Self {
60 self.dynamic_tools = dyn_tools;
61 self
62 }
63
64 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
66 let toolname = tool.name();
67 self.toolset.add_tool(tool);
68 self.static_tool_names.push(toolname);
69 self
70 }
71
72 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
74 #[cfg(feature = "rmcp")]
75 pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
76 use crate::tool::rmcp::McpTool;
77 let toolname = tool.name.clone();
78 self.toolset
79 .add_tool(McpTool::from_mcp_server(tool, client));
80 self.static_tool_names.push(toolname.to_string());
81 self
82 }
83
84 pub fn dynamic_tools(
87 mut self,
88 sample: usize,
89 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
90 toolset: ToolSet,
91 ) -> Self {
92 self.dynamic_tools.push((sample, Arc::new(dynamic_tools)));
93 self.toolset.add_tools(toolset);
94 self
95 }
96
97 pub fn run(self) -> ToolServerHandle {
99 ToolServerHandle(Arc::new(RwLock::new(ToolServerState {
100 static_tool_names: self.static_tool_names,
101 dynamic_tools: self.dynamic_tools,
102 toolset: self.toolset,
103 })))
104 }
105}
106
107#[derive(Clone)]
113pub struct ToolServerHandle(Arc<RwLock<ToolServerState>>);
114
115impl ToolServerHandle {
116 pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
118 let mut state = self.0.write().await;
119 state.static_tool_names.push(tool.name());
120 state.toolset.add_tool_boxed(Box::new(tool));
121 Ok(())
122 }
123
124 pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
126 let mut state = self.0.write().await;
127 state.toolset.add_tools(toolset);
128 Ok(())
129 }
130
131 pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
133 let mut state = self.0.write().await;
134 state.static_tool_names.retain(|x| *x != tool_name);
135 state.toolset.delete_tool(tool_name);
136 Ok(())
137 }
138
139 pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
144 let tool = {
145 let state = self.0.read().await;
146 state.toolset.get(tool_name).cloned()
147 };
148
149 match tool {
150 Some(tool) => {
151 tracing::debug!(target: "rig",
152 "Calling tool {tool_name} with args:\n{}",
153 serde_json::to_string_pretty(&args).unwrap_or_default()
154 );
155 tool.call(args.to_string())
156 .await
157 .map_err(|e| ToolSetError::ToolCallError(e).into())
158 }
159 None => Err(ToolServerError::ToolsetError(
160 ToolSetError::ToolNotFoundError(tool_name.to_string()),
161 )),
162 }
163 }
164
165 pub async fn get_tool_defs(
168 &self,
169 prompt: Option<String>,
170 ) -> Result<Vec<ToolDefinition>, ToolServerError> {
171 let (static_tool_names, dynamic_tools) = {
173 let state = self.0.read().await;
174 (state.static_tool_names.clone(), state.dynamic_tools.clone())
175 };
176
177 let mut tools = if let Some(ref text) = prompt {
178 let search_futures = dynamic_tools.iter().map(|(num_sample, index)| {
180 let text = text.clone();
181 let num_sample = *num_sample;
182 let index = index.clone();
183
184 async move {
185 let req = VectorSearchRequest::builder()
186 .query(text)
187 .samples(num_sample as u64)
188 .build();
189
190 let ids = index
191 .as_ref()
192 .top_n_ids(req.map_filter(Filter::interpret))
193 .await?
194 .into_iter()
195 .map(|(_, id)| id)
196 .collect::<Vec<String>>();
197
198 Ok::<_, VectorStoreError>(ids)
199 }
200 });
201
202 let dynamic_tool_ids: Vec<String> = futures::future::try_join_all(search_futures)
204 .await
205 .map_err(|e| {
206 ToolServerError::DefinitionError(CompletionError::RequestError(Box::new(e)))
207 })?
208 .into_iter()
209 .flatten()
210 .collect();
211
212 let dynamic_tool_handles: Vec<_> = {
213 let state = self.0.read().await;
214 dynamic_tool_ids
215 .iter()
216 .filter_map(|doc| {
217 let handle = state.toolset.get(doc).cloned();
218 if handle.is_none() {
219 tracing::warn!("Tool implementation not found in toolset: {}", doc);
220 }
221 handle
222 })
223 .collect()
224 };
225
226 let mut tools = Vec::new();
227 for tool in dynamic_tool_handles {
228 tools.push(tool.definition(text.clone()).await);
229 }
230 tools
231 } else {
232 Vec::new()
233 };
234
235 let static_tool_handles: Vec<_> = {
236 let state = self.0.read().await;
237 static_tool_names
238 .iter()
239 .filter_map(|toolname| {
240 let handle = state.toolset.get(toolname).cloned();
241 if handle.is_none() {
242 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
243 }
244 handle
245 })
246 .collect()
247 };
248
249 for tool in static_tool_handles {
250 tools.push(tool.definition(String::new()).await);
251 }
252
253 Ok(tools)
254 }
255}
256
257#[derive(Debug, thiserror::Error)]
258pub enum ToolServerError {
259 #[error("Toolset error: {0}")]
260 ToolsetError(#[from] ToolSetError),
261 #[error("Failed to retrieve tool definitions: {0}")]
262 DefinitionError(CompletionError),
263}
264
265#[cfg(test)]
266mod tests {
267 use std::{sync::Arc, time::Duration};
268
269 use crate::{
270 test_utils::{
271 BarrierMockToolIndex, MockAddTool, MockBarrierTool, MockControlledTool,
272 MockSubtractTool, MockToolIndex,
273 },
274 tool::{ToolSet, server::ToolServer},
275 };
276
277 #[tokio::test]
278 pub async fn test_toolserver() {
279 let server = ToolServer::new();
280
281 let handle = server.run();
282
283 handle.add_tool(MockAddTool).await.unwrap();
284 let res = handle.get_tool_defs(None).await.unwrap();
285
286 assert_eq!(res.len(), 1);
287
288 let json_args_as_string =
289 serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
290 let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
291 assert_eq!(res, "7");
292
293 handle.remove_tool("add").await.unwrap();
294 let res = handle.get_tool_defs(None).await.unwrap();
295
296 assert_eq!(res.len(), 0);
297 }
298
299 #[tokio::test]
300 pub async fn test_toolserver_dynamic_tools() {
301 let mut toolset = ToolSet::default();
303 toolset.add_tool(MockAddTool);
304 toolset.add_tool(MockSubtractTool);
305
306 let mock_index = MockToolIndex::new(["subtract"]);
308
309 let server = ToolServer::new().tool(MockAddTool).dynamic_tools(
311 1,
312 mock_index,
313 ToolSet::from_tools(vec![MockSubtractTool]),
314 );
315
316 let handle = server.run();
317
318 let res = handle.get_tool_defs(None).await.unwrap();
320 assert_eq!(res.len(), 1);
321 assert_eq!(res[0].name, "add");
322
323 let res = handle
325 .get_tool_defs(Some("calculate difference".to_string()))
326 .await
327 .unwrap();
328 assert_eq!(res.len(), 2);
329
330 let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
332 assert!(tool_names.contains(&"add"));
333 assert!(tool_names.contains(&"subtract"));
334 }
335
336 #[tokio::test]
337 pub async fn test_toolserver_dynamic_tools_missing_implementation() {
338 let mock_index = MockToolIndex::new(["nonexistent_tool"]);
340
341 let server =
343 ToolServer::new()
344 .tool(MockAddTool)
345 .dynamic_tools(1, mock_index, ToolSet::default());
346
347 let handle = server.run();
348
349 let res = handle
351 .get_tool_defs(Some("some query".to_string()))
352 .await
353 .unwrap();
354 assert_eq!(res.len(), 1);
355 assert_eq!(res[0].name, "add");
356 }
357
358 #[tokio::test]
359 pub async fn test_toolserver_concurrent_tool_execution() {
360 let num_calls = 3;
361 let barrier = Arc::new(tokio::sync::Barrier::new(num_calls));
362
363 let server = ToolServer::new().tool(MockBarrierTool::new(barrier.clone()));
364 let handle = server.run();
365
366 let futures: Vec<_> = (0..num_calls)
368 .map(|_| handle.call_tool("barrier_tool", "{}"))
369 .collect();
370
371 let result =
374 tokio::time::timeout(Duration::from_secs(1), futures::future::join_all(futures)).await;
375
376 assert!(
377 result.is_ok(),
378 "Tool execution deadlocked! Tools are executing sequentially instead of concurrently."
379 );
380
381 for res in result.unwrap() {
383 assert!(res.is_ok(), "Tool call failed: {:?}", res);
384 assert_eq!(res.unwrap(), "done");
385 }
386 }
387
388 #[tokio::test]
389 pub async fn test_toolserver_write_while_tool_running() {
390 let started = Arc::new(tokio::sync::Notify::new());
391 let allow_finish = Arc::new(tokio::sync::Notify::new());
392
393 let tool = MockControlledTool::new(started.clone(), allow_finish.clone());
395
396 let server = ToolServer::new().tool(tool);
397 let handle = server.run();
398
399 let handle_clone = handle.clone();
401 let call_task =
402 tokio::spawn(async move { handle_clone.call_tool("controlled", "{}").await });
403
404 started.notified().await;
406
407 let add_result =
410 tokio::time::timeout(Duration::from_secs(1), handle.add_tool(MockAddTool)).await;
411
412 assert!(
413 add_result.is_ok(),
414 "Writing to ToolServer deadlocked! The read lock is being held across tool execution."
415 );
416 assert!(add_result.unwrap().is_ok());
417
418 allow_finish.notify_one();
420 let call_result = call_task.await.unwrap();
421 assert_eq!(call_result.unwrap(), "42");
422 }
423
424 #[tokio::test]
425 pub async fn test_toolserver_parallel_dynamic_tool_fetching() {
426 let barrier = Arc::new(tokio::sync::Barrier::new(2));
428
429 let index1 = BarrierMockToolIndex::new(barrier.clone(), "add");
430 let index2 = BarrierMockToolIndex::new(barrier.clone(), "subtract");
431
432 let mut toolset = ToolSet::default();
434 toolset.add_tool(MockAddTool);
435 toolset.add_tool(MockSubtractTool);
436
437 let server = ToolServer::new()
438 .dynamic_tools(1, index1, ToolSet::default())
439 .dynamic_tools(1, index2, toolset);
440
441 let handle = server.run();
442
443 let get_defs = tokio::time::timeout(
446 std::time::Duration::from_secs(1),
447 handle.get_tool_defs(Some("do math".to_string())),
448 )
449 .await;
450
451 assert!(
452 get_defs.is_ok(),
453 "Dynamic tools were fetched sequentially! The first query deadlocked waiting for the second query to start."
454 );
455
456 let defs = get_defs.unwrap().unwrap();
457 assert_eq!(defs.len(), 2);
458
459 let tool_names: Vec<&str> = defs.iter().map(|t| t.name.as_str()).collect();
460 assert!(tool_names.contains(&"add"));
461 assert!(tool_names.contains(&"subtract"));
462 }
463}