use std::collections::HashMap;
use derive_builder::Builder;
use reqwest::Client as HttpClient;
use serde::{Deserialize, Serialize};
use crate::{
error::OpenRouterError,
strip_option_map_setter, strip_option_vec_setter,
transport::{request as transport_request, response as transport_response},
types::{
ProviderPreferences, ReasoningConfig, ResponseFormat, completion::CompletionsResponse,
},
};
#[derive(Serialize, Deserialize, Debug, Builder)]
#[builder(build_fn(error = "OpenRouterError"))]
pub struct CompletionRequest {
#[builder(setter(into))]
model: String,
#[builder(setter(into))]
prompt: String,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
seed: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
repetition_penalty: Option<f64>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
logit_bias: Option<HashMap<String, f64>>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
top_logprobs: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
min_p: Option<f64>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
top_a: Option<f64>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
transforms: Option<Vec<String>>,
#[builder(setter(custom), default)]
#[serde(skip_serializing_if = "Option::is_none")]
models: Option<Vec<String>>,
#[builder(setter(into, strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
route: Option<String>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
provider: Option<ProviderPreferences>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ResponseFormat>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<ReasoningConfig>,
}
impl CompletionRequestBuilder {
strip_option_vec_setter!(models, String);
strip_option_map_setter!(logit_bias, String, f64);
strip_option_vec_setter!(transforms, String);
}
impl CompletionRequest {
pub fn builder() -> CompletionRequestBuilder {
CompletionRequestBuilder::default()
}
pub fn new(model: &str, prompt: &str) -> Self {
Self::builder()
.model(model)
.prompt(prompt)
.build()
.expect("Failed to build CompletionRequest")
}
}
pub async fn send_completion_request(
base_url: &str,
api_key: &str,
x_title: &Option<String>,
http_referer: &Option<String>,
app_categories: &Option<Vec<String>>,
request: &CompletionRequest,
) -> Result<CompletionsResponse, OpenRouterError> {
let http_client = crate::transport::new_client()?;
send_completion_request_with_client(
&http_client,
base_url,
api_key,
x_title,
http_referer,
app_categories,
request,
)
.await
}
pub(crate) async fn send_completion_request_with_client(
http_client: &HttpClient,
base_url: &str,
api_key: &str,
x_title: &Option<String>,
http_referer: &Option<String>,
app_categories: &Option<Vec<String>>,
request: &CompletionRequest,
) -> Result<CompletionsResponse, OpenRouterError> {
let url = format!("{base_url}/completions");
let response = transport_request::with_client_request_headers(
transport_request::post(http_client, &url),
api_key,
x_title,
http_referer,
app_categories,
)?
.json(request)
.send()
.await?;
if response.status().is_success() {
let completion_response =
transport_response::parse_json_response(response, "legacy completion").await?;
Ok(completion_response)
} else {
transport_response::handle_error(response).await?;
unreachable!()
}
}