1pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use indexmap::IndexMap;
19use serde::{Deserialize, Serialize};
20
21use crate::{
22 completion::{self, ToolDefinition},
23 embeddings::{embed::EmbedError, tool::ToolSchema},
24 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
25};
26
27#[derive(Debug, thiserror::Error)]
28pub enum ToolError {
29 #[cfg(not(target_family = "wasm"))]
30 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
32
33 #[cfg(target_family = "wasm")]
34 ToolCallError(#[from] Box<dyn std::error::Error>),
36 JsonError(#[from] serde_json::Error),
38}
39
40impl fmt::Display for ToolError {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 ToolError::ToolCallError(e) => {
44 let error_str = e.to_string();
45 if error_str.starts_with("ToolCallError: ") {
48 write!(f, "{}", error_str)
49 } else {
50 write!(f, "ToolCallError: {}", error_str)
51 }
52 }
53 ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
54 }
55 }
56}
57
58pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
114 const NAME: &'static str;
117
118 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
120 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
122 type Output: Serialize;
124
125 fn name(&self) -> String {
127 Self::NAME.to_string()
128 }
129
130 fn definition(
133 &self,
134 _prompt: String,
135 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
136
137 fn call(
141 &self,
142 args: Self::Args,
143 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
144}
145
146pub trait ToolEmbedding: Tool {
148 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
150
151 type Context: for<'a> Deserialize<'a> + Serialize;
156
157 type State: WasmCompatSend;
161
162 fn embedding_docs(&self) -> Vec<String>;
166
167 fn context(&self) -> Self::Context;
169
170 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
172}
173
174pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
176 fn name(&self) -> String;
178
179 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
181
182 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
184}
185
186fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
187 match serde_json::to_value(output)? {
188 serde_json::Value::String(text) => Ok(text),
189 value => Ok(value.to_string()),
190 }
191}
192
193impl<T: Tool> ToolDyn for T {
194 fn name(&self) -> String {
195 self.name()
196 }
197
198 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
199 Box::pin(<Self as Tool>::definition(self, prompt))
200 }
201
202 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
203 Box::pin(async move {
204 let args = match serde_json::from_str(&args) {
211 Ok(args) => Ok(args),
212 Err(err) if args.trim() == "null" => serde_json::from_str("{}").map_err(|_| err),
213 Err(err) => Err(err),
214 };
215 match args {
216 Ok(args) => <Self as Tool>::call(self, args)
217 .await
218 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
219 .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
220 Err(e) => Err(ToolError::JsonError(e)),
221 }
222 })
223 }
224}
225
226#[cfg(feature = "rmcp")]
227#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
228pub mod rmcp;
229
230pub trait ToolEmbeddingDyn: ToolDyn {
232 fn context(&self) -> serde_json::Result<serde_json::Value>;
234
235 fn embedding_docs(&self) -> Vec<String>;
237}
238
239impl<T> ToolEmbeddingDyn for T
240where
241 T: ToolEmbedding + 'static,
242{
243 fn context(&self) -> serde_json::Result<serde_json::Value> {
244 serde_json::to_value(self.context())
245 }
246
247 fn embedding_docs(&self) -> Vec<String> {
248 self.embedding_docs()
249 }
250}
251
252#[derive(Clone)]
253pub(crate) enum ToolType {
254 Simple(Arc<dyn ToolDyn>),
255 Embedding(Arc<dyn ToolEmbeddingDyn>),
256}
257
258impl ToolType {
259 pub fn name(&self) -> String {
260 match self {
261 ToolType::Simple(tool) => tool.name(),
262 ToolType::Embedding(tool) => tool.name(),
263 }
264 }
265
266 pub async fn definition(&self, prompt: String) -> ToolDefinition {
267 match self {
268 ToolType::Simple(tool) => tool.definition(prompt).await,
269 ToolType::Embedding(tool) => tool.definition(prompt).await,
270 }
271 }
272
273 pub async fn call(&self, args: String) -> Result<String, ToolError> {
274 match self {
275 ToolType::Simple(tool) => tool.call(args).await,
276 ToolType::Embedding(tool) => tool.call(args).await,
277 }
278 }
279}
280
281#[derive(Debug, thiserror::Error)]
282pub enum ToolSetError {
283 #[error("ToolCallError: {0}")]
285 ToolCallError(#[from] ToolError),
286
287 #[error("ToolNotFoundError: {0}")]
289 ToolNotFoundError(String),
290
291 #[error("JsonError: {0}")]
293 JsonError(#[from] serde_json::Error),
294
295 #[error("Tool call interrupted")]
297 Interrupted,
298}
299
300#[derive(Default)]
307pub struct ToolSet {
308 pub(crate) tools: IndexMap<String, ToolType>,
309}
310
311impl ToolSet {
312 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
314 let mut toolset = Self::default();
315 tools.into_iter().for_each(|tool| {
316 toolset.add_tool(tool);
317 });
318 toolset
319 }
320
321 pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
323 let mut toolset = Self::default();
324 tools.into_iter().for_each(|tool| {
325 toolset.add_tool_boxed(tool);
326 });
327 toolset
328 }
329
330 pub fn builder() -> ToolSetBuilder {
332 ToolSetBuilder::default()
333 }
334
335 pub fn contains(&self, toolname: &str) -> bool {
337 self.tools.contains_key(toolname)
338 }
339
340 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
342 self.insert(ToolType::Simple(Arc::new(tool)));
343 }
344
345 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
347 self.insert(ToolType::Simple(Arc::from(tool)));
348 }
349
350 pub(crate) fn insert(&mut self, tool: ToolType) {
351 let name = tool.name();
352 if self.tools.insert(name.clone(), tool).is_some() {
356 tracing::warn!(
357 tool_name = %name,
358 "a tool named {name:?} was already registered; replacing it with the new registration"
359 );
360 }
361 }
362
363 pub fn delete_tool(&mut self, tool_name: &str) {
365 self.tools.shift_remove(tool_name);
368 }
369
370 pub fn add_tools(&mut self, toolset: ToolSet) {
373 for (_, tool) in toolset.tools {
374 self.insert(tool);
375 }
376 }
377
378 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
379 self.tools.get(toolname)
380 }
381
382 pub(crate) fn ordered_names(&self) -> impl Iterator<Item = &String> {
384 self.tools.keys()
385 }
386
387 fn ordered_tools(&self) -> impl Iterator<Item = &ToolType> {
389 self.tools.values()
390 }
391
392 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
395 let mut defs = Vec::new();
396 for tool in self.ordered_tools() {
397 let def = tool.definition(String::new()).await;
398 defs.push(def);
399 }
400 Ok(defs)
401 }
402
403 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
405 if let Some(tool) = self.tools.get(toolname) {
406 tracing::debug!(target: "rig",
407 "Calling tool {toolname} with args:\n{}",
408 args
409 );
410 Ok(tool.call(args).await?)
411 } else {
412 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
413 }
414 }
415
416 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
418 let mut docs = Vec::new();
419 for tool in self.ordered_tools() {
420 match tool {
421 ToolType::Simple(tool) => {
422 docs.push(completion::Document {
423 id: tool.name(),
424 text: format!(
425 "\
426 Tool: {}\n\
427 Definition: \n\
428 {}\
429 ",
430 tool.name(),
431 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
432 ),
433 additional_props: HashMap::new(),
434 });
435 }
436 ToolType::Embedding(tool) => {
437 docs.push(completion::Document {
438 id: tool.name(),
439 text: format!(
440 "\
441 Tool: {}\n\
442 Definition: \n\
443 {}\
444 ",
445 tool.name(),
446 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
447 ),
448 additional_props: HashMap::new(),
449 });
450 }
451 }
452 }
453 Ok(docs)
454 }
455
456 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
460 self.ordered_tools()
461 .filter_map(|tool_type| {
462 if let ToolType::Embedding(tool) = tool_type {
463 Some(ToolSchema::try_from(&**tool))
464 } else {
465 None
466 }
467 })
468 .collect::<Result<Vec<_>, _>>()
469 }
470}
471
472#[derive(Default)]
473pub struct ToolSetBuilder {
475 tools: Vec<ToolType>,
476}
477
478impl ToolSetBuilder {
479 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
481 self.tools.push(ToolType::Simple(Arc::new(tool)));
482 self
483 }
484
485 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
487 self.tools.push(ToolType::Embedding(Arc::new(tool)));
488 self
489 }
490
491 pub fn build(self) -> ToolSet {
493 let mut toolset = ToolSet::default();
494 for tool in self.tools {
495 toolset.insert(tool);
496 }
497 toolset
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use crate::message::{DocumentSourceKind, ToolResultContent};
504 use crate::test_utils::{
505 MockExampleTool, MockImageOutputTool, MockObjectOutputTool, MockStringOutputTool,
506 mock_math_toolset,
507 };
508 use serde_json::json;
509
510 use super::*;
511
512 fn get_test_toolset() -> ToolSet {
513 mock_math_toolset()
514 }
515
516 #[tokio::test]
517 async fn test_get_tool_definitions() {
518 let toolset = get_test_toolset();
519 let tools = toolset.get_tool_definitions().await.unwrap();
520 assert_eq!(tools.len(), 2);
521 }
522
523 #[test]
524 fn test_tool_deletion() {
525 let mut toolset = get_test_toolset();
526 assert_eq!(toolset.tools.len(), 2);
527 toolset.delete_tool("add");
528 assert!(!toolset.contains("add"));
529 assert_eq!(toolset.tools.len(), 1);
530 assert_eq!(
531 toolset.ordered_names().cloned().collect::<Vec<_>>(),
532 vec!["subtract".to_string()]
533 );
534 }
535
536 #[test]
537 fn deleting_a_middle_tool_preserves_order_of_survivors() {
538 let mut toolset = ToolSet::default();
543 for name in ["alpha", "beta", "gamma", "delta"] {
544 toolset.add_tool(named_tool(name, "test tool"));
545 }
546
547 toolset.delete_tool("beta");
548
549 assert_eq!(
550 toolset.ordered_names().cloned().collect::<Vec<_>>(),
551 vec![
552 "alpha".to_string(),
553 "gamma".to_string(),
554 "delta".to_string()
555 ],
556 "survivors must keep their registration order after a middle deletion"
557 );
558 }
559
560 struct NamedTool {
563 name: String,
564 description: String,
565 }
566
567 impl ToolDyn for NamedTool {
568 fn name(&self) -> String {
569 self.name.clone()
570 }
571
572 fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
573 Box::pin(async move {
574 ToolDefinition {
575 name: self.name.clone(),
576 description: self.description.clone(),
577 parameters: json!({ "type": "object", "properties": {} }),
578 }
579 })
580 }
581
582 fn call(&self, _args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
583 let output = format!("called {}", self.description);
584 Box::pin(async move { Ok(output) })
585 }
586 }
587
588 fn named_tool(name: &str, description: &str) -> NamedTool {
589 NamedTool {
590 name: name.to_string(),
591 description: description.to_string(),
592 }
593 }
594
595 #[tokio::test]
596 async fn tool_definitions_follow_registration_order() {
597 let names: Vec<String> = (0..32).map(|i| format!("tool_{i:02}")).collect();
601 let mut toolset = ToolSet::default();
602 for name in &names {
603 toolset.add_tool(named_tool(name, "test tool"));
604 }
605
606 let defs = toolset.get_tool_definitions().await.unwrap();
607 let def_names: Vec<String> = defs.into_iter().map(|def| def.name).collect();
608 assert_eq!(def_names, names);
609
610 let docs = toolset.documents().await.unwrap();
611 let doc_ids: Vec<String> = docs.into_iter().map(|doc| doc.id).collect();
612 assert_eq!(doc_ids, names);
613 }
614
615 #[tokio::test]
616 async fn duplicate_registration_replaces_in_place() {
617 let mut toolset = ToolSet::default();
618 toolset.add_tool(named_tool("alpha", "first alpha"));
619 toolset.add_tool(named_tool("beta", "beta"));
620 toolset.add_tool(named_tool("alpha", "second alpha"));
621
622 let defs = toolset.get_tool_definitions().await.unwrap();
623 assert_eq!(
624 defs.iter().map(|def| def.name.as_str()).collect::<Vec<_>>(),
625 vec!["alpha", "beta"],
626 "the duplicate should be deduped and keep its original position"
627 );
628 assert_eq!(
629 defs[0].description, "second alpha",
630 "the last registration should win"
631 );
632
633 let output = toolset.call("alpha", "{}".to_string()).await.unwrap();
634 assert_eq!(output, "called second alpha");
635 }
636
637 #[tokio::test]
638 async fn add_tools_merges_in_order_and_replaces_existing() {
639 let mut base = ToolSet::default();
640 base.add_tool(named_tool("alpha", "base alpha"));
641 base.add_tool(named_tool("beta", "base beta"));
642
643 let mut incoming = ToolSet::default();
644 incoming.add_tool(named_tool("gamma", "incoming gamma"));
645 incoming.add_tool(named_tool("alpha", "incoming alpha"));
646
647 base.add_tools(incoming);
648
649 let defs = base.get_tool_definitions().await.unwrap();
650 assert_eq!(
651 defs.iter().map(|def| def.name.as_str()).collect::<Vec<_>>(),
652 vec!["alpha", "beta", "gamma"],
653 "merged tools should follow registration order with replaced names keeping position"
654 );
655 assert_eq!(defs[0].description, "incoming alpha");
656 }
657
658 #[tokio::test]
659 async fn string_tool_outputs_are_preserved_verbatim() {
660 let mut toolset = ToolSet::default();
661 toolset.add_tool(MockStringOutputTool);
662
663 let output = toolset
664 .call("string_output", "{}".to_string())
665 .await
666 .expect("tool should succeed");
667
668 assert_eq!(output, "Hello\nWorld");
669 }
670
671 #[tokio::test]
672 async fn structured_string_tool_outputs_remain_parseable() {
673 let mut toolset = ToolSet::default();
674 toolset.add_tool(MockImageOutputTool);
675
676 let output = toolset
677 .call("image_output", "{}".to_string())
678 .await
679 .expect("tool should succeed");
680 let content = ToolResultContent::from_tool_output(output);
681
682 assert_eq!(content.len(), 1);
683 match content.first() {
684 ToolResultContent::Image(image) => {
685 assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
686 assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
687 }
688 other => panic!("expected image tool result content, got {other:?}"),
689 }
690 }
691
692 #[tokio::test]
693 async fn object_tool_outputs_still_serialize_as_json() {
694 let mut toolset = ToolSet::default();
695 toolset.add_tool(MockObjectOutputTool);
696
697 let output = toolset
698 .call("object_output", "{}".to_string())
699 .await
700 .expect("tool should succeed");
701
702 assert!(output.starts_with('{'));
703 assert_eq!(
704 serde_json::from_str::<serde_json::Value>(&output).unwrap(),
705 json!({
706 "status": "ok",
707 "count": 42
708 })
709 );
710 }
711
712 #[tokio::test]
713 async fn null_args_are_preserved_for_unit_args() {
714 let mut toolset = ToolSet::default();
715 toolset.add_tool(MockExampleTool);
716
717 let output = toolset
718 .call("example_tool", "null".to_string())
719 .await
720 .expect("unit args should accept null without object fallback");
721
722 assert_eq!(output, "Example answer");
723 }
724
725 #[tokio::test]
730 async fn null_args_are_normalized_to_empty_object() {
731 use crate::test_utils::MockToolError;
732
733 #[derive(serde::Deserialize, serde::Serialize)]
734 struct NoRequiredArgs {
735 label: Option<String>,
736 }
737
738 struct NoArgTool;
739
740 impl Tool for NoArgTool {
741 const NAME: &'static str = "no_arg_tool";
742 type Error = MockToolError;
743 type Args = NoRequiredArgs;
744 type Output = String;
745
746 async fn definition(&self, _prompt: String) -> ToolDefinition {
747 ToolDefinition {
748 name: Self::NAME.to_string(),
749 description: "Tool with no required arguments".to_string(),
750 parameters: json!({"type": "object", "properties": {}}),
751 }
752 }
753
754 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
755 Ok(args.label.unwrap_or_else(|| "default".to_string()))
756 }
757 }
758
759 let mut toolset = ToolSet::default();
760 toolset.add_tool(NoArgTool);
761
762 let output = toolset
765 .call("no_arg_tool", "null".to_string())
766 .await
767 .expect("null args should succeed after normalisation");
768
769 assert_eq!(output, "default");
770 }
771}