use std::collections::HashMap;
use derive_builder::Builder;
use futures_util::{AsyncBufReadExt, StreamExt, stream::BoxStream};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use surf::http::headers::AUTHORIZATION;
use crate::{
error::OpenRouterError,
strip_option_map_setter, strip_option_vec_setter,
types::{
ProviderPreferences, ReasoningConfig, ResponseFormat, Role, completion::CompletionsResponse,
},
utils::handle_error,
};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
impl ImageUrl {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
detail: None,
}
}
pub fn with_detail(url: impl Into<String>, detail: impl Into<String>) -> Self {
Self {
url: url.into(),
detail: Some(detail.into()),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct InputAudio {
pub data: String,
pub format: String,
}
impl InputAudio {
pub fn new(data: impl Into<String>, format: impl Into<String>) -> Self {
Self {
data: data.into(),
format: format.into(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct VideoUrl {
pub url: String,
}
impl VideoUrl {
pub fn new(url: impl Into<String>) -> Self {
Self { url: url.into() }
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct FileInput {
#[serde(skip_serializing_if = "Option::is_none")]
pub file_data: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub file_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub filename: Option<String>,
}
impl FileInput {
pub fn from_data(file_data: impl Into<String>) -> Self {
Self {
file_data: Some(file_data.into()),
file_id: None,
filename: None,
}
}
pub fn from_id(file_id: impl Into<String>) -> Self {
Self {
file_data: None,
file_id: Some(file_id.into()),
filename: None,
}
}
pub fn filename(mut self, filename: impl Into<String>) -> Self {
self.filename = Some(filename.into());
self
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "lowercase")]
pub enum CacheControlType {
Ephemeral,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CacheControl {
#[serde(rename = "type")]
pub kind: CacheControlType,
#[serde(skip_serializing_if = "Option::is_none")]
pub ttl: Option<String>,
}
impl CacheControl {
pub fn ephemeral() -> Self {
Self {
kind: CacheControlType::Ephemeral,
ttl: None,
}
}
pub fn ephemeral_with_ttl(ttl: impl Into<String>) -> Self {
Self {
kind: CacheControlType::Ephemeral,
ttl: Some(ttl.into()),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
ImageUrl { image_url: ImageUrl },
InputAudio { input_audio: InputAudio },
VideoUrl { video_url: VideoUrl },
InputVideo { video_url: VideoUrl },
File { file: FileInput },
}
impl ContentPart {
pub fn text(text: impl Into<String>) -> Self {
Self::Text {
text: text.into(),
cache_control: None,
}
}
pub fn text_with_cache_control(text: impl Into<String>, cache_control: CacheControl) -> Self {
Self::Text {
text: text.into(),
cache_control: Some(cache_control),
}
}
pub fn cacheable_text(text: impl Into<String>) -> Self {
Self::text_with_cache_control(text, CacheControl::ephemeral())
}
pub fn cacheable_text_with_ttl(text: impl Into<String>, ttl: impl Into<String>) -> Self {
Self::text_with_cache_control(text, CacheControl::ephemeral_with_ttl(ttl))
}
pub fn image_url(url: impl Into<String>) -> Self {
Self::ImageUrl {
image_url: ImageUrl::new(url),
}
}
pub fn image_url_with_detail(url: impl Into<String>, detail: impl Into<String>) -> Self {
Self::ImageUrl {
image_url: ImageUrl::with_detail(url, detail),
}
}
pub fn input_audio(data: impl Into<String>, format: impl Into<String>) -> Self {
Self::InputAudio {
input_audio: InputAudio::new(data, format),
}
}
pub fn video_url(url: impl Into<String>) -> Self {
Self::VideoUrl {
video_url: VideoUrl::new(url),
}
}
pub fn input_video(url: impl Into<String>) -> Self {
Self::InputVideo {
video_url: VideoUrl::new(url),
}
}
pub fn file_data(file_data: impl Into<String>) -> Self {
Self::File {
file: FileInput::from_data(file_data),
}
}
pub fn file_data_with_filename(
file_data: impl Into<String>,
filename: impl Into<String>,
) -> Self {
Self::File {
file: FileInput::from_data(file_data).filename(filename),
}
}
pub fn file_id(file_id: impl Into<String>) -> Self {
Self::File {
file: FileInput::from_id(file_id),
}
}
pub fn file_id_with_filename(file_id: impl Into<String>, filename: impl Into<String>) -> Self {
Self::File {
file: FileInput::from_id(file_id).filename(filename),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum Content {
Text(String),
Parts(Vec<ContentPart>),
}
impl From<String> for Content {
fn from(s: String) -> Self {
Self::Text(s)
}
}
impl From<&str> for Content {
fn from(s: &str) -> Self {
Self::Text(s.to_string())
}
}
impl From<Vec<ContentPart>> for Content {
fn from(parts: Vec<ContentPart>) -> Self {
Self::Parts(parts)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Message {
pub role: Role,
pub content: Content,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<crate::types::ToolCall>>,
}
impl Message {
pub fn new(role: Role, content: impl Into<Content>) -> Self {
Self {
role,
content: content.into(),
name: None,
tool_call_id: None,
tool_calls: None,
}
}
pub fn with_parts(role: Role, parts: Vec<ContentPart>) -> Self {
Self {
role,
content: Content::Parts(parts),
name: None,
tool_call_id: None,
tool_calls: None,
}
}
pub fn tool_response(tool_call_id: &str, content: impl Into<Content>) -> Self {
Self {
role: Role::Tool,
content: content.into(),
name: None,
tool_call_id: Some(tool_call_id.to_string()),
tool_calls: None,
}
}
pub fn tool_response_named(
tool_call_id: &str,
tool_name: &str,
content: impl Into<Content>,
) -> Self {
Self {
role: Role::Tool,
content: content.into(),
name: Some(tool_name.to_string()),
tool_call_id: Some(tool_call_id.to_string()),
tool_calls: None,
}
}
pub fn named(role: Role, name: &str, content: impl Into<Content>) -> Self {
Self {
role,
content: content.into(),
name: Some(name.to_string()),
tool_call_id: None,
tool_calls: None,
}
}
pub fn assistant_with_tool_calls(
content: impl Into<Content>,
tool_calls: Vec<crate::types::ToolCall>,
) -> Self {
Self {
role: Role::Assistant,
content: content.into(),
name: None,
tool_call_id: None,
tool_calls: Some(tool_calls),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Modality {
Text,
Image,
Audio,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct DebugOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub echo_upstream_body: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct TraceOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub trace_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub trace_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub span_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent_span_id: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, Value>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct Plugin {
pub id: String,
#[serde(flatten)]
pub config: HashMap<String, Value>,
}
impl Plugin {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
config: HashMap::new(),
}
}
pub fn option(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
self.config.insert(key.into(), value.into());
self
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum StopSequence {
Single(String),
Multiple(Vec<String>),
}
impl From<String> for StopSequence {
fn from(value: String) -> Self {
Self::Single(value)
}
}
impl From<&str> for StopSequence {
fn from(value: &str) -> Self {
Self::Single(value.to_string())
}
}
impl From<Vec<String>> for StopSequence {
fn from(value: Vec<String>) -> Self {
Self::Multiple(value)
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
#[builder(build_fn(error = "OpenRouterError"))]
pub struct ChatCompletionRequest {
#[builder(setter(into))]
model: String,
messages: Vec<Message>,
#[builder(setter(skip), default)]
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
max_completion_tokens: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
seed: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
repetition_penalty: Option<f64>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
logit_bias: Option<HashMap<String, f64>>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
logprobs: Option<bool>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
top_logprobs: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
min_p: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
top_a: Option<f64>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
transforms: Option<Vec<String>>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
models: Option<Vec<String>>,
#[builder(setter(into, strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
route: Option<String>,
#[builder(setter(into, strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<String>,
#[builder(setter(into, strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
trace: Option<TraceOptions>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
provider: Option<ProviderPreferences>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
metadata: Option<HashMap<String, String>>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
plugins: Option<Vec<Plugin>>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
modalities: Option<Vec<Modality>>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
image_config: Option<HashMap<String, Value>>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ResponseFormat>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<ReasoningConfig>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
include_reasoning: Option<bool>,
#[builder(setter(into, strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<StopSequence>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
stream_options: Option<StreamOptions>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
debug: Option<DebugOptions>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<crate::types::Tool>>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<crate::types::ToolChoice>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
parallel_tool_calls: Option<bool>,
}
impl ChatCompletionRequestBuilder {
strip_option_vec_setter!(models, String);
strip_option_map_setter!(logit_bias, String, f64);
strip_option_vec_setter!(transforms, String);
strip_option_map_setter!(metadata, String, String);
strip_option_map_setter!(image_config, String, Value);
strip_option_vec_setter!(plugins, Plugin);
strip_option_vec_setter!(modalities, Modality);
strip_option_vec_setter!(tools, crate::types::Tool);
pub fn enable_reasoning(&mut self) -> &mut Self {
use crate::types::ReasoningConfig;
self.reasoning = Some(Some(ReasoningConfig::enabled()));
self
}
pub fn reasoning_effort(&mut self, effort: crate::types::Effort) -> &mut Self {
use crate::types::ReasoningConfig;
self.reasoning = Some(Some(ReasoningConfig::with_effort(effort)));
self
}
pub fn reasoning_max_tokens(&mut self, max_tokens: u32) -> &mut Self {
use crate::types::ReasoningConfig;
self.reasoning = Some(Some(ReasoningConfig::with_max_tokens(max_tokens)));
self
}
pub fn exclude_reasoning(&mut self) -> &mut Self {
use crate::types::ReasoningConfig;
self.reasoning = Some(Some(ReasoningConfig::excluded()));
self
}
pub fn tool(&mut self, tool: crate::types::Tool) -> &mut Self {
if let Some(Some(ref mut existing_tools)) = self.tools {
existing_tools.push(tool);
} else {
self.tools = Some(Some(vec![tool]));
}
self
}
pub fn tool_choice_auto(&mut self) -> &mut Self {
self.tool_choice = Some(Some(crate::types::ToolChoice::auto()));
self
}
pub fn tool_choice_none(&mut self) -> &mut Self {
self.tool_choice = Some(Some(crate::types::ToolChoice::none()));
self
}
pub fn tool_choice_required(&mut self) -> &mut Self {
self.tool_choice = Some(Some(crate::types::ToolChoice::required()));
self
}
pub fn force_tool(&mut self, tool_name: &str) -> &mut Self {
self.tool_choice = Some(Some(crate::types::ToolChoice::force_tool(tool_name)));
self
}
pub fn typed_tool<T: crate::types::TypedTool>(&mut self) -> &mut Self {
let tool = T::create_tool();
self.tool(tool)
}
pub fn typed_tools_batch(&mut self, tools: &[crate::types::Tool]) -> &mut Self {
for tool in tools {
self.tool(tool.clone());
}
self
}
pub fn force_typed_tool<T: crate::types::TypedTool>(&mut self) -> &mut Self {
let tool_name = T::name();
let tool = T::create_tool();
self.tool(tool);
self.force_tool(tool_name);
self
}
}
impl ChatCompletionRequest {
pub fn builder() -> ChatCompletionRequestBuilder {
ChatCompletionRequestBuilder::default()
}
pub fn new(model: &str, messages: Vec<Message>) -> Self {
Self::builder()
.model(model)
.messages(messages)
.build()
.expect("Failed to build ChatCompletionRequest")
}
pub fn tools(&self) -> Option<&Vec<crate::types::Tool>> {
self.tools.as_ref()
}
pub fn tool_choice(&self) -> Option<&crate::types::ToolChoice> {
self.tool_choice.as_ref()
}
pub fn parallel_tool_calls(&self) -> Option<bool> {
self.parallel_tool_calls
}
pub fn messages(&self) -> &Vec<Message> {
&self.messages
}
fn stream(&self, stream: bool) -> Self {
let mut req = self.clone();
req.stream = Some(stream);
req
}
}
pub async fn send_chat_completion(
base_url: &str,
api_key: &str,
x_title: &Option<String>,
http_referer: &Option<String>,
request: &ChatCompletionRequest,
) -> Result<CompletionsResponse, OpenRouterError> {
let url = format!("{base_url}/chat/completions");
let request = request.stream(false);
let mut surf_req = surf::post(url)
.header(AUTHORIZATION, format!("Bearer {api_key}"))
.body_json(&request)?;
if let Some(x_title) = x_title {
surf_req = surf_req.header("X-Title", x_title);
}
if let Some(http_referer) = http_referer {
surf_req = surf_req.header("HTTP-Referer", http_referer);
}
let mut response = surf_req.await?;
if response.status().is_success() {
let body_text = response.body_string().await?;
let chat_response: CompletionsResponse = serde_json::from_str(&body_text).map_err(|e| {
eprintln!("Failed to deserialize response: {e}\nBody: {body_text}");
OpenRouterError::Serialization(e)
})?;
Ok(chat_response)
} else {
handle_error(response).await?;
unreachable!()
}
}
pub async fn stream_chat_completion(
base_url: &str,
api_key: &str,
request: &ChatCompletionRequest,
) -> Result<BoxStream<'static, Result<CompletionsResponse, OpenRouterError>>, OpenRouterError> {
let url = format!("{base_url}/chat/completions");
let request = request.stream(true);
let response = surf::post(url)
.header(AUTHORIZATION, format!("Bearer {api_key}"))
.body_json(&request)?
.await?;
if response.status().is_success() {
let lines = response
.lines()
.filter_map(async |line| match line {
Ok(line) => line
.strip_prefix("data: ")
.filter(|line| *line != "[DONE]")
.map(serde_json::from_str::<CompletionsResponse>)
.map(|event| event.map_err(OpenRouterError::Serialization)),
Err(error) => Some(Err(OpenRouterError::Io(error))),
})
.boxed();
Ok(lines)
} else {
handle_error(response).await?;
unreachable!()
}
}