1use std::sync::Arc;
2
3use futures::{StreamExt, TryStreamExt, channel::oneshot::Canceled, stream};
4use tokio::sync::{
5 RwLock,
6 mpsc::{Sender, error::SendError},
7};
8
9use crate::{
10 completion::{CompletionError, ToolDefinition},
11 tool::{Tool, ToolDyn, ToolError, ToolSet, ToolSetError},
12 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
13};
14
15pub struct ToolServer {
16 static_tool_names: Vec<String>,
19 dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
21 toolset: Arc<RwLock<ToolSet>>,
24}
25
26impl Default for ToolServer {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl ToolServer {
33 pub fn new() -> Self {
34 Self {
35 static_tool_names: Vec::new(),
36 dynamic_tools: Vec::new(),
37 toolset: Arc::new(RwLock::new(ToolSet::default())),
38 }
39 }
40
41 pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
42 self.static_tool_names = names;
43 self
44 }
45
46 pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
47 self.toolset = Arc::new(RwLock::new(tools));
48 self
49 }
50
51 pub(crate) fn add_dynamic_tools(
52 mut self,
53 dyn_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
54 ) -> Self {
55 self.dynamic_tools = dyn_tools;
56 self
57 }
58
59 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
61 let toolname = tool.name();
62 Arc::get_mut(&mut self.toolset)
66 .expect("ToolServer::tool() called after run()")
67 .get_mut()
68 .add_tool(tool);
69 self.static_tool_names.push(toolname);
70 self
71 }
72
73 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
75 #[cfg(feature = "rmcp")]
76 pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
77 use crate::tool::rmcp::McpTool;
78 let toolname = tool.name.clone();
79 Arc::get_mut(&mut self.toolset)
83 .expect("ToolServer::rmcp_tool() called after run()")
84 .get_mut()
85 .add_tool(McpTool::from_mcp_server(tool, client));
86 self.static_tool_names.push(toolname.to_string());
87 self
88 }
89
90 pub fn dynamic_tools(
93 mut self,
94 sample: usize,
95 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
96 toolset: ToolSet,
97 ) -> Self {
98 self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
99 Arc::get_mut(&mut self.toolset)
103 .expect("ToolServer::dynamic_tools() called after run()")
104 .get_mut()
105 .add_tools(toolset);
106 self
107 }
108
109 pub fn run(mut self) -> ToolServerHandle {
110 let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
111
112 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
113 tokio::spawn(async move {
114 while let Some(message) = rx.recv().await {
115 self.handle_message(message).await;
116 }
117 });
118
119 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
122 wasm_bindgen_futures::spawn_local(async move {
123 while let Some(message) = rx.recv().await {
124 self.handle_message(message).await;
125 }
126 });
127
128 ToolServerHandle(tx)
129 }
130
131 pub async fn handle_message(&mut self, message: ToolServerRequest) {
132 let ToolServerRequest {
133 callback_channel,
134 data,
135 } = message;
136
137 match data {
138 ToolServerRequestMessageKind::AddTool(tool) => {
139 self.static_tool_names.push(tool.name());
140 self.toolset.write().await.add_tool_boxed(tool);
141 callback_channel
142 .send(ToolServerResponse::ToolAdded)
143 .unwrap();
144 }
145 ToolServerRequestMessageKind::AppendToolset(tools) => {
146 self.toolset.write().await.add_tools(tools);
147 callback_channel
148 .send(ToolServerResponse::ToolAdded)
149 .unwrap();
150 }
151 ToolServerRequestMessageKind::RemoveTool { tool_name } => {
152 self.static_tool_names.retain(|x| *x != tool_name);
153 self.toolset.write().await.delete_tool(&tool_name);
154 callback_channel
155 .send(ToolServerResponse::ToolDeleted)
156 .unwrap();
157 }
158 ToolServerRequestMessageKind::CallTool { name, args } => {
159 let toolset = Arc::clone(&self.toolset);
160
161 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
162 tokio::spawn(async move {
163 match toolset.read().await.call(&name, args.clone()).await {
164 Ok(result) => {
165 let _ =
166 callback_channel.send(ToolServerResponse::ToolExecuted { result });
167 }
168 Err(err) => {
169 let _ = callback_channel.send(ToolServerResponse::ToolError {
170 error: err.to_string(),
171 });
172 }
173 }
174 });
175
176 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
177 wasm_bindgen_futures::spawn_local(async move {
178 match toolset.read().await.call(&name, args.clone()).await {
179 Ok(result) => {
180 let _ =
181 callback_channel.send(ToolServerResponse::ToolExecuted { result });
182 }
183 Err(err) => {
184 let _ = callback_channel.send(ToolServerResponse::ToolError {
185 error: err.to_string(),
186 });
187 }
188 }
189 });
190 }
191 ToolServerRequestMessageKind::GetToolDefs { prompt } => {
192 let res = self.get_tool_definitions(prompt).await.unwrap();
193 callback_channel
194 .send(ToolServerResponse::ToolDefinitions(res))
195 .unwrap();
196 }
197 }
198 }
199
200 pub async fn get_tool_definitions(
201 &mut self,
202 text: Option<String>,
203 ) -> Result<Vec<ToolDefinition>, CompletionError> {
204 let static_tool_names = self.static_tool_names.clone();
205 let toolset = self.toolset.read().await;
206
207 let mut tools = if let Some(text) = text {
208 let dynamic_tool_ids: Vec<String> = stream::iter(self.dynamic_tools.iter())
210 .then(|(num_sample, index)| async {
211 let req = VectorSearchRequest::builder()
212 .query(text.clone())
213 .samples(*num_sample as u64)
214 .build()
215 .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
216 Ok::<_, VectorStoreError>(
217 index
218 .as_ref()
219 .top_n_ids(req.map_filter(Filter::interpret))
220 .await?
221 .into_iter()
222 .map(|(_, id)| id)
223 .collect::<Vec<String>>(),
224 )
225 })
226 .try_fold(vec![], |mut acc, docs| async {
227 acc.extend(docs);
228 Ok(acc)
229 })
230 .await
231 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
232
233 let mut tools = Vec::new();
235 for doc in dynamic_tool_ids {
236 if let Some(tool) = toolset.get(&doc) {
237 tools.push(tool.definition(text.clone()).await)
238 } else {
239 tracing::warn!("Tool implementation not found in toolset: {}", doc);
240 }
241 }
242 tools
243 } else {
244 Vec::new()
245 };
246
247 for toolname in static_tool_names {
248 if let Some(tool) = toolset.get(&toolname) {
249 tools.push(tool.definition(String::new()).await)
250 } else {
251 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
252 }
253 }
254
255 Ok(tools)
256 }
257}
258
259#[derive(Clone)]
260pub struct ToolServerHandle(Sender<ToolServerRequest>);
261
262impl ToolServerHandle {
263 pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
264 let tool = Box::new(tool);
265
266 let (tx, rx) = futures::channel::oneshot::channel();
267
268 self.0
269 .send(ToolServerRequest {
270 callback_channel: tx,
271 data: ToolServerRequestMessageKind::AddTool(tool),
272 })
273 .await?;
274
275 let res = rx.await?;
276
277 let ToolServerResponse::ToolAdded = res else {
278 return Err(ToolServerError::InvalidMessage(res));
279 };
280
281 Ok(())
282 }
283
284 pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
285 let (tx, rx) = futures::channel::oneshot::channel();
286
287 self.0
288 .send(ToolServerRequest {
289 callback_channel: tx,
290 data: ToolServerRequestMessageKind::AppendToolset(toolset),
291 })
292 .await?;
293
294 let res = rx.await?;
295
296 let ToolServerResponse::ToolAdded = res else {
297 return Err(ToolServerError::InvalidMessage(res));
298 };
299
300 Ok(())
301 }
302
303 pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
304 let (tx, rx) = futures::channel::oneshot::channel();
305
306 self.0
307 .send(ToolServerRequest {
308 callback_channel: tx,
309 data: ToolServerRequestMessageKind::RemoveTool {
310 tool_name: tool_name.to_string(),
311 },
312 })
313 .await?;
314
315 let res = rx.await?;
316
317 let ToolServerResponse::ToolDeleted = res else {
318 return Err(ToolServerError::InvalidMessage(res));
319 };
320
321 Ok(())
322 }
323
324 pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
325 let (tx, rx) = futures::channel::oneshot::channel();
326
327 self.0
328 .send(ToolServerRequest {
329 callback_channel: tx,
330 data: ToolServerRequestMessageKind::CallTool {
331 name: tool_name.to_string(),
332 args: args.to_string(),
333 },
334 })
335 .await?;
336
337 let res = rx.await?;
338
339 match res {
340 ToolServerResponse::ToolExecuted { result, .. } => Ok(result),
341 ToolServerResponse::ToolError { error } => Err(ToolServerError::ToolsetError(
342 ToolSetError::ToolCallError(ToolError::ToolCallError(error.into())),
343 )),
344 invalid => Err(ToolServerError::InvalidMessage(invalid)),
345 }
346 }
347
348 pub async fn get_tool_defs(
349 &self,
350 prompt: Option<String>,
351 ) -> Result<Vec<ToolDefinition>, ToolServerError> {
352 let (tx, rx) = futures::channel::oneshot::channel();
353
354 self.0
355 .send(ToolServerRequest {
356 callback_channel: tx,
357 data: ToolServerRequestMessageKind::GetToolDefs { prompt },
358 })
359 .await?;
360
361 let res = rx.await?;
362
363 let ToolServerResponse::ToolDefinitions(tooldefs) = res else {
364 return Err(ToolServerError::InvalidMessage(res));
365 };
366
367 Ok(tooldefs)
368 }
369}
370
371pub struct ToolServerRequest {
372 callback_channel: futures::channel::oneshot::Sender<ToolServerResponse>,
373 data: ToolServerRequestMessageKind,
374}
375
376pub enum ToolServerRequestMessageKind {
377 AddTool(Box<dyn ToolDyn>),
378 AppendToolset(ToolSet),
379 RemoveTool { tool_name: String },
380 CallTool { name: String, args: String },
381 GetToolDefs { prompt: Option<String> },
382}
383
384#[derive(PartialEq, Debug)]
385pub enum ToolServerResponse {
386 ToolAdded,
387 ToolDeleted,
388 ToolExecuted { result: String },
389 ToolError { error: String },
390 ToolDefinitions(Vec<ToolDefinition>),
391}
392
393#[derive(Debug, thiserror::Error)]
394pub enum ToolServerError {
395 #[error("Sending message was cancelled")]
396 Canceled(#[from] Canceled),
397 #[error("Toolset error: {0}")]
398 ToolsetError(#[from] ToolSetError),
399 #[error("Error while sending message: {0}")]
400 SendError(#[from] SendError<ToolServerRequest>),
401 #[error("An invalid message type was returned")]
402 InvalidMessage(ToolServerResponse),
403}
404
405#[cfg(test)]
406mod tests {
407 use std::time::Duration;
408
409 use serde::{Deserialize, Serialize};
410 use serde_json::json;
411
412 use crate::{
413 completion::ToolDefinition,
414 tool::{Tool, ToolSet, server::ToolServer},
415 vector_store::{
416 VectorStoreError, VectorStoreIndex,
417 request::{Filter, VectorSearchRequest},
418 },
419 wasm_compat::WasmCompatSend,
420 };
421
422 #[derive(Deserialize)]
423 struct OperationArgs {
424 x: i32,
425 y: i32,
426 }
427
428 #[derive(Debug, thiserror::Error)]
429 #[error("Math error")]
430 struct MathError;
431
432 #[derive(Deserialize, Serialize)]
433 struct Adder;
434 impl Tool for Adder {
435 const NAME: &'static str = "add";
436 type Error = MathError;
437 type Args = OperationArgs;
438 type Output = i32;
439
440 async fn definition(&self, _prompt: String) -> ToolDefinition {
441 ToolDefinition {
442 name: "add".to_string(),
443 description: "Add x and y together".to_string(),
444 parameters: json!({
445 "type": "object",
446 "properties": {
447 "x": {
448 "type": "number",
449 "description": "The first number to add"
450 },
451 "y": {
452 "type": "number",
453 "description": "The second number to add"
454 }
455 },
456 "required": ["x", "y"],
457 }),
458 }
459 }
460
461 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
462 println!("[tool-call] Adding {} and {}", args.x, args.y);
463 let result = args.x + args.y;
464 Ok(result)
465 }
466 }
467
468 #[derive(Deserialize, Serialize)]
469 struct Subtractor;
470 impl Tool for Subtractor {
471 const NAME: &'static str = "subtract";
472 type Error = MathError;
473 type Args = OperationArgs;
474 type Output = i32;
475
476 async fn definition(&self, _prompt: String) -> ToolDefinition {
477 ToolDefinition {
478 name: "subtract".to_string(),
479 description: "Subtract y from x".to_string(),
480 parameters: json!({
481 "type": "object",
482 "properties": {
483 "x": {
484 "type": "number",
485 "description": "The number to subtract from"
486 },
487 "y": {
488 "type": "number",
489 "description": "The number to subtract"
490 }
491 },
492 "required": ["x", "y"],
493 }),
494 }
495 }
496
497 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
498 let result = args.x - args.y;
499 Ok(result)
500 }
501 }
502
503 struct MockToolIndex {
505 tool_ids: Vec<String>,
506 }
507
508 impl VectorStoreIndex for MockToolIndex {
509 type Filter = Filter<serde_json::Value>;
510
511 async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
512 &self,
513 _req: VectorSearchRequest,
514 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
515 Ok(vec![])
517 }
518
519 async fn top_n_ids(
520 &self,
521 _req: VectorSearchRequest,
522 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
523 Ok(self
524 .tool_ids
525 .iter()
526 .enumerate()
527 .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
528 .collect())
529 }
530 }
531
532 #[tokio::test]
533 pub async fn test_toolserver() {
534 let server = ToolServer::new();
535
536 let handle = server.run();
537
538 handle.add_tool(Adder).await.unwrap();
539 let res = handle.get_tool_defs(None).await.unwrap();
540
541 assert_eq!(res.len(), 1);
542
543 let json_args_as_string =
544 serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
545 let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
546 assert_eq!(res, "7");
547
548 handle.remove_tool("add").await.unwrap();
549 let res = handle.get_tool_defs(None).await.unwrap();
550
551 assert_eq!(res.len(), 0);
552 }
553
554 #[tokio::test]
555 pub async fn test_toolserver_dynamic_tools() {
556 let mut toolset = ToolSet::default();
558 toolset.add_tool(Adder);
559 toolset.add_tool(Subtractor);
560
561 let mock_index = MockToolIndex {
563 tool_ids: vec!["subtract".to_string()],
564 };
565
566 let server = ToolServer::new().tool(Adder).dynamic_tools(
568 1,
569 mock_index,
570 ToolSet::from_tools(vec![Subtractor]),
571 );
572
573 let handle = server.run();
574
575 let res = handle.get_tool_defs(None).await.unwrap();
577 assert_eq!(res.len(), 1);
578 assert_eq!(res[0].name, "add");
579
580 let res = handle
582 .get_tool_defs(Some("calculate difference".to_string()))
583 .await
584 .unwrap();
585 assert_eq!(res.len(), 2);
586
587 let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
589 assert!(tool_names.contains(&"add"));
590 assert!(tool_names.contains(&"subtract"));
591 }
592
593 #[tokio::test]
594 pub async fn test_toolserver_dynamic_tools_missing_implementation() {
595 let mock_index = MockToolIndex {
597 tool_ids: vec!["nonexistent_tool".to_string()],
598 };
599
600 let server = ToolServer::new()
602 .tool(Adder)
603 .dynamic_tools(1, mock_index, ToolSet::default());
604
605 let handle = server.run();
606
607 let res = handle
609 .get_tool_defs(Some("some query".to_string()))
610 .await
611 .unwrap();
612 assert_eq!(res.len(), 1);
613 assert_eq!(res[0].name, "add");
614 }
615
616 #[derive(Debug, thiserror::Error)]
617 #[error("Sleeper error")]
618 struct SleeperError;
619
620 #[derive(Deserialize, Serialize, Clone)]
622 struct SleeperTool {
623 sleep_duration_ms: u64,
624 }
625
626 impl SleeperTool {
627 fn new(sleep_duration_ms: u64) -> Self {
628 Self { sleep_duration_ms }
629 }
630 }
631
632 impl Tool for SleeperTool {
633 const NAME: &'static str = "sleeper";
634 type Error = SleeperError;
635 type Args = serde_json::Value;
636 type Output = u64;
637
638 async fn definition(&self, _prompt: String) -> ToolDefinition {
639 ToolDefinition {
640 name: "sleeper".to_string(),
641 description: "Sleeps for configured duration".to_string(),
642 parameters: json!({"type": "object", "properties": {}}),
643 }
644 }
645
646 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
647 tokio::time::sleep(Duration::from_millis(self.sleep_duration_ms)).await;
648 Ok(self.sleep_duration_ms)
649 }
650 }
651
652 #[tokio::test]
653 pub async fn test_toolserver_concurrent_tool_execution() {
654 let sleep_ms: u64 = 100;
655 let num_calls: u64 = 3;
656
657 let server = ToolServer::new().tool(SleeperTool::new(sleep_ms));
658 let handle = server.run();
659
660 let start = std::time::Instant::now();
661
662 let futures: Vec<_> = (0..num_calls)
664 .map(|_| handle.call_tool("sleeper", "{}"))
665 .collect();
666 let results = futures::future::join_all(futures).await;
667
668 let elapsed = start.elapsed();
669
670 for result in &results {
672 assert!(result.is_ok(), "Tool call failed: {:?}", result);
673 }
674
675 let max_concurrent_time = Duration::from_millis(sleep_ms * 2);
679 assert!(
680 elapsed < max_concurrent_time,
681 "Expected concurrent execution in < {:?}, but took {:?}. Tools may be running sequentially.",
682 max_concurrent_time,
683 elapsed
684 );
685 }
686}