Skip to main content

sqlite_graphrag/
embedding_api.rs

1//! HTTP client for the OpenRouter embeddings API.
2//!
3//! Sends embedding requests to the OpenAI-compatible endpoint at
4//! `openrouter.ai/api/v1/embeddings` and returns dense `Vec<f32>`
5//! vectors. Handles retry with exponential backoff + jitter for
6//! transient failures (429, 5xx) and immediate abort for permanent
7//! errors (401, 400).
8
9use crate::errors::AppError;
10use secrecy::{ExposeSecret, SecretBox};
11use serde::{Deserialize, Serialize};
12use std::time::Duration;
13
14const OPENROUTER_EMBEDDINGS_URL: &str = "https://openrouter.ai/api/v1/embeddings";
15const DEFAULT_TIMEOUT_SECS: u64 = 30;
16const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
17const MAX_BATCH_SIZE: usize = 32;
18const MAX_RETRIES: u32 = 4;
19
20#[derive(Serialize)]
21struct EmbeddingRequest<'a> {
22    model: &'a str,
23    input: EmbeddingInput<'a>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    dimensions: Option<usize>,
26    encoding_format: &'a str,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    input_type: Option<&'a str>,
29}
30
31#[derive(Serialize)]
32#[serde(untagged)]
33enum EmbeddingInput<'a> {
34    Single(&'a str),
35    Batch(Vec<&'a str>),
36}
37
38#[derive(Deserialize)]
39struct EmbeddingResponse {
40    data: Vec<EmbeddingData>,
41}
42
43#[derive(Deserialize)]
44struct EmbeddingData {
45    embedding: Vec<f32>,
46    index: usize,
47}
48
49/// Envelope that captures BOTH shapes the OpenRouter embeddings endpoint can
50/// return: the success payload (`data`) and the structured error object
51/// (`error`). OpenRouter sometimes returns the error object inside an HTTP 200
52/// body (e.g. token/context-length overflow); a direct parse to
53/// [`EmbeddingResponse`] would fail with a misleading missing-field error,
54/// masking the real cause. Both fields are optional so the branch is decided
55/// by inspection, not by a parse failure.
56#[derive(Deserialize)]
57struct EmbeddingEnvelope {
58    #[serde(default)]
59    data: Option<Vec<EmbeddingData>>,
60    #[serde(default)]
61    error: Option<ApiError>,
62}
63
64/// Structured OpenRouter error object. `code` is a `serde_json::Value` because
65/// the provider sends it as either a JSON number or string depending on the
66/// failure; `message` defaults to empty so a malformed error object never
67/// re-introduces the missing-field masking.
68#[derive(Deserialize)]
69struct ApiError {
70    #[serde(default)]
71    code: Option<serde_json::Value>,
72    #[serde(default)]
73    message: String,
74}
75
76impl ApiError {
77    /// Renders `code` as a plain string without JSON quoting, falling back to
78    /// `unknown` when the provider omitted it.
79    fn code_string(&self) -> String {
80        match &self.code {
81            Some(serde_json::Value::String(s)) => s.clone(),
82            Some(other) => other.to_string(),
83            None => "unknown".to_string(),
84        }
85    }
86}
87
88pub struct OpenRouterClient {
89    client: reqwest::Client,
90    api_key: SecretBox<String>,
91    model: String,
92    dim: usize,
93    supports_mrl: bool,
94    default_input_type: Option<&'static str>,
95}
96
97fn model_supports_mrl(model: &str) -> bool {
98    model.contains("qwen3-embedding")
99        || model.contains("text-embedding-3")
100        || model.contains("gemini-embedding")
101        || model.contains("llama-nemotron-embed")
102        || model.contains("bge-m3")
103}
104
105fn model_default_input_type(model: &str) -> Option<&'static str> {
106    if model.contains("llama-nemotron-embed") {
107        Some("passage")
108    } else if model.contains("mistral-embed") {
109        None
110    } else {
111        Some("search_document")
112    }
113}
114
115impl OpenRouterClient {
116    pub fn new(api_key: SecretBox<String>, model: String, dim: usize) -> Result<Self, AppError> {
117        let client = reqwest::Client::builder()
118            .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
119            .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
120            .user_agent("sqlite-graphrag/1.0.96")
121            .build()
122            .map_err(|e| AppError::Embedding(format!("failed to build HTTP client: {e}")))?;
123
124        let supports_mrl = model_supports_mrl(&model);
125        let default_input_type = model_default_input_type(&model);
126
127        Ok(Self {
128            client,
129            api_key,
130            model,
131            dim,
132            supports_mrl,
133            default_input_type,
134        })
135    }
136
137    pub fn default_input_type(&self) -> Option<&'static str> {
138        self.default_input_type
139    }
140
141    pub async fn embed_single(
142        &self,
143        text: &str,
144        input_type: Option<&str>,
145    ) -> Result<Vec<f32>, AppError> {
146        // GAP-SG-02: reject an input that would overflow the model's token
147        // window BEFORE the HTTP request, surfacing a clear Validation error
148        // instead of a provider context-length rejection paid for round-trip.
149        crate::memory_guard::check_embedding_input_size(text)?;
150
151        let request = EmbeddingRequest {
152            model: &self.model,
153            input: EmbeddingInput::Single(text),
154            dimensions: if self.supports_mrl {
155                Some(self.dim)
156            } else {
157                None
158            },
159            encoding_format: "float",
160            input_type,
161        };
162
163        let response = self.execute_with_retry(&request).await?;
164
165        let embedding = response
166            .data
167            .into_iter()
168            .next()
169            .ok_or_else(|| AppError::Embedding("empty response from OpenRouter".into()))?
170            .embedding;
171
172        self.truncate_embedding(embedding)
173    }
174
175    pub async fn embed_batch(
176        &self,
177        texts: &[&str],
178        input_type: Option<&str>,
179    ) -> Result<Vec<Vec<f32>>, AppError> {
180        if texts.is_empty() {
181            return Ok(Vec::new());
182        }
183
184        // GAP-SG-02: validate every input before any HTTP request so an
185        // oversized member of the batch fails fast as Validation rather than a
186        // provider context-length rejection mid-batch.
187        for text in texts {
188            crate::memory_guard::check_embedding_input_size(text)?;
189        }
190
191        let mut all = Vec::with_capacity(texts.len());
192
193        for chunk in texts.chunks(MAX_BATCH_SIZE) {
194            let request = EmbeddingRequest {
195                model: &self.model,
196                input: EmbeddingInput::Batch(chunk.to_vec()),
197                dimensions: if self.supports_mrl {
198                    Some(self.dim)
199                } else {
200                    None
201                },
202                encoding_format: "float",
203                input_type,
204            };
205
206            let response = self.execute_with_retry(&request).await?;
207
208            if response.data.len() != chunk.len() {
209                return Err(AppError::Embedding(format!(
210                    "expected {} embeddings, got {}",
211                    chunk.len(),
212                    response.data.len()
213                )));
214            }
215
216            let mut sorted = response.data;
217            sorted.sort_by_key(|d| d.index);
218
219            for d in sorted {
220                all.push(self.truncate_embedding(d.embedding)?);
221            }
222        }
223
224        Ok(all)
225    }
226
227    fn truncate_embedding(&self, embedding: Vec<f32>) -> Result<Vec<f32>, AppError> {
228        if embedding.len() < self.dim {
229            return Err(AppError::Embedding(format!(
230                "embedding dimension {} < requested {}",
231                embedding.len(),
232                self.dim
233            )));
234        }
235        if embedding.len() == self.dim {
236            Ok(embedding)
237        } else {
238            Ok(embedding[..self.dim].to_vec())
239        }
240    }
241
242    async fn execute_with_retry(
243        &self,
244        request: &EmbeddingRequest<'_>,
245    ) -> Result<EmbeddingResponse, AppError> {
246        let mut last_err = None;
247
248        for attempt in 0..MAX_RETRIES {
249            let result = self
250                .client
251                .post(OPENROUTER_EMBEDDINGS_URL)
252                .header(
253                    "Authorization",
254                    format!("Bearer {}", self.api_key.expose_secret()),
255                )
256                .json(request)
257                .send()
258                .await;
259
260            let resp = match result {
261                Ok(r) => r,
262                Err(e) if e.is_timeout() => {
263                    return Err(AppError::Embedding("OpenRouter request timed out".into()));
264                }
265                Err(e) => {
266                    last_err = Some(AppError::Embedding(format!("HTTP request failed: {e}")));
267                    Self::backoff(attempt).await;
268                    continue;
269                }
270            };
271
272            let status = resp.status();
273
274            if status.is_success() {
275                let body = resp.text().await.map_err(|e| {
276                    AppError::Embedding(format!("failed to read response body: {e}"))
277                })?;
278                match serde_json::from_str::<EmbeddingEnvelope>(&body) {
279                    Ok(env) => {
280                        // A structured error object inside a 2xx body is a
281                        // PERMANENT provider rejection (e.g. context-length
282                        // overflow). Surface the REAL code/message instead of
283                        // masking it as a parse failure, and do not retry.
284                        if let Some(api_err) = env.error {
285                            return Err(AppError::ProviderError {
286                                code: api_err.code_string(),
287                                message: api_err.message,
288                            });
289                        }
290                        match env.data {
291                            Some(data) => return Ok(EmbeddingResponse { data }),
292                            None => {
293                                tracing::warn!(
294                                    attempt,
295                                    body_len = body.len(),
296                                    "HTTP 200 with neither data nor error (retrying)"
297                                );
298                                last_err = Some(AppError::Embedding(
299                                    "OpenRouter 200 response had neither data nor error".into(),
300                                ));
301                                Self::backoff(attempt).await;
302                                continue;
303                            }
304                        }
305                    }
306                    Err(e) => {
307                        tracing::warn!(
308                            attempt,
309                            body_len = body.len(),
310                            "HTTP 200 but JSON unparseable (retrying): {e}"
311                        );
312                        last_err = Some(AppError::Embedding(format!(
313                            "failed to parse embedding response: {e}"
314                        )));
315                        Self::backoff(attempt).await;
316                        continue;
317                    }
318                }
319            }
320
321            if status.as_u16() == 401 {
322                return Err(AppError::Embedding(
323                    "invalid OpenRouter API key (HTTP 401)".into(),
324                ));
325            }
326
327            if status.as_u16() == 400 || status.as_u16() == 404 {
328                let body = resp.text().await.unwrap_or_default();
329                return Err(AppError::Embedding(format!(
330                    "OpenRouter returned {status}: {body}"
331                )));
332            }
333
334            if status.as_u16() == 429 {
335                let retry_after = resp
336                    .headers()
337                    .get("retry-after")
338                    .and_then(|v| v.to_str().ok())
339                    .and_then(|v| v.parse::<u64>().ok())
340                    .unwrap_or(2);
341                tracing::warn!(
342                    attempt,
343                    retry_after_secs = retry_after,
344                    "OpenRouter rate limited, waiting"
345                );
346                // GAP-SG-56: surface the Retry-After delay to the caller. If
347                // every attempt is rate limited, the loop exits with this
348                // RateLimited error (retryable) carrying the server-advised
349                // wait, instead of a generic max-retries-exceeded message.
350                last_err = Some(AppError::RateLimited {
351                    detail: format!("OpenRouter HTTP 429 (retry-after {retry_after}s)"),
352                });
353                tokio::time::sleep(Duration::from_secs(retry_after)).await;
354                continue;
355            }
356
357            if status.is_server_error() {
358                tracing::warn!(attempt, status = %status, "OpenRouter server error, retrying");
359                last_err = Some(AppError::Embedding(format!(
360                    "OpenRouter server error: {status}"
361                )));
362                Self::backoff(attempt).await;
363                continue;
364            }
365
366            let body = resp.text().await.unwrap_or_default();
367            return Err(AppError::Embedding(format!(
368                "unexpected HTTP {status}: {body}"
369            )));
370        }
371
372        Err(last_err.unwrap_or_else(|| {
373            AppError::Embedding("max retries exceeded for OpenRouter request".into())
374        }))
375    }
376
377    async fn backoff(attempt: u32) {
378        let base_ms = 1000u64 * 2u64.pow(attempt);
379        let jitter = fastrand::u64(0..500);
380        let sleep_ms = base_ms + jitter;
381        tracing::debug!(attempt, sleep_ms, "exponential backoff");
382        tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn test_supports_mrl_detection() {
392        assert!(model_supports_mrl("qwen/qwen3-embedding-8b"));
393        assert!(model_supports_mrl("qwen/qwen3-embedding-4b"));
394        assert!(model_supports_mrl("openai/text-embedding-3-small"));
395        assert!(model_supports_mrl("openai/text-embedding-3-large"));
396        assert!(model_supports_mrl("google/gemini-embedding-001"));
397        assert!(model_supports_mrl("google/gemini-embedding-2"));
398        assert!(model_supports_mrl(
399            "nvidia/llama-nemotron-embed-vl-1b-v2:free"
400        ));
401        assert!(model_supports_mrl("baai/bge-m3"));
402
403        assert!(!model_supports_mrl("perplexity/pplx-embed-v1-0.6b"));
404        assert!(!model_supports_mrl("mistralai/mistral-embed-2312"));
405        assert!(!model_supports_mrl("some-random-model"));
406    }
407
408    #[test]
409    fn test_model_default_input_type() {
410        assert_eq!(
411            model_default_input_type("nvidia/llama-nemotron-embed-vl-1b-v2:free"),
412            Some("passage")
413        );
414        assert_eq!(
415            model_default_input_type("mistralai/mistral-embed-2312"),
416            None
417        );
418        assert_eq!(
419            model_default_input_type("qwen/qwen3-embedding-8b"),
420            Some("search_document")
421        );
422        assert_eq!(
423            model_default_input_type("openai/text-embedding-3-small"),
424            Some("search_document")
425        );
426        assert_eq!(
427            model_default_input_type("baai/bge-m3"),
428            Some("search_document")
429        );
430    }
431
432    #[test]
433    fn test_truncate_embedding() {
434        let api_key = SecretBox::new(Box::new("test-key".to_string()));
435        let client = OpenRouterClient::new(api_key, "test-model".into(), 3).unwrap();
436
437        let full = vec![1.0, 2.0, 3.0, 4.0, 5.0];
438        let truncated = client.truncate_embedding(full).unwrap();
439        assert_eq!(truncated, vec![1.0, 2.0, 3.0]);
440
441        let exact = vec![1.0, 2.0, 3.0];
442        let kept = client.truncate_embedding(exact).unwrap();
443        assert_eq!(kept, vec![1.0, 2.0, 3.0]);
444
445        let short = vec![1.0, 2.0];
446        let err = client.truncate_embedding(short);
447        assert!(err.is_err());
448    }
449
450    #[test]
451    fn embedding_envelope_surfaces_provider_error_not_missing_field() {
452        // GAP-SG-01: a 200 body carrying an OpenRouter error object must yield
453        // the REAL message, not the misleading missing-field parse failure.
454        let body = r#"{"error":{"code":400,"message":"context length exceeded"}}"#;
455
456        // Precondition: the legacy optimistic parse masked the cause. Match
457        // instead of unwrap_err so EmbeddingResponse need not derive Debug.
458        let legacy_err = match serde_json::from_str::<EmbeddingResponse>(body) {
459            Ok(_) => panic!("legacy parse should have failed on an error body"),
460            Err(e) => e.to_string(),
461        };
462        assert!(
463            legacy_err.contains("missing field"),
464            "precondition: legacy parse masks the cause as a missing field: {legacy_err}"
465        );
466
467        // The envelope captures the structured error instead.
468        let env: EmbeddingEnvelope =
469            serde_json::from_str(body).expect("envelope parses an error body");
470        assert!(env.data.is_none());
471        let api_err = env.error.expect("error object captured");
472        assert_eq!(api_err.message, "context length exceeded");
473        assert_eq!(api_err.code_string(), "400");
474    }
475
476    #[test]
477    fn embedding_envelope_parses_success_body() {
478        let body = r#"{"data":[{"embedding":[1.0,2.0,3.0],"index":0}]}"#;
479        let env: EmbeddingEnvelope =
480            serde_json::from_str(body).expect("envelope parses a success body");
481        assert!(env.error.is_none());
482        let data = env.data.expect("data present");
483        assert_eq!(data.len(), 1);
484        assert_eq!(data[0].embedding, vec![1.0, 2.0, 3.0]);
485    }
486
487    #[test]
488    fn api_error_code_string_handles_number_string_and_missing() {
489        let num: ApiError = serde_json::from_str(r#"{"code":429,"message":"slow down"}"#).unwrap();
490        assert_eq!(num.code_string(), "429");
491
492        let s: ApiError =
493            serde_json::from_str(r#"{"code":"rate_limited","message":"slow down"}"#).unwrap();
494        assert_eq!(s.code_string(), "rate_limited");
495
496        let missing: ApiError = serde_json::from_str(r#"{"message":"oops"}"#).unwrap();
497        assert_eq!(missing.code_string(), "unknown");
498    }
499
500    #[tokio::test]
501    async fn embed_single_rejects_oversized_input_before_request() {
502        // GAP-SG-02: an input above EMBEDDING_REQUEST_MAX_TOKENS must fail as
503        // Validation WITHOUT any network call. The fake key/URL would error
504        // distinctly (Embedding) if the guard let the request through.
505        let api_key = SecretBox::new(Box::new("test-key".to_string()));
506        let client = OpenRouterClient::new(api_key, "qwen/qwen3-embedding-8b".into(), 384).unwrap();
507        let big = "word ".repeat(crate::constants::EMBEDDING_REQUEST_MAX_TOKENS + 5_000);
508        match client.embed_single(&big, None).await {
509            Err(AppError::Validation(msg)) => assert!(msg.contains("tokens")),
510            other => unreachable!("expected Validation before request, got: {other:?}"),
511        }
512    }
513}