use super::registry::get_provider_adapter;
use crate::retry_api::RetryOptions;
use crate::{LlmBuilder, LlmError};
#[derive(Debug, Clone)]
pub struct OpenAiCompatibleBuilder {
base: LlmBuilder,
provider_id: String,
api_key: Option<String>,
base_url: Option<String>,
model: Option<String>,
common_params: crate::types::CommonParams,
http_config: crate::types::HttpConfig,
provider_specific_config: std::collections::HashMap<String, serde_json::Value>,
retry_options: Option<RetryOptions>,
}
impl OpenAiCompatibleBuilder {
pub fn new(base: LlmBuilder, provider_id: &str) -> Self {
let default_model =
crate::providers::openai_compatible::default_models::get_default_chat_model(
provider_id,
)
.map(|model| model.to_string());
Self {
base,
provider_id: provider_id.to_string(),
api_key: None,
base_url: None,
model: default_model,
common_params: crate::types::CommonParams::default(),
http_config: crate::types::HttpConfig::default(),
provider_specific_config: std::collections::HashMap::new(),
retry_options: None,
}
}
pub fn api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn base_url<S: Into<String>>(mut self, base_url: S) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn model<S: Into<String>>(mut self, model: S) -> Self {
self.model = Some(model.into());
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.common_params.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.common_params.max_tokens = Some(max_tokens);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.common_params.top_p = Some(top_p);
self
}
pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
self.http_config.timeout = Some(timeout);
self
}
pub fn connect_timeout(mut self, timeout: std::time::Duration) -> Self {
self.http_config.connect_timeout = Some(timeout);
self
}
pub fn user_agent<S: Into<String>>(mut self, user_agent: S) -> Self {
self.http_config.user_agent = Some(user_agent.into());
self
}
pub fn proxy<S: Into<String>>(mut self, proxy: S) -> Self {
self.http_config.proxy = Some(proxy.into());
self
}
pub fn custom_headers(mut self, headers: std::collections::HashMap<String, String>) -> Self {
self.http_config.headers = headers;
self
}
pub fn header<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
self.http_config.headers.insert(key.into(), value.into());
self
}
pub fn with_retry(mut self, options: RetryOptions) -> Self {
self.retry_options = Some(options);
self
}
pub fn stop<S: Into<String>>(mut self, stop: Vec<S>) -> Self {
self.common_params.stop_sequences = Some(stop.into_iter().map(|s| s.into()).collect());
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.common_params.seed = Some(seed);
self
}
pub fn with_http_config(mut self, config: crate::types::HttpConfig) -> Self {
self.http_config = config;
self
}
pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
self.base = self.base.with_http_client(client);
self
}
pub fn with_thinking(mut self, enable: bool) -> Self {
self.provider_specific_config.insert(
"enable_thinking".to_string(),
serde_json::Value::Bool(enable),
);
self
}
pub fn with_thinking_budget(mut self, budget: u32) -> Self {
let clamped_budget = budget.clamp(128, 32768);
self.provider_specific_config.insert(
"thinking_budget".to_string(),
serde_json::Value::Number(serde_json::Number::from(clamped_budget)),
);
self
}
pub fn reasoning(mut self, enable: bool) -> Self {
match self.provider_id.as_str() {
"siliconflow" => {
self.provider_specific_config.insert(
"enable_thinking".to_string(),
serde_json::Value::Bool(enable),
);
}
"deepseek" | "openrouter" => {
self.provider_specific_config.insert(
"enable_reasoning".to_string(),
serde_json::Value::Bool(enable),
);
}
_ => {
self.provider_specific_config.insert(
"enable_reasoning".to_string(),
serde_json::Value::Bool(enable),
);
}
}
self
}
pub fn reasoning_budget(mut self, budget: i32) -> Self {
let clamped_budget = budget.clamp(128, 32768) as u32;
match self.provider_id.as_str() {
"siliconflow" => {
self.provider_specific_config.insert(
"thinking_budget".to_string(),
serde_json::Value::Number(serde_json::Number::from(clamped_budget)),
);
self.provider_specific_config
.insert("enable_thinking".to_string(), serde_json::Value::Bool(true));
}
"deepseek" | "openrouter" => {
self.provider_specific_config.insert(
"reasoning_budget".to_string(),
serde_json::Value::Number(serde_json::Number::from(clamped_budget)),
);
self.provider_specific_config.insert(
"enable_reasoning".to_string(),
serde_json::Value::Bool(true),
);
}
_ => {
self.provider_specific_config.insert(
"reasoning_budget".to_string(),
serde_json::Value::Number(serde_json::Number::from(clamped_budget)),
);
}
}
self
}
pub async fn build(
self,
) -> Result<crate::providers::openai_compatible::OpenAiCompatibleClient, LlmError> {
let api_key = self.api_key.ok_or_else(|| {
LlmError::ConfigurationError(format!("API key is required for {}", self.provider_id))
})?;
let adapter = get_provider_adapter(&self.provider_id)?;
let base_url = self
.base_url
.unwrap_or_else(|| adapter.base_url().to_string());
let mut config = crate::providers::openai_compatible::OpenAiCompatibleConfig::new(
&self.provider_id,
&api_key,
&base_url,
adapter,
);
if let Some(model) = self.model {
config = config.with_model(&model);
}
config = config.with_common_params(self.common_params);
let mut final_http_config = self.http_config;
if let Some(timeout) = self.base.timeout {
final_http_config.timeout = Some(timeout);
}
if let Some(connect_timeout) = self.base.connect_timeout {
final_http_config.connect_timeout = Some(connect_timeout);
}
if let Some(user_agent) = self.base.user_agent {
final_http_config.user_agent = Some(user_agent);
}
if let Some(proxy) = self.base.proxy {
final_http_config.proxy = Some(proxy);
}
for (key, value) in self.base.default_headers {
final_http_config.headers.insert(key, value);
}
config = config.with_http_config(final_http_config);
let mut client = if let Some(http_client) = self.base.http_client {
crate::providers::openai_compatible::OpenAiCompatibleClient::with_http_client(
config,
http_client,
)
.await?
} else {
crate::providers::openai_compatible::OpenAiCompatibleClient::new(config).await?
};
client.set_retry_options(self.retry_options.clone());
Ok(client)
}
}