use std::time::Instant;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
use super::{LlmProvider, LlmRequest, LlmResponse, error::LlmError};
const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
const CONNECT_TIMEOUT_SECS: u64 = 10;
const READ_TIMEOUT_SECS: u64 = 120;
const HTTP_REFERER: &str = "https://github.com/bobmatnyc/trusty-tools";
const X_TITLE: &str = "trusty-review";
#[derive(Debug, Serialize)]
struct OrcResponseFormat<'a> {
#[serde(rename = "type")]
type_: &'static str,
json_schema: OrcJsonSchema<'a>,
}
#[derive(Debug, Serialize)]
struct OrcJsonSchema<'a> {
name: &'a str,
strict: bool,
schema: &'a serde_json::Value,
}
#[derive(Debug, Serialize)]
struct OrcRequest<'a> {
model: &'a str,
messages: &'a [OrcMessage],
stream: bool,
temperature: f32,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<OrcResponseFormat<'a>>,
}
#[derive(Debug, Serialize)]
struct OrcMessage {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct OrcResponse {
choices: Vec<OrcChoice>,
#[serde(default)]
usage: Option<OrcUsage>,
model: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OrcChoice {
message: OrcChoiceMessage,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OrcChoiceMessage {
content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OrcUsage {
prompt_tokens: u32,
completion_tokens: u32,
}
fn cost_per_million(model: &str) -> (f64, f64) {
match model {
"openai/gpt-5.5-pro-20260423" => (30.00, 180.00),
"openai/gpt-5.5-20260423" => (5.00, 30.00),
"openai/gpt-5.4-20260305" => (2.50, 15.00),
"openai/gpt-5.4-mini-20260317" => (0.75, 4.50),
"openai/gpt-5.4-nano-20260317" => (0.20, 1.25),
other => gemini_cost_per_million(other),
}
}
fn gemini_cost_per_million(model: &str) -> (f64, f64) {
let m = model.to_ascii_lowercase();
if !m.contains("gemini") {
return (0.0, 0.0);
}
if m.contains("flash-lite") {
(0.10, 0.40)
} else if m.contains("flash") {
(0.30, 2.50)
} else {
(1.25, 10.00)
}
}
pub fn estimate_cost_usd(model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
let (in_price, out_price) = cost_per_million(model);
(input_tokens as f64 / 1_000_000.0) * in_price
+ (output_tokens as f64 / 1_000_000.0) * out_price
}
#[derive(Debug)]
pub struct OpenRouterProvider {
api_key: String,
model: String,
client: reqwest::Client,
}
impl OpenRouterProvider {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self, LlmError> {
let api_key = api_key.into();
if api_key.is_empty() {
return Err(LlmError::AccessDenied(
"OPENROUTER_API_KEY is empty".to_string(),
));
}
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(CONNECT_TIMEOUT_SECS))
.timeout(std::time::Duration::from_secs(READ_TIMEOUT_SECS))
.build()
.map_err(|e| LlmError::Transport(format!("build reqwest client: {e}")))?;
Ok(Self {
api_key,
model: model.into(),
client,
})
}
pub fn from_config(
config: &crate::config::ReviewConfig,
model: impl Into<String>,
) -> Result<Self, LlmError> {
Self::new(config.openrouter_api_key.clone(), model)
}
}
#[async_trait]
impl LlmProvider for OpenRouterProvider {
fn name(&self) -> &str {
"openrouter"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
debug!(
model = %self.model,
structured = req.response_schema.is_some(),
"openrouter complete request"
);
let mut messages = Vec::new();
if !req.system.is_empty() {
messages.push(OrcMessage {
role: "system".to_string(),
content: req.system.clone(),
});
}
for msg in &req.messages {
messages.push(OrcMessage {
role: msg.role.clone(),
content: msg.content.clone(),
});
}
let response_format = req.response_schema.as_ref().map(|s| OrcResponseFormat {
type_: "json_schema",
json_schema: OrcJsonSchema {
name: &s.name,
strict: true,
schema: &s.schema,
},
});
let body = OrcRequest {
model: &self.model,
messages: &messages,
stream: false,
temperature: req.temperature,
max_tokens: req.max_tokens,
response_format,
};
let start = Instant::now();
let http_resp = self
.client
.post(OPENROUTER_URL)
.bearer_auth(&self.api_key)
.header("HTTP-Referer", HTTP_REFERER)
.header("X-Title", X_TITLE)
.json(&body)
.send()
.await
.map_err(|e| LlmError::Transport(e.to_string()))?;
let latency_ms = start.elapsed().as_millis() as u64;
let status = http_resp.status();
if !status.is_success() {
let body_text = http_resp.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 | 403 => LlmError::AccessDenied(body_text),
404 => LlmError::ModelNotFound(format!("model={}: {body_text}", self.model)),
422 => LlmError::Validation(body_text),
429 => LlmError::RateLimited,
_ => LlmError::Upstream {
status: status.as_u16(),
body: body_text,
},
});
}
let orc: OrcResponse = http_resp.json().await.map_err(|e| {
warn!("failed to parse OpenRouter response: {e}");
LlmError::Upstream {
status: status.as_u16(),
body: e.to_string(),
}
})?;
let first_choice = orc.choices.into_iter().next();
let finish_reason = first_choice
.as_ref()
.and_then(|c| c.finish_reason.clone())
.map(|r| r.trim().to_ascii_lowercase());
let text = first_choice
.and_then(|c| c.message.content)
.unwrap_or_default();
let (input_tokens, output_tokens) = orc
.usage
.map(|u| (u.prompt_tokens, u.completion_tokens))
.unwrap_or((0, 0));
let model_used = orc.model.unwrap_or_else(|| self.model.clone());
let cost_usd = estimate_cost_usd(&model_used, input_tokens, output_tokens);
Ok(LlmResponse {
text,
model: model_used,
input_tokens,
output_tokens,
latency_ms,
cost_usd,
finish_reason,
})
}
}
#[cfg(test)]
#[path = "openrouter_tests.rs"]
mod tests;