rig/embeddings/
embedding.rs1use crate::wasm_compat::WasmBoxedFuture;
10use crate::{http_client, wasm_compat::*};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, thiserror::Error)]
14pub enum EmbeddingError {
15 #[error("HttpError: {0}")]
17 HttpError(#[from] http_client::Error),
18
19 #[error("JsonError: {0}")]
21 JsonError(#[from] serde_json::Error),
22
23 #[error("UrlError: {0}")]
24 UrlError(#[from] url::ParseError),
25
26 #[cfg(not(target_family = "wasm"))]
27 #[error("DocumentError: {0}")]
29 DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
30
31 #[cfg(target_family = "wasm")]
32 #[error("DocumentError: {0}")]
34 DocumentError(Box<dyn std::error::Error + 'static>),
35
36 #[error("ResponseError: {0}")]
38 ResponseError(String),
39
40 #[error("ProviderError: {0}")]
42 ProviderError(String),
43}
44
45pub trait EmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
47 const MAX_DOCUMENTS: usize;
49
50 fn ndims(&self) -> usize;
52
53 fn embed_texts(
55 &self,
56 texts: impl IntoIterator<Item = String> + WasmCompatSend,
57 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + WasmCompatSend;
58
59 fn embed_text(
61 &self,
62 text: &str,
63 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
64 async {
65 Ok(self
66 .embed_texts(vec![text.to_string()])
67 .await?
68 .pop()
69 .expect("There should be at least one embedding"))
70 }
71 }
72}
73
74pub trait EmbeddingModelDyn: WasmCompatSend + WasmCompatSync {
75 fn max_documents(&self) -> usize;
76 fn ndims(&self) -> usize;
77 fn embed_text<'a>(
78 &'a self,
79 text: &'a str,
80 ) -> WasmBoxedFuture<'a, Result<Embedding, EmbeddingError>>;
81 fn embed_texts(
82 &self,
83 texts: Vec<String>,
84 ) -> WasmBoxedFuture<'_, Result<Vec<Embedding>, EmbeddingError>>;
85}
86
87impl<T> EmbeddingModelDyn for T
88where
89 T: EmbeddingModel + WasmCompatSend + WasmCompatSync,
90{
91 fn max_documents(&self) -> usize {
92 T::MAX_DOCUMENTS
93 }
94
95 fn ndims(&self) -> usize {
96 self.ndims()
97 }
98
99 fn embed_text<'a>(
100 &'a self,
101 text: &'a str,
102 ) -> WasmBoxedFuture<'a, Result<Embedding, EmbeddingError>> {
103 Box::pin(self.embed_text(text))
104 }
105
106 fn embed_texts(
107 &self,
108 texts: Vec<String>,
109 ) -> WasmBoxedFuture<'_, Result<Vec<Embedding>, EmbeddingError>> {
110 Box::pin(self.embed_texts(texts.into_iter().collect::<Vec<_>>()))
111 }
112}
113
114pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
116 const MAX_DOCUMENTS: usize;
118
119 fn ndims(&self) -> usize;
121
122 fn embed_images(
124 &self,
125 images: impl IntoIterator<Item = Vec<u8>> + WasmCompatSend,
126 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
127
128 fn embed_image<'a>(
130 &'a self,
131 bytes: &'a [u8],
132 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
133 async move {
134 Ok(self
135 .embed_images(vec![bytes.to_owned()])
136 .await?
137 .pop()
138 .expect("There should be at least one embedding"))
139 }
140 }
141}
142
143#[derive(Clone, Default, Deserialize, Serialize, Debug)]
145pub struct Embedding {
146 pub document: String,
148 pub vec: Vec<f64>,
150}
151
152impl PartialEq for Embedding {
153 fn eq(&self, other: &Self) -> bool {
154 self.document == other.document
155 }
156}
157
158impl Eq for Embedding {}