rig/embeddings/
embedding.rs

1//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can
2//! generate embeddings for documents.
3//!
4//! The module also defines the [Embedding] struct, which represents a single document embedding.
5//!
6//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that
7//! can occur during embedding generation or processing.
8
9use crate::wasm_compat::WasmBoxedFuture;
10use crate::{http_client, wasm_compat::*};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, thiserror::Error)]
14pub enum EmbeddingError {
15    /// Http error (e.g.: connection error, timeout, etc.)
16    #[error("HttpError: {0}")]
17    HttpError(#[from] http_client::Error),
18
19    /// Json error (e.g.: serialization, deserialization)
20    #[error("JsonError: {0}")]
21    JsonError(#[from] serde_json::Error),
22
23    #[error("UrlError: {0}")]
24    UrlError(#[from] url::ParseError),
25
26    #[cfg(not(target_family = "wasm"))]
27    /// Error processing the document for embedding
28    #[error("DocumentError: {0}")]
29    DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
30
31    #[cfg(target_family = "wasm")]
32    /// Error processing the document for embedding
33    #[error("DocumentError: {0}")]
34    DocumentError(Box<dyn std::error::Error + 'static>),
35
36    /// Error parsing the completion response
37    #[error("ResponseError: {0}")]
38    ResponseError(String),
39
40    /// Error returned by the embedding model provider
41    #[error("ProviderError: {0}")]
42    ProviderError(String),
43}
44
45/// Trait for embedding models that can generate embeddings for documents.
46pub trait EmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
47    /// The maximum number of documents that can be embedded in a single request.
48    const MAX_DOCUMENTS: usize;
49
50    /// The number of dimensions in the embedding vector.
51    fn ndims(&self) -> usize;
52
53    /// Embed multiple text documents in a single request
54    fn embed_texts(
55        &self,
56        texts: impl IntoIterator<Item = String> + WasmCompatSend,
57    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + WasmCompatSend;
58
59    /// Embed a single text document.
60    fn embed_text(
61        &self,
62        text: &str,
63    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
64        async {
65            Ok(self
66                .embed_texts(vec![text.to_string()])
67                .await?
68                .pop()
69                .expect("There should be at least one embedding"))
70        }
71    }
72}
73
74pub trait EmbeddingModelDyn: WasmCompatSend + WasmCompatSync {
75    fn max_documents(&self) -> usize;
76    fn ndims(&self) -> usize;
77    fn embed_text<'a>(
78        &'a self,
79        text: &'a str,
80    ) -> WasmBoxedFuture<'a, Result<Embedding, EmbeddingError>>;
81    fn embed_texts(
82        &self,
83        texts: Vec<String>,
84    ) -> WasmBoxedFuture<'_, Result<Vec<Embedding>, EmbeddingError>>;
85}
86
87impl<T> EmbeddingModelDyn for T
88where
89    T: EmbeddingModel + WasmCompatSend + WasmCompatSync,
90{
91    fn max_documents(&self) -> usize {
92        T::MAX_DOCUMENTS
93    }
94
95    fn ndims(&self) -> usize {
96        self.ndims()
97    }
98
99    fn embed_text<'a>(
100        &'a self,
101        text: &'a str,
102    ) -> WasmBoxedFuture<'a, Result<Embedding, EmbeddingError>> {
103        Box::pin(self.embed_text(text))
104    }
105
106    fn embed_texts(
107        &self,
108        texts: Vec<String>,
109    ) -> WasmBoxedFuture<'_, Result<Vec<Embedding>, EmbeddingError>> {
110        Box::pin(self.embed_texts(texts.into_iter().collect::<Vec<_>>()))
111    }
112}
113
114/// Trait for embedding models that can generate embeddings for images.
115pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
116    /// The maximum number of images that can be embedded in a single request.
117    const MAX_DOCUMENTS: usize;
118
119    /// The number of dimensions in the embedding vector.
120    fn ndims(&self) -> usize;
121
122    /// Embed multiple images in a single request from bytes.
123    fn embed_images(
124        &self,
125        images: impl IntoIterator<Item = Vec<u8>> + WasmCompatSend,
126    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
127
128    /// Embed a single image from bytes.
129    fn embed_image<'a>(
130        &'a self,
131        bytes: &'a [u8],
132    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
133        async move {
134            Ok(self
135                .embed_images(vec![bytes.to_owned()])
136                .await?
137                .pop()
138                .expect("There should be at least one embedding"))
139        }
140    }
141}
142
143/// Struct that holds a single document and its embedding.
144#[derive(Clone, Default, Deserialize, Serialize, Debug)]
145pub struct Embedding {
146    /// The document that was embedded. Used for debugging.
147    pub document: String,
148    /// The embedding vector
149    pub vec: Vec<f64>,
150}
151
152impl PartialEq for Embedding {
153    fn eq(&self, other: &Self) -> bool {
154        self.document == other.document
155    }
156}
157
158impl Eq for Embedding {}