1use futures::{StreamExt, TryStreamExt, channel::oneshot::Canceled, stream};
2use tokio::sync::mpsc::{Sender, error::SendError};
3
4use crate::{
5 completion::{CompletionError, ToolDefinition},
6 tool::{Tool, ToolDyn, ToolError, ToolSet, ToolSetError},
7 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn, request::Filter},
8};
9
10pub struct ToolServer {
11 static_tool_names: Vec<String>,
14 dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
16 toolset: ToolSet,
18}
19
20impl Default for ToolServer {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl ToolServer {
27 pub fn new() -> Self {
28 Self {
29 static_tool_names: Vec::new(),
30 dynamic_tools: Vec::new(),
31 toolset: ToolSet::default(),
32 }
33 }
34
35 pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
36 self.static_tool_names = names;
37 self
38 }
39
40 pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
41 self.toolset = tools;
42 self
43 }
44
45 pub(crate) fn add_dynamic_tools(
46 mut self,
47 dyn_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
48 ) -> Self {
49 self.dynamic_tools = dyn_tools;
50 self
51 }
52
53 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
55 let toolname = tool.name();
56 self.toolset.add_tool(tool);
57 self.static_tool_names.push(toolname);
58 self
59 }
60
61 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
63 #[cfg(feature = "rmcp")]
64 pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
65 use crate::tool::rmcp::McpTool;
66 let toolname = tool.name.clone();
67 self.toolset
68 .add_tool(McpTool::from_mcp_server(tool, client));
69 self.static_tool_names.push(toolname.to_string());
70 self
71 }
72
73 pub fn dynamic_tools(
76 mut self,
77 sample: usize,
78 dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
79 toolset: ToolSet,
80 ) -> Self {
81 self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
82 self.toolset.add_tools(toolset);
83 self
84 }
85
86 pub fn run(mut self) -> ToolServerHandle {
87 let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
88
89 #[cfg(not(target_family = "wasm"))]
90 tokio::spawn(async move {
91 while let Some(message) = rx.recv().await {
92 self.handle_message(message).await;
93 }
94 });
95
96 #[cfg(all(feature = "worker", target_family = "wasm"))]
99 wasm_bindgen_futures::spawn_local(async move {
100 while let Some(message) = rx.recv().await {
101 self.handle_message(message).await;
102 }
103 });
104
105 ToolServerHandle(tx)
106 }
107
108 pub async fn handle_message(&mut self, message: ToolServerRequest) {
109 let ToolServerRequest {
110 callback_channel,
111 data,
112 } = message;
113
114 match data {
115 ToolServerRequestMessageKind::AddTool(tool) => {
116 self.static_tool_names.push(tool.name());
117 self.toolset.add_tool_boxed(tool);
118 callback_channel
119 .send(ToolServerResponse::ToolAdded)
120 .unwrap();
121 }
122 ToolServerRequestMessageKind::AppendToolset(tools) => {
123 self.toolset.add_tools(tools);
124 callback_channel
125 .send(ToolServerResponse::ToolAdded)
126 .unwrap();
127 }
128 ToolServerRequestMessageKind::RemoveTool { tool_name } => {
129 self.static_tool_names.retain(|x| *x != tool_name);
130 self.toolset.delete_tool(&tool_name);
131 callback_channel
132 .send(ToolServerResponse::ToolDeleted)
133 .unwrap();
134 }
135 ToolServerRequestMessageKind::CallTool { name, args } => {
136 match self.toolset.call(&name, args.clone()).await {
137 Ok(result) => {
138 let _ = callback_channel.send(ToolServerResponse::ToolExecuted { result });
139 }
140 Err(err) => {
141 let _ = callback_channel.send(ToolServerResponse::ToolError {
142 error: err.to_string(),
143 });
144 }
145 }
146 }
147 ToolServerRequestMessageKind::GetToolDefs { prompt } => {
148 let res = self.get_tool_definitions(prompt).await.unwrap();
149 callback_channel
150 .send(ToolServerResponse::ToolDefinitions(res))
151 .unwrap();
152 }
153 }
154 }
155
156 pub async fn get_tool_definitions(
157 &mut self,
158 text: Option<String>,
159 ) -> Result<Vec<ToolDefinition>, CompletionError> {
160 let static_tool_names = self.static_tool_names.clone();
161 let mut tools = if let Some(text) = text {
162 stream::iter(self.dynamic_tools.iter())
163 .then(|(num_sample, index)| async {
164 let req =
165 VectorSearchRequest::builder()
166 .query(text.clone())
167 .samples(*num_sample as u64)
168 .build()
169 .expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
170 Ok::<_, VectorStoreError>(
171 index.as_ref()
172 .top_n_ids(req.map_filter(Filter::interpret))
173 .await?
174 .into_iter()
175 .map(|(_, id)| id)
176 .collect::<Vec<String>>(),
177 )
178 })
179 .try_fold(vec![], |mut acc, docs| async {
180 for doc in docs {
181 if let Some(tool) = self.toolset.get(&doc) {
182 acc.push(tool.definition(text.clone()).await)
183 } else {
184 tracing::warn!("Tool implementation not found in toolset: {}", doc);
185 }
186 }
187 Ok(acc)
188 })
189 .await
190 .map_err(|e| CompletionError::RequestError(Box::new(e)))?
191 } else {
192 Vec::new()
193 };
194
195 for toolname in static_tool_names {
196 if let Some(tool) = self.toolset.get(&toolname) {
197 tools.push(tool.definition(String::new()).await)
198 } else {
199 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
200 }
201 }
202
203 Ok(tools)
204 }
205}
206
207#[derive(Clone)]
208pub struct ToolServerHandle(Sender<ToolServerRequest>);
209
210impl ToolServerHandle {
211 pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
212 let tool = Box::new(tool);
213
214 let (tx, rx) = futures::channel::oneshot::channel();
215
216 self.0
217 .send(ToolServerRequest {
218 callback_channel: tx,
219 data: ToolServerRequestMessageKind::AddTool(tool),
220 })
221 .await?;
222
223 let res = rx.await?;
224
225 let ToolServerResponse::ToolAdded = res else {
226 return Err(ToolServerError::InvalidMessage(res));
227 };
228
229 Ok(())
230 }
231
232 pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
233 let (tx, rx) = futures::channel::oneshot::channel();
234
235 self.0
236 .send(ToolServerRequest {
237 callback_channel: tx,
238 data: ToolServerRequestMessageKind::AppendToolset(toolset),
239 })
240 .await?;
241
242 let res = rx.await?;
243
244 let ToolServerResponse::ToolAdded = res else {
245 return Err(ToolServerError::InvalidMessage(res));
246 };
247
248 Ok(())
249 }
250
251 pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
252 let (tx, rx) = futures::channel::oneshot::channel();
253
254 self.0
255 .send(ToolServerRequest {
256 callback_channel: tx,
257 data: ToolServerRequestMessageKind::RemoveTool {
258 tool_name: tool_name.to_string(),
259 },
260 })
261 .await?;
262
263 let res = rx.await?;
264
265 let ToolServerResponse::ToolDeleted = res else {
266 return Err(ToolServerError::InvalidMessage(res));
267 };
268
269 Ok(())
270 }
271
272 pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
273 let (tx, rx) = futures::channel::oneshot::channel();
274
275 self.0
276 .send(ToolServerRequest {
277 callback_channel: tx,
278 data: ToolServerRequestMessageKind::CallTool {
279 name: tool_name.to_string(),
280 args: args.to_string(),
281 },
282 })
283 .await?;
284
285 let res = rx.await?;
286
287 match res {
288 ToolServerResponse::ToolExecuted { result, .. } => Ok(result),
289 ToolServerResponse::ToolError { error } => Err(ToolServerError::ToolsetError(
290 ToolSetError::ToolCallError(ToolError::ToolCallError(error.into())),
291 )),
292 invalid => Err(ToolServerError::InvalidMessage(invalid)),
293 }
294 }
295
296 pub async fn get_tool_defs(
297 &self,
298 prompt: Option<String>,
299 ) -> Result<Vec<ToolDefinition>, ToolServerError> {
300 let (tx, rx) = futures::channel::oneshot::channel();
301
302 self.0
303 .send(ToolServerRequest {
304 callback_channel: tx,
305 data: ToolServerRequestMessageKind::GetToolDefs { prompt },
306 })
307 .await?;
308
309 let res = rx.await?;
310
311 let ToolServerResponse::ToolDefinitions(tooldefs) = res else {
312 return Err(ToolServerError::InvalidMessage(res));
313 };
314
315 Ok(tooldefs)
316 }
317}
318
319pub struct ToolServerRequest {
320 callback_channel: futures::channel::oneshot::Sender<ToolServerResponse>,
321 data: ToolServerRequestMessageKind,
322}
323
324pub enum ToolServerRequestMessageKind {
325 AddTool(Box<dyn ToolDyn>),
326 AppendToolset(ToolSet),
327 RemoveTool { tool_name: String },
328 CallTool { name: String, args: String },
329 GetToolDefs { prompt: Option<String> },
330}
331
332#[derive(PartialEq, Debug)]
333pub enum ToolServerResponse {
334 ToolAdded,
335 ToolDeleted,
336 ToolExecuted { result: String },
337 ToolError { error: String },
338 ToolDefinitions(Vec<ToolDefinition>),
339}
340
341#[derive(Debug, thiserror::Error)]
342pub enum ToolServerError {
343 #[error("Sending message was cancelled")]
344 Canceled(#[from] Canceled),
345 #[error("Toolset error: {0}")]
346 ToolsetError(#[from] ToolSetError),
347 #[error("Error while sending message: {0}")]
348 SendError(#[from] SendError<ToolServerRequest>),
349 #[error("An invalid message type was returned")]
350 InvalidMessage(ToolServerResponse),
351}
352
353#[cfg(test)]
354mod tests {
355 use serde::{Deserialize, Serialize};
356 use serde_json::json;
357
358 use crate::{
359 completion::ToolDefinition,
360 tool::{Tool, server::ToolServer},
361 };
362
363 #[derive(Deserialize)]
364 struct OperationArgs {
365 x: i32,
366 y: i32,
367 }
368
369 #[derive(Debug, thiserror::Error)]
370 #[error("Math error")]
371 struct MathError;
372
373 #[derive(Deserialize, Serialize)]
374 struct Adder;
375 impl Tool for Adder {
376 const NAME: &'static str = "add";
377 type Error = MathError;
378 type Args = OperationArgs;
379 type Output = i32;
380
381 async fn definition(&self, _prompt: String) -> ToolDefinition {
382 ToolDefinition {
383 name: "add".to_string(),
384 description: "Add x and y together".to_string(),
385 parameters: json!({
386 "type": "object",
387 "properties": {
388 "x": {
389 "type": "number",
390 "description": "The first number to add"
391 },
392 "y": {
393 "type": "number",
394 "description": "The second number to add"
395 }
396 },
397 "required": ["x", "y"],
398 }),
399 }
400 }
401
402 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
403 println!("[tool-call] Adding {} and {}", args.x, args.y);
404 let result = args.x + args.y;
405 Ok(result)
406 }
407 }
408
409 #[tokio::test]
410 pub async fn test_toolserver() {
411 let server = ToolServer::new();
412
413 let handle = server.run();
414
415 handle.add_tool(Adder).await.unwrap();
416 let res = handle.get_tool_defs(None).await.unwrap();
417
418 assert_eq!(res.len(), 1);
419
420 let json_args_as_string =
421 serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
422 let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
423 assert_eq!(res, "7");
424
425 handle.remove_tool("add").await.unwrap();
426 let res = handle.get_tool_defs(None).await.unwrap();
427
428 assert_eq!(res.len(), 0);
429 }
430}