1use std::{collections::HashMap, pin::Pin};
13
14use futures::Future;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 completion::{self, ToolDefinition},
19 embeddings::{embed::EmbedError, tool::ToolSchema},
20};
21
22#[derive(Debug, thiserror::Error)]
23pub enum ToolError {
24 #[error("ToolCallError: {0}")]
26 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
27
28 #[error("JsonError: {0}")]
29 JsonError(#[from] serde_json::Error),
30}
31
32pub trait Tool: Sized + Send + Sync {
88 const NAME: &'static str;
90
91 type Error: std::error::Error + Send + Sync + 'static;
93 type Args: for<'a> Deserialize<'a> + Send + Sync;
95 type Output: Serialize;
97
98 fn name(&self) -> String {
100 Self::NAME.to_string()
101 }
102
103 fn definition(&self, _prompt: String) -> impl Future<Output = ToolDefinition> + Send + Sync;
106
107 fn call(
111 &self,
112 args: Self::Args,
113 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send;
114}
115
116pub trait ToolEmbedding: Tool {
118 type InitError: std::error::Error + Send + Sync + 'static;
119
120 type Context: for<'a> Deserialize<'a> + Serialize;
125
126 type State: Send;
130
131 fn embedding_docs(&self) -> Vec<String>;
135
136 fn context(&self) -> Self::Context;
138
139 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
141}
142
143pub trait ToolDyn: Send + Sync {
145 fn name(&self) -> String;
146
147 fn definition(
148 &self,
149 prompt: String,
150 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>>;
151
152 fn call(
153 &self,
154 args: String,
155 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>>;
156}
157
158impl<T: Tool> ToolDyn for T {
159 fn name(&self) -> String {
160 self.name()
161 }
162
163 fn definition(
164 &self,
165 prompt: String,
166 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
167 Box::pin(<Self as Tool>::definition(self, prompt))
168 }
169
170 fn call(
171 &self,
172 args: String,
173 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
174 Box::pin(async move {
175 match serde_json::from_str(&args) {
176 Ok(args) => <Self as Tool>::call(self, args)
177 .await
178 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
179 .and_then(|output| {
180 serde_json::to_string(&output).map_err(ToolError::JsonError)
181 }),
182 Err(e) => Err(ToolError::JsonError(e)),
183 }
184 })
185 }
186}
187
188#[cfg(feature = "rmcp")]
189#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
190pub mod rmcp {
191 use crate::completion::ToolDefinition;
192 use crate::tool::ToolDyn;
193 use crate::tool::ToolError;
194 use rmcp::model::RawContent;
195 use std::borrow::Cow;
196 use std::pin::Pin;
197
198 pub struct McpTool {
199 definition: rmcp::model::Tool,
200 client: rmcp::service::ServerSink,
201 }
202
203 impl McpTool {
204 pub fn from_mcp_server(
205 definition: rmcp::model::Tool,
206 client: rmcp::service::ServerSink,
207 ) -> Self {
208 Self { definition, client }
209 }
210 }
211
212 impl From<&rmcp::model::Tool> for ToolDefinition {
213 fn from(val: &rmcp::model::Tool) -> Self {
214 Self {
215 name: val.name.to_string(),
216 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
217 parameters: val.schema_as_json_value(),
218 }
219 }
220 }
221
222 impl From<rmcp::model::Tool> for ToolDefinition {
223 fn from(val: rmcp::model::Tool) -> Self {
224 Self {
225 name: val.name.to_string(),
226 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
227 parameters: val.schema_as_json_value(),
228 }
229 }
230 }
231
232 #[derive(Debug, thiserror::Error)]
233 #[error("MCP tool error: {0}")]
234 pub struct McpToolError(String);
235
236 impl From<McpToolError> for ToolError {
237 fn from(e: McpToolError) -> Self {
238 ToolError::ToolCallError(Box::new(e))
239 }
240 }
241
242 impl ToolDyn for McpTool {
243 fn name(&self) -> String {
244 self.definition.name.to_string()
245 }
246
247 fn definition(
248 &self,
249 _prompt: String,
250 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
251 Box::pin(async move {
252 ToolDefinition {
253 name: self.definition.name.to_string(),
254 description: self
255 .definition
256 .description
257 .clone()
258 .unwrap_or(Cow::from(""))
259 .to_string(),
260 parameters: serde_json::to_value(&self.definition.input_schema)
261 .unwrap_or_default(),
262 }
263 })
264 }
265
266 fn call(
267 &self,
268 args: String,
269 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
270 let name = self.definition.name.clone();
271 let arguments = serde_json::from_str(&args).unwrap_or_default();
272
273 Box::pin(async move {
274 let result = self
275 .client
276 .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
277 .await
278 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
279
280 if let Some(true) = result.is_error {
281 let error_msg = result
282 .content
283 .as_deref()
284 .and_then(|errors| errors.first())
285 .and_then(|error| error.as_text())
286 .map(|raw| raw.text.as_str())
287 .unwrap_or("No error message returned");
288 return Err(McpToolError(error_msg.to_string()).into());
289 };
290
291 Ok(result
292 .content
293 .into_iter()
294 .flatten()
295 .map(|c| match c.raw {
296 rmcp::model::RawContent::Text(raw) => raw.text,
297 rmcp::model::RawContent::Image(raw) => {
298 format!("data:{};base64,{}", raw.mime_type, raw.data)
299 }
300 rmcp::model::RawContent::Resource(raw) => match raw.resource {
301 rmcp::model::ResourceContents::TextResourceContents {
302 uri,
303 mime_type,
304 text,
305 } => {
306 format!(
307 "{mime_type}{uri}:{text}",
308 mime_type = mime_type
309 .map(|m| format!("data:{m};"))
310 .unwrap_or_default(),
311 )
312 }
313 rmcp::model::ResourceContents::BlobResourceContents {
314 uri,
315 mime_type,
316 blob,
317 } => format!(
318 "{mime_type}{uri}:{blob}",
319 mime_type = mime_type
320 .map(|m| format!("data:{m};"))
321 .unwrap_or_default(),
322 ),
323 },
324 RawContent::Audio(_) => {
325 unimplemented!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
326 }
327 })
328 .collect::<String>())
329 })
330 }
331 }
332}
333
334pub trait ToolEmbeddingDyn: ToolDyn {
336 fn context(&self) -> serde_json::Result<serde_json::Value>;
337
338 fn embedding_docs(&self) -> Vec<String>;
339}
340
341impl<T> ToolEmbeddingDyn for T
342where
343 T: ToolEmbedding,
344{
345 fn context(&self) -> serde_json::Result<serde_json::Value> {
346 serde_json::to_value(self.context())
347 }
348
349 fn embedding_docs(&self) -> Vec<String> {
350 self.embedding_docs()
351 }
352}
353
354pub(crate) enum ToolType {
355 Simple(Box<dyn ToolDyn>),
356 Embedding(Box<dyn ToolEmbeddingDyn>),
357}
358
359impl ToolType {
360 pub fn name(&self) -> String {
361 match self {
362 ToolType::Simple(tool) => tool.name(),
363 ToolType::Embedding(tool) => tool.name(),
364 }
365 }
366
367 pub async fn definition(&self, prompt: String) -> ToolDefinition {
368 match self {
369 ToolType::Simple(tool) => tool.definition(prompt).await,
370 ToolType::Embedding(tool) => tool.definition(prompt).await,
371 }
372 }
373
374 pub async fn call(&self, args: String) -> Result<String, ToolError> {
375 match self {
376 ToolType::Simple(tool) => tool.call(args).await,
377 ToolType::Embedding(tool) => tool.call(args).await,
378 }
379 }
380}
381
382#[derive(Debug, thiserror::Error)]
383pub enum ToolSetError {
384 #[error("ToolCallError: {0}")]
386 ToolCallError(#[from] ToolError),
387
388 #[error("ToolNotFoundError: {0}")]
389 ToolNotFoundError(String),
390
391 #[error("JsonError: {0}")]
393 JsonError(#[from] serde_json::Error),
394}
395
396#[derive(Default)]
398pub struct ToolSet {
399 pub(crate) tools: HashMap<String, ToolType>,
400}
401
402impl ToolSet {
403 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
405 let mut toolset = Self::default();
406 tools.into_iter().for_each(|tool| {
407 toolset.add_tool(tool);
408 });
409 toolset
410 }
411
412 pub fn builder() -> ToolSetBuilder {
414 ToolSetBuilder::default()
415 }
416
417 pub fn contains(&self, toolname: &str) -> bool {
419 self.tools.contains_key(toolname)
420 }
421
422 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
424 self.tools
425 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
426 }
427
428 pub fn delete_tool(&mut self, tool_name: &str) {
429 let _ = self.tools.remove(tool_name);
430 }
431
432 pub fn add_tools(&mut self, toolset: ToolSet) {
434 self.tools.extend(toolset.tools);
435 }
436
437 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
438 self.tools.get(toolname)
439 }
440
441 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
442 let mut defs = Vec::new();
443 for tool in self.tools.values() {
444 let def = tool.definition(String::new()).await;
445 defs.push(def);
446 }
447 Ok(defs)
448 }
449
450 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
452 if let Some(tool) = self.tools.get(toolname) {
453 tracing::info!(target: "rig",
454 "Calling tool {toolname} with args:\n{}",
455 serde_json::to_string_pretty(&args).unwrap()
456 );
457 Ok(tool.call(args).await?)
458 } else {
459 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
460 }
461 }
462
463 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
465 let mut docs = Vec::new();
466 for tool in self.tools.values() {
467 match tool {
468 ToolType::Simple(tool) => {
469 docs.push(completion::Document {
470 id: tool.name(),
471 text: format!(
472 "\
473 Tool: {}\n\
474 Definition: \n\
475 {}\
476 ",
477 tool.name(),
478 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
479 ),
480 additional_props: HashMap::new(),
481 });
482 }
483 ToolType::Embedding(tool) => {
484 docs.push(completion::Document {
485 id: tool.name(),
486 text: format!(
487 "\
488 Tool: {}\n\
489 Definition: \n\
490 {}\
491 ",
492 tool.name(),
493 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
494 ),
495 additional_props: HashMap::new(),
496 });
497 }
498 }
499 }
500 Ok(docs)
501 }
502
503 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
507 self.tools
508 .values()
509 .filter_map(|tool_type| {
510 if let ToolType::Embedding(tool) = tool_type {
511 Some(ToolSchema::try_from(&**tool))
512 } else {
513 None
514 }
515 })
516 .collect::<Result<Vec<_>, _>>()
517 }
518}
519
520#[derive(Default)]
521pub struct ToolSetBuilder {
522 tools: Vec<ToolType>,
523}
524
525impl ToolSetBuilder {
526 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
527 self.tools.push(ToolType::Simple(Box::new(tool)));
528 self
529 }
530
531 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
532 self.tools.push(ToolType::Embedding(Box::new(tool)));
533 self
534 }
535
536 pub fn build(self) -> ToolSet {
537 ToolSet {
538 tools: self
539 .tools
540 .into_iter()
541 .map(|tool| (tool.name(), tool))
542 .collect(),
543 }
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use serde_json::json;
550
551 use super::*;
552
553 fn get_test_toolset() -> ToolSet {
554 let mut toolset = ToolSet::default();
555
556 #[derive(Deserialize)]
557 struct OperationArgs {
558 x: i32,
559 y: i32,
560 }
561
562 #[derive(Debug, thiserror::Error)]
563 #[error("Math error")]
564 struct MathError;
565
566 #[derive(Deserialize, Serialize)]
567 struct Adder;
568
569 impl Tool for Adder {
570 const NAME: &'static str = "add";
571 type Error = MathError;
572 type Args = OperationArgs;
573 type Output = i32;
574
575 async fn definition(&self, _prompt: String) -> ToolDefinition {
576 ToolDefinition {
577 name: "add".to_string(),
578 description: "Add x and y together".to_string(),
579 parameters: json!({
580 "type": "object",
581 "properties": {
582 "x": {
583 "type": "number",
584 "description": "The first number to add"
585 },
586 "y": {
587 "type": "number",
588 "description": "The second number to add"
589 }
590 },
591 "required": ["x", "y"]
592 }),
593 }
594 }
595
596 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
597 let result = args.x + args.y;
598 Ok(result)
599 }
600 }
601
602 #[derive(Deserialize, Serialize)]
603 struct Subtract;
604
605 impl Tool for Subtract {
606 const NAME: &'static str = "subtract";
607 type Error = MathError;
608 type Args = OperationArgs;
609 type Output = i32;
610
611 async fn definition(&self, _prompt: String) -> ToolDefinition {
612 serde_json::from_value(json!({
613 "name": "subtract",
614 "description": "Subtract y from x (i.e.: x - y)",
615 "parameters": {
616 "type": "object",
617 "properties": {
618 "x": {
619 "type": "number",
620 "description": "The number to subtract from"
621 },
622 "y": {
623 "type": "number",
624 "description": "The number to subtract"
625 }
626 },
627 "required": ["x", "y"]
628 }
629 }))
630 .expect("Tool Definition")
631 }
632
633 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
634 let result = args.x - args.y;
635 Ok(result)
636 }
637 }
638
639 toolset.add_tool(Adder);
640 toolset.add_tool(Subtract);
641 toolset
642 }
643
644 #[tokio::test]
645 async fn test_get_tool_definitions() {
646 let toolset = get_test_toolset();
647 let tools = toolset.get_tool_definitions().await.unwrap();
648 assert_eq!(tools.len(), 2);
649 }
650
651 #[test]
652 fn test_tool_deletion() {
653 let mut toolset = get_test_toolset();
654 assert_eq!(toolset.tools.len(), 2);
655 toolset.delete_tool("add");
656 assert!(!toolset.contains("add"));
657 assert_eq!(toolset.tools.len(), 1);
658 }
659}