1pub mod server;
13use std::collections::HashMap;
14
15use futures::Future;
16use serde::{Deserialize, Serialize};
17
18use crate::{
19 completion::{self, ToolDefinition},
20 embeddings::{embed::EmbedError, tool::ToolSchema},
21 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
22};
23
24#[derive(Debug, thiserror::Error)]
25pub enum ToolError {
26 #[cfg(not(target_family = "wasm"))]
27 #[error("ToolCallError: {0}")]
29 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
30
31 #[cfg(target_family = "wasm")]
32 #[error("ToolCallError: {0}")]
34 ToolCallError(#[from] Box<dyn std::error::Error>),
35
36 #[error("JsonError: {0}")]
37 JsonError(#[from] serde_json::Error),
38}
39
40pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
96 const NAME: &'static str;
98
99 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
101 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
103 type Output: Serialize;
105
106 fn name(&self) -> String {
108 Self::NAME.to_string()
109 }
110
111 fn definition(
114 &self,
115 _prompt: String,
116 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
117
118 fn call(
122 &self,
123 args: Self::Args,
124 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
125}
126
127pub trait ToolEmbedding: Tool {
129 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
130
131 type Context: for<'a> Deserialize<'a> + Serialize;
136
137 type State: WasmCompatSend;
141
142 fn embedding_docs(&self) -> Vec<String>;
146
147 fn context(&self) -> Self::Context;
149
150 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
152}
153
154pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
156 fn name(&self) -> String;
157
158 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
159
160 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
161}
162
163impl<T: Tool> ToolDyn for T {
164 fn name(&self) -> String {
165 self.name()
166 }
167
168 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
169 Box::pin(<Self as Tool>::definition(self, prompt))
170 }
171
172 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
173 Box::pin(async move {
174 match serde_json::from_str(&args) {
175 Ok(args) => <Self as Tool>::call(self, args)
176 .await
177 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
178 .and_then(|output| {
179 serde_json::to_string(&output).map_err(ToolError::JsonError)
180 }),
181 Err(e) => Err(ToolError::JsonError(e)),
182 }
183 })
184 }
185}
186
187#[cfg(feature = "rmcp")]
188#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
189pub mod rmcp {
190 use crate::completion::ToolDefinition;
191 use crate::tool::ToolDyn;
192 use crate::tool::ToolError;
193 use crate::wasm_compat::WasmBoxedFuture;
194 use rmcp::model::RawContent;
195 use std::borrow::Cow;
196
197 #[derive(Clone)]
198 pub struct McpTool {
199 definition: rmcp::model::Tool,
200 client: rmcp::service::ServerSink,
201 }
202
203 impl McpTool {
204 pub fn from_mcp_server(
205 definition: rmcp::model::Tool,
206 client: rmcp::service::ServerSink,
207 ) -> Self {
208 Self { definition, client }
209 }
210 }
211
212 impl From<&rmcp::model::Tool> for ToolDefinition {
213 fn from(val: &rmcp::model::Tool) -> Self {
214 Self {
215 name: val.name.to_string(),
216 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
217 parameters: val.schema_as_json_value(),
218 }
219 }
220 }
221
222 impl From<rmcp::model::Tool> for ToolDefinition {
223 fn from(val: rmcp::model::Tool) -> Self {
224 Self {
225 name: val.name.to_string(),
226 description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
227 parameters: val.schema_as_json_value(),
228 }
229 }
230 }
231
232 #[derive(Debug, thiserror::Error)]
233 #[error("MCP tool error: {0}")]
234 pub struct McpToolError(String);
235
236 impl From<McpToolError> for ToolError {
237 fn from(e: McpToolError) -> Self {
238 ToolError::ToolCallError(Box::new(e))
239 }
240 }
241
242 impl ToolDyn for McpTool {
243 fn name(&self) -> String {
244 self.definition.name.to_string()
245 }
246
247 fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
248 Box::pin(async move {
249 ToolDefinition {
250 name: self.definition.name.to_string(),
251 description: self
252 .definition
253 .description
254 .clone()
255 .unwrap_or(Cow::from(""))
256 .to_string(),
257 parameters: serde_json::to_value(&self.definition.input_schema)
258 .unwrap_or_default(),
259 }
260 })
261 }
262
263 fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
264 let name = self.definition.name.clone();
265 let arguments = serde_json::from_str(&args).unwrap_or_default();
266
267 Box::pin(async move {
268 let result = self
269 .client
270 .call_tool(rmcp::model::CallToolRequestParam { name, arguments })
271 .await
272 .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
273
274 if let Some(true) = result.is_error {
275 let error_msg = result
276 .content
277 .into_iter()
278 .map(|x| x.raw.as_text().map(|y| y.to_owned()))
279 .map(|x| x.map(|x| x.clone().text))
280 .collect::<Option<Vec<String>>>();
281
282 let error_message = error_msg.map(|x| x.join("\n"));
283 if let Some(error_message) = error_message {
284 return Err(McpToolError(error_message).into());
285 } else {
286 return Err(McpToolError("No message returned".to_string()).into());
287 }
288 };
289
290 Ok(result
291 .content
292 .into_iter()
293 .map(|c| match c.raw {
294 rmcp::model::RawContent::Text(raw) => raw.text,
295 rmcp::model::RawContent::Image(raw) => {
296 format!("data:{};base64,{}", raw.mime_type, raw.data)
297 }
298 rmcp::model::RawContent::Resource(raw) => match raw.resource {
299 rmcp::model::ResourceContents::TextResourceContents {
300 uri,
301 mime_type,
302 text,
303 ..
304 } => {
305 format!(
306 "{mime_type}{uri}:{text}",
307 mime_type = mime_type
308 .map(|m| format!("data:{m};"))
309 .unwrap_or_default(),
310 )
311 }
312 rmcp::model::ResourceContents::BlobResourceContents {
313 uri,
314 mime_type,
315 blob,
316 ..
317 } => format!(
318 "{mime_type}{uri}:{blob}",
319 mime_type = mime_type
320 .map(|m| format!("data:{m};"))
321 .unwrap_or_default(),
322 ),
323 },
324 RawContent::Audio(_) => {
325 unimplemented!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
326 }
327 thing => {
328 unimplemented!("Unsupported type found: {thing:?}")
329 }
330 })
331 .collect::<String>())
332 })
333 }
334 }
335}
336
337pub trait ToolEmbeddingDyn: ToolDyn {
339 fn context(&self) -> serde_json::Result<serde_json::Value>;
340
341 fn embedding_docs(&self) -> Vec<String>;
342}
343
344impl<T> ToolEmbeddingDyn for T
345where
346 T: ToolEmbedding + 'static,
347{
348 fn context(&self) -> serde_json::Result<serde_json::Value> {
349 serde_json::to_value(self.context())
350 }
351
352 fn embedding_docs(&self) -> Vec<String> {
353 self.embedding_docs()
354 }
355}
356
357pub(crate) enum ToolType {
358 Simple(Box<dyn ToolDyn>),
359 Embedding(Box<dyn ToolEmbeddingDyn>),
360}
361
362impl ToolType {
363 pub fn name(&self) -> String {
364 match self {
365 ToolType::Simple(tool) => tool.name(),
366 ToolType::Embedding(tool) => tool.name(),
367 }
368 }
369
370 pub async fn definition(&self, prompt: String) -> ToolDefinition {
371 match self {
372 ToolType::Simple(tool) => tool.definition(prompt).await,
373 ToolType::Embedding(tool) => tool.definition(prompt).await,
374 }
375 }
376
377 pub async fn call(&self, args: String) -> Result<String, ToolError> {
378 match self {
379 ToolType::Simple(tool) => tool.call(args).await,
380 ToolType::Embedding(tool) => tool.call(args).await,
381 }
382 }
383}
384
385#[derive(Debug, thiserror::Error)]
386pub enum ToolSetError {
387 #[error("ToolCallError: {0}")]
389 ToolCallError(#[from] ToolError),
390
391 #[error("ToolNotFoundError: {0}")]
393 ToolNotFoundError(String),
394
395 #[error("JsonError: {0}")]
397 JsonError(#[from] serde_json::Error),
398
399 #[error("Tool call interrupted")]
401 Interrupted,
402}
403
404#[derive(Default)]
406pub struct ToolSet {
407 pub(crate) tools: HashMap<String, ToolType>,
408}
409
410impl ToolSet {
411 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
413 let mut toolset = Self::default();
414 tools.into_iter().for_each(|tool| {
415 toolset.add_tool(tool);
416 });
417 toolset
418 }
419
420 pub fn builder() -> ToolSetBuilder {
422 ToolSetBuilder::default()
423 }
424
425 pub fn contains(&self, toolname: &str) -> bool {
427 self.tools.contains_key(toolname)
428 }
429
430 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
432 self.tools
433 .insert(tool.name(), ToolType::Simple(Box::new(tool)));
434 }
435
436 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
438 self.tools.insert(tool.name(), ToolType::Simple(tool));
439 }
440
441 pub fn delete_tool(&mut self, tool_name: &str) {
442 let _ = self.tools.remove(tool_name);
443 }
444
445 pub fn add_tools(&mut self, toolset: ToolSet) {
447 self.tools.extend(toolset.tools);
448 }
449
450 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
451 self.tools.get(toolname)
452 }
453
454 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
455 let mut defs = Vec::new();
456 for tool in self.tools.values() {
457 let def = tool.definition(String::new()).await;
458 defs.push(def);
459 }
460 Ok(defs)
461 }
462
463 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
465 if let Some(tool) = self.tools.get(toolname) {
466 tracing::debug!(target: "rig",
467 "Calling tool {toolname} with args:\n{}",
468 serde_json::to_string_pretty(&args).unwrap()
469 );
470 Ok(tool.call(args).await?)
471 } else {
472 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
473 }
474 }
475
476 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
478 let mut docs = Vec::new();
479 for tool in self.tools.values() {
480 match tool {
481 ToolType::Simple(tool) => {
482 docs.push(completion::Document {
483 id: tool.name(),
484 text: format!(
485 "\
486 Tool: {}\n\
487 Definition: \n\
488 {}\
489 ",
490 tool.name(),
491 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
492 ),
493 additional_props: HashMap::new(),
494 });
495 }
496 ToolType::Embedding(tool) => {
497 docs.push(completion::Document {
498 id: tool.name(),
499 text: format!(
500 "\
501 Tool: {}\n\
502 Definition: \n\
503 {}\
504 ",
505 tool.name(),
506 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
507 ),
508 additional_props: HashMap::new(),
509 });
510 }
511 }
512 }
513 Ok(docs)
514 }
515
516 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
520 self.tools
521 .values()
522 .filter_map(|tool_type| {
523 if let ToolType::Embedding(tool) = tool_type {
524 Some(ToolSchema::try_from(&**tool))
525 } else {
526 None
527 }
528 })
529 .collect::<Result<Vec<_>, _>>()
530 }
531}
532
533#[derive(Default)]
534pub struct ToolSetBuilder {
535 tools: Vec<ToolType>,
536}
537
538impl ToolSetBuilder {
539 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
540 self.tools.push(ToolType::Simple(Box::new(tool)));
541 self
542 }
543
544 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
545 self.tools.push(ToolType::Embedding(Box::new(tool)));
546 self
547 }
548
549 pub fn build(self) -> ToolSet {
550 ToolSet {
551 tools: self
552 .tools
553 .into_iter()
554 .map(|tool| (tool.name(), tool))
555 .collect(),
556 }
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use serde_json::json;
563
564 use super::*;
565
566 fn get_test_toolset() -> ToolSet {
567 let mut toolset = ToolSet::default();
568
569 #[derive(Deserialize)]
570 struct OperationArgs {
571 x: i32,
572 y: i32,
573 }
574
575 #[derive(Debug, thiserror::Error)]
576 #[error("Math error")]
577 struct MathError;
578
579 #[derive(Deserialize, Serialize)]
580 struct Adder;
581
582 impl Tool for Adder {
583 const NAME: &'static str = "add";
584 type Error = MathError;
585 type Args = OperationArgs;
586 type Output = i32;
587
588 async fn definition(&self, _prompt: String) -> ToolDefinition {
589 ToolDefinition {
590 name: "add".to_string(),
591 description: "Add x and y together".to_string(),
592 parameters: json!({
593 "type": "object",
594 "properties": {
595 "x": {
596 "type": "number",
597 "description": "The first number to add"
598 },
599 "y": {
600 "type": "number",
601 "description": "The second number to add"
602 }
603 },
604 "required": ["x", "y"]
605 }),
606 }
607 }
608
609 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
610 let result = args.x + args.y;
611 Ok(result)
612 }
613 }
614
615 #[derive(Deserialize, Serialize)]
616 struct Subtract;
617
618 impl Tool for Subtract {
619 const NAME: &'static str = "subtract";
620 type Error = MathError;
621 type Args = OperationArgs;
622 type Output = i32;
623
624 async fn definition(&self, _prompt: String) -> ToolDefinition {
625 serde_json::from_value(json!({
626 "name": "subtract",
627 "description": "Subtract y from x (i.e.: x - y)",
628 "parameters": {
629 "type": "object",
630 "properties": {
631 "x": {
632 "type": "number",
633 "description": "The number to subtract from"
634 },
635 "y": {
636 "type": "number",
637 "description": "The number to subtract"
638 }
639 },
640 "required": ["x", "y"]
641 }
642 }))
643 .expect("Tool Definition")
644 }
645
646 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
647 let result = args.x - args.y;
648 Ok(result)
649 }
650 }
651
652 toolset.add_tool(Adder);
653 toolset.add_tool(Subtract);
654 toolset
655 }
656
657 #[tokio::test]
658 async fn test_get_tool_definitions() {
659 let toolset = get_test_toolset();
660 let tools = toolset.get_tool_definitions().await.unwrap();
661 assert_eq!(tools.len(), 2);
662 }
663
664 #[test]
665 fn test_tool_deletion() {
666 let mut toolset = get_test_toolset();
667 assert_eq!(toolset.tools.len(), 2);
668 toolset.delete_tool("add");
669 assert!(!toolset.contains("add"));
670 assert_eq!(toolset.tools.len(), 1);
671 }
672}