rig/embeddings/
embedding.rs1use crate::wasm_compat::WasmBoxedFuture;
10use crate::{http_client, wasm_compat::*};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, thiserror::Error)]
14pub enum EmbeddingError {
15 #[error("HttpError: {0}")]
17 HttpError(#[from] http_client::Error),
18
19 #[error("JsonError: {0}")]
21 JsonError(#[from] serde_json::Error),
22
23 #[error("UrlError: {0}")]
24 UrlError(#[from] url::ParseError),
25
26 #[cfg(not(target_family = "wasm"))]
27 #[error("DocumentError: {0}")]
29 DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
30
31 #[cfg(target_family = "wasm")]
32 #[error("DocumentError: {0}")]
34 DocumentError(Box<dyn std::error::Error + 'static>),
35
36 #[error("ResponseError: {0}")]
38 ResponseError(String),
39
40 #[error("ProviderError: {0}")]
42 ProviderError(String),
43}
44
45pub trait EmbeddingModel: WasmCompatSend + WasmCompatSync {
47 const MAX_DOCUMENTS: usize;
49
50 type Client;
51
52 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self;
53
54 fn ndims(&self) -> usize;
56
57 fn embed_texts(
59 &self,
60 texts: impl IntoIterator<Item = String> + WasmCompatSend,
61 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + WasmCompatSend;
62
63 fn embed_text(
65 &self,
66 text: &str,
67 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
68 async {
69 Ok(self
70 .embed_texts(vec![text.to_string()])
71 .await?
72 .pop()
73 .expect("There should be at least one embedding"))
74 }
75 }
76}
77
78#[deprecated(
79 since = "0.25.0",
80 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `EmbeddingModel` instead."
81)]
82pub trait EmbeddingModelDyn: WasmCompatSend + WasmCompatSync {
83 fn max_documents(&self) -> usize;
84 fn ndims(&self) -> usize;
85 fn embed_text<'a>(
86 &'a self,
87 text: &'a str,
88 ) -> WasmBoxedFuture<'a, Result<Embedding, EmbeddingError>>;
89 fn embed_texts(
90 &self,
91 texts: Vec<String>,
92 ) -> WasmBoxedFuture<'_, Result<Vec<Embedding>, EmbeddingError>>;
93}
94
95#[allow(deprecated)]
96impl<T> EmbeddingModelDyn for T
97where
98 T: EmbeddingModel + WasmCompatSend + WasmCompatSync,
99{
100 fn max_documents(&self) -> usize {
101 T::MAX_DOCUMENTS
102 }
103
104 fn ndims(&self) -> usize {
105 self.ndims()
106 }
107
108 fn embed_text<'a>(
109 &'a self,
110 text: &'a str,
111 ) -> WasmBoxedFuture<'a, Result<Embedding, EmbeddingError>> {
112 Box::pin(self.embed_text(text))
113 }
114
115 fn embed_texts(
116 &self,
117 texts: Vec<String>,
118 ) -> WasmBoxedFuture<'_, Result<Vec<Embedding>, EmbeddingError>> {
119 Box::pin(self.embed_texts(texts.into_iter().collect::<Vec<_>>()))
120 }
121}
122
123pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
125 const MAX_DOCUMENTS: usize;
127
128 fn ndims(&self) -> usize;
130
131 fn embed_images(
133 &self,
134 images: impl IntoIterator<Item = Vec<u8>> + WasmCompatSend,
135 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
136
137 fn embed_image<'a>(
139 &'a self,
140 bytes: &'a [u8],
141 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
142 async move {
143 Ok(self
144 .embed_images(vec![bytes.to_owned()])
145 .await?
146 .pop()
147 .expect("There should be at least one embedding"))
148 }
149 }
150}
151
152#[derive(Clone, Default, Deserialize, Serialize, Debug)]
154pub struct Embedding {
155 pub document: String,
157 pub vec: Vec<f64>,
159}
160
161impl PartialEq for Embedding {
162 fn eq(&self, other: &Self) -> bool {
163 self.document == other.document
164 }
165}
166
167impl Eq for Embedding {}