use crate::errors::AppError;
use secrecy::{ExposeSecret, SecretBox};
use serde::{Deserialize, Serialize};
use std::time::Duration;
const OPENROUTER_CHAT_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
const DEFAULT_TIMEOUT_SECS: u64 = 300;
const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
const MAX_RETRIES: u32 = 4;
const SCHEMA_NAME: &str = "enrich_output";
#[derive(Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<ChatMessage<'a>>,
response_format: ResponseFormat,
provider: ProviderPrefs,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<ReasoningPrefs>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
}
#[derive(Serialize)]
struct ChatMessage<'a> {
role: &'a str,
content: String,
}
#[derive(Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
format_type: &'static str,
json_schema: JsonSchemaSpec,
}
#[derive(Serialize)]
struct JsonSchemaSpec {
name: &'static str,
strict: bool,
schema: serde_json::Value,
}
#[derive(Serialize)]
struct ProviderPrefs {
require_parameters: bool,
}
#[derive(Serialize)]
struct ReasoningPrefs {
enabled: bool,
}
#[derive(Deserialize)]
struct ChatResponse {
#[serde(default)]
choices: Vec<Choice>,
#[serde(default)]
usage: Option<Usage>,
}
#[derive(Deserialize)]
struct Choice {
message: RespMessage,
}
#[derive(Deserialize)]
struct RespMessage {
#[serde(default)]
content: Option<String>,
}
#[derive(Deserialize)]
struct Usage {
#[serde(default)]
cost: Option<f64>,
}
pub struct OpenRouterChatClient {
client: reqwest::Client,
api_key: SecretBox<String>,
model: String,
base_url: String,
}
impl OpenRouterChatClient {
pub fn new(
api_key: SecretBox<String>,
model: String,
timeout_secs: u64,
) -> Result<Self, AppError> {
let timeout_secs = if timeout_secs == 0 {
DEFAULT_TIMEOUT_SECS
} else {
timeout_secs
};
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
.user_agent("sqlite-graphrag/1.0.95")
.build()
.map_err(|e| AppError::Validation(format!("failed to build HTTP client: {e}")))?;
Ok(Self {
client,
api_key,
model,
base_url: OPENROUTER_CHAT_URL.to_string(),
})
}
#[cfg(test)]
pub fn new_with_url(
api_key: SecretBox<String>,
model: String,
base_url: String,
timeout_secs: u64,
) -> Result<Self, AppError> {
let mut client = Self::new(api_key, model, timeout_secs)?;
client.base_url = base_url;
Ok(client)
}
pub fn model(&self) -> &str {
&self.model
}
pub async fn complete(
&self,
system_prompt: &str,
input_text: &str,
schema_str: &str,
max_tokens: Option<u32>,
) -> Result<(serde_json::Value, f64, bool), AppError> {
let schema: serde_json::Value = serde_json::from_str(schema_str).map_err(|e| {
AppError::Validation(format!("invalid JSON schema for OpenRouter request: {e}"))
})?;
let primary = self.build_request(
schema.clone(),
system_prompt,
input_text,
max_tokens,
Some(ReasoningPrefs { enabled: false }),
);
let response = match self.execute_with_retry(&primary).await {
Ok(r) => r,
Err(first_err) => {
if reasoning_disable_rejected(&first_err) {
tracing::warn!(
model = %self.model,
"model rejected reasoning.enabled=false (mandatory); \
retrying once with reasoning omitted"
);
let fallback =
self.build_request(schema, system_prompt, input_text, max_tokens, None);
match self.execute_with_retry(&fallback).await {
Ok(r) => r,
Err(_) => return Err(first_err),
}
} else {
return Err(first_err);
}
}
};
let content = response
.choices
.into_iter()
.next()
.and_then(|c| c.message.content)
.filter(|c| !c.trim().is_empty())
.ok_or_else(|| {
AppError::Validation(format!(
"model '{}' returned no structured content (incompatible with \
structured outputs, or refused the request)",
self.model
))
})?;
let value: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
AppError::Validation(format!(
"model '{}' returned non-JSON content despite strict schema: {e}",
self.model
))
})?;
let cost = response.usage.and_then(|u| u.cost).unwrap_or(0.0);
Ok((value, cost, false))
}
fn build_request<'a>(
&'a self,
schema: serde_json::Value,
system_prompt: &str,
input_text: &str,
max_tokens: Option<u32>,
reasoning: Option<ReasoningPrefs>,
) -> ChatRequest<'a> {
let mut messages = Vec::with_capacity(2);
messages.push(ChatMessage {
role: "system",
content: system_prompt.to_string(),
});
if !input_text.is_empty() {
messages.push(ChatMessage {
role: "user",
content: input_text.to_string(),
});
}
ChatRequest {
model: &self.model,
messages,
response_format: ResponseFormat {
format_type: "json_schema",
json_schema: JsonSchemaSpec {
name: SCHEMA_NAME,
strict: true,
schema,
},
},
provider: ProviderPrefs {
require_parameters: true,
},
reasoning,
max_tokens,
}
}
async fn execute_with_retry(
&self,
request: &ChatRequest<'_>,
) -> Result<ChatResponse, AppError> {
let mut last_err = None;
for attempt in 0..MAX_RETRIES {
let result = self
.client
.post(&self.base_url)
.header(
"Authorization",
format!("Bearer {}", self.api_key.expose_secret()),
)
.json(request)
.send()
.await;
let resp = match result {
Ok(r) => r,
Err(e) if e.is_timeout() => {
return Err(AppError::Validation(
"OpenRouter chat request timed out".into(),
));
}
Err(e) => {
last_err = Some(AppError::Validation(format!("HTTP request failed: {e}")));
Self::backoff(attempt).await;
continue;
}
};
let status = resp.status();
if status.is_success() {
let body = resp.text().await.map_err(|e| {
AppError::Validation(format!("failed to read response body: {e}"))
})?;
match serde_json::from_str::<ChatResponse>(&body) {
Ok(parsed) => return Ok(parsed),
Err(e) => {
tracing::warn!(
attempt,
body_len = body.len(),
"HTTP 200 but parse failed (retrying): {e}"
);
last_err = Some(AppError::Validation(format!(
"failed to parse chat response: {e}"
)));
Self::backoff(attempt).await;
continue;
}
}
}
if status.as_u16() == 401 {
return Err(AppError::Validation(
"invalid OpenRouter API key (HTTP 401)".into(),
));
}
if status.as_u16() == 400 || status.as_u16() == 404 {
let body = resp.text().await.unwrap_or_default();
return Err(AppError::Validation(format!(
"OpenRouter returned {status} for model '{}': {body}",
self.model
)));
}
if status.as_u16() == 429 {
let retry_after = resp
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(2);
tracing::warn!(
attempt,
retry_after_secs = retry_after,
"OpenRouter rate limited, waiting"
);
tokio::time::sleep(Duration::from_secs(retry_after)).await;
continue;
}
if status.is_server_error() {
tracing::warn!(attempt, status = %status, "OpenRouter server error, retrying");
last_err = Some(AppError::Validation(format!(
"OpenRouter server error: {status}"
)));
Self::backoff(attempt).await;
continue;
}
let body = resp.text().await.unwrap_or_default();
return Err(AppError::Validation(format!(
"unexpected HTTP {status}: {body}"
)));
}
Err(last_err.unwrap_or_else(|| {
AppError::Validation("max retries exceeded for OpenRouter chat request".into())
}))
}
async fn backoff(attempt: u32) {
let base_ms = 1000u64 * 2u64.pow(attempt);
let jitter = fastrand::u64(0..500);
let sleep_ms = base_ms + jitter;
tracing::debug!(attempt, sleep_ms, "exponential backoff");
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
}
}
fn reasoning_disable_rejected(err: &AppError) -> bool {
let msg = err.to_string().to_lowercase();
msg.contains("400") && msg.contains("reasoning")
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use wiremock::matchers::{body_partial_json, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
const TEST_SCHEMA: &str = r#"{"type":"object"}"#;
fn key() -> SecretBox<String> {
SecretBox::new(Box::new("test-key".to_string()))
}
fn success_body(content: &str, cost: Option<f64>) -> serde_json::Value {
let mut body = json!({
"choices": [{ "message": { "content": content } }]
});
if let Some(c) = cost {
body["usage"] = json!({ "cost": c });
}
body
}
async fn client_for(server: &MockServer, model: &str) -> OpenRouterChatClient {
OpenRouterChatClient::new_with_url(
key(),
model.to_string(),
format!("{}/chat/completions", server.uri()),
30,
)
.expect("test client builds")
}
#[test]
fn new_builds_client_and_binds_model() {
let client = OpenRouterChatClient::new(key(), "z-ai/glm-5.2".to_string(), 30)
.expect("client builds");
assert_eq!(client.model(), "z-ai/glm-5.2");
}
#[test]
fn new_defaults_base_url_to_public_endpoint() {
let client = OpenRouterChatClient::new(key(), "z-ai/glm-5.2".to_string(), 30)
.expect("client builds");
assert_eq!(client.base_url, OPENROUTER_CHAT_URL);
}
#[test]
fn request_serializes_with_strict_schema_and_disabled_reasoning() {
let request = ChatRequest {
model: "deepseek/deepseek-v4-flash",
messages: vec![ChatMessage {
role: "system",
content: "extract".to_string(),
}],
response_format: ResponseFormat {
format_type: "json_schema",
json_schema: JsonSchemaSpec {
name: SCHEMA_NAME,
strict: true,
schema: serde_json::json!({"type": "object"}),
},
},
provider: ProviderPrefs {
require_parameters: true,
},
reasoning: Some(ReasoningPrefs { enabled: false }),
max_tokens: None,
};
let json = serde_json::to_value(&request).expect("serializes");
assert_eq!(json["response_format"]["type"], "json_schema");
assert_eq!(json["response_format"]["json_schema"]["strict"], true);
assert_eq!(json["provider"]["require_parameters"], true);
assert_eq!(json["reasoning"]["enabled"], false);
assert!(json.get("max_tokens").is_none());
}
#[tokio::test]
async fn complete_sends_wellformed_request_and_parses_content() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/chat/completions"))
.and(body_partial_json(json!({
"model": "deepseek/deepseek-v4-flash",
"response_format": {
"type": "json_schema",
"json_schema": { "name": "enrich_output", "strict": true }
},
"provider": { "require_parameters": true },
"reasoning": { "enabled": false }
})))
.respond_with(ResponseTemplate::new(200).set_body_json(success_body(
r#"{"entities":[],"relationships":[]}"#,
Some(0.0023),
)))
.expect(1)
.mount(&server)
.await;
let client = client_for(&server, "deepseek/deepseek-v4-flash").await;
let (value, cost, is_oauth) = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect("completion succeeds");
assert_eq!(value, json!({"entities": [], "relationships": []}));
assert!((cost - 0.0023).abs() < f64::EPSILON);
assert!(!is_oauth);
}
#[tokio::test]
async fn complete_defaults_cost_to_zero_when_usage_absent() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(200).set_body_json(success_body(r#"{"entities":[]}"#, None)),
)
.mount(&server)
.await;
let client = client_for(&server, "z-ai/glm-5.2").await;
let (_, cost, _) = client
.complete("system", "", TEST_SCHEMA, Some(4096))
.await
.expect("completion succeeds");
assert_eq!(cost, 0.0);
}
#[tokio::test]
async fn complete_retries_on_429_honouring_retry_after() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(429).insert_header("retry-after", "1"))
.up_to_n_times(1)
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(200).set_body_json(success_body(r#"{"ok":true}"#, Some(0.0))),
)
.expect(1)
.mount(&server)
.await;
let client = client_for(&server, "minimax/minimax-m3").await;
let (value, _, _) = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect("retried completion succeeds");
assert_eq!(value, json!({"ok": true}));
}
#[tokio::test]
async fn complete_retries_on_5xx_with_backoff() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(503))
.up_to_n_times(1)
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(200).set_body_json(success_body(r#"{"ok":1}"#, Some(0.0))),
)
.expect(1)
.mount(&server)
.await;
let client = client_for(&server, "openai/gpt-oss-120b").await;
let (value, _, _) = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect("retried completion succeeds");
assert_eq!(value, json!({"ok": 1}));
}
#[tokio::test]
async fn complete_401_is_permanent_without_retry() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(401))
.expect(1)
.mount(&server)
.await;
let client = client_for(&server, "z-ai/glm-5.2").await;
let err = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect_err("401 is an error");
assert!(err.to_string().contains("401"), "got: {err}");
}
#[tokio::test]
async fn complete_400_returns_body_and_model_without_retry() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(400).set_body_string("schema not supported"))
.expect(1)
.mount(&server)
.await;
let client = client_for(&server, "xiaomi/mimo-v2.5").await;
let err = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect_err("400 is an error");
let msg = err.to_string();
assert!(msg.contains("400"), "got: {msg}");
assert!(msg.contains("xiaomi/mimo-v2.5"), "got: {msg}");
assert!(msg.contains("schema not supported"), "got: {msg}");
}
#[tokio::test]
async fn complete_empty_choices_errors_citing_model() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({ "choices": [] })))
.mount(&server)
.await;
let client = client_for(&server, "minimax/minimax-m2.7").await;
let err = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect_err("empty choices is an error");
let msg = err.to_string();
assert!(msg.contains("minimax/minimax-m2.7"), "got: {msg}");
assert!(msg.contains("no structured content"), "got: {msg}");
}
#[tokio::test]
async fn complete_empty_content_errors() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(success_body(" ", Some(0.0))))
.mount(&server)
.await;
let client = client_for(&server, "z-ai/glm-5.2:nitro").await;
let err = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect_err("blank content is an error");
assert!(
err.to_string().contains("no structured content"),
"got: {err}"
);
}
#[tokio::test]
async fn complete_non_json_content_errors_as_incompatible() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(success_body("this is not json", Some(0.0))),
)
.mount(&server)
.await;
let client = client_for(&server, "google/gemini-3.1-flash-lite").await;
let err = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect_err("non-json content is an error");
let msg = err.to_string();
assert!(msg.contains("non-JSON content"), "got: {msg}");
assert!(msg.contains("google/gemini-3.1-flash-lite"), "got: {msg}");
}
#[tokio::test]
async fn complete_rejects_invalid_schema_before_network() {
let client = OpenRouterChatClient::new_with_url(
key(),
"z-ai/glm-5.2".to_string(),
"http://127.0.0.1:1/chat/completions".to_string(),
30,
)
.expect("client builds");
let err = client
.complete("system", "input", "{not valid json", None)
.await
.expect_err("invalid schema is rejected");
assert!(
err.to_string().contains("invalid JSON schema"),
"got: {err}"
);
}
#[tokio::test]
async fn complete_retries_with_reasoning_omitted_when_mandatory() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(400).set_body_string(
"reasoning is mandatory for this model and cannot be disabled",
),
)
.up_to_n_times(1)
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(success_body(
r#"{"entities":[],"relationships":[]}"#,
Some(0.0),
)))
.expect(1)
.mount(&server)
.await;
let client = client_for(&server, "minimax/minimax-m2.7").await;
let (value, _, _) = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect("fallback completion succeeds");
assert_eq!(value, json!({"entities": [], "relationships": []}));
let requests = server
.received_requests()
.await
.expect("request recording is enabled");
assert_eq!(requests.len(), 2, "expected primary + fallback requests");
let first: serde_json::Value =
serde_json::from_slice(&requests[0].body).expect("first request body is JSON");
let second: serde_json::Value =
serde_json::from_slice(&requests[1].body).expect("second request body is JSON");
assert_eq!(
first["reasoning"]["enabled"],
json!(false),
"primary request must send reasoning.enabled=false"
);
assert!(
second.get("reasoning").is_none(),
"fallback request must omit the reasoning field, got: {second}"
);
}
#[tokio::test]
async fn complete_honours_configured_timeout() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(200)
.set_delay(std::time::Duration::from_secs(2))
.set_body_json(success_body(r#"{"ok":1}"#, Some(0.0))),
)
.mount(&server)
.await;
let client = OpenRouterChatClient::new_with_url(
key(),
"z-ai/glm-5.2".to_string(),
format!("{}/chat/completions", server.uri()),
1,
)
.expect("client builds");
let err = client
.complete("system", "input", TEST_SCHEMA, None)
.await
.expect_err("request exceeds the 1s timeout");
assert!(err.to_string().contains("timed out"), "got: {err}");
}
}