Skip to main content

difflore_core/context/embedding/
openai.rs

1use async_trait::async_trait;
2
3use crate::errors::CoreError;
4
5use super::{Embedder, embedding_http_client};
6
7/// OpenAI-compatible embedding provider.
8///
9/// Works with any backend that speaks the `OpenAI` `/embeddings` shape
10/// (`OpenAI`, Azure `OpenAI`, Together, `DeepInfra`, etc.).
11pub struct OpenAICompatEmbedder {
12    pub base_url: String,
13    pub api_key: String,
14    pub model: String,
15    pub dim: usize,
16    client: reqwest::Client,
17}
18
19impl OpenAICompatEmbedder {
20    pub fn new(base_url: String, api_key: String, model: String, dim: usize) -> Self {
21        Self {
22            base_url,
23            api_key,
24            model,
25            dim,
26            client: embedding_http_client(),
27        }
28    }
29
30    pub(crate) fn endpoint(&self) -> String {
31        let trimmed = self.base_url.trim_end_matches('/');
32        if trimmed.ends_with("/embeddings") {
33            trimmed.to_owned()
34        } else {
35            format!("{trimmed}/embeddings")
36        }
37    }
38
39    /// Build a POST request, attaching `Authorization: Bearer` only when a key
40    /// is configured. Keyless local providers (configured via
41    /// `difflore embeddings setup --no-key`) can reject any auth header, so an
42    /// empty key must send no header at all.
43    fn authed_post(&self, url: &str) -> reqwest::RequestBuilder {
44        let request = self.client.post(url);
45        if self.api_key.is_empty() {
46            request
47        } else {
48            request.bearer_auth(&self.api_key)
49        }
50    }
51}
52
53fn provider_status_error(status: reqwest::StatusCode) -> CoreError {
54    CoreError::Internal(format!(
55        "embedding provider returned {status}; check provider URL, model, dimensions, and API key"
56    ))
57}
58
59#[async_trait]
60impl Embedder for OpenAICompatEmbedder {
61    async fn embed(&self, text: &str) -> Result<Vec<f32>, CoreError> {
62        let url = self.endpoint();
63        // We deliberately do NOT send a `dimensions` parameter: many valid
64        // OpenAI-compatible models (e.g. text-embedding-ada-002) and strict
65        // local providers reject it, which would break configs whose `--dim`
66        // already matches the model's native size. Instead we validate the
67        // returned length below, so a mismatched `--dim` surfaces a clear error
68        // rather than silently storing wrong-length vectors.
69        let body = serde_json::json!({
70            "model": self.model,
71            "input": text,
72        });
73
74        let resp = self
75            .authed_post(&url)
76            .json(&body)
77            .send()
78            .await
79            .map_err(|e| CoreError::Internal(format!("embedding request failed: {e}")))?;
80
81        if !resp.status().is_success() {
82            let status = resp.status();
83            return Err(provider_status_error(status));
84        }
85
86        let json: serde_json::Value = resp
87            .json()
88            .await
89            .map_err(|e| CoreError::Internal(format!("embedding response parse error: {e}")))?;
90
91        let vec = json
92            .get("data")
93            .and_then(|d| d.get(0))
94            .and_then(|d| d.get("embedding"))
95            .and_then(|e| e.as_array())
96            .ok_or_else(|| {
97                CoreError::Internal("embedding response missing data[0].embedding".into())
98            })?
99            .iter()
100            .map(|v| v.as_f64().unwrap_or(0.0) as f32)
101            .collect::<Vec<f32>>();
102
103        // Refuse a length that disagrees with the configured profile rather than
104        // storing mismatched-length vectors under a `byok:<host>:<model>:<dim>`
105        // profile. The actionable message points at the fix.
106        if vec.len() != self.dim {
107            return Err(CoreError::Internal(format!(
108                "embedding provider returned {} dimensions but {} are configured; \
109                 re-run `difflore embeddings setup --dim {}` to match your provider/model",
110                vec.len(),
111                self.dim,
112                vec.len()
113            )));
114        }
115
116        Ok(vec)
117    }
118
119    async fn embed_batch(
120        &self,
121        texts: &[String],
122        _rule_ids: Option<&[String]>,
123    ) -> Result<Vec<Vec<f32>>, CoreError> {
124        if texts.is_empty() {
125            return Ok(Vec::new());
126        }
127        // OpenAI-compatible APIs accept batched `input`, which keeps BYOK indexing
128        // inside the bounded recall/fix/MCP timeouts.
129        let body = serde_json::json!({
130            "model": self.model,
131            "input": texts,
132        });
133        let resp = self
134            .authed_post(&self.endpoint())
135            .json(&body)
136            .send()
137            .await
138            .map_err(|e| CoreError::Internal(format!("embedding request failed: {e}")))?;
139        if !resp.status().is_success() {
140            let status = resp.status();
141            return Err(provider_status_error(status));
142        }
143        let json: serde_json::Value = resp
144            .json()
145            .await
146            .map_err(|e| CoreError::Internal(format!("embedding response parse error: {e}")))?;
147        let data = json
148            .get("data")
149            .and_then(|d| d.as_array())
150            .ok_or_else(|| CoreError::Internal("embedding response missing data array".into()))?;
151        if data.len() != texts.len() {
152            return Err(CoreError::Internal(format!(
153                "embedding response length mismatch: expected {}, got {}",
154                texts.len(),
155                data.len()
156            )));
157        }
158        // OpenAI returns each item with an `index`; order is normally preserved
159        // but we sort defensively so vectors line up with the input texts.
160        let mut indexed: Vec<(usize, Vec<f32>)> = Vec::with_capacity(data.len());
161        for item in data {
162            let index = item
163                .get("index")
164                .and_then(serde_json::Value::as_u64)
165                .map_or(indexed.len(), |i| i as usize);
166            let vector = item
167                .get("embedding")
168                .and_then(|e| e.as_array())
169                .ok_or_else(|| {
170                    CoreError::Internal("embedding response item missing embedding array".into())
171                })?
172                .iter()
173                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
174                .collect::<Vec<f32>>();
175            if vector.len() != self.dim {
176                return Err(CoreError::Internal(format!(
177                    "embedding provider returned {} dimensions but {} are configured; \
178                     re-run `difflore embeddings setup --dim {}` to match your provider/model",
179                    vector.len(),
180                    self.dim,
181                    vector.len()
182                )));
183            }
184            indexed.push((index, vector));
185        }
186        indexed.sort_by_key(|(index, _)| *index);
187        Ok(indexed.into_iter().map(|(_, vector)| vector).collect())
188    }
189
190    fn dim(&self) -> usize {
191        self.dim
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::provider_status_error;
198
199    #[test]
200    fn provider_status_error_does_not_echo_response_body() {
201        let err = provider_status_error(reqwest::StatusCode::UNAUTHORIZED).to_string();
202
203        assert!(err.contains("401"));
204        assert!(err.contains("check provider URL"));
205        assert!(!err.contains("Authorization"));
206        assert!(!err.contains("sk-"));
207    }
208}