1pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21 completion::{self, ToolDefinition},
22 embeddings::{embed::EmbedError, tool::ToolSchema},
23 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28 #[cfg(not(target_family = "wasm"))]
29 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
31
32 #[cfg(target_family = "wasm")]
33 ToolCallError(#[from] Box<dyn std::error::Error>),
35 JsonError(#[from] serde_json::Error),
37}
38
39impl fmt::Display for ToolError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 ToolError::ToolCallError(e) => {
43 let error_str = e.to_string();
44 if error_str.starts_with("ToolCallError: ") {
47 write!(f, "{}", error_str)
48 } else {
49 write!(f, "ToolCallError: {}", error_str)
50 }
51 }
52 ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
53 }
54 }
55}
56
57pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
113 const NAME: &'static str;
116
117 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
119 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
121 type Output: Serialize;
123
124 fn name(&self) -> String {
126 Self::NAME.to_string()
127 }
128
129 fn definition(
132 &self,
133 _prompt: String,
134 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
135
136 fn call(
140 &self,
141 args: Self::Args,
142 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
143}
144
145pub trait ToolEmbedding: Tool {
147 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
148
149 type Context: for<'a> Deserialize<'a> + Serialize;
154
155 type State: WasmCompatSend;
159
160 fn embedding_docs(&self) -> Vec<String>;
164
165 fn context(&self) -> Self::Context;
167
168 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
170}
171
172pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
174 fn name(&self) -> String;
175
176 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
177
178 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
179}
180
181fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
182 match serde_json::to_value(output)? {
183 serde_json::Value::String(text) => Ok(text),
184 value => Ok(value.to_string()),
185 }
186}
187
188impl<T: Tool> ToolDyn for T {
189 fn name(&self) -> String {
190 self.name()
191 }
192
193 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
194 Box::pin(<Self as Tool>::definition(self, prompt))
195 }
196
197 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
198 Box::pin(async move {
199 match serde_json::from_str(&args) {
200 Ok(args) => <Self as Tool>::call(self, args)
201 .await
202 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
203 .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
204 Err(e) => Err(ToolError::JsonError(e)),
205 }
206 })
207 }
208}
209
210#[cfg(feature = "rmcp")]
211#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
212pub mod rmcp;
213
214pub trait ToolEmbeddingDyn: ToolDyn {
216 fn context(&self) -> serde_json::Result<serde_json::Value>;
217
218 fn embedding_docs(&self) -> Vec<String>;
219}
220
221impl<T> ToolEmbeddingDyn for T
222where
223 T: ToolEmbedding + 'static,
224{
225 fn context(&self) -> serde_json::Result<serde_json::Value> {
226 serde_json::to_value(self.context())
227 }
228
229 fn embedding_docs(&self) -> Vec<String> {
230 self.embedding_docs()
231 }
232}
233
234#[derive(Clone)]
235pub(crate) enum ToolType {
236 Simple(Arc<dyn ToolDyn>),
237 Embedding(Arc<dyn ToolEmbeddingDyn>),
238}
239
240impl ToolType {
241 pub fn name(&self) -> String {
242 match self {
243 ToolType::Simple(tool) => tool.name(),
244 ToolType::Embedding(tool) => tool.name(),
245 }
246 }
247
248 pub async fn definition(&self, prompt: String) -> ToolDefinition {
249 match self {
250 ToolType::Simple(tool) => tool.definition(prompt).await,
251 ToolType::Embedding(tool) => tool.definition(prompt).await,
252 }
253 }
254
255 pub async fn call(&self, args: String) -> Result<String, ToolError> {
256 match self {
257 ToolType::Simple(tool) => tool.call(args).await,
258 ToolType::Embedding(tool) => tool.call(args).await,
259 }
260 }
261}
262
263#[derive(Debug, thiserror::Error)]
264pub enum ToolSetError {
265 #[error("ToolCallError: {0}")]
267 ToolCallError(#[from] ToolError),
268
269 #[error("ToolNotFoundError: {0}")]
271 ToolNotFoundError(String),
272
273 #[error("JsonError: {0}")]
275 JsonError(#[from] serde_json::Error),
276
277 #[error("Tool call interrupted")]
279 Interrupted,
280}
281
282#[derive(Default)]
284pub struct ToolSet {
285 pub(crate) tools: HashMap<String, ToolType>,
286}
287
288impl ToolSet {
289 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
291 let mut toolset = Self::default();
292 tools.into_iter().for_each(|tool| {
293 toolset.add_tool(tool);
294 });
295 toolset
296 }
297
298 pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
299 let mut toolset = Self::default();
300 tools.into_iter().for_each(|tool| {
301 toolset.add_tool_boxed(tool);
302 });
303 toolset
304 }
305
306 pub fn builder() -> ToolSetBuilder {
308 ToolSetBuilder::default()
309 }
310
311 pub fn contains(&self, toolname: &str) -> bool {
313 self.tools.contains_key(toolname)
314 }
315
316 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
318 self.tools
319 .insert(tool.name(), ToolType::Simple(Arc::new(tool)));
320 }
321
322 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
324 self.tools
325 .insert(tool.name(), ToolType::Simple(Arc::from(tool)));
326 }
327
328 pub fn delete_tool(&mut self, tool_name: &str) {
329 let _ = self.tools.remove(tool_name);
330 }
331
332 pub fn add_tools(&mut self, toolset: ToolSet) {
334 self.tools.extend(toolset.tools);
335 }
336
337 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
338 self.tools.get(toolname)
339 }
340
341 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
342 let mut defs = Vec::new();
343 for tool in self.tools.values() {
344 let def = tool.definition(String::new()).await;
345 defs.push(def);
346 }
347 Ok(defs)
348 }
349
350 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
352 if let Some(tool) = self.tools.get(toolname) {
353 tracing::debug!(target: "rig",
354 "Calling tool {toolname} with args:\n{}",
355 args
356 );
357 Ok(tool.call(args).await?)
358 } else {
359 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
360 }
361 }
362
363 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
365 let mut docs = Vec::new();
366 for tool in self.tools.values() {
367 match tool {
368 ToolType::Simple(tool) => {
369 docs.push(completion::Document {
370 id: tool.name(),
371 text: format!(
372 "\
373 Tool: {}\n\
374 Definition: \n\
375 {}\
376 ",
377 tool.name(),
378 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
379 ),
380 additional_props: HashMap::new(),
381 });
382 }
383 ToolType::Embedding(tool) => {
384 docs.push(completion::Document {
385 id: tool.name(),
386 text: format!(
387 "\
388 Tool: {}\n\
389 Definition: \n\
390 {}\
391 ",
392 tool.name(),
393 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
394 ),
395 additional_props: HashMap::new(),
396 });
397 }
398 }
399 }
400 Ok(docs)
401 }
402
403 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
407 self.tools
408 .values()
409 .filter_map(|tool_type| {
410 if let ToolType::Embedding(tool) = tool_type {
411 Some(ToolSchema::try_from(&**tool))
412 } else {
413 None
414 }
415 })
416 .collect::<Result<Vec<_>, _>>()
417 }
418}
419
420#[derive(Default)]
421pub struct ToolSetBuilder {
422 tools: Vec<ToolType>,
423}
424
425impl ToolSetBuilder {
426 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
427 self.tools.push(ToolType::Simple(Arc::new(tool)));
428 self
429 }
430
431 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
432 self.tools.push(ToolType::Embedding(Arc::new(tool)));
433 self
434 }
435
436 pub fn build(self) -> ToolSet {
437 ToolSet {
438 tools: self
439 .tools
440 .into_iter()
441 .map(|tool| (tool.name(), tool))
442 .collect(),
443 }
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use crate::message::{DocumentSourceKind, ToolResultContent};
450 use serde_json::json;
451
452 use super::*;
453
454 fn get_test_toolset() -> ToolSet {
455 let mut toolset = ToolSet::default();
456
457 #[derive(Deserialize)]
458 struct OperationArgs {
459 x: i32,
460 y: i32,
461 }
462
463 #[derive(Debug, thiserror::Error)]
464 #[error("Math error")]
465 struct MathError;
466
467 #[derive(Deserialize, Serialize)]
468 struct Adder;
469
470 impl Tool for Adder {
471 const NAME: &'static str = "add";
472 type Error = MathError;
473 type Args = OperationArgs;
474 type Output = i32;
475
476 async fn definition(&self, _prompt: String) -> ToolDefinition {
477 ToolDefinition {
478 name: "add".to_string(),
479 description: "Add x and y together".to_string(),
480 parameters: json!({
481 "type": "object",
482 "properties": {
483 "x": {
484 "type": "number",
485 "description": "The first number to add"
486 },
487 "y": {
488 "type": "number",
489 "description": "The second number to add"
490 }
491 },
492 "required": ["x", "y"]
493 }),
494 }
495 }
496
497 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
498 let result = args.x + args.y;
499 Ok(result)
500 }
501 }
502
503 #[derive(Deserialize, Serialize)]
504 struct Subtract;
505
506 impl Tool for Subtract {
507 const NAME: &'static str = "subtract";
508 type Error = MathError;
509 type Args = OperationArgs;
510 type Output = i32;
511
512 async fn definition(&self, _prompt: String) -> ToolDefinition {
513 serde_json::from_value(json!({
514 "name": "subtract",
515 "description": "Subtract y from x (i.e.: x - y)",
516 "parameters": {
517 "type": "object",
518 "properties": {
519 "x": {
520 "type": "number",
521 "description": "The number to subtract from"
522 },
523 "y": {
524 "type": "number",
525 "description": "The number to subtract"
526 }
527 },
528 "required": ["x", "y"]
529 }
530 }))
531 .expect("Tool Definition")
532 }
533
534 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
535 let result = args.x - args.y;
536 Ok(result)
537 }
538 }
539
540 toolset.add_tool(Adder);
541 toolset.add_tool(Subtract);
542 toolset
543 }
544
545 #[tokio::test]
546 async fn test_get_tool_definitions() {
547 let toolset = get_test_toolset();
548 let tools = toolset.get_tool_definitions().await.unwrap();
549 assert_eq!(tools.len(), 2);
550 }
551
552 #[test]
553 fn test_tool_deletion() {
554 let mut toolset = get_test_toolset();
555 assert_eq!(toolset.tools.len(), 2);
556 toolset.delete_tool("add");
557 assert!(!toolset.contains("add"));
558 assert_eq!(toolset.tools.len(), 1);
559 }
560
561 #[derive(Debug, thiserror::Error)]
562 #[error("Test tool error")]
563 struct TestToolError;
564
565 #[derive(Deserialize, Serialize)]
566 struct StringOutputTool;
567
568 impl Tool for StringOutputTool {
569 const NAME: &'static str = "string_output";
570 type Error = TestToolError;
571 type Args = serde_json::Value;
572 type Output = String;
573
574 async fn definition(&self, _prompt: String) -> ToolDefinition {
575 ToolDefinition {
576 name: Self::NAME.to_string(),
577 description: "Returns a multiline string".to_string(),
578 parameters: json!({
579 "type": "object",
580 "properties": {}
581 }),
582 }
583 }
584
585 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
586 Ok("Hello\nWorld".to_string())
587 }
588 }
589
590 #[tokio::test]
591 async fn string_tool_outputs_are_preserved_verbatim() {
592 let mut toolset = ToolSet::default();
593 toolset.add_tool(StringOutputTool);
594
595 let output = toolset
596 .call("string_output", "{}".to_string())
597 .await
598 .expect("tool should succeed");
599
600 assert_eq!(output, "Hello\nWorld");
601 }
602
603 #[derive(Deserialize, Serialize)]
604 struct ImageOutputTool;
605
606 impl Tool for ImageOutputTool {
607 const NAME: &'static str = "image_output";
608 type Error = TestToolError;
609 type Args = serde_json::Value;
610 type Output = String;
611
612 async fn definition(&self, _prompt: String) -> ToolDefinition {
613 ToolDefinition {
614 name: Self::NAME.to_string(),
615 description: "Returns image JSON".to_string(),
616 parameters: json!({
617 "type": "object",
618 "properties": {}
619 }),
620 }
621 }
622
623 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
624 Ok(json!({
625 "type": "image",
626 "data": "base64data==",
627 "mimeType": "image/png"
628 })
629 .to_string())
630 }
631 }
632
633 #[tokio::test]
634 async fn structured_string_tool_outputs_remain_parseable() {
635 let mut toolset = ToolSet::default();
636 toolset.add_tool(ImageOutputTool);
637
638 let output = toolset
639 .call("image_output", "{}".to_string())
640 .await
641 .expect("tool should succeed");
642 let content = ToolResultContent::from_tool_output(output);
643
644 assert_eq!(content.len(), 1);
645 match content.first() {
646 ToolResultContent::Image(image) => {
647 assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
648 assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
649 }
650 other => panic!("expected image tool result content, got {other:?}"),
651 }
652 }
653
654 #[derive(Deserialize, Serialize)]
655 struct ObjectOutputTool;
656
657 impl Tool for ObjectOutputTool {
658 const NAME: &'static str = "object_output";
659 type Error = TestToolError;
660 type Args = serde_json::Value;
661 type Output = serde_json::Value;
662
663 async fn definition(&self, _prompt: String) -> ToolDefinition {
664 ToolDefinition {
665 name: Self::NAME.to_string(),
666 description: "Returns an object".to_string(),
667 parameters: json!({
668 "type": "object",
669 "properties": {}
670 }),
671 }
672 }
673
674 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
675 Ok(json!({
676 "status": "ok",
677 "count": 42
678 }))
679 }
680 }
681
682 #[tokio::test]
683 async fn object_tool_outputs_still_serialize_as_json() {
684 let mut toolset = ToolSet::default();
685 toolset.add_tool(ObjectOutputTool);
686
687 let output = toolset
688 .call("object_output", "{}".to_string())
689 .await
690 .expect("tool should succeed");
691
692 assert!(output.starts_with('{'));
693 assert_eq!(
694 serde_json::from_str::<serde_json::Value>(&output).unwrap(),
695 json!({
696 "status": "ok",
697 "count": 42
698 })
699 );
700 }
701}