rig/embeddings/
embedding.rs

1//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can
2//! generate embeddings for documents.
3//!
4//! The module also defines the [Embedding] struct, which represents a single document embedding.
5//!
6//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that
7//! can occur during embedding generation or processing.
8
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, thiserror::Error)]
12pub enum EmbeddingError {
13    /// Http error (e.g.: connection error, timeout, etc.)
14    #[error("HttpError: {0}")]
15    HttpError(#[from] reqwest::Error),
16
17    /// Json error (e.g.: serialization, deserialization)
18    #[error("JsonError: {0}")]
19    JsonError(#[from] serde_json::Error),
20
21    /// Error processing the document for embedding
22    #[error("DocumentError: {0}")]
23    DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
24
25    /// Error parsing the completion response
26    #[error("ResponseError: {0}")]
27    ResponseError(String),
28
29    /// Error returned by the embedding model provider
30    #[error("ProviderError: {0}")]
31    ProviderError(String),
32}
33
34/// Trait for embedding models that can generate embeddings for documents.
35pub trait EmbeddingModel: Clone + Sync + Send {
36    /// The maximum number of documents that can be embedded in a single request.
37    const MAX_DOCUMENTS: usize;
38
39    /// The number of dimensions in the embedding vector.
40    fn ndims(&self) -> usize;
41
42    /// Embed multiple text documents in a single request
43    fn embed_texts(
44        &self,
45        texts: impl IntoIterator<Item = String> + Send,
46    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
47
48    /// Embed a single text document.
49    fn embed_text(
50        &self,
51        text: &str,
52    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
53        async {
54            Ok(self
55                .embed_texts(vec![text.to_string()])
56                .await?
57                .pop()
58                .expect("There should be at least one embedding"))
59        }
60    }
61}
62
63/// Trait for embedding models that can generate embeddings for images.
64pub trait ImageEmbeddingModel: Clone + Sync + Send {
65    /// The maximum number of images that can be embedded in a single request.
66    const MAX_DOCUMENTS: usize;
67
68    /// The number of dimensions in the embedding vector.
69    fn ndims(&self) -> usize;
70
71    /// Embed multiple images in a single request from bytes.
72    fn embed_images(
73        &self,
74        images: impl IntoIterator<Item = Vec<u8>> + Send,
75    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
76
77    /// Embed a single image from bytes.
78    fn embed_image<'a>(
79        &'a self,
80        bytes: &'a [u8],
81    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
82        async move {
83            Ok(self
84                .embed_images(vec![bytes.to_owned()])
85                .await?
86                .pop()
87                .expect("There should be at least one embedding"))
88        }
89    }
90}
91
92/// Struct that holds a single document and its embedding.
93#[derive(Clone, Default, Deserialize, Serialize, Debug)]
94pub struct Embedding {
95    /// The document that was embedded. Used for debugging.
96    pub document: String,
97    /// The embedding vector
98    pub vec: Vec<f64>,
99}
100
101impl PartialEq for Embedding {
102    fn eq(&self, other: &Self) -> bool {
103        self.document == other.document
104    }
105}
106
107impl Eq for Embedding {}