1pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15
16use futures::Future;
17use serde::{Deserialize, Serialize};
18
19use crate::{
20 completion::{self, ToolDefinition},
21 embeddings::{embed::EmbedError, tool::ToolSchema},
22 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
23};
24
25#[derive(Debug, thiserror::Error)]
26pub enum ToolError {
27 #[cfg(not(target_family = "wasm"))]
28 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
30
31 #[cfg(target_family = "wasm")]
32 ToolCallError(#[from] Box<dyn std::error::Error>),
34 JsonError(#[from] serde_json::Error),
36}
37
38impl fmt::Display for ToolError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 ToolError::ToolCallError(e) => {
42 let error_str = e.to_string();
43 if error_str.starts_with("ToolCallError: ") {
46 write!(f, "{}", error_str)
47 } else {
48 write!(f, "ToolCallError: {}", error_str)
49 }
50 }
51 ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
52 }
53 }
54}
55
56pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
112 const NAME: &'static str;
114
115 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
117 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
119 type Output: Serialize;
121
122 fn name(&self) -> String {
124 Self::NAME.to_string()
125 }
126
127 fn definition(
130 &self,
131 _prompt: String,
132 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
133
134 fn call(
138 &self,
139 args: Self::Args,
140 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
141}
142
143pub trait ToolEmbedding: Tool {
145 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
146
147 type Context: for<'a> Deserialize<'a> + Serialize;
152
153 type State: WasmCompatSend;
157
158 fn embedding_docs(&self) -> Vec<String>;
162
163 fn context(&self) -> Self::Context;
165
166 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
168}
169
170pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
172 fn name(&self) -> String;
173
174 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
175
176 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
177}
178
179impl<T: Tool> ToolDyn for T {
180 fn name(&self) -> String {
181 self.name()
182 }
183
184 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
185 Box::pin(<Self as Tool>::definition(self, prompt))
186 }
187
188 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
189 Box::pin(async move {
190 match serde_json::from_str(&args) {
191 Ok(args) => <Self as Tool>::call(self, args)
192 .await
193 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
194 .and_then(|output| {
195 serde_json::to_string(&output).map_err(ToolError::JsonError)
196 }),
197 Err(e) => Err(ToolError::JsonError(e)),
198 }
199 })
200 }
201}
202
203#[cfg(feature = "rmcp")]
204#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
205pub mod rmcp {
206 use crate::completion::ToolDefinition;
207 use crate::tool::ToolDyn;
208 use crate::tool::ToolError;
209 use crate::wasm_compat::WasmBoxedFuture;
210 use rmcp::model::RawContent;
211 use std::borrow::Cow;
212
213 #[derive(Clone)]
214 pub struct McpTool {
215 definition: rmcp::model::Tool,
216 client: rmcp::service::ServerSink,
217 }
218
219 impl McpTool {
220 pub fn from_mcp_server(
221 definition: rmcp::model::Tool,
222 client: rmcp::service::ServerSink,
223 ) -> Self {
224 Self { definition, client }
225 }
226 }
227
228 impl From<&rmcp::model::Tool> for ToolDefinition {
229 fn from(val: &rmcp::model::Tool) -> Self {
230 Self {
231 name: val.name.to_string(),
232 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
233 parameters: val.schema_as_json_value(),
234 }
235 }
236 }
237
238 impl From<rmcp::model::Tool> for ToolDefinition {
239 fn from(val: rmcp::model::Tool) -> Self {
240 Self {
241 name: val.name.to_string(),
242 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
243 parameters: val.schema_as_json_value(),
244 }
245 }
246 }
247
248 #[derive(Debug, thiserror::Error)]
249 #[error("MCP tool error: {0}")]
250 pub struct McpToolError(String);
251
252 impl From<McpToolError> for ToolError {
253 fn from(e: McpToolError) -> Self {
254 ToolError::ToolCallError(Box::new(e))
255 }
256 }
257
258 impl ToolDyn for McpTool {
259 fn name(&self) -> String {
260 self.definition.name.to_string()
261 }
262
263 fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
264 Box::pin(async move {
265 ToolDefinition {
266 name: self.definition.name.to_string(),
267 description: self
268 .definition
269 .description
270 .clone()
271 .unwrap_or(Cow::from(""))
272 .to_string(),
273 parameters: serde_json::to_value(&self.definition.input_schema)
274 .unwrap_or_default(),
275 }
276 })
277 }
278
279 fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
280 let name = self.definition.name.clone();
281 let arguments = serde_json::from_str(&args).unwrap_or_default();
282
283 Box::pin(async move {
284 let result = self
285 .client
286 .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
287 .await
288 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
289
290 if let Some(true) = result.is_error {
291 let error_msg = result
292 .content
293 .into_iter()
294 .map(|x| x.raw.as_text().map(|y| y.to_owned()))
295 .map(|x| x.map(|x| x.clone().text))
296 .collect::<Option<Vec<String>>>();
297
298 let error_message = error_msg.map(|x| x.join("\n"));
299 if let Some(error_message) = error_message {
300 return Err(McpToolError(error_message).into());
301 } else {
302 return Err(McpToolError("No message returned".to_string()).into());
303 }
304 };
305
306 Ok(result
307 .content
308 .into_iter()
309 .map(|c| match c.raw {
310 rmcp::model::RawContent::Text(raw) => raw.text,
311 rmcp::model::RawContent::Image(raw) => {
312 format!("data:{};base64,{}", raw.mime_type, raw.data)
313 }
314 rmcp::model::RawContent::Resource(raw) => match raw.resource {
315 rmcp::model::ResourceContents::TextResourceContents {
316 uri,
317 mime_type,
318 text,
319 ..
320 } => {
321 format!(
322 "{mime_type}{uri}:{text}",
323 mime_type = mime_type
324 .map(|m| format!("data:{m};"))
325 .unwrap_or_default(),
326 )
327 }
328 rmcp::model::ResourceContents::BlobResourceContents {
329 uri,
330 mime_type,
331 blob,
332 ..
333 } => format!(
334 "{mime_type}{uri}:{blob}",
335 mime_type = mime_type
336 .map(|m| format!("data:{m};"))
337 .unwrap_or_default(),
338 ),
339 },
340 RawContent::Audio(_) => {
341 panic!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
342 }
343 thing => {
344 panic!("Unsupported type found: {thing:?}")
345 }
346 })
347 .collect::<String>())
348 })
349 }
350 }
351}
352
353pub trait ToolEmbeddingDyn: ToolDyn {
355 fn context(&self) -> serde_json::Result<serde_json::Value>;
356
357 fn embedding_docs(&self) -> Vec<String>;
358}
359
360impl<T> ToolEmbeddingDyn for T
361where
362 T: ToolEmbedding + 'static,
363{
364 fn context(&self) -> serde_json::Result<serde_json::Value> {
365 serde_json::to_value(self.context())
366 }
367
368 fn embedding_docs(&self) -> Vec<String> {
369 self.embedding_docs()
370 }
371}
372
373pub(crate) enum ToolType {
374 Simple(Box<dyn ToolDyn>),
375 Embedding(Box<dyn ToolEmbeddingDyn>),
376}
377
378impl ToolType {
379 pub fn name(&self) -> String {
380 match self {
381 ToolType::Simple(tool) => tool.name(),
382 ToolType::Embedding(tool) => tool.name(),
383 }
384 }
385
386 pub async fn definition(&self, prompt: String) -> ToolDefinition {
387 match self {
388 ToolType::Simple(tool) => tool.definition(prompt).await,
389 ToolType::Embedding(tool) => tool.definition(prompt).await,
390 }
391 }
392
393 pub async fn call(&self, args: String) -> Result<String, ToolError> {
394 match self {
395 ToolType::Simple(tool) => tool.call(args).await,
396 ToolType::Embedding(tool) => tool.call(args).await,
397 }
398 }
399}
400
401#[derive(Debug, thiserror::Error)]
402pub enum ToolSetError {
403 #[error("ToolCallError: {0}")]
405 ToolCallError(#[from] ToolError),
406
407 #[error("ToolNotFoundError: {0}")]
409 ToolNotFoundError(String),
410
411 #[error("JsonError: {0}")]
413 JsonError(#[from] serde_json::Error),
414
415 #[error("Tool call interrupted")]
417 Interrupted,
418}
419
420#[derive(Default)]
422pub struct ToolSet {
423 pub(crate) tools: HashMap<String, ToolType>,
424}
425
426impl ToolSet {
427 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
429 let mut toolset = Self::default();
430 tools.into_iter().for_each(|tool| {
431 toolset.add_tool(tool);
432 });
433 toolset
434 }
435
436 pub fn builder() -> ToolSetBuilder {
438 ToolSetBuilder::default()
439 }
440
441 pub fn contains(&self, toolname: &str) -> bool {
443 self.tools.contains_key(toolname)
444 }
445
446 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
448 self.tools
449 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
450 }
451
452 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
454 self.tools.insert(tool.name(), ToolType::Simple(tool));
455 }
456
457 pub fn delete_tool(&mut self, tool_name: &str) {
458 let _ = self.tools.remove(tool_name);
459 }
460
461 pub fn add_tools(&mut self, toolset: ToolSet) {
463 self.tools.extend(toolset.tools);
464 }
465
466 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
467 self.tools.get(toolname)
468 }
469
470 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
471 let mut defs = Vec::new();
472 for tool in self.tools.values() {
473 let def = tool.definition(String::new()).await;
474 defs.push(def);
475 }
476 Ok(defs)
477 }
478
479 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
481 if let Some(tool) = self.tools.get(toolname) {
482 tracing::debug!(target: "rig",
483 "Calling tool {toolname} with args:\n{}",
484 serde_json::to_string_pretty(&args).unwrap()
485 );
486 Ok(tool.call(args).await?)
487 } else {
488 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
489 }
490 }
491
492 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
494 let mut docs = Vec::new();
495 for tool in self.tools.values() {
496 match tool {
497 ToolType::Simple(tool) => {
498 docs.push(completion::Document {
499 id: tool.name(),
500 text: format!(
501 "\
502 Tool: {}\n\
503 Definition: \n\
504 {}\
505 ",
506 tool.name(),
507 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
508 ),
509 additional_props: HashMap::new(),
510 });
511 }
512 ToolType::Embedding(tool) => {
513 docs.push(completion::Document {
514 id: tool.name(),
515 text: format!(
516 "\
517 Tool: {}\n\
518 Definition: \n\
519 {}\
520 ",
521 tool.name(),
522 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
523 ),
524 additional_props: HashMap::new(),
525 });
526 }
527 }
528 }
529 Ok(docs)
530 }
531
532 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
536 self.tools
537 .values()
538 .filter_map(|tool_type| {
539 if let ToolType::Embedding(tool) = tool_type {
540 Some(ToolSchema::try_from(&**tool))
541 } else {
542 None
543 }
544 })
545 .collect::<Result<Vec<_>, _>>()
546 }
547}
548
549#[derive(Default)]
550pub struct ToolSetBuilder {
551 tools: Vec<ToolType>,
552}
553
554impl ToolSetBuilder {
555 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
556 self.tools.push(ToolType::Simple(Box::new(tool)));
557 self
558 }
559
560 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
561 self.tools.push(ToolType::Embedding(Box::new(tool)));
562 self
563 }
564
565 pub fn build(self) -> ToolSet {
566 ToolSet {
567 tools: self
568 .tools
569 .into_iter()
570 .map(|tool| (tool.name(), tool))
571 .collect(),
572 }
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use serde_json::json;
579
580 use super::*;
581
582 fn get_test_toolset() -> ToolSet {
583 let mut toolset = ToolSet::default();
584
585 #[derive(Deserialize)]
586 struct OperationArgs {
587 x: i32,
588 y: i32,
589 }
590
591 #[derive(Debug, thiserror::Error)]
592 #[error("Math error")]
593 struct MathError;
594
595 #[derive(Deserialize, Serialize)]
596 struct Adder;
597
598 impl Tool for Adder {
599 const NAME: &'static str = "add";
600 type Error = MathError;
601 type Args = OperationArgs;
602 type Output = i32;
603
604 async fn definition(&self, _prompt: String) -> ToolDefinition {
605 ToolDefinition {
606 name: "add".to_string(),
607 description: "Add x and y together".to_string(),
608 parameters: json!({
609 "type": "object",
610 "properties": {
611 "x": {
612 "type": "number",
613 "description": "The first number to add"
614 },
615 "y": {
616 "type": "number",
617 "description": "The second number to add"
618 }
619 },
620 "required": ["x", "y"]
621 }),
622 }
623 }
624
625 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
626 let result = args.x + args.y;
627 Ok(result)
628 }
629 }
630
631 #[derive(Deserialize, Serialize)]
632 struct Subtract;
633
634 impl Tool for Subtract {
635 const NAME: &'static str = "subtract";
636 type Error = MathError;
637 type Args = OperationArgs;
638 type Output = i32;
639
640 async fn definition(&self, _prompt: String) -> ToolDefinition {
641 serde_json::from_value(json!({
642 "name": "subtract",
643 "description": "Subtract y from x (i.e.: x - y)",
644 "parameters": {
645 "type": "object",
646 "properties": {
647 "x": {
648 "type": "number",
649 "description": "The number to subtract from"
650 },
651 "y": {
652 "type": "number",
653 "description": "The number to subtract"
654 }
655 },
656 "required": ["x", "y"]
657 }
658 }))
659 .expect("Tool Definition")
660 }
661
662 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
663 let result = args.x - args.y;
664 Ok(result)
665 }
666 }
667
668 toolset.add_tool(Adder);
669 toolset.add_tool(Subtract);
670 toolset
671 }
672
673 #[tokio::test]
674 async fn test_get_tool_definitions() {
675 let toolset = get_test_toolset();
676 let tools = toolset.get_tool_definitions().await.unwrap();
677 assert_eq!(tools.len(), 2);
678 }
679
680 #[test]
681 fn test_tool_deletion() {
682 let mut toolset = get_test_toolset();
683 assert_eq!(toolset.tools.len(), 2);
684 toolset.delete_tool("add");
685 assert!(!toolset.contains("add"));
686 assert_eq!(toolset.tools.len(), 1);
687 }
688}