rig/embeddings/
embed.rs

1//! The module defines the [Embed] trait, which must be implemented for types
2//! that can be embedded by the [crate::embeddings::EmbeddingsBuilder].
3//!
4//! The module also defines the [EmbedError] struct which is used for when the [Embed::embed]
5//! method of the [Embed] trait fails.
6//!
7//! The module also defines the [TextEmbedder] struct which accumulates string values that need to be embedded.
8//! It is used directly with the [Embed] trait.
9//!
10//! Finally, the module implements [Embed] for many common primitive types.
11
12/// Error type used for when the [Embed::embed] method of the [Embed] trait fails.
13/// Used by default implementations of [Embed] for common types.
14#[derive(Debug, thiserror::Error)]
15#[error("{0}")]
16pub struct EmbedError(#[from] Box<dyn std::error::Error + Send + Sync>);
17
18impl EmbedError {
19    pub fn new<E: std::error::Error + Send + Sync + 'static>(error: E) -> Self {
20        EmbedError(Box::new(error))
21    }
22}
23
24/// Derive this trait for objects that need to be converted to vector embeddings.
25/// The [Embed::embed] method accumulates string values that need to be embedded by adding them to the [TextEmbedder].
26/// If an error occurs, the method should return [EmbedError].
27/// # Example
28/// ```rust
29/// use std::env;
30///
31/// use serde::{Deserialize, Serialize};
32/// use rig::{Embed, embeddings::{TextEmbedder, EmbedError}};
33///
34/// struct WordDefinition {
35///     id: String,
36///     word: String,
37///     definitions: String,
38/// }
39///
40/// impl Embed for WordDefinition {
41///     fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
42///        // Embeddings only need to be generated for `definition` field.
43///        // Split the definitions by comma and collect them into a vector of strings.
44///        // That way, different embeddings can be generated for each definition in the `definitions` string.
45///        self.definitions
46///            .split(",")
47///            .for_each(|s| {
48///                embedder.embed(s.to_string());
49///            });
50///
51///        Ok(())
52///     }
53/// }
54///
55/// let fake_definition = WordDefinition {
56///    id: "1".to_string(),
57///    word: "apple".to_string(),
58///    definitions: "a fruit, a tech company".to_string(),
59/// };
60///
61/// assert_eq!(embeddings::to_texts(fake_definition).unwrap(), vec!["a fruit", " a tech company"]);
62/// ```
63pub trait Embed {
64    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>;
65}
66
67/// Accumulates string values that need to be embedded.
68/// Used by the [Embed] trait.
69#[derive(Default)]
70pub struct TextEmbedder {
71    pub(crate) texts: Vec<String>,
72}
73
74impl TextEmbedder {
75    /// Adds input `text` string to the list of texts in the [TextEmbedder] that need to be embedded.
76    pub fn embed(&mut self, text: String) {
77        self.texts.push(text);
78    }
79}
80
81/// Utility function that returns a vector of strings that need to be embedded for a
82/// given object that implements the [Embed] trait.
83pub fn to_texts(item: impl Embed) -> Result<Vec<String>, EmbedError> {
84    let mut embedder = TextEmbedder::default();
85    item.embed(&mut embedder)?;
86    Ok(embedder.texts)
87}
88
89// ================================================================
90// Implementations of Embed for common types
91// ================================================================
92
93impl Embed for String {
94    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
95        embedder.embed(self.clone());
96        Ok(())
97    }
98}
99
100impl Embed for &str {
101    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
102        embedder.embed(self.to_string());
103        Ok(())
104    }
105}
106
107impl Embed for i8 {
108    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
109        embedder.embed(self.to_string());
110        Ok(())
111    }
112}
113
114impl Embed for i16 {
115    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
116        embedder.embed(self.to_string());
117        Ok(())
118    }
119}
120
121impl Embed for i32 {
122    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
123        embedder.embed(self.to_string());
124        Ok(())
125    }
126}
127
128impl Embed for i64 {
129    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
130        embedder.embed(self.to_string());
131        Ok(())
132    }
133}
134
135impl Embed for i128 {
136    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
137        embedder.embed(self.to_string());
138        Ok(())
139    }
140}
141
142impl Embed for f32 {
143    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
144        embedder.embed(self.to_string());
145        Ok(())
146    }
147}
148
149impl Embed for f64 {
150    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
151        embedder.embed(self.to_string());
152        Ok(())
153    }
154}
155
156impl Embed for bool {
157    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
158        embedder.embed(self.to_string());
159        Ok(())
160    }
161}
162
163impl Embed for char {
164    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
165        embedder.embed(self.to_string());
166        Ok(())
167    }
168}
169
170impl Embed for serde_json::Value {
171    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
172        embedder.embed(serde_json::to_string(self).map_err(EmbedError::new)?);
173        Ok(())
174    }
175}
176
177impl<T: Embed> Embed for &T {
178    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
179        (*self).embed(embedder)
180    }
181}
182
183impl<T: Embed> Embed for Vec<T> {
184    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
185        for item in self {
186            item.embed(embedder).map_err(EmbedError::new)?;
187        }
188        Ok(())
189    }
190}