Skip to main content

rig_core/embeddings/
embed.rs

1//! The module defines the [Embed] trait, which must be implemented for types
2//! that can be embedded by the [crate::embeddings::EmbeddingsBuilder].
3//!
4//! The module also defines the [EmbedError] struct which is used for when the [Embed::embed]
5//! method of the [Embed] trait fails.
6//!
7//! The module also defines the [TextEmbedder] struct which accumulates string values that need to be embedded.
8//! It is used directly with the [Embed] trait.
9//!
10//! Finally, the module implements [Embed] for many common primitive types.
11
12/// Error type used for when the [Embed::embed] method of the [Embed] trait fails.
13/// Used by default implementations of [Embed] for common types.
14#[derive(Debug, thiserror::Error)]
15#[error("{0}")]
16pub struct EmbedError(#[from] Box<dyn std::error::Error + Send + Sync>);
17
18impl EmbedError {
19    pub fn new<E: std::error::Error + Send + Sync + 'static>(error: E) -> Self {
20        EmbedError(Box::new(error))
21    }
22}
23
24/// Derive this trait for objects that need to be converted to vector embeddings.
25/// The [Embed::embed] method accumulates string values that need to be embedded by adding them to the [TextEmbedder].
26/// If an error occurs, the method should return [EmbedError].
27/// # Example
28/// ```rust
29/// use std::env;
30///
31/// use rig_core::{
32///     Embed,
33///     embeddings::{self, EmbedError, TextEmbedder},
34/// };
35///
36/// struct WordDefinition {
37///     id: String,
38///     word: String,
39///     definitions: String,
40/// }
41///
42/// impl Embed for WordDefinition {
43///     fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
44///        // Embeddings only need to be generated for `definition` field.
45///        // Split the definitions by comma and collect them into a vector of strings.
46///        // That way, different embeddings can be generated for each definition in the `definitions` string.
47///        self.definitions
48///            .split(",")
49///            .for_each(|s| {
50///                embedder.embed(s.to_string());
51///            });
52///
53///        Ok(())
54///     }
55/// }
56///
57/// let fake_definition = WordDefinition {
58///    id: "1".to_string(),
59///    word: "apple".to_string(),
60///    definitions: "a fruit, a tech company".to_string(),
61/// };
62///
63/// assert_eq!(embeddings::to_texts(fake_definition).unwrap(), vec!["a fruit", " a tech company"]);
64/// ```
65pub trait Embed {
66    /// Append all text fragments that should be embedded for this value.
67    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>;
68}
69
70/// Accumulates string values that need to be embedded.
71/// Used by the [Embed] trait.
72#[derive(Default)]
73pub struct TextEmbedder {
74    pub(crate) texts: Vec<String>,
75}
76
77impl TextEmbedder {
78    /// Adds input `text` string to the list of texts in the [TextEmbedder] that need to be embedded.
79    pub fn embed(&mut self, text: String) {
80        self.texts.push(text);
81    }
82}
83
84/// Utility function that returns a vector of strings that need to be embedded for a
85/// given object that implements the [Embed] trait.
86pub fn to_texts(item: impl Embed) -> Result<Vec<String>, EmbedError> {
87    let mut embedder = TextEmbedder::default();
88    item.embed(&mut embedder)?;
89    Ok(embedder.texts)
90}
91
92// ================================================================
93// Implementations of Embed for common types
94// ================================================================
95
96impl Embed for String {
97    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
98        embedder.embed(self.clone());
99        Ok(())
100    }
101}
102
103impl Embed for &str {
104    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
105        embedder.embed(self.to_string());
106        Ok(())
107    }
108}
109
110impl Embed for i8 {
111    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
112        embedder.embed(self.to_string());
113        Ok(())
114    }
115}
116
117impl Embed for i16 {
118    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
119        embedder.embed(self.to_string());
120        Ok(())
121    }
122}
123
124impl Embed for i32 {
125    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
126        embedder.embed(self.to_string());
127        Ok(())
128    }
129}
130
131impl Embed for i64 {
132    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
133        embedder.embed(self.to_string());
134        Ok(())
135    }
136}
137
138impl Embed for i128 {
139    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
140        embedder.embed(self.to_string());
141        Ok(())
142    }
143}
144
145impl Embed for f32 {
146    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
147        embedder.embed(self.to_string());
148        Ok(())
149    }
150}
151
152impl Embed for f64 {
153    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
154        embedder.embed(self.to_string());
155        Ok(())
156    }
157}
158
159impl Embed for bool {
160    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
161        embedder.embed(self.to_string());
162        Ok(())
163    }
164}
165
166impl Embed for char {
167    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
168        embedder.embed(self.to_string());
169        Ok(())
170    }
171}
172
173impl Embed for serde_json::Value {
174    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
175        embedder.embed(serde_json::to_string(self).map_err(EmbedError::new)?);
176        Ok(())
177    }
178}
179
180impl<T: Embed> Embed for &T {
181    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
182        (*self).embed(embedder)
183    }
184}
185
186impl<T: Embed> Embed for Vec<T> {
187    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
188        for item in self {
189            item.embed(embedder).map_err(EmbedError::new)?;
190        }
191        Ok(())
192    }
193}