rig/embeddings/tool.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
//! The module defines the [ToolSchema] struct, which is used to embed an object that implements [crate::tool::ToolEmbedding]
use crate::{tool::ToolEmbeddingDyn, Embed};
use serde::Serialize;
use super::embed::EmbedError;
/// Embeddable document that is used as an intermediate representation of a tool when
/// RAGging tools.
#[derive(Clone, Serialize, Default, Eq, PartialEq)]
pub struct ToolSchema {
pub name: String,
pub context: serde_json::Value,
pub embedding_docs: Vec<String>,
}
impl Embed for ToolSchema {
fn embed(&self, embedder: &mut super::embed::TextEmbedder) -> Result<(), EmbedError> {
for doc in &self.embedding_docs {
embedder.embed(doc.clone());
}
Ok(())
}
}
impl ToolSchema {
/// Convert item that implements [ToolEmbeddingDyn] to an [ToolSchema].
///
/// # Example
/// ```rust
/// use rig::{
/// completion::ToolDefinition,
/// embeddings::ToolSchema,
/// tool::{Tool, ToolEmbedding, ToolEmbeddingDyn},
/// };
/// use serde_json::json;
///
/// #[derive(Debug, thiserror::Error)]
/// #[error("Math error")]
/// struct NothingError;
///
/// #[derive(Debug, thiserror::Error)]
/// #[error("Init error")]
/// struct InitError;
///
/// struct Nothing;
/// impl Tool for Nothing {
/// const NAME: &'static str = "nothing";
///
/// type Error = NothingError;
/// type Args = ();
/// type Output = ();
///
/// async fn definition(&self, _prompt: String) -> ToolDefinition {
/// serde_json::from_value(json!({
/// "name": "nothing",
/// "description": "nothing",
/// "parameters": {}
/// }))
/// .expect("Tool Definition")
/// }
///
/// async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
/// Ok(())
/// }
/// }
///
/// impl ToolEmbedding for Nothing {
/// type InitError = InitError;
/// type Context = ();
/// type State = ();
///
/// fn init(_state: Self::State, _context: Self::Context) -> Result<Self, Self::InitError> {
/// Ok(Nothing)
/// }
///
/// fn embedding_docs(&self) -> Vec<String> {
/// vec!["Do nothing.".into()]
/// }
///
/// fn context(&self) -> Self::Context {}
/// }
///
/// let tool = ToolSchema::try_from(&Nothing).unwrap();
///
/// assert_eq!(tool.name, "nothing".to_string());
/// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]);
/// ```
pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result<Self, EmbedError> {
Ok(ToolSchema {
name: tool.name(),
context: tool.context().map_err(EmbedError::new)?,
embedding_docs: tool.embedding_docs(),
})
}
}