1use std::sync::Arc;
2
3use futures::{StreamExt, TryStreamExt, channel::oneshot::Canceled, stream};
4use tokio::sync::{
5 RwLock,
6 mpsc::{Sender, error::SendError},
7};
8use tracing::Instrument;
9
10use crate::{
11 completion::{CompletionError, ToolDefinition},
12 tool::{Tool, ToolDyn, ToolError, ToolSet, ToolSetError},
13 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
14};
15
16pub struct ToolServer {
17 static_tool_names: Vec<String>,
20 dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
22 toolset: Arc<RwLock<ToolSet>>,
25}
26
27impl Default for ToolServer {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl ToolServer {
34 pub fn new() -> Self {
35 Self {
36 static_tool_names: Vec::new(),
37 dynamic_tools: Vec::new(),
38 toolset: Arc::new(RwLock::new(ToolSet::default())),
39 }
40 }
41
42 pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
43 self.static_tool_names = names;
44 self
45 }
46
47 pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
48 self.toolset = Arc::new(RwLock::new(tools));
49 self
50 }
51
52 pub(crate) fn add_dynamic_tools(
53 mut self,
54 dyn_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
55 ) -> Self {
56 self.dynamic_tools = dyn_tools;
57 self
58 }
59
60 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
62 let toolname = tool.name();
63 Arc::get_mut(&mut self.toolset)
67 .expect("ToolServer::tool() called after run()")
68 .get_mut()
69 .add_tool(tool);
70 self.static_tool_names.push(toolname);
71 self
72 }
73
74 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
76 #[cfg(feature = "rmcp")]
77 pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
78 use crate::tool::rmcp::McpTool;
79 let toolname = tool.name.clone();
80 Arc::get_mut(&mut self.toolset)
84 .expect("ToolServer::rmcp_tool() called after run()")
85 .get_mut()
86 .add_tool(McpTool::from_mcp_server(tool, client));
87 self.static_tool_names.push(toolname.to_string());
88 self
89 }
90
91 pub fn dynamic_tools(
94 mut self,
95 sample: usize,
96 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
97 toolset: ToolSet,
98 ) -> Self {
99 self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
100 Arc::get_mut(&mut self.toolset)
104 .expect("ToolServer::dynamic_tools() called after run()")
105 .get_mut()
106 .add_tools(toolset);
107 self
108 }
109
110 pub fn run(mut self) -> ToolServerHandle {
111 let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
112
113 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
114 tokio::spawn(async move {
115 while let Some(message) = rx.recv().await {
116 self.handle_message(message).await;
117 }
118 });
119
120 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
123 wasm_bindgen_futures::spawn_local(async move {
124 while let Some(message) = rx.recv().await {
125 self.handle_message(message).await;
126 }
127 });
128
129 ToolServerHandle(tx)
130 }
131
132 pub async fn handle_message(&mut self, message: ToolServerRequest) {
133 let ToolServerRequest {
134 callback_channel,
135 data,
136 } = message;
137
138 match data {
139 ToolServerRequestMessageKind::AddTool(tool) => {
140 self.static_tool_names.push(tool.name());
141 self.toolset.write().await.add_tool_boxed(tool);
142 callback_channel
143 .send(ToolServerResponse::ToolAdded)
144 .unwrap();
145 }
146 ToolServerRequestMessageKind::AppendToolset(tools) => {
147 self.toolset.write().await.add_tools(tools);
148 callback_channel
149 .send(ToolServerResponse::ToolAdded)
150 .unwrap();
151 }
152 ToolServerRequestMessageKind::RemoveTool { tool_name } => {
153 self.static_tool_names.retain(|x| *x != tool_name);
154 self.toolset.write().await.delete_tool(&tool_name);
155 callback_channel
156 .send(ToolServerResponse::ToolDeleted)
157 .unwrap();
158 }
159 ToolServerRequestMessageKind::CallTool { name, args, span } => {
160 let toolset = Arc::clone(&self.toolset);
161
162 #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
163 tokio::spawn(
164 async move {
165 match toolset.read().await.call(&name, args.clone()).await {
166 Ok(result) => {
167 let _ = callback_channel
168 .send(ToolServerResponse::ToolExecuted { result });
169 }
170 Err(err) => {
171 let _ = callback_channel.send(ToolServerResponse::ToolError {
172 error: err.to_string(),
173 });
174 }
175 }
176 }
177 .instrument(span),
178 );
179
180 #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
181 wasm_bindgen_futures::spawn_local(
182 async move {
183 match toolset.read().await.call(&name, args.clone()).await {
184 Ok(result) => {
185 let _ = callback_channel
186 .send(ToolServerResponse::ToolExecuted { result });
187 }
188 Err(err) => {
189 let _ = callback_channel.send(ToolServerResponse::ToolError {
190 error: err.to_string(),
191 });
192 }
193 }
194 }
195 .instrument(span),
196 );
197 }
198 ToolServerRequestMessageKind::GetToolDefs { prompt } => {
199 let res = self.get_tool_definitions(prompt).await.unwrap();
200 callback_channel
201 .send(ToolServerResponse::ToolDefinitions(res))
202 .unwrap();
203 }
204 }
205 }
206
207 pub async fn get_tool_definitions(
208 &mut self,
209 text: Option<String>,
210 ) -> Result<Vec<ToolDefinition>, CompletionError> {
211 let static_tool_names = self.static_tool_names.clone();
212 let toolset = self.toolset.read().await;
213
214 let mut tools = if let Some(text) = text {
215 let dynamic_tool_ids: Vec<String> = stream::iter(self.dynamic_tools.iter())
217 .then(|(num_sample, index)| async {
218 let req = VectorSearchRequest::builder()
219 .query(text.clone())
220 .samples(*num_sample as u64)
221 .build()
222 .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
223 Ok::<_, VectorStoreError>(
224 index
225 .as_ref()
226 .top_n_ids(req.map_filter(Filter::interpret))
227 .await?
228 .into_iter()
229 .map(|(_, id)| id)
230 .collect::<Vec<String>>(),
231 )
232 })
233 .try_fold(vec![], |mut acc, docs| async {
234 acc.extend(docs);
235 Ok(acc)
236 })
237 .await
238 .map_err(|e| CompletionError::RequestError(Box::new(e)))?;
239
240 let mut tools = Vec::new();
242 for doc in dynamic_tool_ids {
243 if let Some(tool) = toolset.get(&doc) {
244 tools.push(tool.definition(text.clone()).await)
245 } else {
246 tracing::warn!("Tool implementation not found in toolset: {}", doc);
247 }
248 }
249 tools
250 } else {
251 Vec::new()
252 };
253
254 for toolname in static_tool_names {
255 if let Some(tool) = toolset.get(&toolname) {
256 tools.push(tool.definition(String::new()).await)
257 } else {
258 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
259 }
260 }
261
262 Ok(tools)
263 }
264}
265
266#[derive(Clone)]
267pub struct ToolServerHandle(Sender<ToolServerRequest>);
268
269impl ToolServerHandle {
270 pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
271 let tool = Box::new(tool);
272
273 let (tx, rx) = futures::channel::oneshot::channel();
274
275 self.0
276 .send(ToolServerRequest {
277 callback_channel: tx,
278 data: ToolServerRequestMessageKind::AddTool(tool),
279 })
280 .await?;
281
282 let res = rx.await?;
283
284 let ToolServerResponse::ToolAdded = res else {
285 return Err(ToolServerError::InvalidMessage(res));
286 };
287
288 Ok(())
289 }
290
291 pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
292 let (tx, rx) = futures::channel::oneshot::channel();
293
294 self.0
295 .send(ToolServerRequest {
296 callback_channel: tx,
297 data: ToolServerRequestMessageKind::AppendToolset(toolset),
298 })
299 .await?;
300
301 let res = rx.await?;
302
303 let ToolServerResponse::ToolAdded = res else {
304 return Err(ToolServerError::InvalidMessage(res));
305 };
306
307 Ok(())
308 }
309
310 pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
311 let (tx, rx) = futures::channel::oneshot::channel();
312
313 self.0
314 .send(ToolServerRequest {
315 callback_channel: tx,
316 data: ToolServerRequestMessageKind::RemoveTool {
317 tool_name: tool_name.to_string(),
318 },
319 })
320 .await?;
321
322 let res = rx.await?;
323
324 let ToolServerResponse::ToolDeleted = res else {
325 return Err(ToolServerError::InvalidMessage(res));
326 };
327
328 Ok(())
329 }
330
331 pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
332 let (tx, rx) = futures::channel::oneshot::channel();
333
334 self.0
335 .send(ToolServerRequest {
336 callback_channel: tx,
337 data: ToolServerRequestMessageKind::CallTool {
338 name: tool_name.to_string(),
339 args: args.to_string(),
340 span: tracing::Span::current(),
341 },
342 })
343 .await?;
344
345 let res = rx.await?;
346
347 match res {
348 ToolServerResponse::ToolExecuted { result, .. } => Ok(result),
349 ToolServerResponse::ToolError { error } => Err(ToolServerError::ToolsetError(
350 ToolSetError::ToolCallError(ToolError::ToolCallError(error.into())),
351 )),
352 invalid => Err(ToolServerError::InvalidMessage(invalid)),
353 }
354 }
355
356 pub async fn get_tool_defs(
357 &self,
358 prompt: Option<String>,
359 ) -> Result<Vec<ToolDefinition>, ToolServerError> {
360 let (tx, rx) = futures::channel::oneshot::channel();
361
362 self.0
363 .send(ToolServerRequest {
364 callback_channel: tx,
365 data: ToolServerRequestMessageKind::GetToolDefs { prompt },
366 })
367 .await?;
368
369 let res = rx.await?;
370
371 let ToolServerResponse::ToolDefinitions(tooldefs) = res else {
372 return Err(ToolServerError::InvalidMessage(res));
373 };
374
375 Ok(tooldefs)
376 }
377}
378
379pub struct ToolServerRequest {
380 callback_channel: futures::channel::oneshot::Sender<ToolServerResponse>,
381 data: ToolServerRequestMessageKind,
382}
383
384pub enum ToolServerRequestMessageKind {
385 AddTool(Box<dyn ToolDyn>),
386 AppendToolset(ToolSet),
387 RemoveTool {
388 tool_name: String,
389 },
390 CallTool {
391 name: String,
392 args: String,
393 span: tracing::Span,
394 },
395 GetToolDefs {
396 prompt: Option<String>,
397 },
398}
399
400#[derive(PartialEq, Debug)]
401pub enum ToolServerResponse {
402 ToolAdded,
403 ToolDeleted,
404 ToolExecuted { result: String },
405 ToolError { error: String },
406 ToolDefinitions(Vec<ToolDefinition>),
407}
408
409#[derive(Debug, thiserror::Error)]
410pub enum ToolServerError {
411 #[error("Sending message was cancelled")]
412 Canceled(#[from] Canceled),
413 #[error("Toolset error: {0}")]
414 ToolsetError(#[from] ToolSetError),
415 #[error("Error while sending message: {0}")]
416 SendError(#[from] SendError<ToolServerRequest>),
417 #[error("An invalid message type was returned")]
418 InvalidMessage(ToolServerResponse),
419}
420
421#[cfg(test)]
422mod tests {
423 use std::time::Duration;
424
425 use serde::{Deserialize, Serialize};
426 use serde_json::json;
427
428 use crate::{
429 completion::ToolDefinition,
430 tool::{Tool, ToolSet, server::ToolServer},
431 vector_store::{
432 VectorStoreError, VectorStoreIndex,
433 request::{Filter, VectorSearchRequest},
434 },
435 wasm_compat::WasmCompatSend,
436 };
437
438 #[derive(Deserialize)]
439 struct OperationArgs {
440 x: i32,
441 y: i32,
442 }
443
444 #[derive(Debug, thiserror::Error)]
445 #[error("Math error")]
446 struct MathError;
447
448 #[derive(Deserialize, Serialize)]
449 struct Adder;
450 impl Tool for Adder {
451 const NAME: &'static str = "add";
452 type Error = MathError;
453 type Args = OperationArgs;
454 type Output = i32;
455
456 async fn definition(&self, _prompt: String) -> ToolDefinition {
457 ToolDefinition {
458 name: "add".to_string(),
459 description: "Add x and y together".to_string(),
460 parameters: json!({
461 "type": "object",
462 "properties": {
463 "x": {
464 "type": "number",
465 "description": "The first number to add"
466 },
467 "y": {
468 "type": "number",
469 "description": "The second number to add"
470 }
471 },
472 "required": ["x", "y"],
473 }),
474 }
475 }
476
477 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
478 println!("[tool-call] Adding {} and {}", args.x, args.y);
479 let result = args.x + args.y;
480 Ok(result)
481 }
482 }
483
484 #[derive(Deserialize, Serialize)]
485 struct Subtractor;
486 impl Tool for Subtractor {
487 const NAME: &'static str = "subtract";
488 type Error = MathError;
489 type Args = OperationArgs;
490 type Output = i32;
491
492 async fn definition(&self, _prompt: String) -> ToolDefinition {
493 ToolDefinition {
494 name: "subtract".to_string(),
495 description: "Subtract y from x".to_string(),
496 parameters: json!({
497 "type": "object",
498 "properties": {
499 "x": {
500 "type": "number",
501 "description": "The number to subtract from"
502 },
503 "y": {
504 "type": "number",
505 "description": "The number to subtract"
506 }
507 },
508 "required": ["x", "y"],
509 }),
510 }
511 }
512
513 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
514 let result = args.x - args.y;
515 Ok(result)
516 }
517 }
518
519 struct MockToolIndex {
521 tool_ids: Vec<String>,
522 }
523
524 impl VectorStoreIndex for MockToolIndex {
525 type Filter = Filter<serde_json::Value>;
526
527 async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
528 &self,
529 _req: VectorSearchRequest,
530 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
531 Ok(vec![])
533 }
534
535 async fn top_n_ids(
536 &self,
537 _req: VectorSearchRequest,
538 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
539 Ok(self
540 .tool_ids
541 .iter()
542 .enumerate()
543 .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
544 .collect())
545 }
546 }
547
548 #[tokio::test]
549 pub async fn test_toolserver() {
550 let server = ToolServer::new();
551
552 let handle = server.run();
553
554 handle.add_tool(Adder).await.unwrap();
555 let res = handle.get_tool_defs(None).await.unwrap();
556
557 assert_eq!(res.len(), 1);
558
559 let json_args_as_string =
560 serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
561 let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
562 assert_eq!(res, "7");
563
564 handle.remove_tool("add").await.unwrap();
565 let res = handle.get_tool_defs(None).await.unwrap();
566
567 assert_eq!(res.len(), 0);
568 }
569
570 #[tokio::test]
571 pub async fn test_toolserver_dynamic_tools() {
572 let mut toolset = ToolSet::default();
574 toolset.add_tool(Adder);
575 toolset.add_tool(Subtractor);
576
577 let mock_index = MockToolIndex {
579 tool_ids: vec!["subtract".to_string()],
580 };
581
582 let server = ToolServer::new().tool(Adder).dynamic_tools(
584 1,
585 mock_index,
586 ToolSet::from_tools(vec![Subtractor]),
587 );
588
589 let handle = server.run();
590
591 let res = handle.get_tool_defs(None).await.unwrap();
593 assert_eq!(res.len(), 1);
594 assert_eq!(res[0].name, "add");
595
596 let res = handle
598 .get_tool_defs(Some("calculate difference".to_string()))
599 .await
600 .unwrap();
601 assert_eq!(res.len(), 2);
602
603 let tool_names: Vec<&str> = res.iter().map(|t| t.name.as_str()).collect();
605 assert!(tool_names.contains(&"add"));
606 assert!(tool_names.contains(&"subtract"));
607 }
608
609 #[tokio::test]
610 pub async fn test_toolserver_dynamic_tools_missing_implementation() {
611 let mock_index = MockToolIndex {
613 tool_ids: vec!["nonexistent_tool".to_string()],
614 };
615
616 let server = ToolServer::new()
618 .tool(Adder)
619 .dynamic_tools(1, mock_index, ToolSet::default());
620
621 let handle = server.run();
622
623 let res = handle
625 .get_tool_defs(Some("some query".to_string()))
626 .await
627 .unwrap();
628 assert_eq!(res.len(), 1);
629 assert_eq!(res[0].name, "add");
630 }
631
632 #[derive(Debug, thiserror::Error)]
633 #[error("Sleeper error")]
634 struct SleeperError;
635
636 #[derive(Deserialize, Serialize, Clone)]
638 struct SleeperTool {
639 sleep_duration_ms: u64,
640 }
641
642 impl SleeperTool {
643 fn new(sleep_duration_ms: u64) -> Self {
644 Self { sleep_duration_ms }
645 }
646 }
647
648 impl Tool for SleeperTool {
649 const NAME: &'static str = "sleeper";
650 type Error = SleeperError;
651 type Args = serde_json::Value;
652 type Output = u64;
653
654 async fn definition(&self, _prompt: String) -> ToolDefinition {
655 ToolDefinition {
656 name: "sleeper".to_string(),
657 description: "Sleeps for configured duration".to_string(),
658 parameters: json!({"type": "object", "properties": {}}),
659 }
660 }
661
662 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
663 tokio::time::sleep(Duration::from_millis(self.sleep_duration_ms)).await;
664 Ok(self.sleep_duration_ms)
665 }
666 }
667
668 #[tokio::test]
669 pub async fn test_toolserver_concurrent_tool_execution() {
670 let sleep_ms: u64 = 100;
671 let num_calls: u64 = 3;
672
673 let server = ToolServer::new().tool(SleeperTool::new(sleep_ms));
674 let handle = server.run();
675
676 let start = std::time::Instant::now();
677
678 let futures: Vec<_> = (0..num_calls)
680 .map(|_| handle.call_tool("sleeper", "{}"))
681 .collect();
682 let results = futures::future::join_all(futures).await;
683
684 let elapsed = start.elapsed();
685
686 for result in &results {
688 assert!(result.is_ok(), "Tool call failed: {:?}", result);
689 }
690
691 let max_concurrent_time = Duration::from_millis(sleep_ms * 2);
695 assert!(
696 elapsed < max_concurrent_time,
697 "Expected concurrent execution in < {:?}, but took {:?}. Tools may be running sequentially.",
698 max_concurrent_time,
699 elapsed
700 );
701 }
702}