rig/embeddings/
embedding.rs

1//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can
2//! generate embeddings for documents.
3//!
4//! The module also defines the [Embedding] struct, which represents a single document embedding.
5//!
6//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that
7//! can occur during embedding generation or processing.
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, thiserror::Error)]
13pub enum EmbeddingError {
14    /// Http error (e.g.: connection error, timeout, etc.)
15    #[error("HttpError: {0}")]
16    HttpError(#[from] reqwest::Error),
17
18    /// Json error (e.g.: serialization, deserialization)
19    #[error("JsonError: {0}")]
20    JsonError(#[from] serde_json::Error),
21
22    #[error("UrlError: {0}")]
23    UrlError(#[from] url::ParseError),
24
25    /// Error processing the document for embedding
26    #[error("DocumentError: {0}")]
27    DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
28
29    /// Error parsing the completion response
30    #[error("ResponseError: {0}")]
31    ResponseError(String),
32
33    /// Error returned by the embedding model provider
34    #[error("ProviderError: {0}")]
35    ProviderError(String),
36}
37
38/// Trait for embedding models that can generate embeddings for documents.
39pub trait EmbeddingModel: Clone + Sync + Send {
40    /// The maximum number of documents that can be embedded in a single request.
41    const MAX_DOCUMENTS: usize;
42
43    /// The number of dimensions in the embedding vector.
44    fn ndims(&self) -> usize;
45
46    /// Embed multiple text documents in a single request
47    fn embed_texts(
48        &self,
49        texts: impl IntoIterator<Item = String> + Send,
50    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
51
52    /// Embed a single text document.
53    fn embed_text(
54        &self,
55        text: &str,
56    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
57        async {
58            Ok(self
59                .embed_texts(vec![text.to_string()])
60                .await?
61                .pop()
62                .expect("There should be at least one embedding"))
63        }
64    }
65}
66
67pub trait EmbeddingModelDyn: Sync + Send {
68    fn max_documents(&self) -> usize;
69    fn ndims(&self) -> usize;
70    fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Embedding, EmbeddingError>>;
71    fn embed_texts(
72        &self,
73        texts: Vec<String>,
74    ) -> BoxFuture<'_, Result<Vec<Embedding>, EmbeddingError>>;
75}
76
77impl<T> EmbeddingModelDyn for T
78where
79    T: EmbeddingModel,
80{
81    fn max_documents(&self) -> usize {
82        T::MAX_DOCUMENTS
83    }
84
85    fn ndims(&self) -> usize {
86        self.ndims()
87    }
88
89    fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Embedding, EmbeddingError>> {
90        Box::pin(self.embed_text(text))
91    }
92
93    fn embed_texts(
94        &self,
95        texts: Vec<String>,
96    ) -> BoxFuture<'_, Result<Vec<Embedding>, EmbeddingError>> {
97        Box::pin(self.embed_texts(texts.into_iter().collect::<Vec<_>>()))
98    }
99}
100
101/// Trait for embedding models that can generate embeddings for images.
102pub trait ImageEmbeddingModel: Clone + Sync + Send {
103    /// The maximum number of images that can be embedded in a single request.
104    const MAX_DOCUMENTS: usize;
105
106    /// The number of dimensions in the embedding vector.
107    fn ndims(&self) -> usize;
108
109    /// Embed multiple images in a single request from bytes.
110    fn embed_images(
111        &self,
112        images: impl IntoIterator<Item = Vec<u8>> + Send,
113    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
114
115    /// Embed a single image from bytes.
116    fn embed_image<'a>(
117        &'a self,
118        bytes: &'a [u8],
119    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
120        async move {
121            Ok(self
122                .embed_images(vec![bytes.to_owned()])
123                .await?
124                .pop()
125                .expect("There should be at least one embedding"))
126        }
127    }
128}
129
130/// Struct that holds a single document and its embedding.
131#[derive(Clone, Default, Deserialize, Serialize, Debug)]
132pub struct Embedding {
133    /// The document that was embedded. Used for debugging.
134    pub document: String,
135    /// The embedding vector
136    pub vec: Vec<f64>,
137}
138
139impl PartialEq for Embedding {
140    fn eq(&self, other: &Self) -> bool {
141        self.document == other.document
142    }
143}
144
145impl Eq for Embedding {}