use futures::{Stream, StreamExt};
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
Client,
};
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use tracing::{debug, instrument};
use crate::services::llm::models::errors::InferenceClientError;
use crate::services::llm::{
message::Message,
models::embedding::{EmbeddingsRequest, EmbeddingsResponse},
};
use crate::services::llm::{
models::chat::{ChatRequest, ChatResponse, ChatStreamChunk},
StructuredOuputFormat,
};
use crate::{
services::llm::models::base::{InferenceOptions, Role},
ClientConfig,
};
#[derive(Debug, Clone)]
pub struct OpenRouterClient {
client: Client,
base_url: String,
}
impl OpenRouterClient {
pub fn new(cfg: ClientConfig) -> Result<Self, InferenceClientError> {
let api_key = cfg
.api_key
.ok_or_else(|| InferenceClientError::Config("OpenRouter requires api_key".into()))?;
let base_url = cfg
.base_url
.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
InferenceClientError::Config(format!("Invalid api_key header: {e}"))
})?,
);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if let Some(extra) = cfg.extra_headers {
for (k, v) in extra.into_iter() {
let name = HeaderName::from_bytes(k.as_bytes()).map_err(|_| {
InferenceClientError::Config(format!("Invalid header name: {k}"))
})?;
let value = HeaderValue::from_str(&v).map_err(|_| {
InferenceClientError::Config(format!("Invalid header value for {k}"))
})?;
headers.insert(name, value);
}
}
let client = Client::builder().default_headers(headers).build()?;
Ok(Self { client, base_url })
}
fn map_messages(msgs: &[Message]) -> Vec<OrMessage> {
msgs.iter()
.map(|m| OrMessage {
role: match m.role {
Role::System => "system".to_string(),
Role::Developer => "system".to_string(),
Role::User => "user".to_string(),
Role::Assistant => "assistant".to_string(),
Role::Tool => "tool".to_string(),
},
content: m.content.clone().unwrap_or_default(),
})
.collect()
}
fn map_options(opts: &Option<InferenceOptions>) -> OrParams {
let mut p = OrParams::default();
if let Some(o) = opts {
p.temperature = o.temperature;
p.top_p = o.top_p;
p.top_k = o.top_k;
p.max_tokens = o.max_tokens.or(o.num_predict);
p.frequency_penalty = o.frequency_penalty;
p.presence_penalty = o.presence_penalty;
p.stop = o.stop.clone();
}
p
}
#[instrument(name = "openrouter.chat", skip_all)]
async fn chat_inner(
&self,
req: ChatRequest,
stream: bool,
) -> Result<reqwest::Response, InferenceClientError> {
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
let mut body = OrChatRequest::from(req);
body.stream = Some(stream);
let resp = self.client.post(url).json(&body).send().await?;
Ok(resp)
}
pub async fn chat(&self, req: ChatRequest) -> Result<ChatResponse, InferenceClientError> {
let resp = self.chat_inner(req, false).await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
if let Some(e) = parse_oopen_router_error(&text) {
return Err(e);
}
return Err(InferenceClientError::Api(format!(
"Request failed: {status} - {text}"
)));
}
if let Some(e) = parse_oopen_router_error(&text) {
return Err(e);
}
let or: OrChatResponse = serde_json::from_str(&text).map_err(|e| {
InferenceClientError::Serialization(format!("decode error: {e}; raw: {text}"))
})?;
let message = or
.choices
.first()
.map(|c| Message::assistant(c.message.content.clone()))
.unwrap_or(Message::assistant(String::new()));
Ok(ChatResponse {
model: or.model,
created_at: or.created.to_string(),
message,
done: true,
done_reason: or.choices.first().and_then(|c| c.finish_reason.clone()),
total_duration: None,
load_duration: None,
prompt_eval_count: None,
prompt_eval_duration: None,
eval_count: None,
eval_duration: None,
})
}
pub async fn chat_stream(
&self,
req: ChatRequest,
) -> Result<
Pin<Box<dyn Stream<Item = Result<ChatStreamChunk, InferenceClientError>> + Send + 'static>>,
InferenceClientError,
> {
use async_stream::try_stream;
let resp = self.chat_inner(req, true).await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await.unwrap_or_default();
if let Some(e) = parse_oopen_router_error(&text) {
return Err(e);
}
return Err(InferenceClientError::Api(format!(
"Request failed: {status} - {text}"
)));
}
let byte_stream = resp.bytes_stream();
let s = try_stream! {
let mut buf = Vec::<u8>::new();
futures::pin_mut!(byte_stream);
while let Some(chunk) = byte_stream.next().await {
let chunk = chunk.map_err(|e| InferenceClientError::Request(e.to_string()))?;
buf.extend_from_slice(&chunk);
while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
let line: Vec<u8> = buf.drain(..=pos).collect();
let line = String::from_utf8_lossy(&line).trim().to_string();
if line.is_empty() || line.starts_with(':') { continue; }
if !line.starts_with("data:") { continue; }
let data = line[5..].trim();
if data.contains("[DONE]") {
yield ChatStreamChunk {
model: String::new(),
created_at: String::new(),
message: None,
done: true,
done_reason: Some("stop".into()),
total_duration: None,
load_duration: None,
prompt_eval_count: None,
prompt_eval_duration: None,
eval_count: None,
eval_duration: None,
};
return;
}
if let Some(e) = parse_oopen_router_error(data) {
Err(e)?;
}
let parsed: OrStreamChunk = match serde_json::from_str(data) {
Ok(v) => v,
Err(e) => { debug!(err = %e, raw = %data, "stream json decode error"); continue; }
};
let mut out_msg: Option<Message> = None;
if let Some(choice) = parsed.choices.first() {
if let Some(content) = choice.delta.content.clone() {
out_msg = Some(Message::assistant(content));
}
}
yield ChatStreamChunk {
model: parsed.model,
created_at: parsed.created.to_string(),
message: out_msg,
done: false,
done_reason: None,
total_duration: None,
load_duration: None,
prompt_eval_count: None,
prompt_eval_duration: None,
eval_count: None,
eval_duration: None,
};
}
}
yield ChatStreamChunk {
model: String::new(),
created_at: String::new(),
message: None,
done: true,
done_reason: Some("eof".into()),
total_duration: None,
load_duration: None,
prompt_eval_count: None,
prompt_eval_duration: None,
eval_count: None,
eval_duration: None,
};
};
Ok(Box::pin(s))
}
pub async fn embeddings(
&self,
_req: EmbeddingsRequest,
) -> Result<EmbeddingsResponse, InferenceClientError> {
Err(InferenceClientError::Unsupported(
"OpenRouter embeddings are not available".into(),
))
}
}
#[derive(Serialize, Default)]
struct OrChatRequest {
model: String,
messages: Vec<OrMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_a: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
min_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
seed: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
logit_bias: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
structured_outputs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
verbosity: Option<String>,
}
impl From<ChatRequest> for OrChatRequest {
fn from(value: ChatRequest) -> Self {
let ChatRequest {
base,
messages,
tools,
} = value;
let params = OpenRouterClient::map_options(&base.options);
let stop_vec = match base.options.as_ref().and_then(|o| o.stop.clone()) {
Some(s) => Some(vec![s]), None => None,
};
Self {
model: base.model,
messages: OpenRouterClient::map_messages(&messages), temperature: params.temperature,
top_p: params.top_p,
top_k: params.top_k,
top_a: None,
min_p: None,
repetition_penalty: None,
seed: None,
max_tokens: params.max_tokens,
presence_penalty: params.presence_penalty,
frequency_penalty: params.frequency_penalty,
logit_bias: None,
logprobs: None,
top_logprobs: None,
stop: stop_vec,
stream: base.stream,
tools: tools.and_then(|t| serde_json::to_value(t).ok()),
tool_choice: None,
parallel_tool_calls: None,
response_format: base.format,
structured_outputs: None,
verbosity: None,
}
}
}
#[derive(Serialize, Default)]
struct OrParams {
temperature: Option<f32>,
top_p: Option<f32>,
top_k: Option<u32>,
max_tokens: Option<i32>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
stop: Option<String>,
}
#[derive(Serialize, Deserialize, Clone)]
struct OrMessage {
role: String,
content: String,
}
#[derive(Deserialize)]
struct OrChoice {
message: OrMessage,
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct OrChatResponse {
id: String,
provider: String,
model: String,
object: String,
created: u64,
choices: Vec<OrChoice>,
}
#[derive(Deserialize)]
struct OrDelta {
_role: Option<String>,
content: Option<String>,
}
#[derive(Deserialize)]
struct OrDeltaChoice {
delta: OrDelta,
_finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct OrStreamChunk {
_id: String,
created: u64,
model: String,
choices: Vec<OrDeltaChoice>,
}
#[derive(Deserialize, Debug)]
struct OrErrorEnvelope {
error: OrErrorBody,
#[allow(dead_code)]
user_id: Option<String>,
}
#[derive(Deserialize, Debug)]
struct OrErrorBody {
message: String,
code: serde_json::Value,
}
fn parse_oopen_router_error(text: &str) -> Option<InferenceClientError> {
let s = text.trim_start();
if !s.starts_with('{') || !s.contains("\"error\"") {
return None;
}
match serde_json::from_str::<OrErrorEnvelope>(s) {
Ok(env) => {
let code = env.error.code;
let msg = env.error.message;
Some(InferenceClientError::Api(format!(
"OpenRouter error {code}: {msg}"
)))
}
Err(_) => None,
}
}
impl StructuredOuputFormat for OpenRouterClient {
fn format(spec: &crate::services::llm::SchemaSpec) -> serde_json::Value {
serde_json::json!({
"type": "json_schema",
"json_schema": {
"name": spec.name.clone().unwrap_or_else(|| "schema".to_string()),
"strict": spec.strict.unwrap_or(false),
"schema": spec.schema
}
})
}
}