rig_core/providers/openai/
embedding.rs1use super::{
2 client::{ApiErrorResponse, ApiResponse},
3 completion::Usage,
4};
5use crate::embeddings::EmbeddingError;
6use crate::http_client::HttpClientExt;
7use crate::{embeddings, http_client};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10
11pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
16pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
18pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
20
21#[derive(Debug, Deserialize)]
22pub struct EmbeddingResponse {
23 pub object: String,
24 pub data: Vec<EmbeddingData>,
25 pub model: String,
26 pub usage: Usage,
27}
28
29impl From<ApiErrorResponse> for EmbeddingError {
30 fn from(err: ApiErrorResponse) -> Self {
31 EmbeddingError::ProviderError(err.message)
32 }
33}
34
35impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
36 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
37 match value {
38 ApiResponse::Ok(response) => Ok(response),
39 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
40 }
41 }
42}
43
44#[derive(Debug, Deserialize, Clone, Serialize)]
45#[serde(rename_all = "snake_case")]
46pub enum EncodingFormat {
47 Float,
48 Base64,
49}
50
51#[derive(Debug, Deserialize)]
52pub struct EmbeddingData {
53 pub object: String,
54 pub embedding: Vec<serde_json::Number>,
55 pub index: usize,
56}
57
58#[doc(hidden)]
59#[derive(Clone)]
60pub struct GenericEmbeddingModel<Ext = super::OpenAIResponsesExt, H = reqwest::Client> {
61 client: crate::client::Client<Ext, H>,
62 pub model: String,
63 pub encoding_format: Option<EncodingFormat>,
64 pub user: Option<String>,
65 ndims: usize,
66}
67
68pub type EmbeddingModel<H = reqwest::Client> = GenericEmbeddingModel<super::OpenAIResponsesExt, H>;
73
74fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
75 match identifier {
76 TEXT_EMBEDDING_3_LARGE => Some(3_072),
77 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
78 _ => None,
79 }
80}
81
82impl<Ext, H> embeddings::EmbeddingModel for GenericEmbeddingModel<Ext, H>
83where
84 crate::client::Client<Ext, H>: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
85 Ext: crate::client::Provider + Clone + 'static,
86 H: Clone + Default + std::fmt::Debug + 'static,
87{
88 const MAX_DOCUMENTS: usize = 1024;
89
90 type Client = crate::client::Client<Ext, H>;
91
92 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
93 let model = model.into();
94 let dims = ndims
95 .or(model_dimensions_from_identifier(&model))
96 .unwrap_or_default();
97
98 Self::new(client.clone(), model, dims)
99 }
100
101 fn ndims(&self) -> usize {
102 self.ndims
103 }
104
105 async fn embed_texts(
106 &self,
107 documents: impl IntoIterator<Item = String>,
108 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
109 let documents: Vec<String> = documents.into_iter().collect();
110 let response = self.embed_texts_with_usage(documents).await?;
111 Ok(response.embeddings)
112 }
113
114 async fn embed_texts_with_usage(
115 &self,
116 documents: impl IntoIterator<Item = String>,
117 ) -> Result<embeddings::EmbeddingResponse, EmbeddingError> {
118 let documents: Vec<String> = documents.into_iter().collect();
119
120 let mut body = json!({
121 "model": self.model,
122 "input": documents,
123 });
124
125 let body_object = body.as_object_mut().ok_or_else(|| {
126 EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
127 })?;
128
129 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
130 body_object.insert("dimensions".to_owned(), json!(self.ndims));
131 }
132
133 if let Some(encoding_format) = &self.encoding_format {
134 body_object.insert("encoding_format".to_owned(), json!(encoding_format));
135 }
136
137 if let Some(user) = &self.user {
138 body_object.insert("user".to_owned(), json!(user));
139 }
140
141 let body = serde_json::to_vec(&body)?;
142
143 let req = self
144 .client
145 .post("/embeddings")?
146 .body(body)
147 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
148
149 let response = self.client.send(req).await?;
150
151 if response.status().is_success() {
152 let body: Vec<u8> = response.into_body().await?;
153 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
154
155 match body {
156 ApiResponse::Ok(response) => {
157 tracing::info!(target: "rig",
158 "OpenAI embedding token usage: {:?}",
159 response.usage
160 );
161
162 if response.data.len() != documents.len() {
163 return Err(EmbeddingError::ResponseError(
164 "Response data length does not match input length".into(),
165 ));
166 }
167
168 let usage = crate::completion::Usage {
169 input_tokens: response.usage.prompt_tokens as u64,
170 output_tokens: 0,
171 total_tokens: response.usage.total_tokens as u64,
172 cached_input_tokens: response
173 .usage
174 .prompt_tokens_details
175 .as_ref()
176 .map_or(0, |d| d.cached_tokens as u64),
177 cache_creation_input_tokens: 0,
178 tool_use_prompt_tokens: 0,
179 reasoning_tokens: 0,
180 };
181
182 let embeddings = response
183 .data
184 .into_iter()
185 .zip(documents.into_iter())
186 .map(|(embedding, document)| embeddings::Embedding {
187 document,
188 vec: embedding
189 .embedding
190 .into_iter()
191 .filter_map(|n| n.as_f64())
192 .collect(),
193 })
194 .collect();
195
196 Ok(embeddings::EmbeddingResponse { embeddings, usage })
197 }
198 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
199 }
200 } else {
201 let text = http_client::text(response).await?;
202 Err(EmbeddingError::ProviderError(text))
203 }
204 }
205}
206
207impl<Ext, H> GenericEmbeddingModel<Ext, H>
208where
209 Ext: crate::client::Provider,
210{
211 pub fn new(
212 client: crate::client::Client<Ext, H>,
213 model: impl Into<String>,
214 ndims: usize,
215 ) -> Self {
216 Self {
217 client,
218 model: model.into(),
219 encoding_format: None,
220 ndims,
221 user: None,
222 }
223 }
224
225 pub fn with_model(client: crate::client::Client<Ext, H>, model: &str, ndims: usize) -> Self {
226 Self {
227 client,
228 model: model.into(),
229 encoding_format: None,
230 ndims,
231 user: None,
232 }
233 }
234
235 pub fn with_encoding_format(
236 client: crate::client::Client<Ext, H>,
237 model: &str,
238 ndims: usize,
239 encoding_format: EncodingFormat,
240 ) -> Self {
241 Self {
242 client,
243 model: model.into(),
244 encoding_format: Some(encoding_format),
245 ndims,
246 user: None,
247 }
248 }
249
250 pub fn encoding_format(mut self, encoding_format: EncodingFormat) -> Self {
251 self.encoding_format = Some(encoding_format);
252 self
253 }
254
255 pub fn user(mut self, user: impl Into<String>) -> Self {
256 self.user = Some(user.into());
257 self
258 }
259}