1use std::sync::Arc;
4
5use async_trait::async_trait;
6use secrecy::ExposeSecret;
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9
10use entelix_core::auth::CredentialProvider;
11use entelix_core::context::ExecutionContext;
12use entelix_core::error::{Error, Result};
13use entelix_memory::{Embedder, Embedding, EmbeddingUsage};
14
15use crate::error::{OpenAiEmbedderError, OpenAiEmbedderResult};
16
17pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
21
22pub const TEXT_EMBEDDING_3_SMALL_DIMENSION: usize = 1536;
24
25pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
28
29pub const TEXT_EMBEDDING_3_LARGE_DIMENSION: usize = 3072;
31
32pub const DEFAULT_BASE_URL: &str = "https://api.openai.com";
36
37#[derive(Clone)]
41pub struct OpenAiEmbedder {
42 client: reqwest::Client,
43 base_url: Arc<str>,
44 credentials: Arc<dyn CredentialProvider>,
45 model: Arc<str>,
46 dimension: usize,
47 dimension_override: Option<usize>,
52}
53
54impl std::fmt::Debug for OpenAiEmbedder {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("OpenAiEmbedder")
57 .field("base_url", &self.base_url)
58 .field("model", &self.model)
59 .field("dimension", &self.dimension)
60 .field("dimension_override", &self.dimension_override)
61 .finish_non_exhaustive()
62 }
63}
64
65impl OpenAiEmbedder {
66 pub fn small() -> OpenAiEmbedderBuilder {
69 OpenAiEmbedderBuilder::new(TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_3_SMALL_DIMENSION)
70 }
71
72 pub fn large() -> OpenAiEmbedderBuilder {
75 OpenAiEmbedderBuilder::new(TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_LARGE_DIMENSION)
76 }
77
78 pub fn custom(model: impl Into<String>, dimension: usize) -> OpenAiEmbedderBuilder {
83 OpenAiEmbedderBuilder::new(model, dimension)
84 }
85
86 fn embeddings_url(&self) -> String {
87 format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'))
88 }
89
90 async fn call(&self, inputs: Vec<String>) -> OpenAiEmbedderResult<Vec<Embedding>> {
97 let credentials = self
98 .credentials
99 .resolve()
100 .await
101 .map_err(OpenAiEmbedderError::Credential)?;
102
103 let body = self.build_request_body(&inputs);
104 let response = self
105 .client
106 .post(self.embeddings_url())
107 .header(
108 credentials.header_name.clone(),
109 http::HeaderValue::from_str(credentials.header_value.expose_secret()).map_err(
110 |e| OpenAiEmbedderError::Config(format!("invalid credential header: {e}")),
111 )?,
112 )
113 .json(&body)
114 .send()
115 .await
116 .map_err(OpenAiEmbedderError::network)?;
117
118 let status = response.status();
119 if !status.is_success() {
120 let body = response.text().await.unwrap_or_default();
121 return Err(OpenAiEmbedderError::HttpStatus {
122 status: status.as_u16(),
123 body: truncate_for_error(&body),
124 });
125 }
126
127 let parsed: EmbeddingsResponse = response
128 .json()
129 .await
130 .map_err(OpenAiEmbedderError::network)?;
131 self.decode(&parsed, inputs.len())
132 }
133
134 fn build_request_body(&self, inputs: &[String]) -> serde_json::Value {
135 let mut body = json!({
136 "model": &*self.model,
137 "input": inputs,
138 "encoding_format": "float",
139 });
140 if let Some(dim) = self.dimension_override
141 && let Some(obj) = body.as_object_mut()
142 {
143 obj.insert("dimensions".into(), json!(dim));
144 }
145 body
146 }
147
148 fn decode(
149 &self,
150 parsed: &EmbeddingsResponse,
151 expected_len: usize,
152 ) -> OpenAiEmbedderResult<Vec<Embedding>> {
153 if parsed.data.len() != expected_len {
154 return Err(OpenAiEmbedderError::Malformed(format!(
155 "expected {expected_len} embeddings, server returned {}",
156 parsed.data.len()
157 )));
158 }
159 let mut sorted: Vec<&EmbeddingsDataItem> = parsed.data.iter().collect();
162 sorted.sort_by_key(|d| d.index);
163
164 let usage = parsed.usage.map(|u| EmbeddingUsage::new(u.prompt_tokens));
165 let mut out = Vec::with_capacity(expected_len);
166 for (i, item) in sorted.iter().enumerate() {
167 if item.embedding.len() != self.dimension {
168 return Err(OpenAiEmbedderError::Malformed(format!(
169 "embedding {} dimension {} does not match configured {}",
170 i,
171 item.embedding.len(),
172 self.dimension
173 )));
174 }
175 let mut emb = Embedding::new(item.embedding.clone());
179 if i == 0
180 && let Some(u) = usage
181 {
182 emb = emb.with_usage(u);
183 }
184 out.push(emb);
185 }
186 Ok(out)
187 }
188}
189
190#[async_trait]
191impl Embedder for OpenAiEmbedder {
192 fn dimension(&self) -> usize {
193 self.dimension
194 }
195
196 async fn embed(&self, text: &str, ctx: &ExecutionContext) -> Result<Embedding> {
197 if ctx.is_cancelled() {
198 return Err(Error::Cancelled);
199 }
200 let mut out = self
201 .call(vec![text.to_owned()])
202 .await
203 .map_err(Error::from)?;
204 out.pop()
205 .ok_or_else(|| Error::provider_network("OpenAI returned no embedding".to_owned()))
206 }
207
208 async fn embed_batch(
209 &self,
210 texts: &[String],
211 ctx: &ExecutionContext,
212 ) -> Result<Vec<Embedding>> {
213 if ctx.is_cancelled() {
214 return Err(Error::Cancelled);
215 }
216 if texts.is_empty() {
217 return Ok(Vec::new());
218 }
219 self.call(texts.to_vec()).await.map_err(Error::from)
222 }
223}
224
225#[must_use]
227pub struct OpenAiEmbedderBuilder {
228 model: String,
229 dimension: usize,
230 dimension_override: Option<usize>,
231 base_url: String,
232 credentials: Option<Arc<dyn CredentialProvider>>,
233 client: Option<reqwest::Client>,
234}
235
236impl OpenAiEmbedderBuilder {
237 fn new(model: impl Into<String>, native_dimension: usize) -> Self {
238 Self {
239 model: model.into(),
240 dimension: native_dimension,
241 dimension_override: None,
242 base_url: DEFAULT_BASE_URL.to_owned(),
243 credentials: None,
244 client: None,
245 }
246 }
247
248 pub fn with_credentials(mut self, credentials: Arc<dyn CredentialProvider>) -> Self {
250 self.credentials = Some(credentials);
251 self
252 }
253
254 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
258 self.base_url = url.into();
259 self
260 }
261
262 pub const fn with_dimension(mut self, dimension: usize) -> Self {
268 self.dimension_override = Some(dimension);
269 self.dimension = dimension;
270 self
271 }
272
273 pub fn with_client(mut self, client: reqwest::Client) -> Self {
277 self.client = Some(client);
278 self
279 }
280
281 pub fn build(self) -> OpenAiEmbedderResult<OpenAiEmbedder> {
285 let credentials = self
286 .credentials
287 .ok_or_else(|| OpenAiEmbedderError::Config("credentials required".into()))?;
288 if self.dimension == 0 {
289 return Err(OpenAiEmbedderError::Config("dimension must be > 0".into()));
290 }
291 let client = self.client.unwrap_or_default();
292 Ok(OpenAiEmbedder {
293 client,
294 base_url: self.base_url.into(),
295 credentials,
296 model: self.model.into(),
297 dimension: self.dimension,
298 dimension_override: self.dimension_override,
299 })
300 }
301}
302
303#[derive(Debug, Deserialize)]
306struct EmbeddingsResponse {
307 data: Vec<EmbeddingsDataItem>,
308 #[serde(default)]
309 usage: Option<EmbeddingsUsageItem>,
310}
311
312#[derive(Debug, Deserialize)]
313struct EmbeddingsDataItem {
314 embedding: Vec<f32>,
315 index: u32,
316}
317
318#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize)]
319struct EmbeddingsUsageItem {
320 prompt_tokens: u32,
321}
322
323const ERROR_BODY_TRUNCATION_BYTES: usize = 512;
324
325fn truncate_for_error(body: &str) -> String {
326 if body.len() <= ERROR_BODY_TRUNCATION_BYTES {
327 return body.to_owned();
328 }
329 let mut cut = ERROR_BODY_TRUNCATION_BYTES;
330 while cut > 0 && !body.is_char_boundary(cut) {
331 cut -= 1;
332 }
333 format!("{}… ({} bytes truncated)", &body[..cut], body.len() - cut)
334}
335
336#[cfg(test)]
337#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
338mod tests {
339 use super::*;
340 use entelix_core::auth::ApiKeyProvider;
341
342 fn provider() -> Arc<dyn CredentialProvider> {
343 Arc::new(ApiKeyProvider::new("authorization", "Bearer test").unwrap())
344 }
345
346 #[test]
347 fn small_builder_defaults_to_native_dimension() {
348 let e = OpenAiEmbedder::small()
349 .with_credentials(provider())
350 .build()
351 .unwrap();
352 assert_eq!(e.dimension(), TEXT_EMBEDDING_3_SMALL_DIMENSION);
353 assert_eq!(&*e.model, TEXT_EMBEDDING_3_SMALL);
354 }
355
356 #[test]
357 fn large_builder_defaults_to_native_dimension() {
358 let e = OpenAiEmbedder::large()
359 .with_credentials(provider())
360 .build()
361 .unwrap();
362 assert_eq!(e.dimension(), TEXT_EMBEDDING_3_LARGE_DIMENSION);
363 }
364
365 #[test]
366 fn dimension_override_threads_into_request_body() {
367 let e = OpenAiEmbedder::small()
368 .with_credentials(provider())
369 .with_dimension(512)
370 .build()
371 .unwrap();
372 assert_eq!(e.dimension(), 512);
373 let body = e.build_request_body(&["hi".to_owned()]);
374 assert_eq!(body["dimensions"], 512);
375 }
376
377 #[test]
378 fn missing_credentials_rejected_at_build() {
379 let err = OpenAiEmbedder::small().build().unwrap_err();
380 assert!(matches!(err, OpenAiEmbedderError::Config(_)));
381 }
382
383 #[test]
384 fn zero_dimension_rejected_at_build() {
385 let err = OpenAiEmbedder::custom("custom-model", 0)
386 .with_credentials(provider())
387 .build()
388 .unwrap_err();
389 assert!(matches!(err, OpenAiEmbedderError::Config(_)));
390 }
391
392 #[test]
393 fn embeddings_url_strips_trailing_slash() {
394 let e = OpenAiEmbedder::small()
395 .with_credentials(provider())
396 .with_base_url("https://example.test/")
397 .build()
398 .unwrap();
399 assert_eq!(e.embeddings_url(), "https://example.test/v1/embeddings");
400 }
401
402 #[test]
403 fn decode_attributes_usage_to_first_slot_only() {
404 let e = OpenAiEmbedder::custom("test-model", 3)
405 .with_credentials(provider())
406 .build()
407 .unwrap();
408 let parsed = EmbeddingsResponse {
409 data: vec![
410 EmbeddingsDataItem {
411 embedding: vec![0.1, 0.2, 0.3],
412 index: 0,
413 },
414 EmbeddingsDataItem {
415 embedding: vec![0.4, 0.5, 0.6],
416 index: 1,
417 },
418 ],
419 usage: Some(EmbeddingsUsageItem { prompt_tokens: 7 }),
420 };
421 let out = e.decode(&parsed, 2).unwrap();
422 assert_eq!(out.len(), 2);
423 assert_eq!(out[0].usage, Some(EmbeddingUsage::new(7)));
424 assert!(
425 out[1].usage.is_none(),
426 "usage must NOT replicate across slots"
427 );
428 }
429
430 #[test]
431 fn decode_sorts_by_index_when_response_order_shuffled() {
432 let e = OpenAiEmbedder::custom("test-model", 2)
433 .with_credentials(provider())
434 .build()
435 .unwrap();
436 let parsed = EmbeddingsResponse {
437 data: vec![
438 EmbeddingsDataItem {
439 embedding: vec![0.9, 0.9],
440 index: 1,
441 },
442 EmbeddingsDataItem {
443 embedding: vec![0.1, 0.1],
444 index: 0,
445 },
446 ],
447 usage: None,
448 };
449 let out = e.decode(&parsed, 2).unwrap();
450 assert_eq!(out[0].vector, vec![0.1, 0.1]);
451 assert_eq!(out[1].vector, vec![0.9, 0.9]);
452 }
453
454 #[test]
455 fn decode_rejects_dimension_mismatch() {
456 let e = OpenAiEmbedder::custom("test-model", 3)
457 .with_credentials(provider())
458 .build()
459 .unwrap();
460 let parsed = EmbeddingsResponse {
461 data: vec![EmbeddingsDataItem {
462 embedding: vec![0.1, 0.2], index: 0,
464 }],
465 usage: None,
466 };
467 let err = e.decode(&parsed, 1).unwrap_err();
468 assert!(matches!(err, OpenAiEmbedderError::Malformed(_)));
469 }
470
471 #[test]
472 fn decode_rejects_count_mismatch() {
473 let e = OpenAiEmbedder::custom("test-model", 1)
474 .with_credentials(provider())
475 .build()
476 .unwrap();
477 let parsed = EmbeddingsResponse {
478 data: vec![EmbeddingsDataItem {
479 embedding: vec![0.1],
480 index: 0,
481 }],
482 usage: None,
483 };
484 let err = e.decode(&parsed, 2).unwrap_err();
485 assert!(matches!(err, OpenAiEmbedderError::Malformed(_)));
486 }
487
488 #[test]
489 fn truncate_for_error_caps_oversized_body() {
490 let huge = "x".repeat(10_000);
491 let truncated = truncate_for_error(&huge);
492 assert!(truncated.contains("truncated"));
493 assert!(truncated.len() < 1000);
494 }
495}