1pub mod server;
13use std::collections::HashMap;
14
15use futures::Future;
16use serde::{Deserialize, Serialize};
17
18use crate::{
19 completion::{self, ToolDefinition},
20 embeddings::{embed::EmbedError, tool::ToolSchema},
21 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
22};
23
24#[derive(Debug, thiserror::Error)]
25pub enum ToolError {
26 #[cfg(not(target_family = "wasm"))]
27 #[error("ToolCallError: {0}")]
29 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
30
31 #[cfg(target_family = "wasm")]
32 #[error("ToolCallError: {0}")]
34 ToolCallError(#[from] Box<dyn std::error::Error>),
35
36 #[error("JsonError: {0}")]
37 JsonError(#[from] serde_json::Error),
38}
39
40pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
96 const NAME: &'static str;
98
99 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
101 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
103 type Output: Serialize;
105
106 fn name(&self) -> String {
108 Self::NAME.to_string()
109 }
110
111 fn definition(
114 &self,
115 _prompt: String,
116 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
117
118 fn call(
122 &self,
123 args: Self::Args,
124 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
125}
126
127pub trait ToolEmbedding: Tool {
129 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
130
131 type Context: for<'a> Deserialize<'a> + Serialize;
136
137 type State: WasmCompatSend;
141
142 fn embedding_docs(&self) -> Vec<String>;
146
147 fn context(&self) -> Self::Context;
149
150 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
152}
153
154pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
156 fn name(&self) -> String;
157
158 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
159
160 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
161}
162
163impl<T: Tool> ToolDyn for T {
164 fn name(&self) -> String {
165 self.name()
166 }
167
168 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
169 Box::pin(<Self as Tool>::definition(self, prompt))
170 }
171
172 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
173 Box::pin(async move {
174 match serde_json::from_str(&args) {
175 Ok(args) => <Self as Tool>::call(self, args)
176 .await
177 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
178 .and_then(|output| {
179 serde_json::to_string(&output).map_err(ToolError::JsonError)
180 }),
181 Err(e) => Err(ToolError::JsonError(e)),
182 }
183 })
184 }
185}
186
187#[cfg(feature = "rmcp")]
188#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
189pub mod rmcp {
190 use crate::completion::ToolDefinition;
191 use crate::tool::ToolDyn;
192 use crate::tool::ToolError;
193 use crate::wasm_compat::WasmBoxedFuture;
194 use rmcp::model::RawContent;
195 use std::borrow::Cow;
196
197 pub struct McpTool {
198 definition: rmcp::model::Tool,
199 client: rmcp::service::ServerSink,
200 }
201
202 impl McpTool {
203 pub fn from_mcp_server(
204 definition: rmcp::model::Tool,
205 client: rmcp::service::ServerSink,
206 ) -> Self {
207 Self { definition, client }
208 }
209 }
210
211 impl From<&rmcp::model::Tool> for ToolDefinition {
212 fn from(val: &rmcp::model::Tool) -> Self {
213 Self {
214 name: val.name.to_string(),
215 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
216 parameters: val.schema_as_json_value(),
217 }
218 }
219 }
220
221 impl From<rmcp::model::Tool> for ToolDefinition {
222 fn from(val: rmcp::model::Tool) -> Self {
223 Self {
224 name: val.name.to_string(),
225 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
226 parameters: val.schema_as_json_value(),
227 }
228 }
229 }
230
231 #[derive(Debug, thiserror::Error)]
232 #[error("MCP tool error: {0}")]
233 pub struct McpToolError(String);
234
235 impl From<McpToolError> for ToolError {
236 fn from(e: McpToolError) -> Self {
237 ToolError::ToolCallError(Box::new(e))
238 }
239 }
240
241 impl ToolDyn for McpTool {
242 fn name(&self) -> String {
243 self.definition.name.to_string()
244 }
245
246 fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
247 Box::pin(async move {
248 ToolDefinition {
249 name: self.definition.name.to_string(),
250 description: self
251 .definition
252 .description
253 .clone()
254 .unwrap_or(Cow::from(""))
255 .to_string(),
256 parameters: serde_json::to_value(&self.definition.input_schema)
257 .unwrap_or_default(),
258 }
259 })
260 }
261
262 fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
263 let name = self.definition.name.clone();
264 let arguments = serde_json::from_str(&args).unwrap_or_default();
265
266 Box::pin(async move {
267 let result = self
268 .client
269 .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
270 .await
271 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
272
273 if let Some(true) = result.is_error {
274 let error_msg = result
275 .content
276 .into_iter()
277 .map(|x| x.raw.as_text().map(|y| y.to_owned()))
278 .map(|x| x.map(|x| x.clone().text))
279 .collect::<Option<Vec<String>>>();
280
281 let error_message = error_msg.map(|x| x.join("\n"));
282 if let Some(error_message) = error_message {
283 return Err(McpToolError(error_message).into());
284 } else {
285 return Err(McpToolError("No message returned".to_string()).into());
286 }
287 };
288
289 Ok(result
290 .content
291 .into_iter()
292 .map(|c| match c.raw {
293 rmcp::model::RawContent::Text(raw) => raw.text,
294 rmcp::model::RawContent::Image(raw) => {
295 format!("data:{};base64,{}", raw.mime_type, raw.data)
296 }
297 rmcp::model::RawContent::Resource(raw) => match raw.resource {
298 rmcp::model::ResourceContents::TextResourceContents {
299 uri,
300 mime_type,
301 text,
302 ..
303 } => {
304 format!(
305 "{mime_type}{uri}:{text}",
306 mime_type = mime_type
307 .map(|m| format!("data:{m};"))
308 .unwrap_or_default(),
309 )
310 }
311 rmcp::model::ResourceContents::BlobResourceContents {
312 uri,
313 mime_type,
314 blob,
315 ..
316 } => format!(
317 "{mime_type}{uri}:{blob}",
318 mime_type = mime_type
319 .map(|m| format!("data:{m};"))
320 .unwrap_or_default(),
321 ),
322 },
323 RawContent::Audio(_) => {
324 unimplemented!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
325 }
326 thing => {
327 unimplemented!("Unsupported type found: {thing:?}")
328 }
329 })
330 .collect::<String>())
331 })
332 }
333 }
334}
335
336pub trait ToolEmbeddingDyn: ToolDyn {
338 fn context(&self) -> serde_json::Result<serde_json::Value>;
339
340 fn embedding_docs(&self) -> Vec<String>;
341}
342
343impl<T> ToolEmbeddingDyn for T
344where
345 T: ToolEmbedding + 'static,
346{
347 fn context(&self) -> serde_json::Result<serde_json::Value> {
348 serde_json::to_value(self.context())
349 }
350
351 fn embedding_docs(&self) -> Vec<String> {
352 self.embedding_docs()
353 }
354}
355
356pub(crate) enum ToolType {
357 Simple(Box<dyn ToolDyn>),
358 Embedding(Box<dyn ToolEmbeddingDyn>),
359}
360
361impl ToolType {
362 pub fn name(&self) -> String {
363 match self {
364 ToolType::Simple(tool) => tool.name(),
365 ToolType::Embedding(tool) => tool.name(),
366 }
367 }
368
369 pub async fn definition(&self, prompt: String) -> ToolDefinition {
370 match self {
371 ToolType::Simple(tool) => tool.definition(prompt).await,
372 ToolType::Embedding(tool) => tool.definition(prompt).await,
373 }
374 }
375
376 pub async fn call(&self, args: String) -> Result<String, ToolError> {
377 match self {
378 ToolType::Simple(tool) => tool.call(args).await,
379 ToolType::Embedding(tool) => tool.call(args).await,
380 }
381 }
382}
383
384#[derive(Debug, thiserror::Error)]
385pub enum ToolSetError {
386 #[error("ToolCallError: {0}")]
388 ToolCallError(#[from] ToolError),
389
390 #[error("ToolNotFoundError: {0}")]
392 ToolNotFoundError(String),
393
394 #[error("JsonError: {0}")]
396 JsonError(#[from] serde_json::Error),
397
398 #[error("Tool call interrupted")]
400 Interrupted,
401}
402
403#[derive(Default)]
405pub struct ToolSet {
406 pub(crate) tools: HashMap<String, ToolType>,
407}
408
409impl ToolSet {
410 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
412 let mut toolset = Self::default();
413 tools.into_iter().for_each(|tool| {
414 toolset.add_tool(tool);
415 });
416 toolset
417 }
418
419 pub fn builder() -> ToolSetBuilder {
421 ToolSetBuilder::default()
422 }
423
424 pub fn contains(&self, toolname: &str) -> bool {
426 self.tools.contains_key(toolname)
427 }
428
429 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
431 self.tools
432 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
433 }
434
435 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
437 self.tools.insert(tool.name(), ToolType::Simple(tool));
438 }
439
440 pub fn delete_tool(&mut self, tool_name: &str) {
441 let _ = self.tools.remove(tool_name);
442 }
443
444 pub fn add_tools(&mut self, toolset: ToolSet) {
446 self.tools.extend(toolset.tools);
447 }
448
449 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
450 self.tools.get(toolname)
451 }
452
453 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
454 let mut defs = Vec::new();
455 for tool in self.tools.values() {
456 let def = tool.definition(String::new()).await;
457 defs.push(def);
458 }
459 Ok(defs)
460 }
461
462 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
464 if let Some(tool) = self.tools.get(toolname) {
465 tracing::info!(target: "rig",
466 "Calling tool {toolname} with args:\n{}",
467 serde_json::to_string_pretty(&args).unwrap()
468 );
469 Ok(tool.call(args).await?)
470 } else {
471 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
472 }
473 }
474
475 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
477 let mut docs = Vec::new();
478 for tool in self.tools.values() {
479 match tool {
480 ToolType::Simple(tool) => {
481 docs.push(completion::Document {
482 id: tool.name(),
483 text: format!(
484 "\
485 Tool: {}\n\
486 Definition: \n\
487 {}\
488 ",
489 tool.name(),
490 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
491 ),
492 additional_props: HashMap::new(),
493 });
494 }
495 ToolType::Embedding(tool) => {
496 docs.push(completion::Document {
497 id: tool.name(),
498 text: format!(
499 "\
500 Tool: {}\n\
501 Definition: \n\
502 {}\
503 ",
504 tool.name(),
505 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
506 ),
507 additional_props: HashMap::new(),
508 });
509 }
510 }
511 }
512 Ok(docs)
513 }
514
515 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
519 self.tools
520 .values()
521 .filter_map(|tool_type| {
522 if let ToolType::Embedding(tool) = tool_type {
523 Some(ToolSchema::try_from(&**tool))
524 } else {
525 None
526 }
527 })
528 .collect::<Result<Vec<_>, _>>()
529 }
530}
531
532#[derive(Default)]
533pub struct ToolSetBuilder {
534 tools: Vec<ToolType>,
535}
536
537impl ToolSetBuilder {
538 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
539 self.tools.push(ToolType::Simple(Box::new(tool)));
540 self
541 }
542
543 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
544 self.tools.push(ToolType::Embedding(Box::new(tool)));
545 self
546 }
547
548 pub fn build(self) -> ToolSet {
549 ToolSet {
550 tools: self
551 .tools
552 .into_iter()
553 .map(|tool| (tool.name(), tool))
554 .collect(),
555 }
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use serde_json::json;
562
563 use super::*;
564
565 fn get_test_toolset() -> ToolSet {
566 let mut toolset = ToolSet::default();
567
568 #[derive(Deserialize)]
569 struct OperationArgs {
570 x: i32,
571 y: i32,
572 }
573
574 #[derive(Debug, thiserror::Error)]
575 #[error("Math error")]
576 struct MathError;
577
578 #[derive(Deserialize, Serialize)]
579 struct Adder;
580
581 impl Tool for Adder {
582 const NAME: &'static str = "add";
583 type Error = MathError;
584 type Args = OperationArgs;
585 type Output = i32;
586
587 async fn definition(&self, _prompt: String) -> ToolDefinition {
588 ToolDefinition {
589 name: "add".to_string(),
590 description: "Add x and y together".to_string(),
591 parameters: json!({
592 "type": "object",
593 "properties": {
594 "x": {
595 "type": "number",
596 "description": "The first number to add"
597 },
598 "y": {
599 "type": "number",
600 "description": "The second number to add"
601 }
602 },
603 "required": ["x", "y"]
604 }),
605 }
606 }
607
608 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
609 let result = args.x + args.y;
610 Ok(result)
611 }
612 }
613
614 #[derive(Deserialize, Serialize)]
615 struct Subtract;
616
617 impl Tool for Subtract {
618 const NAME: &'static str = "subtract";
619 type Error = MathError;
620 type Args = OperationArgs;
621 type Output = i32;
622
623 async fn definition(&self, _prompt: String) -> ToolDefinition {
624 serde_json::from_value(json!({
625 "name": "subtract",
626 "description": "Subtract y from x (i.e.: x - y)",
627 "parameters": {
628 "type": "object",
629 "properties": {
630 "x": {
631 "type": "number",
632 "description": "The number to subtract from"
633 },
634 "y": {
635 "type": "number",
636 "description": "The number to subtract"
637 }
638 },
639 "required": ["x", "y"]
640 }
641 }))
642 .expect("Tool Definition")
643 }
644
645 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
646 let result = args.x - args.y;
647 Ok(result)
648 }
649 }
650
651 toolset.add_tool(Adder);
652 toolset.add_tool(Subtract);
653 toolset
654 }
655
656 #[tokio::test]
657 async fn test_get_tool_definitions() {
658 let toolset = get_test_toolset();
659 let tools = toolset.get_tool_definitions().await.unwrap();
660 assert_eq!(tools.len(), 2);
661 }
662
663 #[test]
664 fn test_tool_deletion() {
665 let mut toolset = get_test_toolset();
666 assert_eq!(toolset.tools.len(), 2);
667 toolset.delete_tool("add");
668 assert!(!toolset.contains("add"));
669 assert_eq!(toolset.tools.len(), 1);
670 }
671}