1use crate::errors::AppError;
10use secrecy::{ExposeSecret, SecretBox};
11use serde::{Deserialize, Serialize};
12use std::time::Duration;
13
14const OPENROUTER_EMBEDDINGS_URL: &str = "https://openrouter.ai/api/v1/embeddings";
15const DEFAULT_TIMEOUT_SECS: u64 = 30;
16const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
17const MAX_BATCH_SIZE: usize = 32;
18const MAX_RETRIES: u32 = 4;
19
20#[derive(Serialize)]
21struct EmbeddingRequest<'a> {
22 model: &'a str,
23 input: EmbeddingInput<'a>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 dimensions: Option<usize>,
26 encoding_format: &'a str,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 input_type: Option<&'a str>,
29}
30
31#[derive(Serialize)]
32#[serde(untagged)]
33enum EmbeddingInput<'a> {
34 Single(&'a str),
35 Batch(Vec<&'a str>),
36}
37
38#[derive(Deserialize)]
39struct EmbeddingResponse {
40 data: Vec<EmbeddingData>,
41}
42
43#[derive(Deserialize)]
44struct EmbeddingData {
45 embedding: Vec<f32>,
46 index: usize,
47}
48
49pub struct OpenRouterClient {
50 client: reqwest::Client,
51 api_key: SecretBox<String>,
52 model: String,
53 dim: usize,
54 supports_mrl: bool,
55 default_input_type: Option<&'static str>,
56}
57
58fn model_supports_mrl(model: &str) -> bool {
59 model.contains("qwen3-embedding")
60 || model.contains("text-embedding-3")
61 || model.contains("gemini-embedding")
62 || model.contains("llama-nemotron-embed")
63 || model.contains("bge-m3")
64}
65
66fn model_default_input_type(model: &str) -> Option<&'static str> {
67 if model.contains("llama-nemotron-embed") {
68 Some("passage")
69 } else if model.contains("mistral-embed") {
70 None
71 } else {
72 Some("search_document")
73 }
74}
75
76impl OpenRouterClient {
77 pub fn new(api_key: SecretBox<String>, model: String, dim: usize) -> Result<Self, AppError> {
78 let client = reqwest::Client::builder()
79 .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
80 .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
81 .user_agent("sqlite-graphrag/1.0.93")
82 .build()
83 .map_err(|e| AppError::Embedding(format!("failed to build HTTP client: {e}")))?;
84
85 let supports_mrl = model_supports_mrl(&model);
86 let default_input_type = model_default_input_type(&model);
87
88 Ok(Self {
89 client,
90 api_key,
91 model,
92 dim,
93 supports_mrl,
94 default_input_type,
95 })
96 }
97
98 pub fn default_input_type(&self) -> Option<&'static str> {
99 self.default_input_type
100 }
101
102 pub async fn embed_single(
103 &self,
104 text: &str,
105 input_type: Option<&str>,
106 ) -> Result<Vec<f32>, AppError> {
107 let request = EmbeddingRequest {
108 model: &self.model,
109 input: EmbeddingInput::Single(text),
110 dimensions: if self.supports_mrl {
111 Some(self.dim)
112 } else {
113 None
114 },
115 encoding_format: "float",
116 input_type,
117 };
118
119 let response = self.execute_with_retry(&request).await?;
120
121 let embedding = response
122 .data
123 .into_iter()
124 .next()
125 .ok_or_else(|| AppError::Embedding("empty response from OpenRouter".into()))?
126 .embedding;
127
128 self.truncate_embedding(embedding)
129 }
130
131 pub async fn embed_batch(
132 &self,
133 texts: &[&str],
134 input_type: Option<&str>,
135 ) -> Result<Vec<Vec<f32>>, AppError> {
136 if texts.is_empty() {
137 return Ok(Vec::new());
138 }
139
140 let mut all = Vec::with_capacity(texts.len());
141
142 for chunk in texts.chunks(MAX_BATCH_SIZE) {
143 let request = EmbeddingRequest {
144 model: &self.model,
145 input: EmbeddingInput::Batch(chunk.to_vec()),
146 dimensions: if self.supports_mrl {
147 Some(self.dim)
148 } else {
149 None
150 },
151 encoding_format: "float",
152 input_type,
153 };
154
155 let response = self.execute_with_retry(&request).await?;
156
157 if response.data.len() != chunk.len() {
158 return Err(AppError::Embedding(format!(
159 "expected {} embeddings, got {}",
160 chunk.len(),
161 response.data.len()
162 )));
163 }
164
165 let mut sorted = response.data;
166 sorted.sort_by_key(|d| d.index);
167
168 for d in sorted {
169 all.push(self.truncate_embedding(d.embedding)?);
170 }
171 }
172
173 Ok(all)
174 }
175
176 fn truncate_embedding(&self, embedding: Vec<f32>) -> Result<Vec<f32>, AppError> {
177 if embedding.len() < self.dim {
178 return Err(AppError::Embedding(format!(
179 "embedding dimension {} < requested {}",
180 embedding.len(),
181 self.dim
182 )));
183 }
184 if embedding.len() == self.dim {
185 Ok(embedding)
186 } else {
187 Ok(embedding[..self.dim].to_vec())
188 }
189 }
190
191 async fn execute_with_retry(
192 &self,
193 request: &EmbeddingRequest<'_>,
194 ) -> Result<EmbeddingResponse, AppError> {
195 let mut last_err = None;
196
197 for attempt in 0..MAX_RETRIES {
198 let result = self
199 .client
200 .post(OPENROUTER_EMBEDDINGS_URL)
201 .header(
202 "Authorization",
203 format!("Bearer {}", self.api_key.expose_secret()),
204 )
205 .json(request)
206 .send()
207 .await;
208
209 let resp = match result {
210 Ok(r) => r,
211 Err(e) if e.is_timeout() => {
212 return Err(AppError::Embedding("OpenRouter request timed out".into()));
213 }
214 Err(e) => {
215 last_err = Some(AppError::Embedding(format!("HTTP request failed: {e}")));
216 Self::backoff(attempt).await;
217 continue;
218 }
219 };
220
221 let status = resp.status();
222
223 if status.is_success() {
224 let body = resp.text().await.map_err(|e| {
225 AppError::Embedding(format!("failed to read response body: {e}"))
226 })?;
227 match serde_json::from_str::<EmbeddingResponse>(&body) {
228 Ok(parsed) => return Ok(parsed),
229 Err(e) => {
230 tracing::warn!(
231 attempt,
232 body_len = body.len(),
233 "HTTP 200 but parse failed (retrying): {e}"
234 );
235 last_err = Some(AppError::Embedding(format!(
236 "failed to parse embedding response: {e}"
237 )));
238 Self::backoff(attempt).await;
239 continue;
240 }
241 }
242 }
243
244 if status.as_u16() == 401 {
245 return Err(AppError::Embedding(
246 "invalid OpenRouter API key (HTTP 401)".into(),
247 ));
248 }
249
250 if status.as_u16() == 400 || status.as_u16() == 404 {
251 let body = resp.text().await.unwrap_or_default();
252 return Err(AppError::Embedding(format!(
253 "OpenRouter returned {status}: {body}"
254 )));
255 }
256
257 if status.as_u16() == 429 {
258 let retry_after = resp
259 .headers()
260 .get("retry-after")
261 .and_then(|v| v.to_str().ok())
262 .and_then(|v| v.parse::<u64>().ok())
263 .unwrap_or(2);
264 tracing::warn!(
265 attempt,
266 retry_after_secs = retry_after,
267 "OpenRouter rate limited, waiting"
268 );
269 tokio::time::sleep(Duration::from_secs(retry_after)).await;
270 continue;
271 }
272
273 if status.is_server_error() {
274 tracing::warn!(attempt, status = %status, "OpenRouter server error, retrying");
275 last_err = Some(AppError::Embedding(format!(
276 "OpenRouter server error: {status}"
277 )));
278 Self::backoff(attempt).await;
279 continue;
280 }
281
282 let body = resp.text().await.unwrap_or_default();
283 return Err(AppError::Embedding(format!(
284 "unexpected HTTP {status}: {body}"
285 )));
286 }
287
288 Err(last_err.unwrap_or_else(|| {
289 AppError::Embedding("max retries exceeded for OpenRouter request".into())
290 }))
291 }
292
293 async fn backoff(attempt: u32) {
294 let base_ms = 1000u64 * 2u64.pow(attempt);
295 let jitter = fastrand::u64(0..500);
296 let sleep_ms = base_ms + jitter;
297 tracing::debug!(attempt, sleep_ms, "exponential backoff");
298 tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_supports_mrl_detection() {
308 assert!(model_supports_mrl("qwen/qwen3-embedding-8b"));
309 assert!(model_supports_mrl("qwen/qwen3-embedding-4b"));
310 assert!(model_supports_mrl("openai/text-embedding-3-small"));
311 assert!(model_supports_mrl("openai/text-embedding-3-large"));
312 assert!(model_supports_mrl("google/gemini-embedding-001"));
313 assert!(model_supports_mrl("google/gemini-embedding-2"));
314 assert!(model_supports_mrl(
315 "nvidia/llama-nemotron-embed-vl-1b-v2:free"
316 ));
317 assert!(model_supports_mrl("baai/bge-m3"));
318
319 assert!(!model_supports_mrl("perplexity/pplx-embed-v1-0.6b"));
320 assert!(!model_supports_mrl("mistralai/mistral-embed-2312"));
321 assert!(!model_supports_mrl("some-random-model"));
322 }
323
324 #[test]
325 fn test_model_default_input_type() {
326 assert_eq!(
327 model_default_input_type("nvidia/llama-nemotron-embed-vl-1b-v2:free"),
328 Some("passage")
329 );
330 assert_eq!(
331 model_default_input_type("mistralai/mistral-embed-2312"),
332 None
333 );
334 assert_eq!(
335 model_default_input_type("qwen/qwen3-embedding-8b"),
336 Some("search_document")
337 );
338 assert_eq!(
339 model_default_input_type("openai/text-embedding-3-small"),
340 Some("search_document")
341 );
342 assert_eq!(
343 model_default_input_type("baai/bge-m3"),
344 Some("search_document")
345 );
346 }
347
348 #[test]
349 fn test_truncate_embedding() {
350 let api_key = SecretBox::new(Box::new("test-key".to_string()));
351 let client = OpenRouterClient::new(api_key, "test-model".into(), 3).unwrap();
352
353 let full = vec![1.0, 2.0, 3.0, 4.0, 5.0];
354 let truncated = client.truncate_embedding(full).unwrap();
355 assert_eq!(truncated, vec![1.0, 2.0, 3.0]);
356
357 let exact = vec![1.0, 2.0, 3.0];
358 let kept = client.truncate_embedding(exact).unwrap();
359 assert_eq!(kept, vec![1.0, 2.0, 3.0]);
360
361 let short = vec![1.0, 2.0];
362 let err = client.truncate_embedding(short);
363 assert!(err.is_err());
364 }
365}