1pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15
16use futures::Future;
17use serde::{Deserialize, Serialize};
18
19use crate::{
20 completion::{self, ToolDefinition},
21 embeddings::{embed::EmbedError, tool::ToolSchema},
22 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
23};
24
25#[derive(Debug, thiserror::Error)]
26pub enum ToolError {
27 #[cfg(not(target_family = "wasm"))]
28 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
30
31 #[cfg(target_family = "wasm")]
32 ToolCallError(#[from] Box<dyn std::error::Error>),
34 JsonError(#[from] serde_json::Error),
36}
37
38impl fmt::Display for ToolError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 ToolError::ToolCallError(e) => {
42 let error_str = e.to_string();
43 if error_str.starts_with("ToolCallError: ") {
46 write!(f, "{}", error_str)
47 } else {
48 write!(f, "ToolCallError: {}", error_str)
49 }
50 }
51 ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
52 }
53 }
54}
55
56pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
112 const NAME: &'static str;
114
115 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
117 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
119 type Output: Serialize;
121
122 fn name(&self) -> String {
124 Self::NAME.to_string()
125 }
126
127 fn definition(
130 &self,
131 _prompt: String,
132 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
133
134 fn call(
138 &self,
139 args: Self::Args,
140 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
141}
142
143pub trait ToolEmbedding: Tool {
145 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
146
147 type Context: for<'a> Deserialize<'a> + Serialize;
152
153 type State: WasmCompatSend;
157
158 fn embedding_docs(&self) -> Vec<String>;
162
163 fn context(&self) -> Self::Context;
165
166 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
168}
169
170pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
172 fn name(&self) -> String;
173
174 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
175
176 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
177}
178
179impl<T: Tool> ToolDyn for T {
180 fn name(&self) -> String {
181 self.name()
182 }
183
184 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
185 Box::pin(<Self as Tool>::definition(self, prompt))
186 }
187
188 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
189 Box::pin(async move {
190 match serde_json::from_str(&args) {
191 Ok(args) => <Self as Tool>::call(self, args)
192 .await
193 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
194 .and_then(|output| {
195 serde_json::to_string(&output).map_err(ToolError::JsonError)
196 }),
197 Err(e) => Err(ToolError::JsonError(e)),
198 }
199 })
200 }
201}
202
203#[cfg(feature = "rmcp")]
204#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
205pub mod rmcp {
206 use crate::completion::ToolDefinition;
207 use crate::tool::ToolDyn;
208 use crate::tool::ToolError;
209 use crate::wasm_compat::WasmBoxedFuture;
210 use rmcp::model::RawContent;
211 use std::borrow::Cow;
212
213 #[derive(Clone)]
214 pub struct McpTool {
215 definition: rmcp::model::Tool,
216 client: rmcp::service::ServerSink,
217 }
218
219 impl McpTool {
220 pub fn from_mcp_server(
221 definition: rmcp::model::Tool,
222 client: rmcp::service::ServerSink,
223 ) -> Self {
224 Self { definition, client }
225 }
226 }
227
228 impl From<&rmcp::model::Tool> for ToolDefinition {
229 fn from(val: &rmcp::model::Tool) -> Self {
230 Self {
231 name: val.name.to_string(),
232 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
233 parameters: val.schema_as_json_value(),
234 }
235 }
236 }
237
238 impl From<rmcp::model::Tool> for ToolDefinition {
239 fn from(val: rmcp::model::Tool) -> Self {
240 Self {
241 name: val.name.to_string(),
242 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
243 parameters: val.schema_as_json_value(),
244 }
245 }
246 }
247
248 #[derive(Debug, thiserror::Error)]
249 #[error("MCP tool error: {0}")]
250 pub struct McpToolError(String);
251
252 impl From<McpToolError> for ToolError {
253 fn from(e: McpToolError) -> Self {
254 ToolError::ToolCallError(Box::new(e))
255 }
256 }
257
258 impl ToolDyn for McpTool {
259 fn name(&self) -> String {
260 self.definition.name.to_string()
261 }
262
263 fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
264 Box::pin(async move {
265 ToolDefinition {
266 name: self.definition.name.to_string(),
267 description: self
268 .definition
269 .description
270 .clone()
271 .unwrap_or(Cow::from(""))
272 .to_string(),
273 parameters: serde_json::to_value(&self.definition.input_schema)
274 .unwrap_or_default(),
275 }
276 })
277 }
278
279 fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
280 let name = self.definition.name.clone();
281 let arguments = serde_json::from_str(&args).unwrap_or_default();
282
283 Box::pin(async move {
284 let result = self
285 .client
286 .call_tool(rmcp::model::CallToolRequestParams {
287 name,
288 arguments,
289 meta: None,
290 task: None,
291 })
292 .await
293 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
294
295 if let Some(true) = result.is_error {
296 let error_msg = result
297 .content
298 .into_iter()
299 .map(|x| x.raw.as_text().map(|y| y.to_owned()))
300 .map(|x| x.map(|x| x.clone().text))
301 .collect::<Option<Vec<String>>>();
302
303 let error_message = error_msg.map(|x| x.join("\n"));
304 if let Some(error_message) = error_message {
305 return Err(McpToolError(error_message).into());
306 } else {
307 return Err(McpToolError("No message returned".to_string()).into());
308 }
309 };
310
311 Ok(result
312 .content
313 .into_iter()
314 .map(|c| match c.raw {
315 rmcp::model::RawContent::Text(raw) => raw.text,
316 rmcp::model::RawContent::Image(raw) => {
317 format!("data:{};base64,{}", raw.mime_type, raw.data)
318 }
319 rmcp::model::RawContent::Resource(raw) => match raw.resource {
320 rmcp::model::ResourceContents::TextResourceContents {
321 uri,
322 mime_type,
323 text,
324 ..
325 } => {
326 format!(
327 "{mime_type}{uri}:{text}",
328 mime_type = mime_type
329 .map(|m| format!("data:{m};"))
330 .unwrap_or_default(),
331 )
332 }
333 rmcp::model::ResourceContents::BlobResourceContents {
334 uri,
335 mime_type,
336 blob,
337 ..
338 } => format!(
339 "{mime_type}{uri}:{blob}",
340 mime_type = mime_type
341 .map(|m| format!("data:{m};"))
342 .unwrap_or_default(),
343 ),
344 },
345 RawContent::Audio(_) => {
346 panic!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
347 }
348 thing => {
349 panic!("Unsupported type found: {thing:?}")
350 }
351 })
352 .collect::<String>())
353 })
354 }
355 }
356}
357
358pub trait ToolEmbeddingDyn: ToolDyn {
360 fn context(&self) -> serde_json::Result<serde_json::Value>;
361
362 fn embedding_docs(&self) -> Vec<String>;
363}
364
365impl<T> ToolEmbeddingDyn for T
366where
367 T: ToolEmbedding + 'static,
368{
369 fn context(&self) -> serde_json::Result<serde_json::Value> {
370 serde_json::to_value(self.context())
371 }
372
373 fn embedding_docs(&self) -> Vec<String> {
374 self.embedding_docs()
375 }
376}
377
378pub(crate) enum ToolType {
379 Simple(Box<dyn ToolDyn>),
380 Embedding(Box<dyn ToolEmbeddingDyn>),
381}
382
383impl ToolType {
384 pub fn name(&self) -> String {
385 match self {
386 ToolType::Simple(tool) => tool.name(),
387 ToolType::Embedding(tool) => tool.name(),
388 }
389 }
390
391 pub async fn definition(&self, prompt: String) -> ToolDefinition {
392 match self {
393 ToolType::Simple(tool) => tool.definition(prompt).await,
394 ToolType::Embedding(tool) => tool.definition(prompt).await,
395 }
396 }
397
398 pub async fn call(&self, args: String) -> Result<String, ToolError> {
399 match self {
400 ToolType::Simple(tool) => tool.call(args).await,
401 ToolType::Embedding(tool) => tool.call(args).await,
402 }
403 }
404}
405
406#[derive(Debug, thiserror::Error)]
407pub enum ToolSetError {
408 #[error("ToolCallError: {0}")]
410 ToolCallError(#[from] ToolError),
411
412 #[error("ToolNotFoundError: {0}")]
414 ToolNotFoundError(String),
415
416 #[error("JsonError: {0}")]
418 JsonError(#[from] serde_json::Error),
419
420 #[error("Tool call interrupted")]
422 Interrupted,
423}
424
425#[derive(Default)]
427pub struct ToolSet {
428 pub(crate) tools: HashMap<String, ToolType>,
429}
430
431impl ToolSet {
432 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
434 let mut toolset = Self::default();
435 tools.into_iter().for_each(|tool| {
436 toolset.add_tool(tool);
437 });
438 toolset
439 }
440
441 pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
442 let mut toolset = Self::default();
443 tools.into_iter().for_each(|tool| {
444 toolset.add_tool_boxed(tool);
445 });
446 toolset
447 }
448
449 pub fn builder() -> ToolSetBuilder {
451 ToolSetBuilder::default()
452 }
453
454 pub fn contains(&self, toolname: &str) -> bool {
456 self.tools.contains_key(toolname)
457 }
458
459 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
461 self.tools
462 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
463 }
464
465 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
467 self.tools.insert(tool.name(), ToolType::Simple(tool));
468 }
469
470 pub fn delete_tool(&mut self, tool_name: &str) {
471 let _ = self.tools.remove(tool_name);
472 }
473
474 pub fn add_tools(&mut self, toolset: ToolSet) {
476 self.tools.extend(toolset.tools);
477 }
478
479 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
480 self.tools.get(toolname)
481 }
482
483 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
484 let mut defs = Vec::new();
485 for tool in self.tools.values() {
486 let def = tool.definition(String::new()).await;
487 defs.push(def);
488 }
489 Ok(defs)
490 }
491
492 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
494 if let Some(tool) = self.tools.get(toolname) {
495 tracing::debug!(target: "rig",
496 "Calling tool {toolname} with args:\n{}",
497 serde_json::to_string_pretty(&args).unwrap()
498 );
499 Ok(tool.call(args).await?)
500 } else {
501 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
502 }
503 }
504
505 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
507 let mut docs = Vec::new();
508 for tool in self.tools.values() {
509 match tool {
510 ToolType::Simple(tool) => {
511 docs.push(completion::Document {
512 id: tool.name(),
513 text: format!(
514 "\
515 Tool: {}\n\
516 Definition: \n\
517 {}\
518 ",
519 tool.name(),
520 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
521 ),
522 additional_props: HashMap::new(),
523 });
524 }
525 ToolType::Embedding(tool) => {
526 docs.push(completion::Document {
527 id: tool.name(),
528 text: format!(
529 "\
530 Tool: {}\n\
531 Definition: \n\
532 {}\
533 ",
534 tool.name(),
535 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
536 ),
537 additional_props: HashMap::new(),
538 });
539 }
540 }
541 }
542 Ok(docs)
543 }
544
545 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
549 self.tools
550 .values()
551 .filter_map(|tool_type| {
552 if let ToolType::Embedding(tool) = tool_type {
553 Some(ToolSchema::try_from(&**tool))
554 } else {
555 None
556 }
557 })
558 .collect::<Result<Vec<_>, _>>()
559 }
560}
561
562#[derive(Default)]
563pub struct ToolSetBuilder {
564 tools: Vec<ToolType>,
565}
566
567impl ToolSetBuilder {
568 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
569 self.tools.push(ToolType::Simple(Box::new(tool)));
570 self
571 }
572
573 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
574 self.tools.push(ToolType::Embedding(Box::new(tool)));
575 self
576 }
577
578 pub fn build(self) -> ToolSet {
579 ToolSet {
580 tools: self
581 .tools
582 .into_iter()
583 .map(|tool| (tool.name(), tool))
584 .collect(),
585 }
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use serde_json::json;
592
593 use super::*;
594
595 fn get_test_toolset() -> ToolSet {
596 let mut toolset = ToolSet::default();
597
598 #[derive(Deserialize)]
599 struct OperationArgs {
600 x: i32,
601 y: i32,
602 }
603
604 #[derive(Debug, thiserror::Error)]
605 #[error("Math error")]
606 struct MathError;
607
608 #[derive(Deserialize, Serialize)]
609 struct Adder;
610
611 impl Tool for Adder {
612 const NAME: &'static str = "add";
613 type Error = MathError;
614 type Args = OperationArgs;
615 type Output = i32;
616
617 async fn definition(&self, _prompt: String) -> ToolDefinition {
618 ToolDefinition {
619 name: "add".to_string(),
620 description: "Add x and y together".to_string(),
621 parameters: json!({
622 "type": "object",
623 "properties": {
624 "x": {
625 "type": "number",
626 "description": "The first number to add"
627 },
628 "y": {
629 "type": "number",
630 "description": "The second number to add"
631 }
632 },
633 "required": ["x", "y"]
634 }),
635 }
636 }
637
638 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
639 let result = args.x + args.y;
640 Ok(result)
641 }
642 }
643
644 #[derive(Deserialize, Serialize)]
645 struct Subtract;
646
647 impl Tool for Subtract {
648 const NAME: &'static str = "subtract";
649 type Error = MathError;
650 type Args = OperationArgs;
651 type Output = i32;
652
653 async fn definition(&self, _prompt: String) -> ToolDefinition {
654 serde_json::from_value(json!({
655 "name": "subtract",
656 "description": "Subtract y from x (i.e.: x - y)",
657 "parameters": {
658 "type": "object",
659 "properties": {
660 "x": {
661 "type": "number",
662 "description": "The number to subtract from"
663 },
664 "y": {
665 "type": "number",
666 "description": "The number to subtract"
667 }
668 },
669 "required": ["x", "y"]
670 }
671 }))
672 .expect("Tool Definition")
673 }
674
675 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
676 let result = args.x - args.y;
677 Ok(result)
678 }
679 }
680
681 toolset.add_tool(Adder);
682 toolset.add_tool(Subtract);
683 toolset
684 }
685
686 #[tokio::test]
687 async fn test_get_tool_definitions() {
688 let toolset = get_test_toolset();
689 let tools = toolset.get_tool_definitions().await.unwrap();
690 assert_eq!(tools.len(), 2);
691 }
692
693 #[test]
694 fn test_tool_deletion() {
695 let mut toolset = get_test_toolset();
696 assert_eq!(toolset.tools.len(), 2);
697 toolset.delete_tool("add");
698 assert!(!toolset.contains("add"));
699 assert_eq!(toolset.tools.len(), 1);
700 }
701}