1use std::{collections::HashMap, pin::Pin};
13
14use futures::Future;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 completion::{self, ToolDefinition},
19 embeddings::{embed::EmbedError, tool::ToolSchema},
20};
21
22#[derive(Debug, thiserror::Error)]
23pub enum ToolError {
24 #[error("ToolCallError: {0}")]
26 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
27
28 #[error("JsonError: {0}")]
29 JsonError(#[from] serde_json::Error),
30}
31
32pub trait Tool: Sized + Send + Sync {
88 const NAME: &'static str;
90
91 type Error: std::error::Error + Send + Sync + 'static;
93 type Args: for<'a> Deserialize<'a> + Send + Sync;
95 type Output: Serialize;
97
98 fn name(&self) -> String {
100 Self::NAME.to_string()
101 }
102
103 fn definition(&self, _prompt: String) -> impl Future<Output = ToolDefinition> + Send + Sync;
106
107 fn call(
111 &self,
112 args: Self::Args,
113 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send;
114}
115
116pub trait ToolEmbedding: Tool {
118 type InitError: std::error::Error + Send + Sync + 'static;
119
120 type Context: for<'a> Deserialize<'a> + Serialize;
125
126 type State: Send;
130
131 fn embedding_docs(&self) -> Vec<String>;
135
136 fn context(&self) -> Self::Context;
138
139 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
141}
142
143pub trait ToolDyn: Send + Sync {
145 fn name(&self) -> String;
146
147 fn definition(
148 &self,
149 prompt: String,
150 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>>;
151
152 fn call(
153 &self,
154 args: String,
155 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>>;
156}
157
158impl<T: Tool> ToolDyn for T {
159 fn name(&self) -> String {
160 self.name()
161 }
162
163 fn definition(
164 &self,
165 prompt: String,
166 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
167 Box::pin(<Self as Tool>::definition(self, prompt))
168 }
169
170 fn call(
171 &self,
172 args: String,
173 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
174 Box::pin(async move {
175 match serde_json::from_str(&args) {
176 Ok(args) => <Self as Tool>::call(self, args)
177 .await
178 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
179 .and_then(|output| {
180 serde_json::to_string(&output).map_err(ToolError::JsonError)
181 }),
182 Err(e) => Err(ToolError::JsonError(e)),
183 }
184 })
185 }
186}
187
188#[cfg(feature = "rmcp")]
189#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
190pub mod rmcp {
191 use crate::completion::ToolDefinition;
192 use crate::tool::ToolDyn;
193 use crate::tool::ToolError;
194 use rmcp::model::RawContent;
195 use std::borrow::Cow;
196 use std::pin::Pin;
197
198 pub struct McpTool {
199 definition: rmcp::model::Tool,
200 client: rmcp::service::ServerSink,
201 }
202
203 impl McpTool {
204 pub fn from_mcp_server(
205 definition: rmcp::model::Tool,
206 client: rmcp::service::ServerSink,
207 ) -> Self {
208 Self { definition, client }
209 }
210 }
211
212 impl From<&rmcp::model::Tool> for ToolDefinition {
213 fn from(val: &rmcp::model::Tool) -> Self {
214 Self {
215 name: val.name.to_string(),
216 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
217 parameters: val.schema_as_json_value(),
218 }
219 }
220 }
221
222 impl From<rmcp::model::Tool> for ToolDefinition {
223 fn from(val: rmcp::model::Tool) -> Self {
224 Self {
225 name: val.name.to_string(),
226 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
227 parameters: val.schema_as_json_value(),
228 }
229 }
230 }
231
232 #[derive(Debug, thiserror::Error)]
233 #[error("MCP tool error: {0}")]
234 pub struct McpToolError(String);
235
236 impl From<McpToolError> for ToolError {
237 fn from(e: McpToolError) -> Self {
238 ToolError::ToolCallError(Box::new(e))
239 }
240 }
241
242 impl ToolDyn for McpTool {
243 fn name(&self) -> String {
244 self.definition.name.to_string()
245 }
246
247 fn definition(
248 &self,
249 _prompt: String,
250 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
251 Box::pin(async move {
252 ToolDefinition {
253 name: self.definition.name.to_string(),
254 description: self
255 .definition
256 .description
257 .clone()
258 .unwrap_or(Cow::from(""))
259 .to_string(),
260 parameters: serde_json::to_value(&self.definition.input_schema)
261 .unwrap_or_default(),
262 }
263 })
264 }
265
266 fn call(
267 &self,
268 args: String,
269 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
270 let name = self.definition.name.clone();
271 let arguments = serde_json::from_str(&args).unwrap_or_default();
272
273 Box::pin(async move {
274 let result = self
275 .client
276 .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
277 .await
278 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
279
280 if let Some(true) = result.is_error {
281 let error_msg = result
282 .content
283 .into_iter()
284 .map(|x| x.raw.as_text().map(|y| y.to_owned()))
285 .map(|x| x.map(|x| x.clone().text))
286 .collect::<Option<Vec<String>>>();
287
288 let error_message = error_msg.map(|x| x.join("\n"));
289 if let Some(error_message) = error_message {
290 return Err(McpToolError(error_message).into());
291 } else {
292 return Err(McpToolError("No message returned".to_string()).into());
293 }
294 };
295
296 Ok(result
297 .content
298 .into_iter()
299 .map(|c| match c.raw {
300 rmcp::model::RawContent::Text(raw) => raw.text,
301 rmcp::model::RawContent::Image(raw) => {
302 format!("data:{};base64,{}", raw.mime_type, raw.data)
303 }
304 rmcp::model::RawContent::Resource(raw) => match raw.resource {
305 rmcp::model::ResourceContents::TextResourceContents {
306 uri,
307 mime_type,
308 text,
309 ..
310 } => {
311 format!(
312 "{mime_type}{uri}:{text}",
313 mime_type = mime_type
314 .map(|m| format!("data:{m};"))
315 .unwrap_or_default(),
316 )
317 }
318 rmcp::model::ResourceContents::BlobResourceContents {
319 uri,
320 mime_type,
321 blob,
322 ..
323 } => format!(
324 "{mime_type}{uri}:{blob}",
325 mime_type = mime_type
326 .map(|m| format!("data:{m};"))
327 .unwrap_or_default(),
328 ),
329 },
330 RawContent::Audio(_) => {
331 unimplemented!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
332 }
333 thing => {
334 unimplemented!("Unsupported type found: {thing:?}")
335 }
336 })
337 .collect::<String>())
338 })
339 }
340 }
341}
342
343pub trait ToolEmbeddingDyn: ToolDyn {
345 fn context(&self) -> serde_json::Result<serde_json::Value>;
346
347 fn embedding_docs(&self) -> Vec<String>;
348}
349
350impl<T> ToolEmbeddingDyn for T
351where
352 T: ToolEmbedding,
353{
354 fn context(&self) -> serde_json::Result<serde_json::Value> {
355 serde_json::to_value(self.context())
356 }
357
358 fn embedding_docs(&self) -> Vec<String> {
359 self.embedding_docs()
360 }
361}
362
363pub(crate) enum ToolType {
364 Simple(Box<dyn ToolDyn>),
365 Embedding(Box<dyn ToolEmbeddingDyn>),
366}
367
368impl ToolType {
369 pub fn name(&self) -> String {
370 match self {
371 ToolType::Simple(tool) => tool.name(),
372 ToolType::Embedding(tool) => tool.name(),
373 }
374 }
375
376 pub async fn definition(&self, prompt: String) -> ToolDefinition {
377 match self {
378 ToolType::Simple(tool) => tool.definition(prompt).await,
379 ToolType::Embedding(tool) => tool.definition(prompt).await,
380 }
381 }
382
383 pub async fn call(&self, args: String) -> Result<String, ToolError> {
384 match self {
385 ToolType::Simple(tool) => tool.call(args).await,
386 ToolType::Embedding(tool) => tool.call(args).await,
387 }
388 }
389}
390
391#[derive(Debug, thiserror::Error)]
392pub enum ToolSetError {
393 #[error("ToolCallError: {0}")]
395 ToolCallError(#[from] ToolError),
396
397 #[error("ToolNotFoundError: {0}")]
398 ToolNotFoundError(String),
399
400 #[error("JsonError: {0}")]
402 JsonError(#[from] serde_json::Error),
403}
404
405#[derive(Default)]
407pub struct ToolSet {
408 pub(crate) tools: HashMap<String, ToolType>,
409}
410
411impl ToolSet {
412 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
414 let mut toolset = Self::default();
415 tools.into_iter().for_each(|tool| {
416 toolset.add_tool(tool);
417 });
418 toolset
419 }
420
421 pub fn builder() -> ToolSetBuilder {
423 ToolSetBuilder::default()
424 }
425
426 pub fn contains(&self, toolname: &str) -> bool {
428 self.tools.contains_key(toolname)
429 }
430
431 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
433 self.tools
434 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
435 }
436
437 pub fn delete_tool(&mut self, tool_name: &str) {
438 let _ = self.tools.remove(tool_name);
439 }
440
441 pub fn add_tools(&mut self, toolset: ToolSet) {
443 self.tools.extend(toolset.tools);
444 }
445
446 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
447 self.tools.get(toolname)
448 }
449
450 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
451 let mut defs = Vec::new();
452 for tool in self.tools.values() {
453 let def = tool.definition(String::new()).await;
454 defs.push(def);
455 }
456 Ok(defs)
457 }
458
459 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
461 if let Some(tool) = self.tools.get(toolname) {
462 tracing::info!(target: "rig",
463 "Calling tool {toolname} with args:\n{}",
464 serde_json::to_string_pretty(&args).unwrap()
465 );
466 Ok(tool.call(args).await?)
467 } else {
468 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
469 }
470 }
471
472 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
474 let mut docs = Vec::new();
475 for tool in self.tools.values() {
476 match tool {
477 ToolType::Simple(tool) => {
478 docs.push(completion::Document {
479 id: tool.name(),
480 text: format!(
481 "\
482 Tool: {}\n\
483 Definition: \n\
484 {}\
485 ",
486 tool.name(),
487 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
488 ),
489 additional_props: HashMap::new(),
490 });
491 }
492 ToolType::Embedding(tool) => {
493 docs.push(completion::Document {
494 id: tool.name(),
495 text: format!(
496 "\
497 Tool: {}\n\
498 Definition: \n\
499 {}\
500 ",
501 tool.name(),
502 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
503 ),
504 additional_props: HashMap::new(),
505 });
506 }
507 }
508 }
509 Ok(docs)
510 }
511
512 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
516 self.tools
517 .values()
518 .filter_map(|tool_type| {
519 if let ToolType::Embedding(tool) = tool_type {
520 Some(ToolSchema::try_from(&**tool))
521 } else {
522 None
523 }
524 })
525 .collect::<Result<Vec<_>, _>>()
526 }
527}
528
529#[derive(Default)]
530pub struct ToolSetBuilder {
531 tools: Vec<ToolType>,
532}
533
534impl ToolSetBuilder {
535 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
536 self.tools.push(ToolType::Simple(Box::new(tool)));
537 self
538 }
539
540 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
541 self.tools.push(ToolType::Embedding(Box::new(tool)));
542 self
543 }
544
545 pub fn build(self) -> ToolSet {
546 ToolSet {
547 tools: self
548 .tools
549 .into_iter()
550 .map(|tool| (tool.name(), tool))
551 .collect(),
552 }
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use serde_json::json;
559
560 use super::*;
561
562 fn get_test_toolset() -> ToolSet {
563 let mut toolset = ToolSet::default();
564
565 #[derive(Deserialize)]
566 struct OperationArgs {
567 x: i32,
568 y: i32,
569 }
570
571 #[derive(Debug, thiserror::Error)]
572 #[error("Math error")]
573 struct MathError;
574
575 #[derive(Deserialize, Serialize)]
576 struct Adder;
577
578 impl Tool for Adder {
579 const NAME: &'static str = "add";
580 type Error = MathError;
581 type Args = OperationArgs;
582 type Output = i32;
583
584 async fn definition(&self, _prompt: String) -> ToolDefinition {
585 ToolDefinition {
586 name: "add".to_string(),
587 description: "Add x and y together".to_string(),
588 parameters: json!({
589 "type": "object",
590 "properties": {
591 "x": {
592 "type": "number",
593 "description": "The first number to add"
594 },
595 "y": {
596 "type": "number",
597 "description": "The second number to add"
598 }
599 },
600 "required": ["x", "y"]
601 }),
602 }
603 }
604
605 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
606 let result = args.x + args.y;
607 Ok(result)
608 }
609 }
610
611 #[derive(Deserialize, Serialize)]
612 struct Subtract;
613
614 impl Tool for Subtract {
615 const NAME: &'static str = "subtract";
616 type Error = MathError;
617 type Args = OperationArgs;
618 type Output = i32;
619
620 async fn definition(&self, _prompt: String) -> ToolDefinition {
621 serde_json::from_value(json!({
622 "name": "subtract",
623 "description": "Subtract y from x (i.e.: x - y)",
624 "parameters": {
625 "type": "object",
626 "properties": {
627 "x": {
628 "type": "number",
629 "description": "The number to subtract from"
630 },
631 "y": {
632 "type": "number",
633 "description": "The number to subtract"
634 }
635 },
636 "required": ["x", "y"]
637 }
638 }))
639 .expect("Tool Definition")
640 }
641
642 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
643 let result = args.x - args.y;
644 Ok(result)
645 }
646 }
647
648 toolset.add_tool(Adder);
649 toolset.add_tool(Subtract);
650 toolset
651 }
652
653 #[tokio::test]
654 async fn test_get_tool_definitions() {
655 let toolset = get_test_toolset();
656 let tools = toolset.get_tool_definitions().await.unwrap();
657 assert_eq!(tools.len(), 2);
658 }
659
660 #[test]
661 fn test_tool_deletion() {
662 let mut toolset = get_test_toolset();
663 assert_eq!(toolset.tools.len(), 2);
664 toolset.delete_tool("add");
665 assert!(!toolset.contains("add"));
666 assert_eq!(toolset.tools.len(), 1);
667 }
668}