use crate::builder::LlmBuilder;
use crate::error::LlmError;
use crate::params::{OpenAiParams, ResponseFormat, ToolChoice};
use crate::retry_api::RetryOptions;
use crate::types::*;
use super::OpenAiClient;
pub struct OpenAiBuilder {
pub(crate) base: LlmBuilder,
api_key: Option<String>,
base_url: Option<String>,
organization: Option<String>,
project: Option<String>,
model: Option<String>,
common_params: CommonParams,
openai_params: OpenAiParams,
http_config: HttpConfig,
tracing_config: Option<crate::tracing::TracingConfig>,
retry_options: Option<RetryOptions>,
}
#[cfg(feature = "openai")]
impl OpenAiBuilder {
pub fn new(base: LlmBuilder) -> Self {
Self {
base,
api_key: None,
base_url: None,
organization: None,
project: None,
model: None,
common_params: CommonParams::default(),
openai_params: OpenAiParams::default(),
http_config: HttpConfig::default(),
tracing_config: None,
retry_options: None,
}
}
pub fn api_key<S: Into<String>>(mut self, key: S) -> Self {
self.api_key = Some(key.into());
self
}
pub fn base_url<S: Into<String>>(mut self, url: S) -> Self {
self.base_url = Some(url.into());
self
}
pub fn organization<S: Into<String>>(mut self, org: S) -> Self {
self.organization = Some(org.into());
self
}
pub fn project<S: Into<String>>(mut self, project: S) -> Self {
self.project = Some(project.into());
self
}
pub fn model<S: Into<String>>(mut self, model: S) -> Self {
let model_str = model.into();
self.model = Some(model_str.clone());
self.common_params.model = model_str;
self
}
pub const fn temperature(mut self, temp: f32) -> Self {
self.common_params.temperature = Some(temp);
self
}
pub const fn max_tokens(mut self, tokens: u32) -> Self {
self.common_params.max_tokens = Some(tokens);
self
}
pub const fn top_p(mut self, top_p: f32) -> Self {
self.common_params.top_p = Some(top_p);
self
}
pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
self.common_params.stop_sequences = Some(sequences);
self
}
pub const fn seed(mut self, seed: u64) -> Self {
self.common_params.seed = Some(seed);
self
}
pub fn response_format(mut self, format: ResponseFormat) -> Self {
self.openai_params.response_format = Some(format);
self
}
pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
self.openai_params.tool_choice = Some(choice);
self
}
pub const fn frequency_penalty(mut self, penalty: f32) -> Self {
self.openai_params.frequency_penalty = Some(penalty);
self
}
pub const fn presence_penalty(mut self, penalty: f32) -> Self {
self.openai_params.presence_penalty = Some(penalty);
self
}
pub fn user<S: Into<String>>(mut self, user: S) -> Self {
self.openai_params.user = Some(user.into());
self
}
pub const fn parallel_tool_calls(mut self, enabled: bool) -> Self {
self.openai_params.parallel_tool_calls = Some(enabled);
self
}
pub fn with_http_config(mut self, config: HttpConfig) -> Self {
self.http_config = config;
self
}
pub fn tracing(mut self, config: crate::tracing::TracingConfig) -> Self {
self.tracing_config = Some(config);
self
}
pub fn debug_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::development())
}
pub fn minimal_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::minimal())
}
pub fn json_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::json_production())
}
pub fn enable_tracing(self) -> Self {
self.debug_tracing()
}
pub fn disable_tracing(self) -> Self {
self.tracing(crate::tracing::TracingConfig::disabled())
}
pub fn pretty_json(mut self, pretty: bool) -> Self {
let config = self
.tracing_config
.take()
.unwrap_or_else(crate::tracing::TracingConfig::development);
let updated_config = crate::tracing::TracingConfigBuilder::from_config(config)
.pretty_json(pretty)
.build();
self.tracing_config = Some(updated_config);
self
}
pub fn mask_sensitive_values(mut self, mask: bool) -> Self {
let config = self
.tracing_config
.take()
.unwrap_or_else(crate::tracing::TracingConfig::development);
let updated_config = crate::tracing::TracingConfigBuilder::from_config(config)
.mask_sensitive_values(mask)
.build();
self.tracing_config = Some(updated_config);
self
}
pub fn with_retry(mut self, options: RetryOptions) -> Self {
self.retry_options = Some(options);
self
}
pub async fn build(self) -> Result<OpenAiClient, LlmError> {
let api_key = self
.api_key
.or_else(|| std::env::var("OPENAI_API_KEY").ok())
.ok_or(LlmError::MissingApiKey(
"OpenAI API key not provided".to_string(),
))?;
let base_url = self
.base_url
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
let _tracing_guard = if let Some(tracing_config) = self.tracing_config {
crate::tracing::init_tracing(tracing_config)?
} else {
None
};
let http_client = self.base.http_client.unwrap_or_else(|| {
let mut builder = reqwest::Client::builder().timeout(
self.base
.timeout
.unwrap_or(crate::defaults::http::REQUEST_TIMEOUT),
);
if let Some(timeout) = self.http_config.timeout {
builder = builder.timeout(timeout);
}
builder.build().unwrap()
});
let mut client = OpenAiClient::new_legacy(
api_key,
base_url,
http_client,
self.common_params,
self.openai_params,
self.http_config,
self.organization,
self.project,
);
client.set_tracing_guard(_tracing_guard);
client.set_retry_options(self.retry_options.clone());
Ok(client)
}
}