Skip to main content

rig_core/embeddings/
embedding.rs

1//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can
2//! generate embeddings for documents.
3//!
4//! The module also defines the [Embedding] struct, which represents a single document embedding.
5//!
6//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that
7//! can occur during embedding generation or processing.
8
9use crate::{
10    http_client,
11    wasm_compat::{WasmCompatSend, WasmCompatSync},
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, thiserror::Error)]
16pub enum EmbeddingError {
17    /// Http error (e.g.: connection error, timeout, etc.)
18    #[error("HttpError: {0}")]
19    HttpError(#[from] http_client::Error),
20
21    /// Json error (e.g.: serialization, deserialization)
22    #[error("JsonError: {0}")]
23    JsonError(#[from] serde_json::Error),
24
25    /// URL construction or parsing failed while preparing a provider request.
26    #[error("UrlError: {0}")]
27    UrlError(#[from] url::ParseError),
28
29    #[cfg(not(target_family = "wasm"))]
30    /// Error processing the document for embedding
31    #[error("DocumentError: {0}")]
32    DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
33
34    #[cfg(target_family = "wasm")]
35    /// Error processing the document for embedding
36    #[error("DocumentError: {0}")]
37    DocumentError(Box<dyn std::error::Error + 'static>),
38
39    /// Error parsing the completion response
40    #[error("ResponseError: {0}")]
41    ResponseError(String),
42
43    /// Error returned by the embedding model provider
44    #[error("ProviderError: {0}")]
45    ProviderError(String),
46}
47
48/// Trait for embedding models that can generate embeddings for documents.
49pub trait EmbeddingModel: WasmCompatSend + WasmCompatSync {
50    /// The maximum number of documents that can be embedded in a single request.
51    const MAX_DOCUMENTS: usize;
52
53    /// Provider client type used to construct this embedding model.
54    type Client;
55
56    /// Construct a model handle from a provider client, model identifier, and optional dimensions.
57    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self;
58
59    /// The number of dimensions in the embedding vector.
60    fn ndims(&self) -> usize;
61
62    /// Embed multiple text documents in a single request
63    fn embed_texts(
64        &self,
65        texts: impl IntoIterator<Item = String> + WasmCompatSend,
66    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + WasmCompatSend;
67
68    /// Embed a single text document.
69    fn embed_text(
70        &self,
71        text: &str,
72    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
73        async {
74            let mut embeddings = self.embed_texts(vec![text.to_string()]).await?;
75            embeddings.pop().ok_or_else(|| {
76                EmbeddingError::ResponseError(
77                    "embedding provider returned an empty response for embed_text".to_string(),
78                )
79            })
80        }
81    }
82}
83
84/// Trait for embedding models that can generate embeddings for images.
85pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
86    /// The maximum number of images that can be embedded in a single request.
87    const MAX_DOCUMENTS: usize;
88
89    /// The number of dimensions in the embedding vector.
90    fn ndims(&self) -> usize;
91
92    /// Embed multiple images in a single request from bytes.
93    ///
94    /// Implementations should preserve input order in the returned embeddings.
95    fn embed_images(
96        &self,
97        images: impl IntoIterator<Item = Vec<u8>> + WasmCompatSend,
98    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
99
100    /// Embed a single image from bytes.
101    fn embed_image<'a>(
102        &'a self,
103        bytes: &'a [u8],
104    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
105        async move {
106            let mut embeddings = self.embed_images(vec![bytes.to_owned()]).await?;
107            embeddings.pop().ok_or_else(|| {
108                EmbeddingError::ResponseError(
109                    "embedding provider returned an empty response for embed_image".to_string(),
110                )
111            })
112        }
113    }
114}
115
116/// Struct that holds a single document and its embedding.
117#[derive(Clone, Default, Deserialize, Serialize, Debug)]
118pub struct Embedding {
119    /// The document that was embedded. Used for debugging.
120    pub document: String,
121    /// The embedding vector
122    pub vec: Vec<f64>,
123}
124
125impl PartialEq for Embedding {
126    fn eq(&self, other: &Self) -> bool {
127        self.document == other.document
128    }
129}
130
131impl Eq for Embedding {}