1use crate::errors::AppError;
17use secrecy::{ExposeSecret, SecretBox};
18use serde::{Deserialize, Serialize};
19use std::time::Duration;
20
21const OPENROUTER_CHAT_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
22const DEFAULT_TIMEOUT_SECS: u64 = 300;
23const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
24const MAX_RETRIES: u32 = 4;
25
26const SCHEMA_NAME: &str = "enrich_output";
29
30#[derive(Serialize)]
31struct ChatRequest<'a> {
32 model: &'a str,
33 messages: Vec<ChatMessage<'a>>,
34 response_format: ResponseFormat,
35 provider: ProviderPrefs,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 reasoning: Option<ReasoningPrefs>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 max_tokens: Option<u32>,
40}
41
42#[derive(Serialize)]
43struct ChatMessage<'a> {
44 role: &'a str,
45 content: String,
46}
47
48#[derive(Serialize)]
49struct ResponseFormat {
50 #[serde(rename = "type")]
51 format_type: &'static str,
52 json_schema: JsonSchemaSpec,
53}
54
55#[derive(Serialize)]
56struct JsonSchemaSpec {
57 name: &'static str,
58 strict: bool,
59 schema: serde_json::Value,
60}
61
62#[derive(Serialize)]
63struct ProviderPrefs {
64 require_parameters: bool,
65}
66
67#[derive(Serialize)]
68struct ReasoningPrefs {
69 enabled: bool,
70}
71
72#[derive(Deserialize)]
73struct ChatResponse {
74 #[serde(default)]
75 choices: Vec<Choice>,
76 #[serde(default)]
77 usage: Option<Usage>,
78}
79
80#[derive(Deserialize)]
81struct Choice {
82 message: RespMessage,
83}
84
85#[derive(Deserialize)]
86struct RespMessage {
87 #[serde(default)]
88 content: Option<String>,
89}
90
91#[derive(Deserialize)]
92struct Usage {
93 #[serde(default)]
94 cost: Option<f64>,
95}
96
97pub struct OpenRouterChatClient {
100 client: reqwest::Client,
101 api_key: SecretBox<String>,
102 model: String,
103 base_url: String,
107}
108
109impl OpenRouterChatClient {
110 pub fn new(
115 api_key: SecretBox<String>,
116 model: String,
117 timeout_secs: u64,
118 ) -> Result<Self, AppError> {
119 let timeout_secs = if timeout_secs == 0 {
120 DEFAULT_TIMEOUT_SECS
121 } else {
122 timeout_secs
123 };
124 let client = reqwest::Client::builder()
125 .timeout(Duration::from_secs(timeout_secs))
126 .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
127 .user_agent("sqlite-graphrag/1.0.95")
128 .build()
129 .map_err(|e| AppError::Validation(format!("failed to build HTTP client: {e}")))?;
130
131 Ok(Self {
132 client,
133 api_key,
134 model,
135 base_url: OPENROUTER_CHAT_URL.to_string(),
136 })
137 }
138
139 #[cfg(test)]
143 pub fn new_with_url(
144 api_key: SecretBox<String>,
145 model: String,
146 base_url: String,
147 timeout_secs: u64,
148 ) -> Result<Self, AppError> {
149 let mut client = Self::new(api_key, model, timeout_secs)?;
150 client.base_url = base_url;
151 Ok(client)
152 }
153
154 pub fn model(&self) -> &str {
156 &self.model
157 }
158
159 pub async fn complete(
168 &self,
169 system_prompt: &str,
170 input_text: &str,
171 schema_str: &str,
172 max_tokens: Option<u32>,
173 ) -> Result<(serde_json::Value, f64, bool), AppError> {
174 let schema: serde_json::Value = serde_json::from_str(schema_str).map_err(|e| {
175 AppError::Validation(format!("invalid JSON schema for OpenRouter request: {e}"))
176 })?;
177
178 let primary = self.build_request(
185 schema.clone(),
186 system_prompt,
187 input_text,
188 max_tokens,
189 Some(ReasoningPrefs { enabled: false }),
190 );
191 let response = match self.execute_with_retry(&primary).await {
192 Ok(r) => r,
193 Err(first_err) => {
194 if reasoning_disable_rejected(&first_err) {
195 tracing::warn!(
196 model = %self.model,
197 "model rejected reasoning.enabled=false (mandatory); \
198 retrying once with reasoning omitted"
199 );
200 let fallback =
201 self.build_request(schema, system_prompt, input_text, max_tokens, None);
202 match self.execute_with_retry(&fallback).await {
203 Ok(r) => r,
204 Err(_) => return Err(first_err),
205 }
206 } else {
207 return Err(first_err);
208 }
209 }
210 };
211
212 let content = response
213 .choices
214 .into_iter()
215 .next()
216 .and_then(|c| c.message.content)
217 .filter(|c| !c.trim().is_empty())
218 .ok_or_else(|| {
219 AppError::Validation(format!(
220 "model '{}' returned no structured content (incompatible with \
221 structured outputs, or refused the request)",
222 self.model
223 ))
224 })?;
225
226 let value: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
227 AppError::Validation(format!(
228 "model '{}' returned non-JSON content despite strict schema: {e}",
229 self.model
230 ))
231 })?;
232
233 let cost = response.usage.and_then(|u| u.cost).unwrap_or(0.0);
234
235 Ok((value, cost, false))
236 }
237
238 fn build_request<'a>(
242 &'a self,
243 schema: serde_json::Value,
244 system_prompt: &str,
245 input_text: &str,
246 max_tokens: Option<u32>,
247 reasoning: Option<ReasoningPrefs>,
248 ) -> ChatRequest<'a> {
249 let mut messages = Vec::with_capacity(2);
250 messages.push(ChatMessage {
251 role: "system",
252 content: system_prompt.to_string(),
253 });
254 if !input_text.is_empty() {
255 messages.push(ChatMessage {
256 role: "user",
257 content: input_text.to_string(),
258 });
259 }
260 ChatRequest {
261 model: &self.model,
262 messages,
263 response_format: ResponseFormat {
264 format_type: "json_schema",
265 json_schema: JsonSchemaSpec {
266 name: SCHEMA_NAME,
267 strict: true,
268 schema,
269 },
270 },
271 provider: ProviderPrefs {
272 require_parameters: true,
273 },
274 reasoning,
275 max_tokens,
276 }
277 }
278
279 async fn execute_with_retry(
280 &self,
281 request: &ChatRequest<'_>,
282 ) -> Result<ChatResponse, AppError> {
283 let mut last_err = None;
284
285 for attempt in 0..MAX_RETRIES {
286 let result = self
287 .client
288 .post(&self.base_url)
289 .header(
290 "Authorization",
291 format!("Bearer {}", self.api_key.expose_secret()),
292 )
293 .json(request)
294 .send()
295 .await;
296
297 let resp = match result {
298 Ok(r) => r,
299 Err(e) if e.is_timeout() => {
300 return Err(AppError::Validation(
301 "OpenRouter chat request timed out".into(),
302 ));
303 }
304 Err(e) => {
305 last_err = Some(AppError::Validation(format!("HTTP request failed: {e}")));
306 Self::backoff(attempt).await;
307 continue;
308 }
309 };
310
311 let status = resp.status();
312
313 if status.is_success() {
314 let body = resp.text().await.map_err(|e| {
315 AppError::Validation(format!("failed to read response body: {e}"))
316 })?;
317 match serde_json::from_str::<ChatResponse>(&body) {
318 Ok(parsed) => return Ok(parsed),
319 Err(e) => {
320 tracing::warn!(
321 attempt,
322 body_len = body.len(),
323 "HTTP 200 but parse failed (retrying): {e}"
324 );
325 last_err = Some(AppError::Validation(format!(
326 "failed to parse chat response: {e}"
327 )));
328 Self::backoff(attempt).await;
329 continue;
330 }
331 }
332 }
333
334 if status.as_u16() == 401 {
335 return Err(AppError::Validation(
336 "invalid OpenRouter API key (HTTP 401)".into(),
337 ));
338 }
339
340 if status.as_u16() == 400 || status.as_u16() == 404 {
341 let body = resp.text().await.unwrap_or_default();
342 return Err(AppError::Validation(format!(
343 "OpenRouter returned {status} for model '{}': {body}",
344 self.model
345 )));
346 }
347
348 if status.as_u16() == 429 {
349 let retry_after = resp
350 .headers()
351 .get("retry-after")
352 .and_then(|v| v.to_str().ok())
353 .and_then(|v| v.parse::<u64>().ok())
354 .unwrap_or(2);
355 tracing::warn!(
356 attempt,
357 retry_after_secs = retry_after,
358 "OpenRouter rate limited, waiting"
359 );
360 tokio::time::sleep(Duration::from_secs(retry_after)).await;
361 continue;
362 }
363
364 if status.is_server_error() {
365 tracing::warn!(attempt, status = %status, "OpenRouter server error, retrying");
366 last_err = Some(AppError::Validation(format!(
367 "OpenRouter server error: {status}"
368 )));
369 Self::backoff(attempt).await;
370 continue;
371 }
372
373 let body = resp.text().await.unwrap_or_default();
374 return Err(AppError::Validation(format!(
375 "unexpected HTTP {status}: {body}"
376 )));
377 }
378
379 Err(last_err.unwrap_or_else(|| {
380 AppError::Validation("max retries exceeded for OpenRouter chat request".into())
381 }))
382 }
383
384 async fn backoff(attempt: u32) {
385 let base_ms = 1000u64 * 2u64.pow(attempt);
386 let jitter = fastrand::u64(0..500);
387 let sleep_ms = base_ms + jitter;
388 tracing::debug!(attempt, sleep_ms, "exponential backoff");
389 tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
390 }
391}
392
393fn reasoning_disable_rejected(err: &AppError) -> bool {
398 let msg = err.to_string().to_lowercase();
399 msg.contains("400") && msg.contains("reasoning")
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use serde_json::json;
406 use wiremock::matchers::{body_partial_json, method, path};
407 use wiremock::{Mock, MockServer, ResponseTemplate};
408
409 const TEST_SCHEMA: &str = r#"{"type":"object"}"#;
410
411 fn key() -> SecretBox<String> {
412 SecretBox::new(Box::new("test-key".to_string()))
413 }
414
415 fn success_body(content: &str, cost: Option<f64>) -> serde_json::Value {
419 let mut body = json!({
420 "choices": [{ "message": { "content": content } }]
421 });
422 if let Some(c) = cost {
423 body["usage"] = json!({ "cost": c });
424 }
425 body
426 }
427
428 async fn client_for(server: &MockServer, model: &str) -> OpenRouterChatClient {
429 OpenRouterChatClient::new_with_url(
430 key(),
431 model.to_string(),
432 format!("{}/chat/completions", server.uri()),
433 30,
434 )
435 .expect("test client builds")
436 }
437
438 #[test]
439 fn new_builds_client_and_binds_model() {
440 let client = OpenRouterChatClient::new(key(), "z-ai/glm-5.2".to_string(), 30)
441 .expect("client builds");
442 assert_eq!(client.model(), "z-ai/glm-5.2");
443 }
444
445 #[test]
446 fn new_defaults_base_url_to_public_endpoint() {
447 let client = OpenRouterChatClient::new(key(), "z-ai/glm-5.2".to_string(), 30)
448 .expect("client builds");
449 assert_eq!(client.base_url, OPENROUTER_CHAT_URL);
450 }
451
452 #[test]
453 fn request_serializes_with_strict_schema_and_disabled_reasoning() {
454 let request = ChatRequest {
455 model: "deepseek/deepseek-v4-flash",
456 messages: vec![ChatMessage {
457 role: "system",
458 content: "extract".to_string(),
459 }],
460 response_format: ResponseFormat {
461 format_type: "json_schema",
462 json_schema: JsonSchemaSpec {
463 name: SCHEMA_NAME,
464 strict: true,
465 schema: serde_json::json!({"type": "object"}),
466 },
467 },
468 provider: ProviderPrefs {
469 require_parameters: true,
470 },
471 reasoning: Some(ReasoningPrefs { enabled: false }),
472 max_tokens: None,
473 };
474 let json = serde_json::to_value(&request).expect("serializes");
475 assert_eq!(json["response_format"]["type"], "json_schema");
476 assert_eq!(json["response_format"]["json_schema"]["strict"], true);
477 assert_eq!(json["provider"]["require_parameters"], true);
478 assert_eq!(json["reasoning"]["enabled"], false);
479 assert!(json.get("max_tokens").is_none());
481 }
482
483 #[tokio::test]
484 async fn complete_sends_wellformed_request_and_parses_content() {
485 let server = MockServer::start().await;
486 Mock::given(method("POST"))
487 .and(path("/chat/completions"))
488 .and(body_partial_json(json!({
489 "model": "deepseek/deepseek-v4-flash",
490 "response_format": {
491 "type": "json_schema",
492 "json_schema": { "name": "enrich_output", "strict": true }
493 },
494 "provider": { "require_parameters": true },
495 "reasoning": { "enabled": false }
496 })))
497 .respond_with(ResponseTemplate::new(200).set_body_json(success_body(
498 r#"{"entities":[],"relationships":[]}"#,
499 Some(0.0023),
500 )))
501 .expect(1)
502 .mount(&server)
503 .await;
504
505 let client = client_for(&server, "deepseek/deepseek-v4-flash").await;
506 let (value, cost, is_oauth) = client
507 .complete("system", "input", TEST_SCHEMA, None)
508 .await
509 .expect("completion succeeds");
510
511 assert_eq!(value, json!({"entities": [], "relationships": []}));
512 assert!((cost - 0.0023).abs() < f64::EPSILON);
513 assert!(!is_oauth);
514 }
515
516 #[tokio::test]
517 async fn complete_defaults_cost_to_zero_when_usage_absent() {
518 let server = MockServer::start().await;
519 Mock::given(method("POST"))
520 .respond_with(
521 ResponseTemplate::new(200).set_body_json(success_body(r#"{"entities":[]}"#, None)),
522 )
523 .mount(&server)
524 .await;
525
526 let client = client_for(&server, "z-ai/glm-5.2").await;
527 let (_, cost, _) = client
528 .complete("system", "", TEST_SCHEMA, Some(4096))
529 .await
530 .expect("completion succeeds");
531 assert_eq!(cost, 0.0);
532 }
533
534 #[tokio::test]
535 async fn complete_retries_on_429_honouring_retry_after() {
536 let server = MockServer::start().await;
537 Mock::given(method("POST"))
538 .respond_with(ResponseTemplate::new(429).insert_header("retry-after", "1"))
539 .up_to_n_times(1)
540 .expect(1)
541 .mount(&server)
542 .await;
543 Mock::given(method("POST"))
544 .respond_with(
545 ResponseTemplate::new(200).set_body_json(success_body(r#"{"ok":true}"#, Some(0.0))),
546 )
547 .expect(1)
548 .mount(&server)
549 .await;
550
551 let client = client_for(&server, "minimax/minimax-m3").await;
552 let (value, _, _) = client
553 .complete("system", "input", TEST_SCHEMA, None)
554 .await
555 .expect("retried completion succeeds");
556 assert_eq!(value, json!({"ok": true}));
557 }
558
559 #[tokio::test]
560 async fn complete_retries_on_5xx_with_backoff() {
561 let server = MockServer::start().await;
562 Mock::given(method("POST"))
563 .respond_with(ResponseTemplate::new(503))
564 .up_to_n_times(1)
565 .expect(1)
566 .mount(&server)
567 .await;
568 Mock::given(method("POST"))
569 .respond_with(
570 ResponseTemplate::new(200).set_body_json(success_body(r#"{"ok":1}"#, Some(0.0))),
571 )
572 .expect(1)
573 .mount(&server)
574 .await;
575
576 let client = client_for(&server, "openai/gpt-oss-120b").await;
577 let (value, _, _) = client
578 .complete("system", "input", TEST_SCHEMA, None)
579 .await
580 .expect("retried completion succeeds");
581 assert_eq!(value, json!({"ok": 1}));
582 }
583
584 #[tokio::test]
585 async fn complete_401_is_permanent_without_retry() {
586 let server = MockServer::start().await;
587 Mock::given(method("POST"))
588 .respond_with(ResponseTemplate::new(401))
589 .expect(1)
590 .mount(&server)
591 .await;
592
593 let client = client_for(&server, "z-ai/glm-5.2").await;
594 let err = client
595 .complete("system", "input", TEST_SCHEMA, None)
596 .await
597 .expect_err("401 is an error");
598 assert!(err.to_string().contains("401"), "got: {err}");
599 }
600
601 #[tokio::test]
602 async fn complete_400_returns_body_and_model_without_retry() {
603 let server = MockServer::start().await;
604 Mock::given(method("POST"))
605 .respond_with(ResponseTemplate::new(400).set_body_string("schema not supported"))
606 .expect(1)
607 .mount(&server)
608 .await;
609
610 let client = client_for(&server, "xiaomi/mimo-v2.5").await;
611 let err = client
612 .complete("system", "input", TEST_SCHEMA, None)
613 .await
614 .expect_err("400 is an error");
615 let msg = err.to_string();
616 assert!(msg.contains("400"), "got: {msg}");
617 assert!(msg.contains("xiaomi/mimo-v2.5"), "got: {msg}");
618 assert!(msg.contains("schema not supported"), "got: {msg}");
619 }
620
621 #[tokio::test]
622 async fn complete_empty_choices_errors_citing_model() {
623 let server = MockServer::start().await;
624 Mock::given(method("POST"))
625 .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "choices": [] })))
626 .mount(&server)
627 .await;
628
629 let client = client_for(&server, "minimax/minimax-m2.7").await;
630 let err = client
631 .complete("system", "input", TEST_SCHEMA, None)
632 .await
633 .expect_err("empty choices is an error");
634 let msg = err.to_string();
635 assert!(msg.contains("minimax/minimax-m2.7"), "got: {msg}");
636 assert!(msg.contains("no structured content"), "got: {msg}");
637 }
638
639 #[tokio::test]
640 async fn complete_empty_content_errors() {
641 let server = MockServer::start().await;
642 Mock::given(method("POST"))
643 .respond_with(ResponseTemplate::new(200).set_body_json(success_body(" ", Some(0.0))))
644 .mount(&server)
645 .await;
646
647 let client = client_for(&server, "z-ai/glm-5.2:nitro").await;
648 let err = client
649 .complete("system", "input", TEST_SCHEMA, None)
650 .await
651 .expect_err("blank content is an error");
652 assert!(
653 err.to_string().contains("no structured content"),
654 "got: {err}"
655 );
656 }
657
658 #[tokio::test]
659 async fn complete_non_json_content_errors_as_incompatible() {
660 let server = MockServer::start().await;
661 Mock::given(method("POST"))
662 .respond_with(
663 ResponseTemplate::new(200)
664 .set_body_json(success_body("this is not json", Some(0.0))),
665 )
666 .mount(&server)
667 .await;
668
669 let client = client_for(&server, "google/gemini-3.1-flash-lite").await;
670 let err = client
671 .complete("system", "input", TEST_SCHEMA, None)
672 .await
673 .expect_err("non-json content is an error");
674 let msg = err.to_string();
675 assert!(msg.contains("non-JSON content"), "got: {msg}");
676 assert!(msg.contains("google/gemini-3.1-flash-lite"), "got: {msg}");
677 }
678
679 #[tokio::test]
680 async fn complete_rejects_invalid_schema_before_network() {
681 let client = OpenRouterChatClient::new_with_url(
683 key(),
684 "z-ai/glm-5.2".to_string(),
685 "http://127.0.0.1:1/chat/completions".to_string(),
686 30,
687 )
688 .expect("client builds");
689 let err = client
690 .complete("system", "input", "{not valid json", None)
691 .await
692 .expect_err("invalid schema is rejected");
693 assert!(
694 err.to_string().contains("invalid JSON schema"),
695 "got: {err}"
696 );
697 }
698
699 #[tokio::test]
700 async fn complete_retries_with_reasoning_omitted_when_mandatory() {
701 let server = MockServer::start().await;
702 Mock::given(method("POST"))
706 .respond_with(
707 ResponseTemplate::new(400).set_body_string(
708 "reasoning is mandatory for this model and cannot be disabled",
709 ),
710 )
711 .up_to_n_times(1)
712 .expect(1)
713 .mount(&server)
714 .await;
715 Mock::given(method("POST"))
717 .respond_with(ResponseTemplate::new(200).set_body_json(success_body(
718 r#"{"entities":[],"relationships":[]}"#,
719 Some(0.0),
720 )))
721 .expect(1)
722 .mount(&server)
723 .await;
724
725 let client = client_for(&server, "minimax/minimax-m2.7").await;
726 let (value, _, _) = client
727 .complete("system", "input", TEST_SCHEMA, None)
728 .await
729 .expect("fallback completion succeeds");
730 assert_eq!(value, json!({"entities": [], "relationships": []}));
731
732 let requests = server
735 .received_requests()
736 .await
737 .expect("request recording is enabled");
738 assert_eq!(requests.len(), 2, "expected primary + fallback requests");
739 let first: serde_json::Value =
740 serde_json::from_slice(&requests[0].body).expect("first request body is JSON");
741 let second: serde_json::Value =
742 serde_json::from_slice(&requests[1].body).expect("second request body is JSON");
743 assert_eq!(
744 first["reasoning"]["enabled"],
745 json!(false),
746 "primary request must send reasoning.enabled=false"
747 );
748 assert!(
749 second.get("reasoning").is_none(),
750 "fallback request must omit the reasoning field, got: {second}"
751 );
752 }
753
754 #[tokio::test]
755 async fn complete_honours_configured_timeout() {
756 let server = MockServer::start().await;
760 Mock::given(method("POST"))
761 .respond_with(
762 ResponseTemplate::new(200)
763 .set_delay(std::time::Duration::from_secs(2))
764 .set_body_json(success_body(r#"{"ok":1}"#, Some(0.0))),
765 )
766 .mount(&server)
767 .await;
768
769 let client = OpenRouterChatClient::new_with_url(
770 key(),
771 "z-ai/glm-5.2".to_string(),
772 format!("{}/chat/completions", server.uri()),
773 1,
774 )
775 .expect("client builds");
776 let err = client
777 .complete("system", "input", TEST_SCHEMA, None)
778 .await
779 .expect_err("request exceeds the 1s timeout");
780 assert!(err.to_string().contains("timed out"), "got: {err}");
781 }
782}