1use crate::errors::AppError;
10use crate::retry::AttemptOutcome;
11use secrecy::{ExposeSecret, SecretBox};
12use serde::{Deserialize, Serialize};
13use std::time::Duration;
14
15const OPENROUTER_EMBEDDINGS_URL: &str = "https://openrouter.ai/api/v1/embeddings";
16const DEFAULT_TIMEOUT_SECS: u64 = 30;
17const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
18const MAX_BATCH_SIZE: usize = 32;
19
20#[derive(Serialize)]
21struct EmbeddingRequest<'a> {
22 model: &'a str,
23 input: EmbeddingInput<'a>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 dimensions: Option<usize>,
26 encoding_format: &'a str,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 input_type: Option<&'a str>,
29}
30
31#[derive(Serialize)]
32#[serde(untagged)]
33enum EmbeddingInput<'a> {
34 Single(&'a str),
35 Batch(Vec<&'a str>),
36}
37
38#[derive(Deserialize)]
39struct EmbeddingResponse {
40 data: Vec<EmbeddingData>,
41}
42
43#[derive(Deserialize)]
44struct EmbeddingData {
45 embedding: Vec<f32>,
46 index: usize,
47}
48
49#[derive(Deserialize)]
57struct EmbeddingEnvelope {
58 #[serde(default)]
59 data: Option<Vec<EmbeddingData>>,
60 #[serde(default)]
61 error: Option<ApiError>,
62}
63
64use crate::openrouter_http::ApiError;
69
70#[derive(Debug)]
82pub struct EmbedError {
83 pub source: AppError,
85 pub retry_class: AttemptOutcome,
88}
89
90impl EmbedError {
91 fn new(source: AppError, retry_class: AttemptOutcome) -> Self {
92 Self {
93 source,
94 retry_class,
95 }
96 }
97}
98
99impl std::fmt::Display for EmbedError {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 std::fmt::Display::fmt(&self.source, f)
102 }
103}
104
105impl std::error::Error for EmbedError {
106 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
107 Some(&self.source)
108 }
109}
110
111impl From<AppError> for EmbedError {
120 fn from(source: AppError) -> Self {
121 Self::new(source, AttemptOutcome::HardFailure)
122 }
123}
124
125impl From<EmbedError> for AppError {
132 fn from(err: EmbedError) -> Self {
133 err.source
134 }
135}
136
137pub struct OpenRouterClient {
138 client: reqwest::Client,
139 api_key: SecretBox<String>,
140 model: String,
141 dim: usize,
142 supports_mrl: bool,
143 default_input_type: Option<&'static str>,
144 base_url: String,
149}
150
151fn model_supports_mrl(model: &str) -> bool {
152 model.contains("qwen3-embedding")
153 || model.contains("text-embedding-3")
154 || model.contains("gemini-embedding")
155 || model.contains("llama-nemotron-embed")
156 || model.contains("bge-m3")
157}
158
159fn model_default_input_type(model: &str) -> Option<&'static str> {
160 if model.contains("llama-nemotron-embed") {
161 Some("passage")
162 } else if model.contains("mistral-embed") {
163 None
164 } else {
165 Some("search_document")
166 }
167}
168
169impl OpenRouterClient {
170 pub fn new(api_key: SecretBox<String>, model: String, dim: usize) -> Result<Self, AppError> {
171 let client = reqwest::Client::builder()
172 .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
173 .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
174 .user_agent(concat!("sqlite-graphrag/", env!("CARGO_PKG_VERSION")))
177 .build()
178 .map_err(|e| AppError::Embedding(format!("failed to build HTTP client: {e}")))?;
179
180 let supports_mrl = model_supports_mrl(&model);
181 let default_input_type = model_default_input_type(&model);
182
183 Ok(Self {
184 client,
185 api_key,
186 model,
187 dim,
188 supports_mrl,
189 default_input_type,
190 base_url: OPENROUTER_EMBEDDINGS_URL.to_string(),
191 })
192 }
193
194 #[cfg(test)]
198 fn new_with_url(
199 api_key: SecretBox<String>,
200 model: String,
201 dim: usize,
202 base_url: String,
203 ) -> Result<Self, AppError> {
204 let mut client = Self::new(api_key, model, dim)?;
205 client.base_url = base_url;
206 Ok(client)
207 }
208
209 pub fn default_input_type(&self) -> Option<&'static str> {
210 self.default_input_type
211 }
212
213 pub async fn embed_single(
214 &self,
215 text: &str,
216 input_type: Option<&str>,
217 ) -> Result<Vec<f32>, EmbedError> {
218 crate::memory_guard::check_embedding_input_size(text)?;
222
223 let request = EmbeddingRequest {
224 model: &self.model,
225 input: EmbeddingInput::Single(text),
226 dimensions: if self.supports_mrl {
227 Some(self.dim)
228 } else {
229 None
230 },
231 encoding_format: "float",
232 input_type,
233 };
234
235 let response = self.execute_with_retry(&request).await?;
236
237 let embedding = response
238 .data
239 .into_iter()
240 .next()
241 .ok_or_else(|| AppError::Embedding("empty response from OpenRouter".into()))?
242 .embedding;
243
244 Ok(self.truncate_embedding(embedding)?)
245 }
246
247 pub async fn embed_batch(
248 &self,
249 texts: &[&str],
250 input_type: Option<&str>,
251 ) -> Result<Vec<Vec<f32>>, EmbedError> {
252 if texts.is_empty() {
253 return Ok(Vec::new());
254 }
255
256 for text in texts {
260 crate::memory_guard::check_embedding_input_size(text)?;
261 }
262
263 let mut all = Vec::with_capacity(texts.len());
264
265 for chunk in texts.chunks(MAX_BATCH_SIZE) {
266 let request = EmbeddingRequest {
267 model: &self.model,
268 input: EmbeddingInput::Batch(chunk.to_vec()),
269 dimensions: if self.supports_mrl {
270 Some(self.dim)
271 } else {
272 None
273 },
274 encoding_format: "float",
275 input_type,
276 };
277
278 let response = self.execute_with_retry(&request).await?;
279
280 if response.data.len() != chunk.len() {
281 return Err(AppError::Embedding(format!(
282 "expected {} embeddings, got {}",
283 chunk.len(),
284 response.data.len()
285 ))
286 .into());
287 }
288
289 let mut sorted = response.data;
290 sorted.sort_by_key(|d| d.index);
291
292 for d in sorted {
293 all.push(self.truncate_embedding(d.embedding)?);
294 }
295 }
296
297 Ok(all)
298 }
299
300 fn truncate_embedding(&self, embedding: Vec<f32>) -> Result<Vec<f32>, AppError> {
301 if embedding.len() < self.dim {
302 return Err(AppError::Embedding(format!(
303 "embedding dimension {} < requested {}",
304 embedding.len(),
305 self.dim
306 )));
307 }
308 if embedding.len() == self.dim {
309 Ok(embedding)
310 } else {
311 Ok(embedding[..self.dim].to_vec())
312 }
313 }
314
315 async fn execute_with_retry(
322 &self,
323 request: &EmbeddingRequest<'_>,
324 ) -> Result<EmbeddingResponse, EmbedError> {
325 let mut last_err: Option<EmbedError> = None;
326
327 for attempt in 0..crate::openrouter_http::MAX_RETRIES {
328 let result = self
329 .client
330 .post(&self.base_url)
331 .header(
332 "Authorization",
333 format!("Bearer {}", self.api_key.expose_secret()),
334 )
335 .json(request)
336 .send()
337 .await;
338
339 let resp = match result {
340 Ok(r) => r,
341 Err(e) if e.is_timeout() => {
342 return Err(EmbedError::new(
343 AppError::Embedding("OpenRouter request timed out".into()),
344 AttemptOutcome::Transient,
345 ));
346 }
347 Err(e) => {
348 last_err = Some(EmbedError::new(
349 AppError::Embedding(format!("HTTP request failed: {e}")),
350 AttemptOutcome::Transient,
351 ));
352 crate::openrouter_http::backoff(attempt).await;
353 continue;
354 }
355 };
356
357 let status = resp.status();
358
359 if status.is_success() {
360 let body = resp.text().await.map_err(|e| {
361 EmbedError::new(
362 AppError::Embedding(format!("failed to read response body: {e}")),
363 AttemptOutcome::Transient,
364 )
365 })?;
366 match serde_json::from_str::<EmbeddingEnvelope>(&body) {
367 Ok(env) => {
368 if let Some(api_err) = env.error {
373 let retry_class =
374 crate::openrouter_http::provider_error_retry_class(&api_err);
375 return Err(EmbedError::new(
376 AppError::ProviderError {
377 code: api_err.code_string(),
378 message: api_err.message,
379 },
380 retry_class,
381 ));
382 }
383 match env.data {
384 Some(data) => return Ok(EmbeddingResponse { data }),
385 None => {
386 tracing::warn!(
387 attempt,
388 body_len = body.len(),
389 "HTTP 200 with neither data nor error (retrying)"
390 );
391 last_err = Some(EmbedError::new(
392 AppError::Embedding(
393 "OpenRouter 200 response had neither data nor error".into(),
394 ),
395 AttemptOutcome::Transient,
396 ));
397 crate::openrouter_http::backoff(attempt).await;
398 continue;
399 }
400 }
401 }
402 Err(e) => {
403 tracing::warn!(
404 attempt,
405 body_len = body.len(),
406 "HTTP 200 but JSON unparseable (retrying): {e}"
407 );
408 last_err = Some(EmbedError::new(
409 AppError::Embedding(format!("failed to parse embedding response: {e}")),
410 AttemptOutcome::Transient,
411 ));
412 crate::openrouter_http::backoff(attempt).await;
413 continue;
414 }
415 }
416 }
417
418 if status.as_u16() == 401 {
419 return Err(EmbedError::new(
420 AppError::Embedding("invalid OpenRouter API key (HTTP 401)".into()),
421 AttemptOutcome::HardFailure,
422 ));
423 }
424
425 if status.as_u16() == 400 || status.as_u16() == 404 {
426 let body = resp.text().await.unwrap_or_default();
427 return Err(EmbedError::new(
428 AppError::Embedding(format!("OpenRouter returned {status}: {body}")),
429 AttemptOutcome::HardFailure,
430 ));
431 }
432
433 if status.as_u16() == 429 {
434 let retry_after = resp
435 .headers()
436 .get("retry-after")
437 .and_then(|v| v.to_str().ok())
438 .and_then(|v| v.parse::<u64>().ok())
439 .unwrap_or(2);
440 tracing::warn!(
441 attempt,
442 retry_after_secs = retry_after,
443 "OpenRouter rate limited, waiting"
444 );
445 last_err = Some(EmbedError::new(
450 AppError::RateLimited {
451 detail: format!("OpenRouter HTTP 429 (retry-after {retry_after}s)"),
452 },
453 AttemptOutcome::Transient,
454 ));
455 tokio::time::sleep(Duration::from_secs(retry_after)).await;
456 continue;
457 }
458
459 if status.is_server_error() {
460 tracing::warn!(attempt, status = %status, "OpenRouter server error, retrying");
461 last_err = Some(EmbedError::new(
462 AppError::Embedding(format!("OpenRouter server error: {status}")),
463 AttemptOutcome::Transient,
464 ));
465 crate::openrouter_http::backoff(attempt).await;
466 continue;
467 }
468
469 let body = resp.text().await.unwrap_or_default();
470 return Err(EmbedError::new(
471 AppError::Embedding(format!("unexpected HTTP {status}: {body}")),
472 crate::openrouter_http::status_retry_class(status),
473 ));
474 }
475
476 Err(last_err.unwrap_or_else(|| {
481 EmbedError::new(
482 AppError::Embedding("max retries exceeded for OpenRouter request".into()),
483 AttemptOutcome::Transient,
484 )
485 }))
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn test_supports_mrl_detection() {
495 assert!(model_supports_mrl("qwen/qwen3-embedding-8b"));
496 assert!(model_supports_mrl("qwen/qwen3-embedding-4b"));
497 assert!(model_supports_mrl("openai/text-embedding-3-small"));
498 assert!(model_supports_mrl("openai/text-embedding-3-large"));
499 assert!(model_supports_mrl("google/gemini-embedding-001"));
500 assert!(model_supports_mrl("google/gemini-embedding-2"));
501 assert!(model_supports_mrl(
502 "nvidia/llama-nemotron-embed-vl-1b-v2:free"
503 ));
504 assert!(model_supports_mrl("baai/bge-m3"));
505
506 assert!(!model_supports_mrl("perplexity/pplx-embed-v1-0.6b"));
507 assert!(!model_supports_mrl("mistralai/mistral-embed-2312"));
508 assert!(!model_supports_mrl("some-random-model"));
509 }
510
511 #[test]
512 fn test_model_default_input_type() {
513 assert_eq!(
514 model_default_input_type("nvidia/llama-nemotron-embed-vl-1b-v2:free"),
515 Some("passage")
516 );
517 assert_eq!(
518 model_default_input_type("mistralai/mistral-embed-2312"),
519 None
520 );
521 assert_eq!(
522 model_default_input_type("qwen/qwen3-embedding-8b"),
523 Some("search_document")
524 );
525 assert_eq!(
526 model_default_input_type("openai/text-embedding-3-small"),
527 Some("search_document")
528 );
529 assert_eq!(
530 model_default_input_type("baai/bge-m3"),
531 Some("search_document")
532 );
533 }
534
535 #[test]
536 fn test_truncate_embedding() {
537 let api_key = SecretBox::new(Box::new("test-key".to_string()));
538 let client = OpenRouterClient::new(api_key, "test-model".into(), 3).unwrap();
539
540 let full = vec![1.0, 2.0, 3.0, 4.0, 5.0];
541 let truncated = client.truncate_embedding(full).unwrap();
542 assert_eq!(truncated, vec![1.0, 2.0, 3.0]);
543
544 let exact = vec![1.0, 2.0, 3.0];
545 let kept = client.truncate_embedding(exact).unwrap();
546 assert_eq!(kept, vec![1.0, 2.0, 3.0]);
547
548 let short = vec![1.0, 2.0];
549 let err = client.truncate_embedding(short);
550 assert!(err.is_err());
551 }
552
553 #[test]
554 fn embedding_envelope_surfaces_provider_error_not_missing_field() {
555 let body = r#"{"error":{"code":400,"message":"context length exceeded"}}"#;
558
559 let legacy_err = match serde_json::from_str::<EmbeddingResponse>(body) {
562 Ok(_) => panic!("legacy parse should have failed on an error body"),
563 Err(e) => e.to_string(),
564 };
565 assert!(
566 legacy_err.contains("missing field"),
567 "precondition: legacy parse masks the cause as a missing field: {legacy_err}"
568 );
569
570 let env: EmbeddingEnvelope =
572 serde_json::from_str(body).expect("envelope parses an error body");
573 assert!(env.data.is_none());
574 let api_err = env.error.expect("error object captured");
575 assert_eq!(api_err.message, "context length exceeded");
576 assert_eq!(api_err.code_string(), "400");
577 }
578
579 #[test]
580 fn embedding_envelope_parses_success_body() {
581 let body = r#"{"data":[{"embedding":[1.0,2.0,3.0],"index":0}]}"#;
582 let env: EmbeddingEnvelope =
583 serde_json::from_str(body).expect("envelope parses a success body");
584 assert!(env.error.is_none());
585 let data = env.data.expect("data present");
586 assert_eq!(data.len(), 1);
587 assert_eq!(data[0].embedding, vec![1.0, 2.0, 3.0]);
588 }
589
590 #[test]
591 fn api_error_code_string_handles_number_string_and_missing() {
592 let num: ApiError = serde_json::from_str(r#"{"code":429,"message":"slow down"}"#).unwrap();
593 assert_eq!(num.code_string(), "429");
594
595 let s: ApiError =
596 serde_json::from_str(r#"{"code":"rate_limited","message":"slow down"}"#).unwrap();
597 assert_eq!(s.code_string(), "rate_limited");
598
599 let missing: ApiError = serde_json::from_str(r#"{"message":"oops"}"#).unwrap();
600 assert_eq!(missing.code_string(), "unknown");
601 }
602
603 #[tokio::test]
604 async fn embed_single_rejects_oversized_input_before_request() {
605 let api_key = SecretBox::new(Box::new("test-key".to_string()));
609 let client = OpenRouterClient::new(api_key, "qwen/qwen3-embedding-8b".into(), 384).unwrap();
610 let big = "word ".repeat(crate::constants::EMBEDDING_REQUEST_MAX_TOKENS + 5_000);
611 match client.embed_single(&big, None).await {
612 Err(EmbedError {
613 source: AppError::Validation(msg),
614 retry_class,
615 }) => {
616 assert!(msg.contains("tokens"));
617 assert_eq!(
618 retry_class,
619 AttemptOutcome::HardFailure,
620 "an oversized input is a permanent client error"
621 );
622 }
623 other => unreachable!("expected Validation before request, got: {other:?}"),
624 }
625 }
626
627 async fn client_for(server: &wiremock::MockServer, model: &str) -> OpenRouterClient {
628 OpenRouterClient::new_with_url(
629 SecretBox::new(Box::new("test-key".to_string())),
630 model.to_string(),
631 384,
632 format!("{}/embeddings", server.uri()),
633 )
634 .expect("test client builds")
635 }
636
637 #[tokio::test]
638 async fn embed_single_401_is_hard_failure() {
639 use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
642 let server = MockServer::start().await;
643 Mock::given(method("POST"))
644 .respond_with(ResponseTemplate::new(401))
645 .mount(&server)
646 .await;
647
648 let client = client_for(&server, "qwen/qwen3-embedding-8b").await;
649 let err = client
650 .embed_single("hello", None)
651 .await
652 .expect_err("401 is an error");
653 assert_eq!(err.retry_class, AttemptOutcome::HardFailure);
654 }
655
656 #[tokio::test]
657 async fn embed_single_exhausted_5xx_is_transient() {
658 use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
662 let server = MockServer::start().await;
663 Mock::given(method("POST"))
664 .respond_with(ResponseTemplate::new(503))
665 .mount(&server)
666 .await;
667
668 let client = client_for(&server, "qwen/qwen3-embedding-8b").await;
669 let err = client
670 .embed_single("hello", None)
671 .await
672 .expect_err("persistent 5xx exhausts retries");
673 assert_eq!(err.retry_class, AttemptOutcome::Transient);
674 }
675
676 #[tokio::test]
677 async fn embed_single_provider_error_code_classifies_by_code_not_message() {
678 use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
682 let server = MockServer::start().await;
683 Mock::given(method("POST"))
684 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
685 "error": { "code": "context_length_exceeded", "message": "too many tokens" }
686 })))
687 .mount(&server)
688 .await;
689
690 let client = client_for(&server, "qwen/qwen3-embedding-8b").await;
691 let err = client
692 .embed_single("hello", None)
693 .await
694 .expect_err("provider error must surface");
695 assert_eq!(err.retry_class, AttemptOutcome::HardFailure);
696 }
697}