use std::collections::HashMap;
use cognis_core::error::Result;
use crate::chat_models::openai::{ChatOpenAI, ChatOpenAIBuilder};
const OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api";
pub struct ChatOpenRouter;
impl ChatOpenRouter {
pub fn builder() -> ChatOpenRouterBuilder {
ChatOpenRouterBuilder::new()
}
}
pub struct ChatOpenRouterBuilder {
inner: ChatOpenAIBuilder,
}
impl ChatOpenRouterBuilder {
fn new() -> Self {
Self {
inner: ChatOpenAI::builder().base_url(OPENROUTER_BASE_URL),
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.inner = self.inner.model(model);
self
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.inner = self.inner.api_key(key);
self
}
pub fn app_name(mut self, name: impl Into<String>) -> Self {
self.inner = self.inner.extra_header("X-Title", name);
self
}
pub fn app_url(mut self, url: impl Into<String>) -> Self {
self.inner = self.inner.extra_header("HTTP-Referer", url);
self
}
pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
self.inner = self.inner.base_url(base_url);
self
}
pub fn temperature(mut self, temperature: f64) -> Self {
self.inner = self.inner.temperature(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.inner = self.inner.max_tokens(max_tokens);
self
}
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.inner = self.inner.max_retries(max_retries);
self
}
pub fn streaming(mut self, streaming: bool) -> Self {
self.inner = self.inner.streaming(streaming);
self
}
pub fn extra_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.inner = self.inner.extra_header(key, value);
self
}
pub fn extra_headers(mut self, headers: HashMap<String, String>) -> Self {
self.inner = self.inner.extra_headers(headers);
self
}
pub fn build(self) -> Result<ChatOpenAI> {
self.inner.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_defaults_base_url_to_openrouter() {
let model = ChatOpenRouter::builder()
.model("meta-llama/llama-3.3-70b-instruct:free")
.api_key("sk-or-test")
.build()
.unwrap();
assert_eq!(model.base_url, OPENROUTER_BASE_URL);
}
#[test]
fn app_name_sets_x_title_header() {
let model = ChatOpenRouter::builder()
.model("gpt-4o")
.api_key("sk-or-test")
.app_name("my-assistant")
.build()
.unwrap();
assert_eq!(
model.extra_headers.get("X-Title").map(String::as_str),
Some("my-assistant"),
);
}
#[test]
fn app_url_sets_http_referer_header() {
let model = ChatOpenRouter::builder()
.model("gpt-4o")
.api_key("sk-or-test")
.app_url("https://example.com")
.build()
.unwrap();
assert_eq!(
model.extra_headers.get("HTTP-Referer").map(String::as_str),
Some("https://example.com"),
);
}
#[test]
fn app_name_and_app_url_coexist_with_extra_header() {
let model = ChatOpenRouter::builder()
.model("gpt-4o")
.api_key("sk-or-test")
.app_name("assistant")
.app_url("https://example.com")
.extra_header("X-Custom", "value")
.build()
.unwrap();
assert_eq!(model.extra_headers.len(), 3);
assert_eq!(
model.extra_headers.get("X-Title").map(String::as_str),
Some("assistant"),
);
assert_eq!(
model.extra_headers.get("HTTP-Referer").map(String::as_str),
Some("https://example.com"),
);
assert_eq!(
model.extra_headers.get("X-Custom").map(String::as_str),
Some("value"),
);
}
#[test]
fn base_url_override_is_respected() {
let model = ChatOpenRouter::builder()
.model("gpt-4o")
.api_key("sk-or-test")
.base_url("https://staging.openrouter.example")
.build()
.unwrap();
assert_eq!(model.base_url, "https://staging.openrouter.example");
}
}