1use std::{collections::HashMap, pin::Pin};
13
14use futures::Future;
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 completion::{self, ToolDefinition},
19 embeddings::{embed::EmbedError, tool::ToolSchema},
20};
21
22#[derive(Debug, thiserror::Error)]
23pub enum ToolError {
24 #[error("ToolCallError: {0}")]
26 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
27
28 #[error("JsonError: {0}")]
29 JsonError(#[from] serde_json::Error),
30}
31
32pub trait Tool: Sized + Send + Sync {
88 const NAME: &'static str;
90
91 type Error: std::error::Error + Send + Sync + 'static;
93 type Args: for<'a> Deserialize<'a> + Send + Sync;
95 type Output: Serialize;
97
98 fn name(&self) -> String {
100 Self::NAME.to_string()
101 }
102
103 fn definition(&self, _prompt: String) -> impl Future<Output = ToolDefinition> + Send + Sync;
106
107 fn call(
111 &self,
112 args: Self::Args,
113 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send;
114}
115
116pub trait ToolEmbedding: Tool {
118 type InitError: std::error::Error + Send + Sync + 'static;
119
120 type Context: for<'a> Deserialize<'a> + Serialize;
125
126 type State: Send;
130
131 fn embedding_docs(&self) -> Vec<String>;
135
136 fn context(&self) -> Self::Context;
138
139 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
141}
142
143pub trait ToolDyn: Send + Sync {
145 fn name(&self) -> String;
146
147 fn definition(
148 &self,
149 prompt: String,
150 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>>;
151
152 fn call(
153 &self,
154 args: String,
155 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>>;
156}
157
158impl<T: Tool> ToolDyn for T {
159 fn name(&self) -> String {
160 self.name()
161 }
162
163 fn definition(
164 &self,
165 prompt: String,
166 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
167 Box::pin(<Self as Tool>::definition(self, prompt))
168 }
169
170 fn call(
171 &self,
172 args: String,
173 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
174 Box::pin(async move {
175 match serde_json::from_str(&args) {
176 Ok(args) => <Self as Tool>::call(self, args)
177 .await
178 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
179 .and_then(|output| {
180 serde_json::to_string(&output).map_err(ToolError::JsonError)
181 }),
182 Err(e) => Err(ToolError::JsonError(e)),
183 }
184 })
185 }
186}
187
188#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
189#[cfg(feature = "rmcp")]
190pub mod rmcp {
191 use crate::completion::ToolDefinition;
192 use crate::tool::ToolDyn;
193 use crate::tool::ToolError;
194 use rmcp::model::RawContent;
195 use std::borrow::Cow;
196 use std::pin::Pin;
197
198 pub struct McpTool {
199 definition: rmcp::model::Tool,
200 client: rmcp::service::ServerSink,
201 }
202
203 impl McpTool {
204 pub fn from_mcp_server(
205 definition: rmcp::model::Tool,
206 client: rmcp::service::ServerSink,
207 ) -> Self {
208 Self { definition, client }
209 }
210 }
211
212 impl From<&rmcp::model::Tool> for ToolDefinition {
213 fn from(val: &rmcp::model::Tool) -> Self {
214 Self {
215 name: val.name.to_string(),
216 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
217 parameters: val.schema_as_json_value(),
218 }
219 }
220 }
221
222 impl From<rmcp::model::Tool> for ToolDefinition {
223 fn from(val: rmcp::model::Tool) -> Self {
224 Self {
225 name: val.name.to_string(),
226 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
227 parameters: val.schema_as_json_value(),
228 }
229 }
230 }
231
232 #[derive(Debug, thiserror::Error)]
233 #[error("MCP tool error: {0}")]
234 pub struct McpToolError(String);
235
236 impl From<McpToolError> for ToolError {
237 fn from(e: McpToolError) -> Self {
238 ToolError::ToolCallError(Box::new(e))
239 }
240 }
241
242 impl ToolDyn for McpTool {
243 fn name(&self) -> String {
244 self.definition.name.to_string()
245 }
246
247 fn definition(
248 &self,
249 _prompt: String,
250 ) -> Pin<Box<dyn Future<Output = ToolDefinition> + Send + Sync + '_>> {
251 Box::pin(async move {
252 ToolDefinition {
253 name: self.definition.name.to_string(),
254 description: self
255 .definition
256 .description
257 .clone()
258 .unwrap_or(Cow::from(""))
259 .to_string(),
260 parameters: serde_json::to_value(&self.definition.input_schema)
261 .unwrap_or_default(),
262 }
263 })
264 }
265
266 fn call(
267 &self,
268 args: String,
269 ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + '_>> {
270 let name = self.definition.name.clone();
271 let arguments = serde_json::from_str(&args).unwrap_or_default();
272
273 Box::pin(async move {
274 let result = self
275 .client
276 .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
277 .await
278 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
279
280 if let Some(true) = result.is_error {
281 let error_msg = result
282 .content
283 .as_deref()
284 .and_then(|errors| errors.first())
285 .and_then(|error| error.as_text())
286 .map(|raw| raw.text.as_str())
287 .unwrap_or("No error message returned");
288 return Err(McpToolError(error_msg.to_string()).into());
289 };
290
291 Ok(result
292 .content
293 .into_iter()
294 .flatten()
295 .map(|c| match c.raw {
296 rmcp::model::RawContent::Text(raw) => raw.text,
297 rmcp::model::RawContent::Image(raw) => {
298 format!("data:{};base64,{}", raw.mime_type, raw.data)
299 }
300 rmcp::model::RawContent::Resource(raw) => match raw.resource {
301 rmcp::model::ResourceContents::TextResourceContents {
302 uri,
303 mime_type,
304 text,
305 } => {
306 format!(
307 "{mime_type}{uri}:{text}",
308 mime_type = mime_type
309 .map(|m| format!("data:{m};"))
310 .unwrap_or_default(),
311 )
312 }
313 rmcp::model::ResourceContents::BlobResourceContents {
314 uri,
315 mime_type,
316 blob,
317 } => format!(
318 "{mime_type}{uri}:{blob}",
319 mime_type = mime_type
320 .map(|m| format!("data:{m};"))
321 .unwrap_or_default(),
322 ),
323 },
324 RawContent::Audio(_) => {
325 unimplemented!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
326 }
327 })
328 .collect::<String>())
329 })
330 }
331 }
332}
333
334pub trait ToolEmbeddingDyn: ToolDyn {
336 fn context(&self) -> serde_json::Result<serde_json::Value>;
337
338 fn embedding_docs(&self) -> Vec<String>;
339}
340
341impl<T: ToolEmbedding> ToolEmbeddingDyn for T {
342 fn context(&self) -> serde_json::Result<serde_json::Value> {
343 serde_json::to_value(self.context())
344 }
345
346 fn embedding_docs(&self) -> Vec<String> {
347 self.embedding_docs()
348 }
349}
350
351pub(crate) enum ToolType {
352 Simple(Box<dyn ToolDyn>),
353 Embedding(Box<dyn ToolEmbeddingDyn>),
354}
355
356impl ToolType {
357 pub fn name(&self) -> String {
358 match self {
359 ToolType::Simple(tool) => tool.name(),
360 ToolType::Embedding(tool) => tool.name(),
361 }
362 }
363
364 pub async fn definition(&self, prompt: String) -> ToolDefinition {
365 match self {
366 ToolType::Simple(tool) => tool.definition(prompt).await,
367 ToolType::Embedding(tool) => tool.definition(prompt).await,
368 }
369 }
370
371 pub async fn call(&self, args: String) -> Result<String, ToolError> {
372 match self {
373 ToolType::Simple(tool) => tool.call(args).await,
374 ToolType::Embedding(tool) => tool.call(args).await,
375 }
376 }
377}
378
379#[derive(Debug, thiserror::Error)]
380pub enum ToolSetError {
381 #[error("ToolCallError: {0}")]
383 ToolCallError(#[from] ToolError),
384
385 #[error("ToolNotFoundError: {0}")]
386 ToolNotFoundError(String),
387
388 #[error("JsonError: {0}")]
390 JsonError(#[from] serde_json::Error),
391}
392
393#[derive(Default)]
395pub struct ToolSet {
396 pub(crate) tools: HashMap<String, ToolType>,
397}
398
399impl ToolSet {
400 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
402 let mut toolset = Self::default();
403 tools.into_iter().for_each(|tool| {
404 toolset.add_tool(tool);
405 });
406 toolset
407 }
408
409 pub fn builder() -> ToolSetBuilder {
411 ToolSetBuilder::default()
412 }
413
414 pub fn contains(&self, toolname: &str) -> bool {
416 self.tools.contains_key(toolname)
417 }
418
419 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
421 self.tools
422 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
423 }
424
425 pub fn delete_tool(&mut self, tool_name: &str) {
426 let _ = self.tools.remove(tool_name);
427 }
428
429 pub fn add_tools(&mut self, toolset: ToolSet) {
431 self.tools.extend(toolset.tools);
432 }
433
434 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
435 self.tools.get(toolname)
436 }
437
438 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
439 let mut defs = Vec::new();
440 for tool in self.tools.values() {
441 let def = tool.definition(String::new()).await;
442 defs.push(def);
443 }
444 Ok(defs)
445 }
446
447 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
449 if let Some(tool) = self.tools.get(toolname) {
450 tracing::info!(target: "rig",
451 "Calling tool {toolname} with args:\n{}",
452 serde_json::to_string_pretty(&args).unwrap()
453 );
454 Ok(tool.call(args).await?)
455 } else {
456 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
457 }
458 }
459
460 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
462 let mut docs = Vec::new();
463 for tool in self.tools.values() {
464 match tool {
465 ToolType::Simple(tool) => {
466 docs.push(completion::Document {
467 id: tool.name(),
468 text: format!(
469 "\
470 Tool: {}\n\
471 Definition: \n\
472 {}\
473 ",
474 tool.name(),
475 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
476 ),
477 additional_props: HashMap::new(),
478 });
479 }
480 ToolType::Embedding(tool) => {
481 docs.push(completion::Document {
482 id: tool.name(),
483 text: format!(
484 "\
485 Tool: {}\n\
486 Definition: \n\
487 {}\
488 ",
489 tool.name(),
490 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
491 ),
492 additional_props: HashMap::new(),
493 });
494 }
495 }
496 }
497 Ok(docs)
498 }
499
500 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
504 self.tools
505 .values()
506 .filter_map(|tool_type| {
507 if let ToolType::Embedding(tool) = tool_type {
508 Some(ToolSchema::try_from(&**tool))
509 } else {
510 None
511 }
512 })
513 .collect::<Result<Vec<_>, _>>()
514 }
515}
516
517#[derive(Default)]
518pub struct ToolSetBuilder {
519 tools: Vec<ToolType>,
520}
521
522impl ToolSetBuilder {
523 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
524 self.tools.push(ToolType::Simple(Box::new(tool)));
525 self
526 }
527
528 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
529 self.tools.push(ToolType::Embedding(Box::new(tool)));
530 self
531 }
532
533 pub fn build(self) -> ToolSet {
534 ToolSet {
535 tools: self
536 .tools
537 .into_iter()
538 .map(|tool| (tool.name(), tool))
539 .collect(),
540 }
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use serde_json::json;
547
548 use super::*;
549
550 fn get_test_toolset() -> ToolSet {
551 let mut toolset = ToolSet::default();
552
553 #[derive(Deserialize)]
554 struct OperationArgs {
555 x: i32,
556 y: i32,
557 }
558
559 #[derive(Debug, thiserror::Error)]
560 #[error("Math error")]
561 struct MathError;
562
563 #[derive(Deserialize, Serialize)]
564 struct Adder;
565
566 impl Tool for Adder {
567 const NAME: &'static str = "add";
568 type Error = MathError;
569 type Args = OperationArgs;
570 type Output = i32;
571
572 async fn definition(&self, _prompt: String) -> ToolDefinition {
573 ToolDefinition {
574 name: "add".to_string(),
575 description: "Add x and y together".to_string(),
576 parameters: json!({
577 "type": "object",
578 "properties": {
579 "x": {
580 "type": "number",
581 "description": "The first number to add"
582 },
583 "y": {
584 "type": "number",
585 "description": "The second number to add"
586 }
587 },
588 "required": ["x", "y"]
589 }),
590 }
591 }
592
593 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
594 let result = args.x + args.y;
595 Ok(result)
596 }
597 }
598
599 #[derive(Deserialize, Serialize)]
600 struct Subtract;
601
602 impl Tool for Subtract {
603 const NAME: &'static str = "subtract";
604 type Error = MathError;
605 type Args = OperationArgs;
606 type Output = i32;
607
608 async fn definition(&self, _prompt: String) -> ToolDefinition {
609 serde_json::from_value(json!({
610 "name": "subtract",
611 "description": "Subtract y from x (i.e.: x - y)",
612 "parameters": {
613 "type": "object",
614 "properties": {
615 "x": {
616 "type": "number",
617 "description": "The number to subtract from"
618 },
619 "y": {
620 "type": "number",
621 "description": "The number to subtract"
622 }
623 },
624 "required": ["x", "y"]
625 }
626 }))
627 .expect("Tool Definition")
628 }
629
630 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
631 let result = args.x - args.y;
632 Ok(result)
633 }
634 }
635
636 toolset.add_tool(Adder);
637 toolset.add_tool(Subtract);
638 toolset
639 }
640
641 #[tokio::test]
642 async fn test_get_tool_definitions() {
643 let toolset = get_test_toolset();
644 let tools = toolset.get_tool_definitions().await.unwrap();
645 assert_eq!(tools.len(), 2);
646 }
647
648 #[test]
649 fn test_tool_deletion() {
650 let mut toolset = get_test_toolset();
651 assert_eq!(toolset.tools.len(), 2);
652 toolset.delete_tool("add");
653 assert!(!toolset.contains("add"));
654 assert_eq!(toolset.tools.len(), 1);
655 }
656}