Skip to main content

rig/embeddings/
embedding.rs

1//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can
2//! generate embeddings for documents.
3//!
4//! The module also defines the [Embedding] struct, which represents a single document embedding.
5//!
6//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that
7//! can occur during embedding generation or processing.
8
9use crate::{
10    http_client,
11    wasm_compat::{WasmCompatSend, WasmCompatSync},
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, thiserror::Error)]
16pub enum EmbeddingError {
17    /// Http error (e.g.: connection error, timeout, etc.)
18    #[error("HttpError: {0}")]
19    HttpError(#[from] http_client::Error),
20
21    /// Json error (e.g.: serialization, deserialization)
22    #[error("JsonError: {0}")]
23    JsonError(#[from] serde_json::Error),
24
25    #[error("UrlError: {0}")]
26    UrlError(#[from] url::ParseError),
27
28    #[cfg(not(target_family = "wasm"))]
29    /// Error processing the document for embedding
30    #[error("DocumentError: {0}")]
31    DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
32
33    #[cfg(target_family = "wasm")]
34    /// Error processing the document for embedding
35    #[error("DocumentError: {0}")]
36    DocumentError(Box<dyn std::error::Error + 'static>),
37
38    /// Error parsing the completion response
39    #[error("ResponseError: {0}")]
40    ResponseError(String),
41
42    /// Error returned by the embedding model provider
43    #[error("ProviderError: {0}")]
44    ProviderError(String),
45}
46
47/// Trait for embedding models that can generate embeddings for documents.
48pub trait EmbeddingModel: WasmCompatSend + WasmCompatSync {
49    /// The maximum number of documents that can be embedded in a single request.
50    const MAX_DOCUMENTS: usize;
51
52    type Client;
53
54    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self;
55
56    /// The number of dimensions in the embedding vector.
57    fn ndims(&self) -> usize;
58
59    /// Embed multiple text documents in a single request
60    fn embed_texts(
61        &self,
62        texts: impl IntoIterator<Item = String> + WasmCompatSend,
63    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + WasmCompatSend;
64
65    /// Embed a single text document.
66    fn embed_text(
67        &self,
68        text: &str,
69    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
70        async {
71            let mut embeddings = self.embed_texts(vec![text.to_string()]).await?;
72            embeddings.pop().ok_or_else(|| {
73                EmbeddingError::ResponseError(
74                    "embedding provider returned an empty response for embed_text".to_string(),
75                )
76            })
77        }
78    }
79}
80
81/// Trait for embedding models that can generate embeddings for images.
82pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
83    /// The maximum number of images that can be embedded in a single request.
84    const MAX_DOCUMENTS: usize;
85
86    /// The number of dimensions in the embedding vector.
87    fn ndims(&self) -> usize;
88
89    /// Embed multiple images in a single request from bytes.
90    fn embed_images(
91        &self,
92        images: impl IntoIterator<Item = Vec<u8>> + WasmCompatSend,
93    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
94
95    /// Embed a single image from bytes.
96    fn embed_image<'a>(
97        &'a self,
98        bytes: &'a [u8],
99    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
100        async move {
101            let mut embeddings = self.embed_images(vec![bytes.to_owned()]).await?;
102            embeddings.pop().ok_or_else(|| {
103                EmbeddingError::ResponseError(
104                    "embedding provider returned an empty response for embed_image".to_string(),
105                )
106            })
107        }
108    }
109}
110
111/// Struct that holds a single document and its embedding.
112#[derive(Clone, Default, Deserialize, Serialize, Debug)]
113pub struct Embedding {
114    /// The document that was embedded. Used for debugging.
115    pub document: String,
116    /// The embedding vector
117    pub vec: Vec<f64>,
118}
119
120impl PartialEq for Embedding {
121    fn eq(&self, other: &Self) -> bool {
122        self.document == other.document
123    }
124}
125
126impl Eq for Embedding {}