1use std::{collections::HashMap, pin::Pin};
13
14use futures::Future;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 completion::{self, ToolDefinition},
19 embeddings::{embed::EmbedError, tool::ToolSchema},
20};
21
22#[derive(Debug, thiserror::Error)]
23pub enum ToolError {
24 #[error("ToolCallError: {0}")]
26 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
27
28 #[error("JsonError: {0}")]
29 JsonError(#[from] serde_json::Error),
30}
31
32pub trait Tool: Sized + Send + Sync {
88 const NAME: &'static str;
90
91 type Error: std::error::Error + Send + Sync + 'static;
93 type Args: for<'a> Deserialize<'a> + Send + Sync;
95 type Output: Serialize;
97
98 fn name(&self) -> String {
100 Self::NAME.to_string()
101 }
102
103 fn definition(&self, _prompt: String) -> impl Future<Output = ToolDefinition> + Send + Sync;
106
107 fn call(
111 &self,
112 args: Self::Args,
113 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send + Sync;
114}
115
116pub trait ToolEmbedding: Tool {
118 type InitError: std::error::Error + Send + Sync + 'static;
119
120 type Context: for<'a> Deserialize<'a> + Serialize;
125
126 type State: Send;
130
131 fn embedding_docs(&self) -> Vec<String>;
135
136 fn context(&self) -> Self::Context;
138
139 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
141}
142
143pub trait ToolDyn: Send + Sync {
145 fn name(&self) -> String;
146
147 fn definition(
148 &self,
149 prompt: String,
150 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>>;
151
152 fn call(
153 &self,
154 args: String,
155 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>>;
156}
157
158impl<T: Tool> ToolDyn for T {
159 fn name(&self) -> String {
160 self.name()
161 }
162
163 fn definition(
164 &self,
165 prompt: String,
166 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
167 Box::pin(<Self as Tool>::definition(self, prompt))
168 }
169
170 fn call(
171 &self,
172 args: String,
173 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
174 Box::pin(async move {
175 match serde_json::from_str(&args) {
176 Ok(args) => <Self as Tool>::call(self, args)
177 .await
178 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
179 .and_then(|output| {
180 serde_json::to_string(&output).map_err(ToolError::JsonError)
181 }),
182 Err(e) => Err(ToolError::JsonError(e)),
183 }
184 })
185 }
186}
187
188#[deprecated(
189 since = "0.16.0",
190 note = "Since the official Rust MCP SDK (`rmcp`) has been added, the original Rig MCP integration with `mcp-core` has now been deprecated.
191Please migrate over to the new integration - you can use the `rmcp` feature flag to do so.
192This integration will be fully removed (and replaced with the `rmcp` one) by 0.18.0 at the earliest.
193A guide can be found at `http://docs.rig.rs`, Rig's official docsite."
194)]
195#[cfg(feature = "mcp")]
196pub mod mcp {
197 use crate::completion::ToolDefinition;
198 use crate::tool::ToolDyn;
199 use crate::tool::ToolError;
200 use std::pin::Pin;
201
202 pub struct McpTool<T: mcp_core::transport::Transport> {
203 definition: mcp_core::types::Tool,
204 client: mcp_core::client::Client<T>,
205 }
206
207 impl<T> McpTool<T>
208 where
209 T: mcp_core::transport::Transport,
210 {
211 pub fn from_mcp_server(
212 definition: mcp_core::types::Tool,
213 client: mcp_core::client::Client<T>,
214 ) -> Self {
215 Self { definition, client }
216 }
217 }
218
219 impl From<&mcp_core::types::Tool> for ToolDefinition {
220 fn from(val: &mcp_core::types::Tool) -> Self {
221 Self {
222 name: val.name.to_owned(),
223 description: val.description.to_owned().unwrap_or_default(),
224 parameters: val.input_schema.to_owned(),
225 }
226 }
227 }
228
229 impl From<mcp_core::types::Tool> for ToolDefinition {
230 fn from(val: mcp_core::types::Tool) -> Self {
231 Self {
232 name: val.name,
233 description: val.description.unwrap_or_default(),
234 parameters: val.input_schema,
235 }
236 }
237 }
238
239 #[derive(Debug, thiserror::Error)]
240 #[error("MCP tool error: {0}")]
241 pub struct McpToolError(String);
242
243 impl From<McpToolError> for ToolError {
244 fn from(e: McpToolError) -> Self {
245 ToolError::ToolCallError(Box::new(e))
246 }
247 }
248
249 impl<T> ToolDyn for McpTool<T>
250 where
251 T: mcp_core::transport::Transport,
252 {
253 fn name(&self) -> String {
254 self.definition.name.clone()
255 }
256
257 fn definition(
258 &self,
259 _prompt: String,
260 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
261 Box::pin(async move {
262 ToolDefinition {
263 name: self.definition.name.clone(),
264 description: match &self.definition.description {
265 Some(desc) => desc.clone(),
266 None => String::new(),
267 },
268 parameters: serde_json::to_value(&self.definition.input_schema)
269 .unwrap_or_default(),
270 }
271 })
272 }
273
274 fn call(
275 &self,
276 args: String,
277 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
278 let name = self.definition.name.clone();
279 let args_clone = args.clone();
280 let args: serde_json::Value = serde_json::from_str(&args_clone).unwrap_or_default();
281 Box::pin(async move {
282 let result = self
283 .client
284 .call_tool(&name, Some(args))
285 .await
286 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
287
288 if result.is_error.unwrap_or(false) {
289 if let Some(error) = result.content.first() {
290 match error {
291 mcp_core::types::ToolResponseContent::Text(text_content) => {
292 return Err(McpToolError(text_content.text.clone()).into());
293 }
294 _ => {
295 return Err(
296 McpToolError("Unsuppported error type".to_string()).into()
297 );
298 }
299 }
300 } else {
301 return Err(McpToolError("No error message returned".to_string()).into());
302 }
303 }
304
305 Ok(result
306 .content
307 .into_iter()
308 .map(|c| match c {
309 mcp_core::types::ToolResponseContent::Text(text_content) => {
310 text_content.text
311 }
312 mcp_core::types::ToolResponseContent::Image(image_content) => {
313 format!(
314 "data:{};base64,{}",
315 image_content.mime_type, image_content.data
316 )
317 }
318 mcp_core::types::ToolResponseContent::Audio(audio_content) => {
319 format!(
320 "data:{};base64,{}",
321 audio_content.mime_type, audio_content.data
322 )
323 }
324
325 mcp_core::types::ToolResponseContent::Resource(embedded_resource) => {
326 format!(
327 "{}{}",
328 embedded_resource
329 .resource
330 .mime_type
331 .map(|m| format!("data:{m};"))
332 .unwrap_or_default(),
333 embedded_resource.resource.uri
334 )
335 }
336 })
337 .collect::<Vec<_>>()
338 .join(""))
339 })
340 }
341 }
342}
343
344#[cfg(feature = "rmcp")]
345pub mod rmcp {
346 use crate::completion::ToolDefinition;
347 use crate::tool::ToolDyn;
348 use crate::tool::ToolError;
349 use rmcp::model::RawContent;
350 use std::borrow::Cow;
351 use std::pin::Pin;
352
353 pub struct McpTool {
354 definition: rmcp::model::Tool,
355 client: rmcp::service::ServerSink,
356 }
357
358 impl McpTool {
359 pub fn from_mcp_server(
360 definition: rmcp::model::Tool,
361 client: rmcp::service::ServerSink,
362 ) -> Self {
363 Self { definition, client }
364 }
365 }
366
367 impl From<&rmcp::model::Tool> for ToolDefinition {
368 fn from(val: &rmcp::model::Tool) -> Self {
369 Self {
370 name: val.name.to_string(),
371 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
372 parameters: val.schema_as_json_value(),
373 }
374 }
375 }
376
377 impl From<rmcp::model::Tool> for ToolDefinition {
378 fn from(val: rmcp::model::Tool) -> Self {
379 Self {
380 name: val.name.to_string(),
381 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
382 parameters: val.schema_as_json_value(),
383 }
384 }
385 }
386
387 #[derive(Debug, thiserror::Error)]
388 #[error("MCP tool error: {0}")]
389 pub struct McpToolError(String);
390
391 impl From<McpToolError> for ToolError {
392 fn from(e: McpToolError) -> Self {
393 ToolError::ToolCallError(Box::new(e))
394 }
395 }
396
397 impl ToolDyn for McpTool {
398 fn name(&self) -> String {
399 self.definition.name.to_string()
400 }
401
402 fn definition(
403 &self,
404 _prompt: String,
405 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
406 Box::pin(async move {
407 ToolDefinition {
408 name: self.definition.name.to_string(),
409 description: self
410 .definition
411 .description
412 .clone()
413 .unwrap_or(Cow::from(""))
414 .to_string(),
415 parameters: serde_json::to_value(&self.definition.input_schema)
416 .unwrap_or_default(),
417 }
418 })
419 }
420
421 fn call(
422 &self,
423 args: String,
424 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + Sync + '_>> {
425 let name = self.definition.name.clone();
426 let arguments = serde_json::from_str(&args).unwrap_or_default();
427
428 Box::pin(async move {
429 let result = self
430 .client
431 .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
432 .await
433 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
434
435 if result.is_error.unwrap_or(false) {
436 if let Some(error) = result.content.first() {
437 if let Some(raw) = error.as_text() {
438 return Err(McpToolError(raw.text.clone()).into());
439 } else {
440 return Err(McpToolError("Unsuppported error type".to_string()).into());
441 }
442 } else {
443 return Err(McpToolError("No error message returned".to_string()).into());
444 }
445 };
446
447 Ok(result
448 .content
449 .into_iter()
450 .map(|c| match c.raw {
451 rmcp::model::RawContent::Text(raw) => raw.text,
452 rmcp::model::RawContent::Image(raw) => {
453 format!("data:{};base64,{}", raw.mime_type, raw.data)
454 }
455 rmcp::model::RawContent::Resource(raw) => match raw.resource {
456 rmcp::model::ResourceContents::TextResourceContents {
457 uri,
458 mime_type,
459 text,
460 } => {
461 format!(
462 "{mime_type}{uri}:{text}",
463 mime_type = mime_type
464 .map(|m| format!("data:{m};"))
465 .unwrap_or_default(),
466 )
467 }
468 rmcp::model::ResourceContents::BlobResourceContents {
469 uri,
470 mime_type,
471 blob,
472 } => format!(
473 "{mime_type}{uri}:{blob}",
474 mime_type = mime_type
475 .map(|m| format!("data:{m};"))
476 .unwrap_or_default(),
477 ),
478 },
479 RawContent::Audio(_) => {
480 unimplemented!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
481 }
482 })
483 .collect::<Vec<_>>()
484 .join(""))
485 })
486 }
487 }
488}
489
490pub trait ToolEmbeddingDyn: ToolDyn {
492 fn context(&self) -> serde_json::Result<serde_json::Value>;
493
494 fn embedding_docs(&self) -> Vec<String>;
495}
496
497impl<T: ToolEmbedding> ToolEmbeddingDyn for T {
498 fn context(&self) -> serde_json::Result<serde_json::Value> {
499 serde_json::to_value(self.context())
500 }
501
502 fn embedding_docs(&self) -> Vec<String> {
503 self.embedding_docs()
504 }
505}
506
507pub(crate) enum ToolType {
508 Simple(Box<dyn ToolDyn>),
509 Embedding(Box<dyn ToolEmbeddingDyn>),
510}
511
512impl ToolType {
513 pub fn name(&self) -> String {
514 match self {
515 ToolType::Simple(tool) => tool.name(),
516 ToolType::Embedding(tool) => tool.name(),
517 }
518 }
519
520 pub async fn definition(&self, prompt: String) -> ToolDefinition {
521 match self {
522 ToolType::Simple(tool) => tool.definition(prompt).await,
523 ToolType::Embedding(tool) => tool.definition(prompt).await,
524 }
525 }
526
527 pub async fn call(&self, args: String) -> Result<String, ToolError> {
528 match self {
529 ToolType::Simple(tool) => tool.call(args).await,
530 ToolType::Embedding(tool) => tool.call(args).await,
531 }
532 }
533}
534
535#[derive(Debug, thiserror::Error)]
536pub enum ToolSetError {
537 #[error("ToolCallError: {0}")]
539 ToolCallError(#[from] ToolError),
540
541 #[error("ToolNotFoundError: {0}")]
542 ToolNotFoundError(String),
543
544 #[error("JsonError: {0}")]
546 JsonError(#[from] serde_json::Error),
547}
548
549#[derive(Default)]
551pub struct ToolSet {
552 pub(crate) tools: HashMap<String, ToolType>,
553}
554
555impl ToolSet {
556 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
558 let mut toolset = Self::default();
559 tools.into_iter().for_each(|tool| {
560 toolset.add_tool(tool);
561 });
562 toolset
563 }
564
565 pub fn builder() -> ToolSetBuilder {
567 ToolSetBuilder::default()
568 }
569
570 pub fn contains(&self, toolname: &str) -> bool {
572 self.tools.contains_key(toolname)
573 }
574
575 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
577 self.tools
578 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
579 }
580
581 pub fn add_tools(&mut self, toolset: ToolSet) {
583 self.tools.extend(toolset.tools);
584 }
585
586 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
587 self.tools.get(toolname)
588 }
589
590 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
592 if let Some(tool) = self.tools.get(toolname) {
593 tracing::info!(target: "rig",
594 "Calling tool {toolname} with args:\n{}",
595 serde_json::to_string_pretty(&args).unwrap()
596 );
597 Ok(tool.call(args).await?)
598 } else {
599 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
600 }
601 }
602
603 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
605 let mut docs = Vec::new();
606 for tool in self.tools.values() {
607 match tool {
608 ToolType::Simple(tool) => {
609 docs.push(completion::Document {
610 id: tool.name(),
611 text: format!(
612 "\
613 Tool: {}\n\
614 Definition: \n\
615 {}\
616 ",
617 tool.name(),
618 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
619 ),
620 additional_props: HashMap::new(),
621 });
622 }
623 ToolType::Embedding(tool) => {
624 docs.push(completion::Document {
625 id: tool.name(),
626 text: format!(
627 "\
628 Tool: {}\n\
629 Definition: \n\
630 {}\
631 ",
632 tool.name(),
633 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
634 ),
635 additional_props: HashMap::new(),
636 });
637 }
638 }
639 }
640 Ok(docs)
641 }
642
643 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
647 self.tools
648 .values()
649 .filter_map(|tool_type| {
650 if let ToolType::Embedding(tool) = tool_type {
651 Some(ToolSchema::try_from(&**tool))
652 } else {
653 None
654 }
655 })
656 .collect::<Result<Vec<_>, _>>()
657 }
658}
659
660#[derive(Default)]
661pub struct ToolSetBuilder {
662 tools: Vec<ToolType>,
663}
664
665impl ToolSetBuilder {
666 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
667 self.tools.push(ToolType::Simple(Box::new(tool)));
668 self
669 }
670
671 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
672 self.tools.push(ToolType::Embedding(Box::new(tool)));
673 self
674 }
675
676 pub fn build(self) -> ToolSet {
677 ToolSet {
678 tools: self
679 .tools
680 .into_iter()
681 .map(|tool| (tool.name(), tool))
682 .collect(),
683 }
684 }
685}