rig/embeddings/
embedding.rs

1//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can
2//! generate embeddings for documents.
3//!
4//! The module also defines the [Embedding] struct, which represents a single document embedding.
5//!
6//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that
7//! can occur during embedding generation or processing.
8
9use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, thiserror::Error)]
13pub enum EmbeddingError {
14    /// Http error (e.g.: connection error, timeout, etc.)
15    #[error("HttpError: {0}")]
16    HttpError(#[from] reqwest::Error),
17
18    /// Json error (e.g.: serialization, deserialization)
19    #[error("JsonError: {0}")]
20    JsonError(#[from] serde_json::Error),
21
22    /// Error processing the document for embedding
23    #[error("DocumentError: {0}")]
24    DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
25
26    /// Error parsing the completion response
27    #[error("ResponseError: {0}")]
28    ResponseError(String),
29
30    /// Error returned by the embedding model provider
31    #[error("ProviderError: {0}")]
32    ProviderError(String),
33}
34
35/// Trait for embedding models that can generate embeddings for documents.
36pub trait EmbeddingModel: Clone + Sync + Send {
37    /// The maximum number of documents that can be embedded in a single request.
38    const MAX_DOCUMENTS: usize;
39
40    /// The number of dimensions in the embedding vector.
41    fn ndims(&self) -> usize;
42
43    /// Embed multiple text documents in a single request
44    fn embed_texts(
45        &self,
46        texts: impl IntoIterator<Item = String> + Send,
47    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
48
49    /// Embed a single text document.
50    fn embed_text(
51        &self,
52        text: &str,
53    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
54        async {
55            Ok(self
56                .embed_texts(vec![text.to_string()])
57                .await?
58                .pop()
59                .expect("There should be at least one embedding"))
60        }
61    }
62}
63
64pub trait EmbeddingModelDyn: Sync + Send {
65    fn max_documents(&self) -> usize;
66    fn ndims(&self) -> usize;
67    fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Embedding, EmbeddingError>>;
68    fn embed_texts(
69        &self,
70        texts: Vec<String>,
71    ) -> BoxFuture<'_, Result<Vec<Embedding>, EmbeddingError>>;
72}
73
74impl<T: EmbeddingModel> EmbeddingModelDyn for T {
75    fn max_documents(&self) -> usize {
76        T::MAX_DOCUMENTS
77    }
78
79    fn ndims(&self) -> usize {
80        self.ndims()
81    }
82
83    fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Embedding, EmbeddingError>> {
84        Box::pin(self.embed_text(text))
85    }
86
87    fn embed_texts(&self, texts: Vec<String>) -> BoxFuture<Result<Vec<Embedding>, EmbeddingError>> {
88        Box::pin(self.embed_texts(texts.into_iter().collect::<Vec<_>>()))
89    }
90}
91
92/// Trait for embedding models that can generate embeddings for images.
93pub trait ImageEmbeddingModel: Clone + Sync + Send {
94    /// The maximum number of images that can be embedded in a single request.
95    const MAX_DOCUMENTS: usize;
96
97    /// The number of dimensions in the embedding vector.
98    fn ndims(&self) -> usize;
99
100    /// Embed multiple images in a single request from bytes.
101    fn embed_images(
102        &self,
103        images: impl IntoIterator<Item = Vec<u8>> + Send,
104    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
105
106    /// Embed a single image from bytes.
107    fn embed_image<'a>(
108        &'a self,
109        bytes: &'a [u8],
110    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
111        async move {
112            Ok(self
113                .embed_images(vec![bytes.to_owned()])
114                .await?
115                .pop()
116                .expect("There should be at least one embedding"))
117        }
118    }
119}
120
121/// Struct that holds a single document and its embedding.
122#[derive(Clone, Default, Deserialize, Serialize, Debug)]
123pub struct Embedding {
124    /// The document that was embedded. Used for debugging.
125    pub document: String,
126    /// The embedding vector
127    pub vec: Vec<f64>,
128}
129
130impl PartialEq for Embedding {
131    fn eq(&self, other: &Self) -> bool {
132        self.document == other.document
133    }
134}
135
136impl Eq for Embedding {}