rig/embeddings/tool.rs
1//! The module defines the [ToolSchema] struct, which is used to embed an object that implements [crate::tool::ToolEmbedding]
2
3use crate::{Embed, tool::ToolEmbeddingDyn};
4use serde::Serialize;
5
6use super::embed::EmbedError;
7
8/// Embeddable document that is used as an intermediate representation of a tool when
9/// RAGging tools.
10#[derive(Clone, Serialize, Default, Eq, PartialEq)]
11pub struct ToolSchema {
12 pub name: String,
13 pub context: serde_json::Value,
14 pub embedding_docs: Vec<String>,
15}
16
17impl Embed for ToolSchema {
18 fn embed(&self, embedder: &mut super::embed::TextEmbedder) -> Result<(), EmbedError> {
19 for doc in &self.embedding_docs {
20 embedder.embed(doc.clone());
21 }
22 Ok(())
23 }
24}
25
26impl ToolSchema {
27 /// Convert item that implements [ToolEmbeddingDyn] to an [ToolSchema].
28 ///
29 /// # Example
30 /// ```rust
31 /// use rig::{
32 /// completion::ToolDefinition,
33 /// embeddings::ToolSchema,
34 /// tool::{Tool, ToolEmbedding, ToolEmbeddingDyn},
35 /// };
36 /// use serde_json::json;
37 ///
38 /// #[derive(Debug, thiserror::Error)]
39 /// #[error("Math error")]
40 /// struct NothingError;
41 ///
42 /// #[derive(Debug, thiserror::Error)]
43 /// #[error("Init error")]
44 /// struct InitError;
45 ///
46 /// struct Nothing;
47 /// impl Tool for Nothing {
48 /// const NAME: &'static str = "nothing";
49 ///
50 /// type Error = NothingError;
51 /// type Args = ();
52 /// type Output = ();
53 ///
54 /// async fn definition(&self, _prompt: String) -> ToolDefinition {
55 /// serde_json::from_value(json!({
56 /// "name": "nothing",
57 /// "description": "nothing",
58 /// "parameters": {}
59 /// }))
60 /// .expect("Tool Definition")
61 /// }
62 ///
63 /// async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
64 /// Ok(())
65 /// }
66 /// }
67 ///
68 /// impl ToolEmbedding for Nothing {
69 /// type InitError = InitError;
70 /// type Context = ();
71 /// type State = ();
72 ///
73 /// fn init(_state: Self::State, _context: Self::Context) -> Result<Self, Self::InitError> {
74 /// Ok(Nothing)
75 /// }
76 ///
77 /// fn embedding_docs(&self) -> Vec<String> {
78 /// vec!["Do nothing.".into()]
79 /// }
80 ///
81 /// fn context(&self) -> Self::Context {}
82 /// }
83 ///
84 /// let tool = ToolSchema::try_from(&Nothing).unwrap();
85 ///
86 /// assert_eq!(tool.name, "nothing".to_string());
87 /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]);
88 /// ```
89 pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbedError> {
90 Ok(ToolSchema {
91 name: tool.name(),
92 context: tool.context().map_err(EmbedError::new)?,
93 embedding_docs: tool.embedding_docs(),
94 })
95 }
96}