1use futures::{StreamExt, TryStreamExt, channel::oneshot::Canceled, stream};
2use tokio::sync::mpsc::{Sender, error::SendError};
3
4use crate::{
5 completion::{CompletionError, ToolDefinition},
6 tool::{Tool, ToolDyn, ToolError, ToolSet, ToolSetError},
7 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndexDyn},
8};
9
10pub struct ToolServer {
11 static_tool_names: Vec<String>,
14 dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
16 toolset: ToolSet,
18}
19
20impl Default for ToolServer {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl ToolServer {
27 pub fn new() -> Self {
28 Self {
29 static_tool_names: Vec::new(),
30 dynamic_tools: Vec::new(),
31 toolset: ToolSet::default(),
32 }
33 }
34
35 pub(crate) fn static_tool_names(mut self, names: Vec<String>) -> Self {
36 self.static_tool_names = names;
37 self
38 }
39
40 pub(crate) fn add_tools(mut self, tools: ToolSet) -> Self {
41 self.toolset = tools;
42 self
43 }
44
45 pub(crate) fn add_dynamic_tools(
46 mut self,
47 dyn_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
48 ) -> Self {
49 self.dynamic_tools = dyn_tools;
50 self
51 }
52
53 pub fn tool(mut self, tool: impl Tool + 'static) -> Self {
55 let toolname = tool.name();
56 self.toolset.add_tool(tool);
57 self.static_tool_names.push(toolname);
58 self
59 }
60
61 #[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
63 #[cfg(feature = "rmcp")]
64 pub fn rmcp_tool(mut self, tool: rmcp::model::Tool, client: rmcp::service::ServerSink) -> Self {
65 use crate::tool::rmcp::McpTool;
66 let toolname = tool.name.clone();
67 self.toolset
68 .add_tool(McpTool::from_mcp_server(tool, client));
69 self.static_tool_names.push(toolname.to_string());
70 self
71 }
72
73 pub fn dynamic_tools(
76 mut self,
77 sample: usize,
78 dynamic_tools: impl VectorStoreIndexDyn + 'static,
79 toolset: ToolSet,
80 ) -> Self {
81 self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
82 self.toolset.add_tools(toolset);
83 self
84 }
85
86 pub fn run(mut self) -> ToolServerHandle {
87 let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
88
89 #[cfg(not(target_family = "wasm"))]
90 tokio::spawn(async move {
91 while let Some(message) = rx.recv().await {
92 self.handle_message(message).await;
93 }
94 });
95
96 #[cfg(target_family = "wasm")]
97 tokio::task::spawn_local(async move {
98 while let Some(message) = rx.recv().await {
99 self.handle_message(message).await;
100 }
101 });
102
103 ToolServerHandle(tx)
104 }
105
106 pub async fn handle_message(&mut self, message: ToolServerRequest) {
107 let ToolServerRequest {
108 callback_channel,
109 data,
110 } = message;
111
112 match data {
113 ToolServerRequestMessageKind::AddTool(tool) => {
114 self.static_tool_names.push(tool.name());
115 self.toolset.add_tool_boxed(tool);
116 callback_channel
117 .send(ToolServerResponse::ToolAdded)
118 .unwrap();
119 }
120 ToolServerRequestMessageKind::AppendToolset(tools) => {
121 self.toolset.add_tools(tools);
122 callback_channel
123 .send(ToolServerResponse::ToolAdded)
124 .unwrap();
125 }
126 ToolServerRequestMessageKind::RemoveTool { tool_name } => {
127 self.static_tool_names.retain(|x| *x != tool_name);
128 self.toolset.delete_tool(&tool_name);
129 callback_channel
130 .send(ToolServerResponse::ToolDeleted)
131 .unwrap();
132 }
133 ToolServerRequestMessageKind::CallTool { name, args } => {
134 match self.toolset.call(&name, args.clone()).await {
135 Ok(result) => {
136 let _ = callback_channel.send(ToolServerResponse::ToolExecuted { result });
137 }
138 Err(err) => {
139 let _ = callback_channel.send(ToolServerResponse::ToolError {
140 error: err.to_string(),
141 });
142 }
143 }
144 }
145 ToolServerRequestMessageKind::GetToolDefs { prompt } => {
146 let res = self.get_tool_definitions(prompt).await.unwrap();
147 callback_channel
148 .send(ToolServerResponse::ToolDefinitions(res))
149 .unwrap();
150 }
151 }
152 }
153
154 pub async fn get_tool_definitions(
155 &mut self,
156 text: Option<String>,
157 ) -> Result<Vec<ToolDefinition>, CompletionError> {
158 let static_tool_names = self.static_tool_names.clone();
159 let mut tools = if let Some(text) = text {
160 stream::iter(self.dynamic_tools.iter())
161 .then(|(num_sample, index)| async {
162 let req = VectorSearchRequest::builder().query(text.clone()).samples(*num_sample as u64).build().expect("Creating VectorSearchRequest here shouldn't fail since the query and samples to return are always present");
163 Ok::<_, VectorStoreError>(
164 index
165 .top_n_ids(req)
166 .await?
167 .into_iter()
168 .map(|(_, id)| id)
169 .collect::<Vec<String>>(),
170 )
171 })
172 .try_fold(vec![], |mut acc, docs| async {
173 for doc in docs {
174 if let Some(tool) = self.toolset.get(&doc) {
175 acc.push(tool.definition(text.clone()).await)
176 } else {
177 tracing::warn!("Tool implementation not found in toolset: {}", doc);
178 }
179 }
180 Ok(acc)
181 })
182 .await
183 .map_err(|e| CompletionError::RequestError(Box::new(e)))?
184 } else {
185 Vec::new()
186 };
187
188 for toolname in static_tool_names {
189 if let Some(tool) = self.toolset.get(&toolname) {
190 tools.push(tool.definition(String::new()).await)
191 } else {
192 tracing::warn!("Tool implementation not found in toolset: {}", toolname);
193 }
194 }
195
196 Ok(tools)
197 }
198}
199
200#[derive(Clone)]
201pub struct ToolServerHandle(Sender<ToolServerRequest>);
202
203impl ToolServerHandle {
204 pub async fn add_tool(&self, tool: impl ToolDyn + 'static) -> Result<(), ToolServerError> {
205 let tool = Box::new(tool);
206
207 let (tx, rx) = futures::channel::oneshot::channel();
208
209 self.0
210 .send(ToolServerRequest {
211 callback_channel: tx,
212 data: ToolServerRequestMessageKind::AddTool(tool),
213 })
214 .await?;
215
216 let res = rx.await?;
217
218 let ToolServerResponse::ToolAdded = res else {
219 return Err(ToolServerError::InvalidMessage(res));
220 };
221
222 Ok(())
223 }
224
225 pub async fn append_toolset(&self, toolset: ToolSet) -> Result<(), ToolServerError> {
226 let (tx, rx) = futures::channel::oneshot::channel();
227
228 self.0
229 .send(ToolServerRequest {
230 callback_channel: tx,
231 data: ToolServerRequestMessageKind::AppendToolset(toolset),
232 })
233 .await?;
234
235 let res = rx.await?;
236
237 let ToolServerResponse::ToolAdded = res else {
238 return Err(ToolServerError::InvalidMessage(res));
239 };
240
241 Ok(())
242 }
243
244 pub async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolServerError> {
245 let (tx, rx) = futures::channel::oneshot::channel();
246
247 self.0
248 .send(ToolServerRequest {
249 callback_channel: tx,
250 data: ToolServerRequestMessageKind::RemoveTool {
251 tool_name: tool_name.to_string(),
252 },
253 })
254 .await?;
255
256 let res = rx.await?;
257
258 let ToolServerResponse::ToolDeleted = res else {
259 return Err(ToolServerError::InvalidMessage(res));
260 };
261
262 Ok(())
263 }
264
265 pub async fn call_tool(&self, tool_name: &str, args: &str) -> Result<String, ToolServerError> {
266 let (tx, rx) = futures::channel::oneshot::channel();
267
268 self.0
269 .send(ToolServerRequest {
270 callback_channel: tx,
271 data: ToolServerRequestMessageKind::CallTool {
272 name: tool_name.to_string(),
273 args: args.to_string(),
274 },
275 })
276 .await?;
277
278 let res = rx.await?;
279
280 match res {
281 ToolServerResponse::ToolExecuted { result, .. } => Ok(result),
282 ToolServerResponse::ToolError { error } => Err(ToolServerError::ToolsetError(
283 ToolSetError::ToolCallError(ToolError::ToolCallError(error.into())),
284 )),
285 invalid => Err(ToolServerError::InvalidMessage(invalid)),
286 }
287 }
288
289 pub async fn get_tool_defs(
290 &self,
291 prompt: Option<String>,
292 ) -> Result<Vec<ToolDefinition>, ToolServerError> {
293 let (tx, rx) = futures::channel::oneshot::channel();
294
295 self.0
296 .send(ToolServerRequest {
297 callback_channel: tx,
298 data: ToolServerRequestMessageKind::GetToolDefs { prompt },
299 })
300 .await?;
301
302 let res = rx.await?;
303
304 let ToolServerResponse::ToolDefinitions(tooldefs) = res else {
305 return Err(ToolServerError::InvalidMessage(res));
306 };
307
308 Ok(tooldefs)
309 }
310}
311
312pub struct ToolServerRequest {
313 callback_channel: futures::channel::oneshot::Sender<ToolServerResponse>,
314 data: ToolServerRequestMessageKind,
315}
316
317pub enum ToolServerRequestMessageKind {
318 AddTool(Box<dyn ToolDyn>),
319 AppendToolset(ToolSet),
320 RemoveTool { tool_name: String },
321 CallTool { name: String, args: String },
322 GetToolDefs { prompt: Option<String> },
323}
324
325#[derive(PartialEq, Debug)]
326pub enum ToolServerResponse {
327 ToolAdded,
328 ToolDeleted,
329 ToolExecuted { result: String },
330 ToolError { error: String },
331 ToolDefinitions(Vec<ToolDefinition>),
332}
333
334#[derive(Debug, thiserror::Error)]
335pub enum ToolServerError {
336 #[error("Sending message was cancelled")]
337 Canceled(#[from] Canceled),
338 #[error("Toolset error: {0}")]
339 ToolsetError(#[from] ToolSetError),
340 #[error("Error while sending message: {0}")]
341 SendError(#[from] SendError<ToolServerRequest>),
342 #[error("An invalid message type was returned")]
343 InvalidMessage(ToolServerResponse),
344}
345
346#[cfg(test)]
347mod tests {
348 use serde::{Deserialize, Serialize};
349 use serde_json::json;
350
351 use crate::{
352 completion::ToolDefinition,
353 tool::{Tool, server::ToolServer},
354 };
355
356 #[derive(Deserialize)]
357 struct OperationArgs {
358 x: i32,
359 y: i32,
360 }
361
362 #[derive(Debug, thiserror::Error)]
363 #[error("Math error")]
364 struct MathError;
365
366 #[derive(Deserialize, Serialize)]
367 struct Adder;
368 impl Tool for Adder {
369 const NAME: &'static str = "add";
370 type Error = MathError;
371 type Args = OperationArgs;
372 type Output = i32;
373
374 async fn definition(&self, _prompt: String) -> ToolDefinition {
375 ToolDefinition {
376 name: "add".to_string(),
377 description: "Add x and y together".to_string(),
378 parameters: json!({
379 "type": "object",
380 "properties": {
381 "x": {
382 "type": "number",
383 "description": "The first number to add"
384 },
385 "y": {
386 "type": "number",
387 "description": "The second number to add"
388 }
389 },
390 "required": ["x", "y"],
391 }),
392 }
393 }
394
395 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
396 println!("[tool-call] Adding {} and {}", args.x, args.y);
397 let result = args.x + args.y;
398 Ok(result)
399 }
400 }
401
402 #[tokio::test]
403 pub async fn test_toolserver() {
404 let server = ToolServer::new();
405
406 let handle = server.run();
407
408 handle.add_tool(Adder).await.unwrap();
409 let res = handle.get_tool_defs(None).await.unwrap();
410
411 assert_eq!(res.len(), 1);
412
413 let json_args_as_string =
414 serde_json::to_string(&serde_json::json!({"x": 2, "y": 5})).unwrap();
415 let res = handle.call_tool("add", &json_args_as_string).await.unwrap();
416 assert_eq!(res, "7");
417
418 handle.remove_tool("add").await.unwrap();
419 let res = handle.get_tool_defs(None).await.unwrap();
420
421 assert_eq!(res.len(), 0);
422 }
423}