rig/embeddings/
tool.rs

1//! The module defines the [ToolSchema] struct, which is used to embed an object that implements [crate::tool::ToolEmbedding]
2
3use crate::{Embed, tool::ToolEmbeddingDyn};
4use serde::Serialize;
5
6use super::embed::EmbedError;
7
8/// Embeddable document that is used as an intermediate representation of a tool when
9/// RAGging tools.
10#[derive(Clone, Serialize, Default, Eq, PartialEq)]
11pub struct ToolSchema {
12    pub name: String,
13    pub context: serde_json::Value,
14    pub embedding_docs: Vec<String>,
15}
16
17impl Embed for ToolSchema {
18    fn embed(&self, embedder: &mut super::embed::TextEmbedder) -> Result<(), EmbedError> {
19        for doc in &self.embedding_docs {
20            embedder.embed(doc.clone());
21        }
22        Ok(())
23    }
24}
25
26impl ToolSchema {
27    /// Convert item that implements [ToolEmbeddingDyn] to an [ToolSchema].
28    ///
29    /// # Example
30    /// ```rust
31    /// use rig::{
32    ///     completion::ToolDefinition,
33    ///     embeddings::ToolSchema,
34    ///     tool::{Tool, ToolEmbedding, ToolEmbeddingDyn},
35    /// };
36    /// use serde_json::json;
37    ///
38    /// #[derive(Debug, thiserror::Error)]
39    /// #[error("Math error")]
40    /// struct NothingError;
41    ///
42    /// #[derive(Debug, thiserror::Error)]
43    /// #[error("Init error")]
44    /// struct InitError;
45    ///
46    /// struct Nothing;
47    /// impl Tool for Nothing {
48    ///     const NAME: &'static str = "nothing";
49    ///
50    ///     type Error = NothingError;
51    ///     type Args = ();
52    ///     type Output = ();
53    ///
54    ///     async fn definition(&self, _prompt: String) -> ToolDefinition {
55    ///         serde_json::from_value(json!({
56    ///             "name": "nothing",
57    ///             "description": "nothing",
58    ///             "parameters": {}
59    ///         }))
60    ///         .expect("Tool Definition")
61    ///     }
62    ///
63    ///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
64    ///         Ok(())
65    ///     }
66    /// }
67    ///
68    /// impl ToolEmbedding for Nothing {
69    ///     type InitError = InitError;
70    ///     type Context = ();
71    ///     type State = ();
72    ///
73    ///     fn init(_state: Self::State, _context: Self::Context) -> Result<Self, Self::InitError> {
74    ///         Ok(Nothing)
75    ///     }
76    ///
77    ///     fn embedding_docs(&self) -> Vec<String> {
78    ///         vec!["Do nothing.".into()]
79    ///     }
80    ///
81    ///     fn context(&self) -> Self::Context {}
82    /// }
83    ///
84    /// let tool = ToolSchema::try_from(&Nothing).unwrap();
85    ///
86    /// assert_eq!(tool.name, "nothing".to_string());
87    /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]);
88    /// ```
89    pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbedError> {
90        Ok(ToolSchema {
91            name: tool.name(),
92            context: tool.context().map_err(EmbedError::new)?,
93            embedding_docs: tool.embedding_docs(),
94        })
95    }
96}