rig/embeddings/
tool.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
//! The module defines the [ToolSchema] struct, which is used to embed an object that implements [crate::tool::ToolEmbedding]

use crate::{tool::ToolEmbeddingDyn, Embed};
use serde::Serialize;

use super::embed::EmbedError;

/// Embeddable document that is used as an intermediate representation of a tool when
/// RAGging tools.
#[derive(Clone, Serialize, Default, Eq, PartialEq)]
pub struct ToolSchema {
    pub name: String,
    pub context: serde_json::Value,
    pub embedding_docs: Vec<String>,
}

impl Embed for ToolSchema {
    fn embed(&self, embedder: &mut super::embed::TextEmbedder) -> Result<(), EmbedError> {
        for doc in &self.embedding_docs {
            embedder.embed(doc.clone());
        }
        Ok(())
    }
}

impl ToolSchema {
    /// Convert item that implements [ToolEmbeddingDyn] to an [ToolSchema].
    ///
    /// # Example
    /// ```rust
    /// use rig::{
    ///     completion::ToolDefinition,
    ///     embeddings::ToolSchema,
    ///     tool::{Tool, ToolEmbedding, ToolEmbeddingDyn},
    /// };
    /// use serde_json::json;
    ///
    /// #[derive(Debug, thiserror::Error)]
    /// #[error("Math error")]
    /// struct NothingError;
    ///
    /// #[derive(Debug, thiserror::Error)]
    /// #[error("Init error")]
    /// struct InitError;
    ///
    /// struct Nothing;
    /// impl Tool for Nothing {
    ///     const NAME: &'static str = "nothing";
    ///
    ///     type Error = NothingError;
    ///     type Args = ();
    ///     type Output = ();
    ///
    ///     async fn definition(&self, _prompt: String) -> ToolDefinition {
    ///         serde_json::from_value(json!({
    ///             "name": "nothing",
    ///             "description": "nothing",
    ///             "parameters": {}
    ///         }))
    ///         .expect("Tool Definition")
    ///     }
    ///
    ///     async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
    ///         Ok(())
    ///     }
    /// }
    ///
    /// impl ToolEmbedding for Nothing {
    ///     type InitError = InitError;
    ///     type Context = ();
    ///     type State = ();
    ///
    ///     fn init(_state: Self::State, _context: Self::Context) -> Result<Self, Self::InitError> {
    ///         Ok(Nothing)
    ///     }
    ///
    ///     fn embedding_docs(&self) -> Vec<String> {
    ///         vec!["Do nothing.".into()]
    ///     }
    ///
    ///     fn context(&self) -> Self::Context {}
    /// }
    ///
    /// let tool = ToolSchema::try_from(&Nothing).unwrap();
    ///
    /// assert_eq!(tool.name, "nothing".to_string());
    /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]);
    /// ```
    pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbedError> {
        Ok(ToolSchema {
            name: tool.name(),
            context: tool.context().map_err(EmbedError::new)?,
            embedding_docs: tool.embedding_docs(),
        })
    }
}