1pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15
16use futures::Future;
17use serde::{Deserialize, Serialize};
18
19use crate::{
20 completion::{self, ToolDefinition},
21 embeddings::{embed::EmbedError, tool::ToolSchema},
22 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
23};
24
25#[derive(Debug, thiserror::Error)]
26pub enum ToolError {
27 #[cfg(not(target_family = "wasm"))]
28 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
30
31 #[cfg(target_family = "wasm")]
32 ToolCallError(#[from] Box<dyn std::error::Error>),
34 JsonError(#[from] serde_json::Error),
36}
37
38impl fmt::Display for ToolError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 ToolError::ToolCallError(e) => {
42 let error_str = e.to_string();
43 if error_str.starts_with("ToolCallError: ") {
46 write!(f, "{}", error_str)
47 } else {
48 write!(f, "ToolCallError: {}", error_str)
49 }
50 }
51 ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
52 }
53 }
54}
55
56pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
112 const NAME: &'static str;
114
115 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
117 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
119 type Output: Serialize;
121
122 fn name(&self) -> String {
124 Self::NAME.to_string()
125 }
126
127 fn definition(
130 &self,
131 _prompt: String,
132 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
133
134 fn call(
138 &self,
139 args: Self::Args,
140 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
141}
142
143pub trait ToolEmbedding: Tool {
145 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
146
147 type Context: for<'a> Deserialize<'a> + Serialize;
152
153 type State: WasmCompatSend;
157
158 fn embedding_docs(&self) -> Vec<String>;
162
163 fn context(&self) -> Self::Context;
165
166 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
168}
169
170pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
172 fn name(&self) -> String;
173
174 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
175
176 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
177}
178
179impl<T: Tool> ToolDyn for T {
180 fn name(&self) -> String {
181 self.name()
182 }
183
184 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
185 Box::pin(<Self as Tool>::definition(self, prompt))
186 }
187
188 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
189 Box::pin(async move {
190 match serde_json::from_str(&args) {
191 Ok(args) => <Self as Tool>::call(self, args)
192 .await
193 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
194 .and_then(|output| {
195 serde_json::to_string(&output).map_err(ToolError::JsonError)
196 }),
197 Err(e) => Err(ToolError::JsonError(e)),
198 }
199 })
200 }
201}
202
203#[cfg(feature = "rmcp")]
204#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
205pub mod rmcp;
206
207pub trait ToolEmbeddingDyn: ToolDyn {
209 fn context(&self) -> serde_json::Result<serde_json::Value>;
210
211 fn embedding_docs(&self) -> Vec<String>;
212}
213
214impl<T> ToolEmbeddingDyn for T
215where
216 T: ToolEmbedding + 'static,
217{
218 fn context(&self) -> serde_json::Result<serde_json::Value> {
219 serde_json::to_value(self.context())
220 }
221
222 fn embedding_docs(&self) -> Vec<String> {
223 self.embedding_docs()
224 }
225}
226
227pub(crate) enum ToolType {
228 Simple(Box<dyn ToolDyn>),
229 Embedding(Box<dyn ToolEmbeddingDyn>),
230}
231
232impl ToolType {
233 pub fn name(&self) -> String {
234 match self {
235 ToolType::Simple(tool) => tool.name(),
236 ToolType::Embedding(tool) => tool.name(),
237 }
238 }
239
240 pub async fn definition(&self, prompt: String) -> ToolDefinition {
241 match self {
242 ToolType::Simple(tool) => tool.definition(prompt).await,
243 ToolType::Embedding(tool) => tool.definition(prompt).await,
244 }
245 }
246
247 pub async fn call(&self, args: String) -> Result<String, ToolError> {
248 match self {
249 ToolType::Simple(tool) => tool.call(args).await,
250 ToolType::Embedding(tool) => tool.call(args).await,
251 }
252 }
253}
254
255#[derive(Debug, thiserror::Error)]
256pub enum ToolSetError {
257 #[error("ToolCallError: {0}")]
259 ToolCallError(#[from] ToolError),
260
261 #[error("ToolNotFoundError: {0}")]
263 ToolNotFoundError(String),
264
265 #[error("JsonError: {0}")]
267 JsonError(#[from] serde_json::Error),
268
269 #[error("Tool call interrupted")]
271 Interrupted,
272}
273
274#[derive(Default)]
276pub struct ToolSet {
277 pub(crate) tools: HashMap<String, ToolType>,
278}
279
280impl ToolSet {
281 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
283 let mut toolset = Self::default();
284 tools.into_iter().for_each(|tool| {
285 toolset.add_tool(tool);
286 });
287 toolset
288 }
289
290 pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
291 let mut toolset = Self::default();
292 tools.into_iter().for_each(|tool| {
293 toolset.add_tool_boxed(tool);
294 });
295 toolset
296 }
297
298 pub fn builder() -> ToolSetBuilder {
300 ToolSetBuilder::default()
301 }
302
303 pub fn contains(&self, toolname: &str) -> bool {
305 self.tools.contains_key(toolname)
306 }
307
308 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
310 self.tools
311 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
312 }
313
314 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
316 self.tools.insert(tool.name(), ToolType::Simple(tool));
317 }
318
319 pub fn delete_tool(&mut self, tool_name: &str) {
320 let _ = self.tools.remove(tool_name);
321 }
322
323 pub fn add_tools(&mut self, toolset: ToolSet) {
325 self.tools.extend(toolset.tools);
326 }
327
328 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
329 self.tools.get(toolname)
330 }
331
332 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
333 let mut defs = Vec::new();
334 for tool in self.tools.values() {
335 let def = tool.definition(String::new()).await;
336 defs.push(def);
337 }
338 Ok(defs)
339 }
340
341 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
343 if let Some(tool) = self.tools.get(toolname) {
344 tracing::debug!(target: "rig",
345 "Calling tool {toolname} with args:\n{}",
346 serde_json::to_string_pretty(&args).unwrap()
347 );
348 Ok(tool.call(args).await?)
349 } else {
350 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
351 }
352 }
353
354 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
356 let mut docs = Vec::new();
357 for tool in self.tools.values() {
358 match tool {
359 ToolType::Simple(tool) => {
360 docs.push(completion::Document {
361 id: tool.name(),
362 text: format!(
363 "\
364 Tool: {}\n\
365 Definition: \n\
366 {}\
367 ",
368 tool.name(),
369 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
370 ),
371 additional_props: HashMap::new(),
372 });
373 }
374 ToolType::Embedding(tool) => {
375 docs.push(completion::Document {
376 id: tool.name(),
377 text: format!(
378 "\
379 Tool: {}\n\
380 Definition: \n\
381 {}\
382 ",
383 tool.name(),
384 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
385 ),
386 additional_props: HashMap::new(),
387 });
388 }
389 }
390 }
391 Ok(docs)
392 }
393
394 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
398 self.tools
399 .values()
400 .filter_map(|tool_type| {
401 if let ToolType::Embedding(tool) = tool_type {
402 Some(ToolSchema::try_from(&**tool))
403 } else {
404 None
405 }
406 })
407 .collect::<Result<Vec<_>, _>>()
408 }
409}
410
411#[derive(Default)]
412pub struct ToolSetBuilder {
413 tools: Vec<ToolType>,
414}
415
416impl ToolSetBuilder {
417 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
418 self.tools.push(ToolType::Simple(Box::new(tool)));
419 self
420 }
421
422 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
423 self.tools.push(ToolType::Embedding(Box::new(tool)));
424 self
425 }
426
427 pub fn build(self) -> ToolSet {
428 ToolSet {
429 tools: self
430 .tools
431 .into_iter()
432 .map(|tool| (tool.name(), tool))
433 .collect(),
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use serde_json::json;
441
442 use super::*;
443
444 fn get_test_toolset() -> ToolSet {
445 let mut toolset = ToolSet::default();
446
447 #[derive(Deserialize)]
448 struct OperationArgs {
449 x: i32,
450 y: i32,
451 }
452
453 #[derive(Debug, thiserror::Error)]
454 #[error("Math error")]
455 struct MathError;
456
457 #[derive(Deserialize, Serialize)]
458 struct Adder;
459
460 impl Tool for Adder {
461 const NAME: &'static str = "add";
462 type Error = MathError;
463 type Args = OperationArgs;
464 type Output = i32;
465
466 async fn definition(&self, _prompt: String) -> ToolDefinition {
467 ToolDefinition {
468 name: "add".to_string(),
469 description: "Add x and y together".to_string(),
470 parameters: json!({
471 "type": "object",
472 "properties": {
473 "x": {
474 "type": "number",
475 "description": "The first number to add"
476 },
477 "y": {
478 "type": "number",
479 "description": "The second number to add"
480 }
481 },
482 "required": ["x", "y"]
483 }),
484 }
485 }
486
487 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
488 let result = args.x + args.y;
489 Ok(result)
490 }
491 }
492
493 #[derive(Deserialize, Serialize)]
494 struct Subtract;
495
496 impl Tool for Subtract {
497 const NAME: &'static str = "subtract";
498 type Error = MathError;
499 type Args = OperationArgs;
500 type Output = i32;
501
502 async fn definition(&self, _prompt: String) -> ToolDefinition {
503 serde_json::from_value(json!({
504 "name": "subtract",
505 "description": "Subtract y from x (i.e.: x - y)",
506 "parameters": {
507 "type": "object",
508 "properties": {
509 "x": {
510 "type": "number",
511 "description": "The number to subtract from"
512 },
513 "y": {
514 "type": "number",
515 "description": "The number to subtract"
516 }
517 },
518 "required": ["x", "y"]
519 }
520 }))
521 .expect("Tool Definition")
522 }
523
524 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
525 let result = args.x - args.y;
526 Ok(result)
527 }
528 }
529
530 toolset.add_tool(Adder);
531 toolset.add_tool(Subtract);
532 toolset
533 }
534
535 #[tokio::test]
536 async fn test_get_tool_definitions() {
537 let toolset = get_test_toolset();
538 let tools = toolset.get_tool_definitions().await.unwrap();
539 assert_eq!(tools.len(), 2);
540 }
541
542 #[test]
543 fn test_tool_deletion() {
544 let mut toolset = get_test_toolset();
545 assert_eq!(toolset.tools.len(), 2);
546 toolset.delete_tool("add");
547 assert!(!toolset.contains("add"));
548 assert_eq!(toolset.tools.len(), 1);
549 }
550}