rig_core/embeddings/
embedding.rs1use crate::{
10 completion::Usage,
11 http_client,
12 wasm_compat::{WasmCompatSend, WasmCompatSync},
13};
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, thiserror::Error)]
17pub enum EmbeddingError {
18 #[error("HttpError: {0}")]
20 HttpError(#[from] http_client::Error),
21
22 #[error("JsonError: {0}")]
24 JsonError(#[from] serde_json::Error),
25
26 #[error("UrlError: {0}")]
28 UrlError(#[from] url::ParseError),
29
30 #[cfg(not(target_family = "wasm"))]
31 #[error("DocumentError: {0}")]
33 DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),
34
35 #[cfg(target_family = "wasm")]
36 #[error("DocumentError: {0}")]
38 DocumentError(Box<dyn std::error::Error + 'static>),
39
40 #[error("ResponseError: {0}")]
42 ResponseError(String),
43
44 #[error("ProviderError: {0}")]
46 ProviderError(String),
47}
48
49pub trait EmbeddingModel: WasmCompatSend + WasmCompatSync {
51 const MAX_DOCUMENTS: usize;
53
54 type Client;
56
57 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self;
59
60 fn ndims(&self) -> usize;
62
63 fn embed_texts(
65 &self,
66 texts: impl IntoIterator<Item = String> + WasmCompatSend,
67 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + WasmCompatSend;
68
69 fn embed_text(
71 &self,
72 text: &str,
73 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
74 async {
75 let mut embeddings = self.embed_texts(vec![text.to_string()]).await?;
76 embeddings.pop().ok_or_else(|| {
77 EmbeddingError::ResponseError(
78 "embedding provider returned an empty response for embed_text".to_string(),
79 )
80 })
81 }
82 }
83
84 fn embed_texts_with_usage(
90 &self,
91 texts: impl IntoIterator<Item = String> + WasmCompatSend,
92 ) -> impl std::future::Future<Output = Result<EmbeddingResponse, EmbeddingError>> + WasmCompatSend
93 {
94 async {
95 let embeddings = self.embed_texts(texts).await?;
96 Ok(EmbeddingResponse {
97 embeddings,
98 usage: Usage::default(),
99 })
100 }
101 }
102
103 fn embed_text_with_usage(
108 &self,
109 text: &str,
110 ) -> impl std::future::Future<Output = Result<EmbeddingResponse, EmbeddingError>> + WasmCompatSend
111 {
112 async {
113 let response = self.embed_texts_with_usage(vec![text.to_string()]).await?;
114 if response.embeddings.is_empty() {
115 return Err(EmbeddingError::ResponseError(
116 "embedding provider returned an empty response for embed_text_with_usage"
117 .to_string(),
118 ));
119 }
120 Ok(response)
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct EmbeddingResponse {
128 pub embeddings: Vec<Embedding>,
130 pub usage: Usage,
132}
133
134pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync {
136 const MAX_DOCUMENTS: usize;
138
139 fn ndims(&self) -> usize;
141
142 fn embed_images(
146 &self,
147 images: impl IntoIterator<Item = Vec<u8>> + WasmCompatSend,
148 ) -> impl std::future::Future<Output = Result<Vec<Embedding>, EmbeddingError>> + Send;
149
150 fn embed_image<'a>(
152 &'a self,
153 bytes: &'a [u8],
154 ) -> impl std::future::Future<Output = Result<Embedding, EmbeddingError>> + WasmCompatSend {
155 async move {
156 let mut embeddings = self.embed_images(vec![bytes.to_owned()]).await?;
157 embeddings.pop().ok_or_else(|| {
158 EmbeddingError::ResponseError(
159 "embedding provider returned an empty response for embed_image".to_string(),
160 )
161 })
162 }
163 }
164}
165
166#[derive(Clone, Default, Deserialize, Serialize, Debug)]
168pub struct Embedding {
169 pub document: String,
171 pub vec: Vec<f64>,
173}
174
175impl PartialEq for Embedding {
176 fn eq(&self, other: &Self) -> bool {
177 self.document == other.document
178 }
179}
180
181impl Eq for Embedding {}