rig/embeddings/
embedding.rs1use futures::future::BoxFuture;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, thiserror::Error)]
13pub enum EmbeddingError {
14 #[error("HttpError: {0}")]
16 HttpError(#[from] reqwest::Error),
17
18 #[error("JsonError: {0}")]
20 JsonError(#[from] serde_json::Error),
21
22 #[error("UrlError: {0}")]
23 UrlError(#[from] url::ParseError),
24
25 #[error("DocumentError: {0}")]
27 DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
28
29 #[error("ResponseError: {0}")]
31 ResponseError(String),
32
33 #[error("ProviderError: {0}")]
35 ProviderError(String),
36}
37
38pub trait EmbeddingModel: Clone + Sync + Send {
40 const MAX_DOCUMENTS: usize;
42
43 fn ndims(&self) -> usize;
45
46 fn embed_texts(
48 &self,
49 texts: impl IntoIterator<Item = String> + Send,
50 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
51
52 fn embed_text(
54 &self,
55 text: &str,
56 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
57 async {
58 Ok(self
59 .embed_texts(vec![text.to_string()])
60 .await?
61 .pop()
62 .expect("There should be at least one embedding"))
63 }
64 }
65}
66
67pub trait EmbeddingModelDyn: Sync + Send {
68 fn max_documents(&self) -> usize;
69 fn ndims(&self) -> usize;
70 fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Embedding, EmbeddingError>>;
71 fn embed_texts(
72 &self,
73 texts: Vec<String>,
74 ) -> BoxFuture<'_, Result<Vec<Embedding>, EmbeddingError>>;
75}
76
77impl<T> EmbeddingModelDyn for T
78where
79 T: EmbeddingModel,
80{
81 fn max_documents(&self) -> usize {
82 T::MAX_DOCUMENTS
83 }
84
85 fn ndims(&self) -> usize {
86 self.ndims()
87 }
88
89 fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result<Embedding, EmbeddingError>> {
90 Box::pin(self.embed_text(text))
91 }
92
93 fn embed_texts(
94 &self,
95 texts: Vec<String>,
96 ) -> BoxFuture<'_, Result<Vec<Embedding>, EmbeddingError>> {
97 Box::pin(self.embed_texts(texts.into_iter().collect::<Vec<_>>()))
98 }
99}
100
101pub trait ImageEmbeddingModel: Clone + Sync + Send {
103 const MAX_DOCUMENTS: usize;
105
106 fn ndims(&self) -> usize;
108
109 fn embed_images(
111 &self,
112 images: impl IntoIterator<Item = Vec<u8>> + Send,
113 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
114
115 fn embed_image<'a>(
117 &'a self,
118 bytes: &'a [u8],
119 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + Send {
120 async move {
121 Ok(self
122 .embed_images(vec![bytes.to_owned()])
123 .await?
124 .pop()
125 .expect("There should be at least one embedding"))
126 }
127 }
128}
129
130#[derive(Clone, Default, Deserialize, Serialize, Debug)]
132pub struct Embedding {
133 pub document: String,
135 pub vec: Vec<f64>,
137}
138
139impl PartialEq for Embedding {
140 fn eq(&self, other: &Self) -> bool {
141 self.document == other.document
142 }
143}
144
145impl Eq for Embedding {}