rig/embeddings/
embedding.rs

1//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can
2//! generate embeddings for documents.
3//!
4//! The module also defines the [Embedding] struct, which represents a single document embedding.
5//!
6//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that
7//! can occur during embedding generation or processing.
8
9use crate::wasm_compat::WasmBoxedFuture;
10use crate::{http_client, wasm_compat::*};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, thiserror::Error)]
14pub enum EmbeddingError {
15    /// Http error (e.g.: connection error, timeout, etc.)
16    #[error("HttpError: {0}")]
17    HttpError(#[from] http_client::Error),
18
19    /// Json error (e.g.: serialization, deserialization)
20    #[error("JsonError: {0}")]
21    JsonError(#[from] serde_json::Error),
22
23    #[error("UrlError: {0}")]
24    UrlError(#[from] url::ParseError),
25
26    #[cfg(not(target_family = "wasm"))]
27    /// Error processing the document for embedding
28    #[error("DocumentError: {0}")]
29    DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
30
31    #[cfg(target_family = "wasm")]
32    /// Error processing the document for embedding
33    #[error("DocumentError: {0}")]
34    DocumentError(Box<dyn std::error::Error + 'static>),
35
36    /// Error parsing the completion response
37    #[error("ResponseError: {0}")]
38    ResponseError(String),
39
40    /// Error returned by the embedding model provider
41    #[error("ProviderError: {0}")]
42    ProviderError(String),
43}
44
45/// Trait for embedding models that can generate embeddings for documents.
46pub trait EmbeddingModel: WasmCompatSend + WasmCompatSync {
47    /// The maximum number of documents that can be embedded in a single request.
48    const MAX_DOCUMENTS: usize;
49
50    type Client;
51
52    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self;
53
54    /// The number of dimensions in the embedding vector.
55    fn ndims(&self) -> usize;
56
57    /// Embed multiple text documents in a single request
58    fn embed_texts(
59        &self,
60        texts: impl IntoIterator<Item = String> + WasmCompatSend,
61    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + WasmCompatSend;
62
63    /// Embed a single text document.
64    fn embed_text(
65        &self,
66        text: &str,
67    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
68        async {
69            Ok(self
70                .embed_texts(vec![text.to_string()])
71                .await?
72                .pop()
73                .expect("There should be at least one embedding"))
74        }
75    }
76}
77
78#[deprecated(
79    since = "0.25.0",
80    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `EmbeddingModel` instead."
81)]
82pub trait EmbeddingModelDyn: WasmCompatSend + WasmCompatSync {
83    fn max_documents(&self) -> usize;
84    fn ndims(&self) -> usize;
85    fn embed_text<'a>(
86        &'a self,
87        text: &'a str,
88    ) -> WasmBoxedFuture<'a, Result<Embedding, EmbeddingError>>;
89    fn embed_texts(
90        &self,
91        texts: Vec<String>,
92    ) -> WasmBoxedFuture<'_, Result<Vec<Embedding>, EmbeddingError>>;
93}
94
95#[allow(deprecated)]
96impl<T> EmbeddingModelDyn for T
97where
98    T: EmbeddingModel + WasmCompatSend + WasmCompatSync,
99{
100    fn max_documents(&self) -> usize {
101        T::MAX_DOCUMENTS
102    }
103
104    fn ndims(&self) -> usize {
105        self.ndims()
106    }
107
108    fn embed_text<'a>(
109        &'a self,
110        text: &'a str,
111    ) -> WasmBoxedFuture<'a, Result<Embedding, EmbeddingError>> {
112        Box::pin(self.embed_text(text))
113    }
114
115    fn embed_texts(
116        &self,
117        texts: Vec<String>,
118    ) -> WasmBoxedFuture<'_, Result<Vec<Embedding>, EmbeddingError>> {
119        Box::pin(self.embed_texts(texts.into_iter().collect::<Vec<_>>()))
120    }
121}
122
123/// Trait for embedding models that can generate embeddings for images.
124pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
125    /// The maximum number of images that can be embedded in a single request.
126    const MAX_DOCUMENTS: usize;
127
128    /// The number of dimensions in the embedding vector.
129    fn ndims(&self) -> usize;
130
131    /// Embed multiple images in a single request from bytes.
132    fn embed_images(
133        &self,
134        images: impl IntoIterator<Item = Vec<u8>> + WasmCompatSend,
135    ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
136
137    /// Embed a single image from bytes.
138    fn embed_image<'a>(
139        &'a self,
140        bytes: &'a [u8],
141    ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
142        async move {
143            Ok(self
144                .embed_images(vec![bytes.to_owned()])
145                .await?
146                .pop()
147                .expect("There should be at least one embedding"))
148        }
149    }
150}
151
152/// Struct that holds a single document and its embedding.
153#[derive(Clone, Default, Deserialize, Serialize, Debug)]
154pub struct Embedding {
155    /// The document that was embedded. Used for debugging.
156    pub document: String,
157    /// The embedding vector
158    pub vec: Vec<f64>,
159}
160
161impl PartialEq for Embedding {
162    fn eq(&self, other: &Self) -> bool {
163        self.document == other.document
164    }
165}
166
167impl Eq for Embedding {}