1pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21 completion::{self, ToolDefinition},
22 embeddings::{embed::EmbedError, tool::ToolSchema},
23 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28 #[cfg(not(target_family = "wasm"))]
29 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
31
32 #[cfg(target_family = "wasm")]
33 ToolCallError(#[from] Box<dyn std::error::Error>),
35 JsonError(#[from] serde_json::Error),
37}
38
39impl fmt::Display for ToolError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 ToolError::ToolCallError(e) => {
43 let error_str = e.to_string();
44 if error_str.starts_with("ToolCallError: ") {
47 write!(f, "{}", error_str)
48 } else {
49 write!(f, "ToolCallError: {}", error_str)
50 }
51 }
52 ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
53 }
54 }
55}
56
57pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
113 const NAME: &'static str;
115
116 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
118 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
120 type Output: Serialize;
122
123 fn name(&self) -> String {
125 Self::NAME.to_string()
126 }
127
128 fn definition(
131 &self,
132 _prompt: String,
133 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
134
135 fn call(
139 &self,
140 args: Self::Args,
141 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
142}
143
144pub trait ToolEmbedding: Tool {
146 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
147
148 type Context: for<'a> Deserialize<'a> + Serialize;
153
154 type State: WasmCompatSend;
158
159 fn embedding_docs(&self) -> Vec<String>;
163
164 fn context(&self) -> Self::Context;
166
167 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
169}
170
171pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
173 fn name(&self) -> String;
174
175 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
176
177 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
178}
179
180fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
181 match serde_json::to_value(output)? {
182 serde_json::Value::String(text) => Ok(text),
183 value => Ok(value.to_string()),
184 }
185}
186
187impl<T: Tool> ToolDyn for T {
188 fn name(&self) -> String {
189 self.name()
190 }
191
192 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
193 Box::pin(<Self as Tool>::definition(self, prompt))
194 }
195
196 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
197 Box::pin(async move {
198 match serde_json::from_str(&args) {
199 Ok(args) => <Self as Tool>::call(self, args)
200 .await
201 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
202 .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
203 Err(e) => Err(ToolError::JsonError(e)),
204 }
205 })
206 }
207}
208
209#[cfg(feature = "rmcp")]
210#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
211pub mod rmcp;
212
213pub trait ToolEmbeddingDyn: ToolDyn {
215 fn context(&self) -> serde_json::Result<serde_json::Value>;
216
217 fn embedding_docs(&self) -> Vec<String>;
218}
219
220impl<T> ToolEmbeddingDyn for T
221where
222 T: ToolEmbedding + 'static,
223{
224 fn context(&self) -> serde_json::Result<serde_json::Value> {
225 serde_json::to_value(self.context())
226 }
227
228 fn embedding_docs(&self) -> Vec<String> {
229 self.embedding_docs()
230 }
231}
232
233#[derive(Clone)]
234pub(crate) enum ToolType {
235 Simple(Arc<dyn ToolDyn>),
236 Embedding(Arc<dyn ToolEmbeddingDyn>),
237}
238
239impl ToolType {
240 pub fn name(&self) -> String {
241 match self {
242 ToolType::Simple(tool) => tool.name(),
243 ToolType::Embedding(tool) => tool.name(),
244 }
245 }
246
247 pub async fn definition(&self, prompt: String) -> ToolDefinition {
248 match self {
249 ToolType::Simple(tool) => tool.definition(prompt).await,
250 ToolType::Embedding(tool) => tool.definition(prompt).await,
251 }
252 }
253
254 pub async fn call(&self, args: String) -> Result<String, ToolError> {
255 match self {
256 ToolType::Simple(tool) => tool.call(args).await,
257 ToolType::Embedding(tool) => tool.call(args).await,
258 }
259 }
260}
261
262#[derive(Debug, thiserror::Error)]
263pub enum ToolSetError {
264 #[error("ToolCallError: {0}")]
266 ToolCallError(#[from] ToolError),
267
268 #[error("ToolNotFoundError: {0}")]
270 ToolNotFoundError(String),
271
272 #[error("JsonError: {0}")]
274 JsonError(#[from] serde_json::Error),
275
276 #[error("Tool call interrupted")]
278 Interrupted,
279}
280
281#[derive(Default)]
283pub struct ToolSet {
284 pub(crate) tools: HashMap<String, ToolType>,
285}
286
287impl ToolSet {
288 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
290 let mut toolset = Self::default();
291 tools.into_iter().for_each(|tool| {
292 toolset.add_tool(tool);
293 });
294 toolset
295 }
296
297 pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
298 let mut toolset = Self::default();
299 tools.into_iter().for_each(|tool| {
300 toolset.add_tool_boxed(tool);
301 });
302 toolset
303 }
304
305 pub fn builder() -> ToolSetBuilder {
307 ToolSetBuilder::default()
308 }
309
310 pub fn contains(&self, toolname: &str) -> bool {
312 self.tools.contains_key(toolname)
313 }
314
315 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
317 self.tools
318 .insert(tool.name(), ToolType::Simple(Arc::new(tool)));
319 }
320
321 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
323 self.tools
324 .insert(tool.name(), ToolType::Simple(Arc::from(tool)));
325 }
326
327 pub fn delete_tool(&mut self, tool_name: &str) {
328 let _ = self.tools.remove(tool_name);
329 }
330
331 pub fn add_tools(&mut self, toolset: ToolSet) {
333 self.tools.extend(toolset.tools);
334 }
335
336 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
337 self.tools.get(toolname)
338 }
339
340 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
341 let mut defs = Vec::new();
342 for tool in self.tools.values() {
343 let def = tool.definition(String::new()).await;
344 defs.push(def);
345 }
346 Ok(defs)
347 }
348
349 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
351 if let Some(tool) = self.tools.get(toolname) {
352 tracing::debug!(target: "rig",
353 "Calling tool {toolname} with args:\n{}",
354 serde_json::to_string_pretty(&args).unwrap()
355 );
356 Ok(tool.call(args).await?)
357 } else {
358 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
359 }
360 }
361
362 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
364 let mut docs = Vec::new();
365 for tool in self.tools.values() {
366 match tool {
367 ToolType::Simple(tool) => {
368 docs.push(completion::Document {
369 id: tool.name(),
370 text: format!(
371 "\
372 Tool: {}\n\
373 Definition: \n\
374 {}\
375 ",
376 tool.name(),
377 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
378 ),
379 additional_props: HashMap::new(),
380 });
381 }
382 ToolType::Embedding(tool) => {
383 docs.push(completion::Document {
384 id: tool.name(),
385 text: format!(
386 "\
387 Tool: {}\n\
388 Definition: \n\
389 {}\
390 ",
391 tool.name(),
392 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
393 ),
394 additional_props: HashMap::new(),
395 });
396 }
397 }
398 }
399 Ok(docs)
400 }
401
402 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
406 self.tools
407 .values()
408 .filter_map(|tool_type| {
409 if let ToolType::Embedding(tool) = tool_type {
410 Some(ToolSchema::try_from(&**tool))
411 } else {
412 None
413 }
414 })
415 .collect::<Result<Vec<_>, _>>()
416 }
417}
418
419#[derive(Default)]
420pub struct ToolSetBuilder {
421 tools: Vec<ToolType>,
422}
423
424impl ToolSetBuilder {
425 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
426 self.tools.push(ToolType::Simple(Arc::new(tool)));
427 self
428 }
429
430 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
431 self.tools.push(ToolType::Embedding(Arc::new(tool)));
432 self
433 }
434
435 pub fn build(self) -> ToolSet {
436 ToolSet {
437 tools: self
438 .tools
439 .into_iter()
440 .map(|tool| (tool.name(), tool))
441 .collect(),
442 }
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use crate::message::{DocumentSourceKind, ToolResultContent};
449 use serde_json::json;
450
451 use super::*;
452
453 fn get_test_toolset() -> ToolSet {
454 let mut toolset = ToolSet::default();
455
456 #[derive(Deserialize)]
457 struct OperationArgs {
458 x: i32,
459 y: i32,
460 }
461
462 #[derive(Debug, thiserror::Error)]
463 #[error("Math error")]
464 struct MathError;
465
466 #[derive(Deserialize, Serialize)]
467 struct Adder;
468
469 impl Tool for Adder {
470 const NAME: &'static str = "add";
471 type Error = MathError;
472 type Args = OperationArgs;
473 type Output = i32;
474
475 async fn definition(&self, _prompt: String) -> ToolDefinition {
476 ToolDefinition {
477 name: "add".to_string(),
478 description: "Add x and y together".to_string(),
479 parameters: json!({
480 "type": "object",
481 "properties": {
482 "x": {
483 "type": "number",
484 "description": "The first number to add"
485 },
486 "y": {
487 "type": "number",
488 "description": "The second number to add"
489 }
490 },
491 "required": ["x", "y"]
492 }),
493 }
494 }
495
496 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
497 let result = args.x + args.y;
498 Ok(result)
499 }
500 }
501
502 #[derive(Deserialize, Serialize)]
503 struct Subtract;
504
505 impl Tool for Subtract {
506 const NAME: &'static str = "subtract";
507 type Error = MathError;
508 type Args = OperationArgs;
509 type Output = i32;
510
511 async fn definition(&self, _prompt: String) -> ToolDefinition {
512 serde_json::from_value(json!({
513 "name": "subtract",
514 "description": "Subtract y from x (i.e.: x - y)",
515 "parameters": {
516 "type": "object",
517 "properties": {
518 "x": {
519 "type": "number",
520 "description": "The number to subtract from"
521 },
522 "y": {
523 "type": "number",
524 "description": "The number to subtract"
525 }
526 },
527 "required": ["x", "y"]
528 }
529 }))
530 .expect("Tool Definition")
531 }
532
533 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
534 let result = args.x - args.y;
535 Ok(result)
536 }
537 }
538
539 toolset.add_tool(Adder);
540 toolset.add_tool(Subtract);
541 toolset
542 }
543
544 #[tokio::test]
545 async fn test_get_tool_definitions() {
546 let toolset = get_test_toolset();
547 let tools = toolset.get_tool_definitions().await.unwrap();
548 assert_eq!(tools.len(), 2);
549 }
550
551 #[test]
552 fn test_tool_deletion() {
553 let mut toolset = get_test_toolset();
554 assert_eq!(toolset.tools.len(), 2);
555 toolset.delete_tool("add");
556 assert!(!toolset.contains("add"));
557 assert_eq!(toolset.tools.len(), 1);
558 }
559
560 #[derive(Debug, thiserror::Error)]
561 #[error("Test tool error")]
562 struct TestToolError;
563
564 #[derive(Deserialize, Serialize)]
565 struct StringOutputTool;
566
567 impl Tool for StringOutputTool {
568 const NAME: &'static str = "string_output";
569 type Error = TestToolError;
570 type Args = serde_json::Value;
571 type Output = String;
572
573 async fn definition(&self, _prompt: String) -> ToolDefinition {
574 ToolDefinition {
575 name: Self::NAME.to_string(),
576 description: "Returns a multiline string".to_string(),
577 parameters: json!({
578 "type": "object",
579 "properties": {}
580 }),
581 }
582 }
583
584 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
585 Ok("Hello\nWorld".to_string())
586 }
587 }
588
589 #[tokio::test]
590 async fn string_tool_outputs_are_preserved_verbatim() {
591 let mut toolset = ToolSet::default();
592 toolset.add_tool(StringOutputTool);
593
594 let output = toolset
595 .call("string_output", "{}".to_string())
596 .await
597 .expect("tool should succeed");
598
599 assert_eq!(output, "Hello\nWorld");
600 }
601
602 #[derive(Deserialize, Serialize)]
603 struct ImageOutputTool;
604
605 impl Tool for ImageOutputTool {
606 const NAME: &'static str = "image_output";
607 type Error = TestToolError;
608 type Args = serde_json::Value;
609 type Output = String;
610
611 async fn definition(&self, _prompt: String) -> ToolDefinition {
612 ToolDefinition {
613 name: Self::NAME.to_string(),
614 description: "Returns image JSON".to_string(),
615 parameters: json!({
616 "type": "object",
617 "properties": {}
618 }),
619 }
620 }
621
622 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
623 Ok(json!({
624 "type": "image",
625 "data": "base64data==",
626 "mimeType": "image/png"
627 })
628 .to_string())
629 }
630 }
631
632 #[tokio::test]
633 async fn structured_string_tool_outputs_remain_parseable() {
634 let mut toolset = ToolSet::default();
635 toolset.add_tool(ImageOutputTool);
636
637 let output = toolset
638 .call("image_output", "{}".to_string())
639 .await
640 .expect("tool should succeed");
641 let content = ToolResultContent::from_tool_output(output);
642
643 assert_eq!(content.len(), 1);
644 match content.first() {
645 ToolResultContent::Image(image) => {
646 assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
647 assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
648 }
649 other => panic!("expected image tool result content, got {other:?}"),
650 }
651 }
652
653 #[derive(Deserialize, Serialize)]
654 struct ObjectOutputTool;
655
656 impl Tool for ObjectOutputTool {
657 const NAME: &'static str = "object_output";
658 type Error = TestToolError;
659 type Args = serde_json::Value;
660 type Output = serde_json::Value;
661
662 async fn definition(&self, _prompt: String) -> ToolDefinition {
663 ToolDefinition {
664 name: Self::NAME.to_string(),
665 description: "Returns an object".to_string(),
666 parameters: json!({
667 "type": "object",
668 "properties": {}
669 }),
670 }
671 }
672
673 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
674 Ok(json!({
675 "status": "ok",
676 "count": 42
677 }))
678 }
679 }
680
681 #[tokio::test]
682 async fn object_tool_outputs_still_serialize_as_json() {
683 let mut toolset = ToolSet::default();
684 toolset.add_tool(ObjectOutputTool);
685
686 let output = toolset
687 .call("object_output", "{}".to_string())
688 .await
689 .expect("tool should succeed");
690
691 assert!(output.starts_with('{'));
692 assert_eq!(
693 serde_json::from_str::<serde_json::Value>(&output).unwrap(),
694 json!({
695 "status": "ok",
696 "count": 42
697 })
698 );
699 }
700}