use std::collections::BTreeMap;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
const BASE_URL: &str = "https://openrouter.ai/api/v1";
#[derive(Clone)]
pub struct OpenRouterClient {
http: reqwest::Client,
api_key: String,
base_url: String,
}
impl OpenRouterClient {
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("OPENROUTER_API_KEY")
.context("OPENROUTER_API_KEY environment variable is not set")?;
Ok(Self {
http: reqwest::Client::new(),
api_key,
base_url: BASE_URL.to_string(),
})
}
#[cfg(test)]
pub(crate) fn with_base_url(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
Self {
http: reqwest::Client::new(),
api_key: api_key.into(),
base_url: base_url.into(),
}
}
pub async fn list_models(&self, query: &ModelsQuery) -> Result<Vec<Model>> {
let resp = self
.http
.get(format!("{}/models", self.base_url))
.bearer_auth(&self.api_key)
.query(&query.to_pairs())
.send()
.await
.context("request to OpenRouter /models failed")?
.error_for_status()
.context("OpenRouter /models returned an error status")?;
let parsed: ModelsResponse = resp
.json()
.await
.context("failed to decode OpenRouter /models response")?;
Ok(parsed.data)
}
}
impl OpenRouterClient {
pub async fn list_video_models(&self) -> Result<Vec<VideoModel>> {
let resp = self
.http
.get(format!("{}/videos/models", self.base_url))
.bearer_auth(&self.api_key)
.send()
.await
.context("request to OpenRouter /videos/models failed")?
.error_for_status()
.context("OpenRouter /videos/models returned an error status")?;
let parsed: VideoModelsResponse = resp
.json()
.await
.context("failed to decode OpenRouter /videos/models response")?;
Ok(parsed.data)
}
}
#[derive(Debug, Deserialize)]
pub struct VideoModelsResponse {
pub data: Vec<VideoModel>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct VideoModel {
pub id: String,
#[serde(default)]
pub pricing_skus: BTreeMap<String, String>,
}
#[derive(Debug, Default)]
pub struct ModelsQuery {
pub q: Option<String>,
pub output_modalities: Option<String>,
pub input_modalities: Option<String>,
pub supported_parameters: Option<String>,
pub sort: Option<String>,
pub context: Option<u64>,
}
impl ModelsQuery {
fn to_pairs(&self) -> Vec<(&'static str, String)> {
let mut pairs = Vec::new();
if let Some(v) = &self.q {
pairs.push(("q", v.clone()));
}
if let Some(v) = &self.output_modalities {
pairs.push(("output_modalities", v.clone()));
}
if let Some(v) = &self.input_modalities {
pairs.push(("input_modalities", v.clone()));
}
if let Some(v) = &self.supported_parameters {
pairs.push(("supported_parameters", v.clone()));
}
if let Some(v) = &self.sort {
pairs.push(("sort", v.clone()));
}
if let Some(v) = &self.context {
pairs.push(("context", v.to_string()));
}
pairs
}
}
#[derive(Debug, Deserialize)]
pub struct ModelsResponse {
pub data: Vec<Model>,
}
pub const DEFAULT_MODEL_LIMIT: usize = 20;
pub struct FilteredModels {
pub models: Vec<Model>,
pub total: usize,
}
impl FilteredModels {
pub fn truncated(&self) -> usize {
self.total - self.models.len()
}
}
pub fn apply_filters(mut models: Vec<Model>, search: Option<&str>, all: bool) -> FilteredModels {
if let Some(needle) = search {
models.retain(|m| m.matches_search(needle));
}
let total = models.len();
if !all {
models.truncate(DEFAULT_MODEL_LIMIT);
}
FilteredModels { models, total }
}
impl OpenRouterClient {
pub async fn chat_completion(&self, req: &ChatRequest) -> Result<ChatResponse> {
let resp = self
.http
.post(format!("{}/chat/completions", self.base_url))
.bearer_auth(&self.api_key)
.json(req)
.send()
.await
.context("request to OpenRouter /chat/completions failed")?;
let generation_id = resp
.headers()
.get("x-generation-id")
.and_then(|v| v.to_str().ok())
.map(str::to_string);
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("OpenRouter /chat/completions returned {status}: {body}");
}
let completion: ChatCompletion = resp
.json()
.await
.context("failed to decode OpenRouter /chat/completions response")?;
Ok(ChatResponse {
completion,
generation_id,
})
}
}
#[derive(Debug, Serialize)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub modalities: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_config: Option<ImageConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
pub stream: bool,
}
#[derive(Debug, Serialize)]
pub struct Message {
pub role: String,
pub content: Content,
}
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub enum Content {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Serialize)]
pub struct ImageConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub aspect_ratio: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_size: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
}
#[derive(Debug, Deserialize)]
pub struct ChatResponse {
pub completion: ChatCompletion,
pub generation_id: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ChatCompletion {
#[serde(default)]
pub id: Option<String>,
#[allow(dead_code)]
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub provider: Option<String>,
#[serde(default)]
pub choices: Vec<Choice>,
#[serde(default)]
pub usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
pub struct Choice {
pub message: ResponseMessage,
#[allow(dead_code)]
#[serde(default)]
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ResponseMessage {
#[serde(default)]
pub content: Option<String>,
#[serde(default)]
pub images: Vec<OutImage>,
}
#[derive(Debug, Deserialize)]
pub struct OutImage {
pub image_url: ImageUrl,
}
#[derive(Debug, Deserialize)]
pub struct Usage {
#[serde(default)]
pub cost: Option<f64>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Model {
pub id: String,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub context_length: Option<u64>,
#[serde(default)]
pub architecture: Option<Architecture>,
#[serde(default)]
pub pricing: Option<Pricing>,
}
impl Model {
pub fn matches_search(&self, needle: &str) -> bool {
let needle = needle.to_lowercase();
self.id.to_lowercase().contains(&needle)
|| self
.name
.as_deref()
.is_some_and(|n| n.to_lowercase().contains(&needle))
|| self
.description
.as_deref()
.is_some_and(|d| d.to_lowercase().contains(&needle))
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Architecture {
#[serde(default)]
pub modality: Option<String>,
#[serde(default)]
pub input_modalities: Vec<String>,
#[serde(default)]
pub output_modalities: Vec<String>,
#[serde(default)]
pub tokenizer: Option<String>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Pricing {
#[serde(default)]
pub prompt: Option<String>,
#[serde(default)]
pub completion: Option<String>,
#[serde(default)]
pub request: Option<String>,
#[serde(default)]
pub image: Option<String>,
#[serde(default)]
pub image_output: Option<String>,
#[serde(default)]
pub image_token: Option<String>,
#[serde(default)]
pub audio: Option<String>,
#[serde(default)]
pub audio_output: Option<String>,
#[serde(default)]
pub web_search: Option<String>,
#[serde(default)]
pub internal_reasoning: Option<String>,
#[serde(default)]
pub input_audio_cache: Option<String>,
#[serde(default)]
pub input_cache_read: Option<String>,
#[serde(default)]
pub input_cache_write: Option<String>,
#[serde(default)]
pub discount: Option<f64>,
}
#[cfg(test)]
mod tests {
use serde_json::json;
use wiremock::matchers::{method, path, query_param};
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::*;
#[tokio::test]
async fn list_models_sends_query_params_and_parses_data() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/models"))
.and(query_param("q", "openai"))
.and(query_param("sort", "newest"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"data": [
{"id": "openai/gpt", "name": "GPT", "context_length": 128000}
]
})))
.mount(&server)
.await;
let client = OpenRouterClient::with_base_url(server.uri(), "test-key");
let query = ModelsQuery {
q: Some("openai".to_string()),
sort: Some("newest".to_string()),
..Default::default()
};
let models = client.list_models(&query).await.unwrap();
assert_eq!(models.len(), 1);
assert_eq!(models[0].id, "openai/gpt");
assert_eq!(models[0].context_length, Some(128_000));
}
#[tokio::test]
async fn list_models_errors_on_non_success_status() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/models"))
.respond_with(ResponseTemplate::new(401))
.mount(&server)
.await;
let client = OpenRouterClient::with_base_url(server.uri(), "bad-key");
let err = client
.list_models(&ModelsQuery::default())
.await
.unwrap_err();
assert!(err.to_string().contains("error status"));
}
#[tokio::test]
async fn list_video_models_parses_pricing_skus() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/videos/models"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"data": [
{"id": "google/veo", "pricing_skus": {"duration_seconds": "0.1"}}
]
})))
.mount(&server)
.await;
let client = OpenRouterClient::with_base_url(server.uri(), "test-key");
let vms = client.list_video_models().await.unwrap();
assert_eq!(vms.len(), 1);
assert_eq!(vms[0].id, "google/veo");
assert_eq!(
vms[0]
.pricing_skus
.get("duration_seconds")
.map(String::as_str),
Some("0.1")
);
}
fn models(n: usize) -> Vec<Model> {
(0..n)
.map(|i| Model {
id: format!("model-{i}"),
name: None,
description: None,
context_length: None,
architecture: None,
pricing: None,
})
.collect()
}
#[test]
fn apply_filters_caps_at_default_limit_and_reports_total() {
let filtered = apply_filters(models(25), None, false);
assert_eq!(filtered.models.len(), DEFAULT_MODEL_LIMIT);
assert_eq!(filtered.total, 25);
assert_eq!(filtered.truncated(), 5);
}
#[test]
fn apply_filters_all_returns_everything_with_no_truncation() {
let filtered = apply_filters(models(25), None, true);
assert_eq!(filtered.models.len(), 25);
assert_eq!(filtered.total, 25);
assert_eq!(filtered.truncated(), 0);
}
#[test]
fn apply_filters_below_limit_is_not_truncated() {
let filtered = apply_filters(models(3), None, false);
assert_eq!(filtered.models.len(), 3);
assert_eq!(filtered.total, 3);
assert_eq!(filtered.truncated(), 0);
}
#[test]
fn apply_filters_search_runs_before_truncation() {
let mut all = models(30);
all[1].name = Some("special".to_string());
let filtered = apply_filters(all, Some("model-2"), false);
assert_eq!(filtered.total, 11);
assert_eq!(filtered.models.len(), 11); assert!(filtered.models.iter().all(|m| m.id.contains("model-2")));
}
#[test]
fn apply_filters_search_then_cap_reports_pre_truncation_total() {
let filtered = apply_filters(models(25), Some("MODEL-"), false);
assert_eq!(filtered.total, 25);
assert_eq!(filtered.models.len(), DEFAULT_MODEL_LIMIT);
assert_eq!(filtered.truncated(), 5);
}
#[test]
fn query_pairs_omit_empty_fields_and_keep_expected_names() {
let query = ModelsQuery {
q: Some("openai".to_string()),
output_modalities: Some("image,text".to_string()),
input_modalities: None,
supported_parameters: Some("tools".to_string()),
sort: Some("newest".to_string()),
context: Some(128_000),
};
assert_eq!(
query.to_pairs(),
vec![
("q", "openai".to_string()),
("output_modalities", "image,text".to_string()),
("supported_parameters", "tools".to_string()),
("sort", "newest".to_string()),
("context", "128000".to_string()),
]
);
}
#[test]
fn matches_search_checks_id_name_and_description_case_insensitively() {
let model = Model {
id: "openai/gpt-audio-mini".to_string(),
name: Some("OpenAI: GPT Audio Mini".to_string()),
description: Some("A cost-efficient audio model.".to_string()),
context_length: None,
architecture: None,
pricing: None,
};
assert!(model.matches_search("OPENAI"));
assert!(model.matches_search("audio mini"));
assert!(model.matches_search("cost-efficient"));
assert!(!model.matches_search("anthropic"));
}
#[test]
fn models_response_decodes_missing_optional_fields() {
let json = r#"{
"data": [
{
"id": "provider/minimal"
}
]
}"#;
let parsed: ModelsResponse = serde_json::from_str(json).unwrap();
assert_eq!(parsed.data.len(), 1);
let model = &parsed.data[0];
assert_eq!(model.id, "provider/minimal");
assert!(model.name.is_none());
assert!(model.architecture.is_none());
assert!(model.pricing.is_none());
}
#[test]
fn models_response_decodes_capabilities_and_pricing() {
let json = r#"{
"data": [
{
"id": "openai/example",
"name": "OpenAI Example",
"description": "Example model",
"context_length": 400000,
"architecture": {
"modality": "text+image->text",
"input_modalities": ["text", "image"],
"output_modalities": ["text"],
"tokenizer": "GPT"
},
"pricing": {
"prompt": "0.00000125",
"completion": "0.00001",
"web_search": "0.01",
"discount": 0.5
}
}
]
}"#;
let parsed: ModelsResponse = serde_json::from_str(json).unwrap();
let model = &parsed.data[0];
assert_eq!(model.context_length, Some(400_000));
let arch = model.architecture.as_ref().unwrap();
assert_eq!(arch.input_modalities, vec!["text", "image"]);
assert_eq!(arch.output_modalities, vec!["text"]);
assert_eq!(arch.tokenizer.as_deref(), Some("GPT"));
let pricing = model.pricing.as_ref().unwrap();
assert_eq!(pricing.prompt.as_deref(), Some("0.00000125"));
assert_eq!(pricing.completion.as_deref(), Some("0.00001"));
assert_eq!(pricing.web_search.as_deref(), Some("0.01"));
assert_eq!(pricing.discount, Some(0.5));
assert!(pricing.image.is_none());
}
}