use secrecy::{ExposeSecret, SecretString};
use std::collections::HashMap;
use crate::params::OpenAiParams;
use crate::types::{CommonParams, HttpConfig, WebSearchConfig};
#[derive(Debug, Clone)]
pub struct OpenAiConfig {
pub api_key: SecretString,
pub base_url: String,
pub organization: Option<String>,
pub project: Option<String>,
pub common_params: CommonParams,
pub openai_params: OpenAiParams,
pub http_config: HttpConfig,
pub web_search_config: WebSearchConfig,
pub use_responses_api: bool,
pub previous_response_id: Option<String>,
pub built_in_tools: Vec<crate::types::OpenAiBuiltInTool>,
}
impl OpenAiConfig {
pub fn new<S: Into<String>>(api_key: S) -> Self {
Self {
api_key: SecretString::from(api_key.into()),
base_url: "https://api.openai.com/v1".to_string(),
organization: None,
project: None,
common_params: CommonParams::default(),
openai_params: OpenAiParams::default(),
http_config: HttpConfig::default(),
web_search_config: WebSearchConfig::default(),
use_responses_api: false,
previous_response_id: None,
built_in_tools: Vec::new(),
}
}
pub fn with_base_url<S: Into<String>>(mut self, url: S) -> Self {
self.base_url = url.into();
self
}
pub fn with_organization<S: Into<String>>(mut self, org: S) -> Self {
self.organization = Some(org.into());
self
}
pub fn with_project<S: Into<String>>(mut self, project: S) -> Self {
self.project = Some(project.into());
self
}
pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
self.common_params.model = model.into();
self
}
pub const fn with_temperature(mut self, temperature: f32) -> Self {
self.common_params.temperature = Some(temperature);
self
}
pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.common_params.max_tokens = Some(max_tokens);
self
}
pub fn with_web_search(mut self, config: Option<WebSearchConfig>) -> Self {
self.web_search_config = config.unwrap_or_else(|| WebSearchConfig {
enabled: true,
..Default::default()
});
self
}
pub const fn enable_web_search(mut self) -> Self {
self.web_search_config.enabled = true;
self
}
pub const fn with_responses_api(mut self, use_responses: bool) -> Self {
self.use_responses_api = use_responses;
self
}
pub fn with_previous_response_id<S: Into<String>>(mut self, response_id: S) -> Self {
self.previous_response_id = Some(response_id.into());
self
}
pub fn with_built_in_tool(mut self, tool: crate::types::OpenAiBuiltInTool) -> Self {
self.built_in_tools.push(tool);
self
}
pub fn with_built_in_tools(mut self, tools: Vec<crate::types::OpenAiBuiltInTool>) -> Self {
self.built_in_tools.extend(tools);
self
}
pub fn auth_header(&self) -> String {
format!("Bearer {}", self.api_key.expose_secret())
}
pub fn organization_header(&self) -> Option<String> {
self.organization.clone()
}
pub fn project_header(&self) -> Option<String> {
self.project.clone()
}
pub fn get_headers(&self) -> HashMap<String, String> {
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), self.auth_header());
headers.insert("Content-Type".to_string(), "application/json".to_string());
if let Some(org) = &self.organization {
headers.insert("OpenAI-Organization".to_string(), org.clone());
}
if let Some(project) = &self.project {
headers.insert("OpenAI-Project".to_string(), project.clone());
}
if self.use_responses_api {
headers.insert(
"OpenAI-Beta".to_string(),
"responses-2024-12-17".to_string(),
);
}
headers
}
pub fn validate(&self) -> Result<(), String> {
if self.api_key.expose_secret().is_empty() {
return Err("API key cannot be empty".to_string());
}
if self.base_url.is_empty() {
return Err("Base URL cannot be empty".to_string());
}
if !self.base_url.starts_with("http://") && !self.base_url.starts_with("https://") {
return Err("Base URL must start with http:// or https://".to_string());
}
if let Some(temp) = self.common_params.temperature
&& !(0.0..=2.0).contains(&temp)
{
return Err("Temperature must be between 0.0 and 2.0".to_string());
}
if let Some(top_p) = self.common_params.top_p
&& !(0.0..=1.0).contains(&top_p)
{
return Err("Top-p must be between 0.0 and 1.0".to_string());
}
if let Some(freq_penalty) = self.openai_params.frequency_penalty
&& !(-2.0..=2.0).contains(&freq_penalty)
{
return Err("Frequency penalty must be between -2.0 and 2.0".to_string());
}
if let Some(pres_penalty) = self.openai_params.presence_penalty
&& !(-2.0..=2.0).contains(&pres_penalty)
{
return Err("Presence penalty must be between -2.0 and 2.0".to_string());
}
Ok(())
}
}
impl Default for OpenAiConfig {
fn default() -> Self {
Self {
api_key: SecretString::from(String::new()),
base_url: "https://api.openai.com/v1".to_string(),
organization: None,
project: None,
common_params: CommonParams::default(),
openai_params: OpenAiParams::default(),
http_config: HttpConfig::default(),
web_search_config: WebSearchConfig::default(),
use_responses_api: false,
previous_response_id: None,
built_in_tools: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_creation() {
let config = OpenAiConfig::new("test-key");
assert_eq!(config.api_key.expose_secret(), "test-key");
assert_eq!(config.base_url, "https://api.openai.com/v1");
}
#[test]
fn test_config_validation() {
let mut config = OpenAiConfig::new("test-key");
assert!(config.validate().is_ok());
config.api_key = SecretString::from(String::new());
assert!(config.validate().is_err());
}
#[test]
fn test_headers() {
let config = OpenAiConfig::new("test-key")
.with_organization("org-123")
.with_project("proj-456");
let headers = config.get_headers();
assert_eq!(
headers.get("Authorization"),
Some(&"Bearer test-key".to_string())
);
assert_eq!(
headers.get("OpenAI-Organization"),
Some(&"org-123".to_string())
);
assert_eq!(headers.get("OpenAI-Project"), Some(&"proj-456".to_string()));
}
}