Skip to main content

entelix_memory_openai/
embedder.rs

1//! `OpenAiEmbedder` — concrete `Embedder` over OpenAI `/v1/embeddings`.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use secrecy::ExposeSecret;
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9
10use entelix_core::auth::CredentialProvider;
11use entelix_core::context::ExecutionContext;
12use entelix_core::error::{Error, Result};
13use entelix_memory::{Embedder, Embedding, EmbeddingUsage};
14
15use crate::error::{OpenAiEmbedderError, OpenAiEmbedderResult};
16
17/// OpenAI's lower-cost embedding model. Native dimension 1536; can
18/// be reduced via the `dimensions` request parameter (operator
19/// `with_dimension` on the builder).
20pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
21
22/// Native dimension of [`TEXT_EMBEDDING_3_SMALL`].
23pub const TEXT_EMBEDDING_3_SMALL_DIMENSION: usize = 1536;
24
25/// OpenAI's higher-quality embedding model. Native dimension 3072;
26/// can be reduced via the `dimensions` request parameter.
27pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
28
29/// Native dimension of [`TEXT_EMBEDDING_3_LARGE`].
30pub const TEXT_EMBEDDING_3_LARGE_DIMENSION: usize = 3072;
31
32/// Default OpenAI API base URL. Override via
33/// [`OpenAiEmbedderBuilder::with_base_url`] for proxies, regional
34/// endpoints, or test fixtures.
35pub const DEFAULT_BASE_URL: &str = "https://api.openai.com";
36
37/// Concrete [`Embedder`] backed by OpenAI's `/v1/embeddings` HTTPS
38/// endpoint. Stateless beyond the connection pool inside
39/// `reqwest::Client`; clone freely or wrap in `Arc` per F10.
40#[derive(Clone)]
41pub struct OpenAiEmbedder {
42    client: reqwest::Client,
43    base_url: Arc<str>,
44    credentials: Arc<dyn CredentialProvider>,
45    model: Arc<str>,
46    dimension: usize,
47    /// `Some` when the operator explicitly chose a reduced
48    /// dimension via `with_dimension`. We forward this on the
49    /// `dimensions` field to the API; native dimension is sent as
50    /// `None` so OpenAI applies its own default.
51    dimension_override: Option<usize>,
52}
53
54impl std::fmt::Debug for OpenAiEmbedder {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("OpenAiEmbedder")
57            .field("base_url", &self.base_url)
58            .field("model", &self.model)
59            .field("dimension", &self.dimension)
60            .field("dimension_override", &self.dimension_override)
61            .finish_non_exhaustive()
62    }
63}
64
65impl OpenAiEmbedder {
66    /// Builder for the lower-cost `text-embedding-3-small` model
67    /// (native dimension 1536).
68    pub fn small() -> OpenAiEmbedderBuilder {
69        OpenAiEmbedderBuilder::new(TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_3_SMALL_DIMENSION)
70    }
71
72    /// Builder for the higher-quality `text-embedding-3-large` model
73    /// (native dimension 3072).
74    pub fn large() -> OpenAiEmbedderBuilder {
75        OpenAiEmbedderBuilder::new(TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_LARGE_DIMENSION)
76    }
77
78    /// Builder for an operator-supplied custom model identifier.
79    /// Use for OpenAI-compatible APIs (Azure OpenAI, vLLM-compatible
80    /// gateways) or for future models the SDK has not yet promoted
81    /// to a `pub const`.
82    pub fn custom(model: impl Into<String>, dimension: usize) -> OpenAiEmbedderBuilder {
83        OpenAiEmbedderBuilder::new(model, dimension)
84    }
85
86    fn embeddings_url(&self) -> String {
87        format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'))
88    }
89
90    /// Send one batch (1 or N inputs) to `/v1/embeddings` and decode
91    /// the response. The response's `usage` is split across the
92    /// returned `Embedding`s — OpenAI reports a single per-call
93    /// `prompt_tokens` count which we attribute to the first
94    /// `Embedding` to keep aggregate accounting accurate without
95    /// double-charging.
96    async fn call(&self, inputs: Vec<String>) -> OpenAiEmbedderResult<Vec<Embedding>> {
97        let credentials = self
98            .credentials
99            .resolve()
100            .await
101            .map_err(OpenAiEmbedderError::Credential)?;
102
103        let body = self.build_request_body(&inputs);
104        let response = self
105            .client
106            .post(self.embeddings_url())
107            .header(
108                credentials.header_name.clone(),
109                http::HeaderValue::from_str(credentials.header_value.expose_secret()).map_err(
110                    |e| OpenAiEmbedderError::Config(format!("invalid credential header: {e}")),
111                )?,
112            )
113            .json(&body)
114            .send()
115            .await
116            .map_err(OpenAiEmbedderError::network)?;
117
118        let status = response.status();
119        if !status.is_success() {
120            let body = response.text().await.unwrap_or_default();
121            return Err(OpenAiEmbedderError::HttpStatus {
122                status: status.as_u16(),
123                body: truncate_for_error(&body),
124            });
125        }
126
127        let parsed: EmbeddingsResponse = response
128            .json()
129            .await
130            .map_err(OpenAiEmbedderError::network)?;
131        self.decode(&parsed, inputs.len())
132    }
133
134    fn build_request_body(&self, inputs: &[String]) -> serde_json::Value {
135        let mut body = json!({
136            "model": &*self.model,
137            "input": inputs,
138            "encoding_format": "float",
139        });
140        if let Some(dim) = self.dimension_override
141            && let Some(obj) = body.as_object_mut()
142        {
143            obj.insert("dimensions".into(), json!(dim));
144        }
145        body
146    }
147
148    fn decode(
149        &self,
150        parsed: &EmbeddingsResponse,
151        expected_len: usize,
152    ) -> OpenAiEmbedderResult<Vec<Embedding>> {
153        if parsed.data.len() != expected_len {
154            return Err(OpenAiEmbedderError::Malformed(format!(
155                "expected {expected_len} embeddings, server returned {}",
156                parsed.data.len()
157            )));
158        }
159        // OpenAI does not guarantee response-`index` ordering matches
160        // request order — sort by `index` to match the input slot.
161        let mut sorted: Vec<&EmbeddingsDataItem> = parsed.data.iter().collect();
162        sorted.sort_by_key(|d| d.index);
163
164        let usage = parsed.usage.map(|u| EmbeddingUsage::new(u.prompt_tokens));
165        let mut out = Vec::with_capacity(expected_len);
166        for (i, item) in sorted.iter().enumerate() {
167            if item.embedding.len() != self.dimension {
168                return Err(OpenAiEmbedderError::Malformed(format!(
169                    "embedding {} dimension {} does not match configured {}",
170                    i,
171                    item.embedding.len(),
172                    self.dimension
173                )));
174            }
175            // Attribute the per-call usage to slot 0 only — downstream
176            // meters sum across the batch and would double-charge if
177            // we replicated the count on every slot.
178            let mut emb = Embedding::new(item.embedding.clone());
179            if i == 0
180                && let Some(u) = usage
181            {
182                emb = emb.with_usage(u);
183            }
184            out.push(emb);
185        }
186        Ok(out)
187    }
188}
189
190#[async_trait]
191impl Embedder for OpenAiEmbedder {
192    fn dimension(&self) -> usize {
193        self.dimension
194    }
195
196    async fn embed(&self, text: &str, ctx: &ExecutionContext) -> Result<Embedding> {
197        if ctx.is_cancelled() {
198            return Err(Error::Cancelled);
199        }
200        let mut out = self
201            .call(vec![text.to_owned()])
202            .await
203            .map_err(Error::from)?;
204        out.pop()
205            .ok_or_else(|| Error::provider_network("OpenAI returned no embedding".to_owned()))
206    }
207
208    async fn embed_batch(
209        &self,
210        texts: &[String],
211        ctx: &ExecutionContext,
212    ) -> Result<Vec<Embedding>> {
213        if ctx.is_cancelled() {
214            return Err(Error::Cancelled);
215        }
216        if texts.is_empty() {
217            return Ok(Vec::new());
218        }
219        // One HTTP call per batch — F10 amortization. Default
220        // sequential impl would issue N round-trips; we override.
221        self.call(texts.to_vec()).await.map_err(Error::from)
222    }
223}
224
225/// Builder for [`OpenAiEmbedder`].
226#[must_use]
227pub struct OpenAiEmbedderBuilder {
228    model: String,
229    dimension: usize,
230    dimension_override: Option<usize>,
231    base_url: String,
232    credentials: Option<Arc<dyn CredentialProvider>>,
233    client: Option<reqwest::Client>,
234}
235
236impl OpenAiEmbedderBuilder {
237    fn new(model: impl Into<String>, native_dimension: usize) -> Self {
238        Self {
239            model: model.into(),
240            dimension: native_dimension,
241            dimension_override: None,
242            base_url: DEFAULT_BASE_URL.to_owned(),
243            credentials: None,
244            client: None,
245        }
246    }
247
248    /// Attach a credential provider. Required.
249    pub fn with_credentials(mut self, credentials: Arc<dyn CredentialProvider>) -> Self {
250        self.credentials = Some(credentials);
251        self
252    }
253
254    /// Override the API base URL (defaults to
255    /// [`DEFAULT_BASE_URL`]). Used for Azure OpenAI, regional
256    /// endpoints, or test fixtures.
257    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
258        self.base_url = url.into();
259        self
260    }
261
262    /// Request a reduced dimension via the API's `dimensions`
263    /// parameter. Must be ≤ the model's native dimension. Storage
264    /// savings come from `text-embedding-3` family's matryoshka
265    /// representation — quality degrades gracefully toward the
266    /// chosen dimension.
267    pub const fn with_dimension(mut self, dimension: usize) -> Self {
268        self.dimension_override = Some(dimension);
269        self.dimension = dimension;
270        self
271    }
272
273    /// Override the underlying HTTP client. Useful when the operator
274    /// runs a shared `reqwest::Client` to consolidate connection
275    /// pools across embedder + chat + tool transports.
276    pub fn with_client(mut self, client: reqwest::Client) -> Self {
277        self.client = Some(client);
278        self
279    }
280
281    /// Finalize the builder. Returns
282    /// [`OpenAiEmbedderError::Config`] if credentials are missing
283    /// or the configured dimension exceeds the native maximum.
284    pub fn build(self) -> OpenAiEmbedderResult<OpenAiEmbedder> {
285        let credentials = self
286            .credentials
287            .ok_or_else(|| OpenAiEmbedderError::Config("credentials required".into()))?;
288        if self.dimension == 0 {
289            return Err(OpenAiEmbedderError::Config("dimension must be > 0".into()));
290        }
291        let client = self.client.unwrap_or_default();
292        Ok(OpenAiEmbedder {
293            client,
294            base_url: self.base_url.into(),
295            credentials,
296            model: self.model.into(),
297            dimension: self.dimension,
298            dimension_override: self.dimension_override,
299        })
300    }
301}
302
303// ── wire format ────────────────────────────────────────────────────────────
304
305#[derive(Debug, Deserialize)]
306struct EmbeddingsResponse {
307    data: Vec<EmbeddingsDataItem>,
308    #[serde(default)]
309    usage: Option<EmbeddingsUsageItem>,
310}
311
312#[derive(Debug, Deserialize)]
313struct EmbeddingsDataItem {
314    embedding: Vec<f32>,
315    index: u32,
316}
317
318#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize)]
319struct EmbeddingsUsageItem {
320    prompt_tokens: u32,
321}
322
323const ERROR_BODY_TRUNCATION_BYTES: usize = 512;
324
325fn truncate_for_error(body: &str) -> String {
326    if body.len() <= ERROR_BODY_TRUNCATION_BYTES {
327        return body.to_owned();
328    }
329    let mut cut = ERROR_BODY_TRUNCATION_BYTES;
330    while cut > 0 && !body.is_char_boundary(cut) {
331        cut -= 1;
332    }
333    format!("{}… ({} bytes truncated)", &body[..cut], body.len() - cut)
334}
335
336#[cfg(test)]
337#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
338mod tests {
339    use super::*;
340    use entelix_core::auth::ApiKeyProvider;
341
342    fn provider() -> Arc<dyn CredentialProvider> {
343        Arc::new(ApiKeyProvider::new("authorization", "Bearer test").unwrap())
344    }
345
346    #[test]
347    fn small_builder_defaults_to_native_dimension() {
348        let e = OpenAiEmbedder::small()
349            .with_credentials(provider())
350            .build()
351            .unwrap();
352        assert_eq!(e.dimension(), TEXT_EMBEDDING_3_SMALL_DIMENSION);
353        assert_eq!(&*e.model, TEXT_EMBEDDING_3_SMALL);
354    }
355
356    #[test]
357    fn large_builder_defaults_to_native_dimension() {
358        let e = OpenAiEmbedder::large()
359            .with_credentials(provider())
360            .build()
361            .unwrap();
362        assert_eq!(e.dimension(), TEXT_EMBEDDING_3_LARGE_DIMENSION);
363    }
364
365    #[test]
366    fn dimension_override_threads_into_request_body() {
367        let e = OpenAiEmbedder::small()
368            .with_credentials(provider())
369            .with_dimension(512)
370            .build()
371            .unwrap();
372        assert_eq!(e.dimension(), 512);
373        let body = e.build_request_body(&["hi".to_owned()]);
374        assert_eq!(body["dimensions"], 512);
375    }
376
377    #[test]
378    fn missing_credentials_rejected_at_build() {
379        let err = OpenAiEmbedder::small().build().unwrap_err();
380        assert!(matches!(err, OpenAiEmbedderError::Config(_)));
381    }
382
383    #[test]
384    fn zero_dimension_rejected_at_build() {
385        let err = OpenAiEmbedder::custom("custom-model", 0)
386            .with_credentials(provider())
387            .build()
388            .unwrap_err();
389        assert!(matches!(err, OpenAiEmbedderError::Config(_)));
390    }
391
392    #[test]
393    fn embeddings_url_strips_trailing_slash() {
394        let e = OpenAiEmbedder::small()
395            .with_credentials(provider())
396            .with_base_url("https://example.test/")
397            .build()
398            .unwrap();
399        assert_eq!(e.embeddings_url(), "https://example.test/v1/embeddings");
400    }
401
402    #[test]
403    fn decode_attributes_usage_to_first_slot_only() {
404        let e = OpenAiEmbedder::custom("test-model", 3)
405            .with_credentials(provider())
406            .build()
407            .unwrap();
408        let parsed = EmbeddingsResponse {
409            data: vec![
410                EmbeddingsDataItem {
411                    embedding: vec![0.1, 0.2, 0.3],
412                    index: 0,
413                },
414                EmbeddingsDataItem {
415                    embedding: vec![0.4, 0.5, 0.6],
416                    index: 1,
417                },
418            ],
419            usage: Some(EmbeddingsUsageItem { prompt_tokens: 7 }),
420        };
421        let out = e.decode(&parsed, 2).unwrap();
422        assert_eq!(out.len(), 2);
423        assert_eq!(out[0].usage, Some(EmbeddingUsage::new(7)));
424        assert!(
425            out[1].usage.is_none(),
426            "usage must NOT replicate across slots"
427        );
428    }
429
430    #[test]
431    fn decode_sorts_by_index_when_response_order_shuffled() {
432        let e = OpenAiEmbedder::custom("test-model", 2)
433            .with_credentials(provider())
434            .build()
435            .unwrap();
436        let parsed = EmbeddingsResponse {
437            data: vec![
438                EmbeddingsDataItem {
439                    embedding: vec![0.9, 0.9],
440                    index: 1,
441                },
442                EmbeddingsDataItem {
443                    embedding: vec![0.1, 0.1],
444                    index: 0,
445                },
446            ],
447            usage: None,
448        };
449        let out = e.decode(&parsed, 2).unwrap();
450        assert_eq!(out[0].vector, vec![0.1, 0.1]);
451        assert_eq!(out[1].vector, vec![0.9, 0.9]);
452    }
453
454    #[test]
455    fn decode_rejects_dimension_mismatch() {
456        let e = OpenAiEmbedder::custom("test-model", 3)
457            .with_credentials(provider())
458            .build()
459            .unwrap();
460        let parsed = EmbeddingsResponse {
461            data: vec![EmbeddingsDataItem {
462                embedding: vec![0.1, 0.2], // 2 != 3
463                index: 0,
464            }],
465            usage: None,
466        };
467        let err = e.decode(&parsed, 1).unwrap_err();
468        assert!(matches!(err, OpenAiEmbedderError::Malformed(_)));
469    }
470
471    #[test]
472    fn decode_rejects_count_mismatch() {
473        let e = OpenAiEmbedder::custom("test-model", 1)
474            .with_credentials(provider())
475            .build()
476            .unwrap();
477        let parsed = EmbeddingsResponse {
478            data: vec![EmbeddingsDataItem {
479                embedding: vec![0.1],
480                index: 0,
481            }],
482            usage: None,
483        };
484        let err = e.decode(&parsed, 2).unwrap_err();
485        assert!(matches!(err, OpenAiEmbedderError::Malformed(_)));
486    }
487
488    #[test]
489    fn truncate_for_error_caps_oversized_body() {
490        let huge = "x".repeat(10_000);
491        let truncated = truncate_for_error(&huge);
492        assert!(truncated.contains("truncated"));
493        assert!(truncated.len() < 1000);
494    }
495}