Skip to main content

rig_core/embeddings/
embedding.rs

1//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can
2//! generate embeddings for documents.
3//!
4//! The module also defines the [Embedding] struct, which represents a single document embedding.
5//!
6//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that
7//! can occur during embedding generation or processing.
8
9use crate::{
10    completion::Usage,
11    http_client,
12    wasm_compat::{WasmCompatSend, WasmCompatSync},
13};
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, thiserror::Error)]
17pub enum EmbeddingError {
18    /// Http error (e.g.: connection error, timeout, etc.)
19    #[error("HttpError: {0}")]
20    HttpError(#[from] http_client::Error),
21
22    /// Json error (e.g.: serialization, deserialization)
23    #[error("JsonError: {0}")]
24    JsonError(#[from] serde_json::Error),
25
26    /// URL construction or parsing failed while preparing a provider request.
27    #[error("UrlError: {0}")]
28    UrlError(#[from] url::ParseError),
29
30    #[cfg(not(target_family = "wasm"))]
31    /// Error processing the document for embedding
32    #[error("DocumentError: {0}")]
33    DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
34
35    #[cfg(target_family = "wasm")]
36    /// Error processing the document for embedding
37    #[error("DocumentError: {0}")]
38    DocumentError(Box<dyn std::error::Error + 'static>),
39
40    /// Error parsing the completion response
41    #[error("ResponseError: {0}")]
42    ResponseError(String),
43
44    /// Error returned by the embedding model provider
45    #[error("ProviderError: {0}")]
46    ProviderError(String),
47}
48
49/// Trait for embedding models that can generate embeddings for documents.
50pub trait EmbeddingModel: WasmCompatSend + WasmCompatSync {
51    /// The maximum number of documents that can be embedded in a single request.
52    const MAX_DOCUMENTS: usize;
53
54    /// Provider client type used to construct this embedding model.
55    type Client;
56
57    /// Construct a model handle from a provider client, model identifier, and optional dimensions.
58    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self;
59
60    /// The number of dimensions in the embedding vector.
61    fn ndims(&self) -> usize;
62
63    /// Embed multiple text documents in a single request
64    fn embed_texts(
65        &self,
66        texts: impl IntoIterator<Item = String> + WasmCompatSend,
67    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + WasmCompatSend;
68
69    /// Embed a single text document.
70    fn embed_text(
71        &self,
72        text: &str,
73    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
74        async {
75            let mut embeddings = self.embed_texts(vec![text.to_string()]).await?;
76            embeddings.pop().ok_or_else(|| {
77                EmbeddingError::ResponseError(
78                    "embedding provider returned an empty response for embed_text".to_string(),
79                )
80            })
81        }
82    }
83
84    /// Embed multiple text documents in a single request and return token usage.
85    ///
86    /// The default implementation delegates to [`EmbeddingModel::embed_texts`] and returns
87    /// zero-valued usage. Providers that expose usage information from their embedding API
88    /// should override this method.
89    fn embed_texts_with_usage(
90        &self,
91        texts: impl IntoIterator<Item = String> + WasmCompatSend,
92    ) -> impl std::future::Future<Output = Result<EmbeddingResponse, EmbeddingError>> + WasmCompatSend
93    {
94        async {
95            let embeddings = self.embed_texts(texts).await?;
96            Ok(EmbeddingResponse {
97                embeddings,
98                usage: Usage::default(),
99            })
100        }
101    }
102
103    /// Embed a single text document and return token usage.
104    ///
105    /// The default implementation delegates to
106    /// [`EmbeddingModel::embed_texts_with_usage`].
107    fn embed_text_with_usage(
108        &self,
109        text: &str,
110    ) -> impl std::future::Future<Output = Result<EmbeddingResponse, EmbeddingError>> + WasmCompatSend
111    {
112        async {
113            let response = self.embed_texts_with_usage(vec![text.to_string()]).await?;
114            if response.embeddings.is_empty() {
115                return Err(EmbeddingError::ResponseError(
116                    "embedding provider returned an empty response for embed_text_with_usage"
117                        .to_string(),
118                ));
119            }
120            Ok(response)
121        }
122    }
123}
124
125/// Response from an embedding request containing the embeddings and token usage.
126#[derive(Debug, Clone)]
127pub struct EmbeddingResponse {
128    /// The embeddings returned by the provider, one per input text.
129    pub embeddings: Vec<Embedding>,
130    /// Token usage for this embedding request.
131    pub usage: Usage,
132}
133
134/// Trait for embedding models that can generate embeddings for images.
135pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
136    /// The maximum number of images that can be embedded in a single request.
137    const MAX_DOCUMENTS: usize;
138
139    /// The number of dimensions in the embedding vector.
140    fn ndims(&self) -> usize;
141
142    /// Embed multiple images in a single request from bytes.
143    ///
144    /// Implementations should preserve input order in the returned embeddings.
145    fn embed_images(
146        &self,
147        images: impl IntoIterator<Item = Vec<u8>> + WasmCompatSend,
148    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
149
150    /// Embed a single image from bytes.
151    fn embed_image<'a>(
152        &'a self,
153        bytes: &'a [u8],
154    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
155        async move {
156            let mut embeddings = self.embed_images(vec![bytes.to_owned()]).await?;
157            embeddings.pop().ok_or_else(|| {
158                EmbeddingError::ResponseError(
159                    "embedding provider returned an empty response for embed_image".to_string(),
160                )
161            })
162        }
163    }
164}
165
166/// Struct that holds a single document and its embedding.
167#[derive(Clone, Default, Deserialize, Serialize, Debug)]
168pub struct Embedding {
169    /// The document that was embedded. Used for debugging.
170    pub document: String,
171    /// The embedding vector
172    pub vec: Vec<f64>,
173}
174
175impl PartialEq for Embedding {
176    fn eq(&self, other: &Self) -> bool {
177        self.document == other.document
178    }
179}
180
181impl Eq for Embedding {}