1pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15
16use futures::Future;
17use serde::{Deserialize, Serialize};
18
19use crate::{
20 completion::{self, ToolDefinition},
21 embeddings::{embed::EmbedError, tool::ToolSchema},
22 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
23};
24
25#[derive(Debug, thiserror::Error)]
26pub enum ToolError {
27 #[cfg(not(target_family = "wasm"))]
28 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
30
31 #[cfg(target_family = "wasm")]
32 ToolCallError(#[from] Box<dyn std::error::Error>),
34 JsonError(#[from] serde_json::Error),
36}
37
38impl fmt::Display for ToolError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 ToolError::ToolCallError(e) => {
42 let error_str = e.to_string();
43 if error_str.starts_with("ToolCallError: ") {
46 write!(f, "{}", error_str)
47 } else {
48 write!(f, "ToolCallError: {}", error_str)
49 }
50 }
51 ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
52 }
53 }
54}
55
56pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
112 const NAME: &'static str;
114
115 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
117 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
119 type Output: Serialize;
121
122 fn name(&self) -> String {
124 Self::NAME.to_string()
125 }
126
127 fn definition(
130 &self,
131 _prompt: String,
132 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
133
134 fn call(
138 &self,
139 args: Self::Args,
140 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
141}
142
143pub trait ToolEmbedding: Tool {
145 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
146
147 type Context: for<'a> Deserialize<'a> + Serialize;
152
153 type State: WasmCompatSend;
157
158 fn embedding_docs(&self) -> Vec<String>;
162
163 fn context(&self) -> Self::Context;
165
166 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
168}
169
170pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
172 fn name(&self) -> String;
173
174 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
175
176 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
177}
178
179impl<T: Tool> ToolDyn for T {
180 fn name(&self) -> String {
181 self.name()
182 }
183
184 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
185 Box::pin(<Self as Tool>::definition(self, prompt))
186 }
187
188 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
189 Box::pin(async move {
190 match serde_json::from_str(&args) {
191 Ok(args) => <Self as Tool>::call(self, args)
192 .await
193 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
194 .and_then(|output| {
195 serde_json::to_string(&output).map_err(ToolError::JsonError)
196 }),
197 Err(e) => Err(ToolError::JsonError(e)),
198 }
199 })
200 }
201}
202
203#[cfg(feature = "rmcp")]
204#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
205pub mod rmcp {
206 use crate::completion::ToolDefinition;
207 use crate::tool::ToolDyn;
208 use crate::tool::ToolError;
209 use crate::wasm_compat::WasmBoxedFuture;
210 use rmcp::model::RawContent;
211 use std::borrow::Cow;
212
213 #[derive(Clone)]
214 pub struct McpTool {
215 definition: rmcp::model::Tool,
216 client: rmcp::service::ServerSink,
217 }
218
219 impl McpTool {
220 pub fn from_mcp_server(
221 definition: rmcp::model::Tool,
222 client: rmcp::service::ServerSink,
223 ) -> Self {
224 Self { definition, client }
225 }
226 }
227
228 impl From<&rmcp::model::Tool> for ToolDefinition {
229 fn from(val: &rmcp::model::Tool) -> Self {
230 Self {
231 name: val.name.to_string(),
232 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
233 parameters: val.schema_as_json_value(),
234 }
235 }
236 }
237
238 impl From<rmcp::model::Tool> for ToolDefinition {
239 fn from(val: rmcp::model::Tool) -> Self {
240 Self {
241 name: val.name.to_string(),
242 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
243 parameters: val.schema_as_json_value(),
244 }
245 }
246 }
247
248 #[derive(Debug, thiserror::Error)]
249 #[error("MCP tool error: {0}")]
250 pub struct McpToolError(String);
251
252 impl From<McpToolError> for ToolError {
253 fn from(e: McpToolError) -> Self {
254 ToolError::ToolCallError(Box::new(e))
255 }
256 }
257
258 impl ToolDyn for McpTool {
259 fn name(&self) -> String {
260 self.definition.name.to_string()
261 }
262
263 fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
264 Box::pin(async move {
265 ToolDefinition {
266 name: self.definition.name.to_string(),
267 description: self
268 .definition
269 .description
270 .clone()
271 .unwrap_or(Cow::from(""))
272 .to_string(),
273 parameters: serde_json::to_value(&self.definition.input_schema)
274 .unwrap_or_default(),
275 }
276 })
277 }
278
279 fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
280 let name = self.definition.name.clone();
281 let arguments = serde_json::from_str(&args).unwrap_or_default();
282
283 Box::pin(async move {
284 let result = self
285 .client
286 .call_tool(rmcp::model::CallToolRequestParam {
287 name,
288 arguments,
289 task: None,
290 })
291 .await
292 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
293
294 if let Some(true) = result.is_error {
295 let error_msg = result
296 .content
297 .into_iter()
298 .map(|x| x.raw.as_text().map(|y| y.to_owned()))
299 .map(|x| x.map(|x| x.clone().text))
300 .collect::<Option<Vec<String>>>();
301
302 let error_message = error_msg.map(|x| x.join("\n"));
303 if let Some(error_message) = error_message {
304 return Err(McpToolError(error_message).into());
305 } else {
306 return Err(McpToolError("No message returned".to_string()).into());
307 }
308 };
309
310 Ok(result
311 .content
312 .into_iter()
313 .map(|c| match c.raw {
314 rmcp::model::RawContent::Text(raw) => raw.text,
315 rmcp::model::RawContent::Image(raw) => {
316 format!("data:{};base64,{}", raw.mime_type, raw.data)
317 }
318 rmcp::model::RawContent::Resource(raw) => match raw.resource {
319 rmcp::model::ResourceContents::TextResourceContents {
320 uri,
321 mime_type,
322 text,
323 ..
324 } => {
325 format!(
326 "{mime_type}{uri}:{text}",
327 mime_type = mime_type
328 .map(|m| format!("data:{m};"))
329 .unwrap_or_default(),
330 )
331 }
332 rmcp::model::ResourceContents::BlobResourceContents {
333 uri,
334 mime_type,
335 blob,
336 ..
337 } => format!(
338 "{mime_type}{uri}:{blob}",
339 mime_type = mime_type
340 .map(|m| format!("data:{m};"))
341 .unwrap_or_default(),
342 ),
343 },
344 RawContent::Audio(_) => {
345 panic!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
346 }
347 thing => {
348 panic!("Unsupported type found: {thing:?}")
349 }
350 })
351 .collect::<String>())
352 })
353 }
354 }
355}
356
357pub trait ToolEmbeddingDyn: ToolDyn {
359 fn context(&self) -> serde_json::Result<serde_json::Value>;
360
361 fn embedding_docs(&self) -> Vec<String>;
362}
363
364impl<T> ToolEmbeddingDyn for T
365where
366 T: ToolEmbedding + 'static,
367{
368 fn context(&self) -> serde_json::Result<serde_json::Value> {
369 serde_json::to_value(self.context())
370 }
371
372 fn embedding_docs(&self) -> Vec<String> {
373 self.embedding_docs()
374 }
375}
376
377pub(crate) enum ToolType {
378 Simple(Box<dyn ToolDyn>),
379 Embedding(Box<dyn ToolEmbeddingDyn>),
380}
381
382impl ToolType {
383 pub fn name(&self) -> String {
384 match self {
385 ToolType::Simple(tool) => tool.name(),
386 ToolType::Embedding(tool) => tool.name(),
387 }
388 }
389
390 pub async fn definition(&self, prompt: String) -> ToolDefinition {
391 match self {
392 ToolType::Simple(tool) => tool.definition(prompt).await,
393 ToolType::Embedding(tool) => tool.definition(prompt).await,
394 }
395 }
396
397 pub async fn call(&self, args: String) -> Result<String, ToolError> {
398 match self {
399 ToolType::Simple(tool) => tool.call(args).await,
400 ToolType::Embedding(tool) => tool.call(args).await,
401 }
402 }
403}
404
405#[derive(Debug, thiserror::Error)]
406pub enum ToolSetError {
407 #[error("ToolCallError: {0}")]
409 ToolCallError(#[from] ToolError),
410
411 #[error("ToolNotFoundError: {0}")]
413 ToolNotFoundError(String),
414
415 #[error("JsonError: {0}")]
417 JsonError(#[from] serde_json::Error),
418
419 #[error("Tool call interrupted")]
421 Interrupted,
422}
423
424#[derive(Default)]
426pub struct ToolSet {
427 pub(crate) tools: HashMap<String, ToolType>,
428}
429
430impl ToolSet {
431 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
433 let mut toolset = Self::default();
434 tools.into_iter().for_each(|tool| {
435 toolset.add_tool(tool);
436 });
437 toolset
438 }
439
440 pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
441 let mut toolset = Self::default();
442 tools.into_iter().for_each(|tool| {
443 toolset.add_tool_boxed(tool);
444 });
445 toolset
446 }
447
448 pub fn builder() -> ToolSetBuilder {
450 ToolSetBuilder::default()
451 }
452
453 pub fn contains(&self, toolname: &str) -> bool {
455 self.tools.contains_key(toolname)
456 }
457
458 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
460 self.tools
461 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
462 }
463
464 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
466 self.tools.insert(tool.name(), ToolType::Simple(tool));
467 }
468
469 pub fn delete_tool(&mut self, tool_name: &str) {
470 let _ = self.tools.remove(tool_name);
471 }
472
473 pub fn add_tools(&mut self, toolset: ToolSet) {
475 self.tools.extend(toolset.tools);
476 }
477
478 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
479 self.tools.get(toolname)
480 }
481
482 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
483 let mut defs = Vec::new();
484 for tool in self.tools.values() {
485 let def = tool.definition(String::new()).await;
486 defs.push(def);
487 }
488 Ok(defs)
489 }
490
491 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
493 if let Some(tool) = self.tools.get(toolname) {
494 tracing::debug!(target: "rig",
495 "Calling tool {toolname} with args:\n{}",
496 serde_json::to_string_pretty(&args).unwrap()
497 );
498 Ok(tool.call(args).await?)
499 } else {
500 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
501 }
502 }
503
504 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
506 let mut docs = Vec::new();
507 for tool in self.tools.values() {
508 match tool {
509 ToolType::Simple(tool) => {
510 docs.push(completion::Document {
511 id: tool.name(),
512 text: format!(
513 "\
514 Tool: {}\n\
515 Definition: \n\
516 {}\
517 ",
518 tool.name(),
519 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
520 ),
521 additional_props: HashMap::new(),
522 });
523 }
524 ToolType::Embedding(tool) => {
525 docs.push(completion::Document {
526 id: tool.name(),
527 text: format!(
528 "\
529 Tool: {}\n\
530 Definition: \n\
531 {}\
532 ",
533 tool.name(),
534 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
535 ),
536 additional_props: HashMap::new(),
537 });
538 }
539 }
540 }
541 Ok(docs)
542 }
543
544 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
548 self.tools
549 .values()
550 .filter_map(|tool_type| {
551 if let ToolType::Embedding(tool) = tool_type {
552 Some(ToolSchema::try_from(&**tool))
553 } else {
554 None
555 }
556 })
557 .collect::<Result<Vec<_>, _>>()
558 }
559}
560
561#[derive(Default)]
562pub struct ToolSetBuilder {
563 tools: Vec<ToolType>,
564}
565
566impl ToolSetBuilder {
567 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
568 self.tools.push(ToolType::Simple(Box::new(tool)));
569 self
570 }
571
572 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
573 self.tools.push(ToolType::Embedding(Box::new(tool)));
574 self
575 }
576
577 pub fn build(self) -> ToolSet {
578 ToolSet {
579 tools: self
580 .tools
581 .into_iter()
582 .map(|tool| (tool.name(), tool))
583 .collect(),
584 }
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use serde_json::json;
591
592 use super::*;
593
594 fn get_test_toolset() -> ToolSet {
595 let mut toolset = ToolSet::default();
596
597 #[derive(Deserialize)]
598 struct OperationArgs {
599 x: i32,
600 y: i32,
601 }
602
603 #[derive(Debug, thiserror::Error)]
604 #[error("Math error")]
605 struct MathError;
606
607 #[derive(Deserialize, Serialize)]
608 struct Adder;
609
610 impl Tool for Adder {
611 const NAME: &'static str = "add";
612 type Error = MathError;
613 type Args = OperationArgs;
614 type Output = i32;
615
616 async fn definition(&self, _prompt: String) -> ToolDefinition {
617 ToolDefinition {
618 name: "add".to_string(),
619 description: "Add x and y together".to_string(),
620 parameters: json!({
621 "type": "object",
622 "properties": {
623 "x": {
624 "type": "number",
625 "description": "The first number to add"
626 },
627 "y": {
628 "type": "number",
629 "description": "The second number to add"
630 }
631 },
632 "required": ["x", "y"]
633 }),
634 }
635 }
636
637 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
638 let result = args.x + args.y;
639 Ok(result)
640 }
641 }
642
643 #[derive(Deserialize, Serialize)]
644 struct Subtract;
645
646 impl Tool for Subtract {
647 const NAME: &'static str = "subtract";
648 type Error = MathError;
649 type Args = OperationArgs;
650 type Output = i32;
651
652 async fn definition(&self, _prompt: String) -> ToolDefinition {
653 serde_json::from_value(json!({
654 "name": "subtract",
655 "description": "Subtract y from x (i.e.: x - y)",
656 "parameters": {
657 "type": "object",
658 "properties": {
659 "x": {
660 "type": "number",
661 "description": "The number to subtract from"
662 },
663 "y": {
664 "type": "number",
665 "description": "The number to subtract"
666 }
667 },
668 "required": ["x", "y"]
669 }
670 }))
671 .expect("Tool Definition")
672 }
673
674 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
675 let result = args.x - args.y;
676 Ok(result)
677 }
678 }
679
680 toolset.add_tool(Adder);
681 toolset.add_tool(Subtract);
682 toolset
683 }
684
685 #[tokio::test]
686 async fn test_get_tool_definitions() {
687 let toolset = get_test_toolset();
688 let tools = toolset.get_tool_definitions().await.unwrap();
689 assert_eq!(tools.len(), 2);
690 }
691
692 #[test]
693 fn test_tool_deletion() {
694 let mut toolset = get_test_toolset();
695 assert_eq!(toolset.tools.len(), 2);
696 toolset.delete_tool("add");
697 assert!(!toolset.contains("add"));
698 assert_eq!(toolset.tools.len(), 1);
699 }
700}