ferrous_llm_openai/
config.rs

1//! OpenAI provider configuration.
2
3use ferrous_llm_core::{ConfigError, HttpConfig, ProviderConfig, SecretString, validation};
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6use url::Url;
7
8/// Configuration for the OpenAI provider.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct OpenAIConfig {
11    /// OpenAI API key
12    pub api_key: SecretString,
13
14    /// Model to use (e.g., "gpt-4", "gpt-3.5-turbo")
15    pub model: String,
16
17    /// Base URL for the OpenAI API (defaults to https://api.openai.com/v1)
18    pub base_url: Option<Url>,
19
20    /// Organization ID (optional)
21    pub organization: Option<String>,
22
23    /// Project ID (optional)
24    pub project: Option<String>,
25
26    /// HTTP client configuration
27    pub http: HttpConfig,
28
29    /// Embedding model to use (e.g., "text-embedding-ada-002")
30    pub embedding_model: Option<String>,
31}
32
33impl Default for OpenAIConfig {
34    fn default() -> Self {
35        Self {
36            api_key: SecretString::new(""),
37            model: "gpt-3.5-turbo".to_string(),
38            base_url: None,
39            organization: None,
40            project: None,
41            http: HttpConfig::default(),
42            embedding_model: None,
43        }
44    }
45}
46
47impl ProviderConfig for OpenAIConfig {
48    type Provider = crate::provider::OpenAIProvider;
49
50    fn build(self) -> Result<Self::Provider, ConfigError> {
51        self.validate()?;
52        crate::provider::OpenAIProvider::new(self).map_err(|e| match e {
53            crate::error::OpenAIError::Config { source } => source,
54            _ => ConfigError::validation_failed("Failed to create provider"),
55        })
56    }
57
58    fn validate(&self) -> Result<(), ConfigError> {
59        // Validate API key
60        validation::validate_api_key(&self.api_key, "api_key")?;
61
62        // Validate model name
63        validation::validate_model_name(&self.model, "model")?;
64
65        // Validate base URL if provided
66        if let Some(ref url) = self.base_url {
67            validation::validate_https_url(url, "base_url")?;
68        }
69
70        // Validate HTTP configuration
71        validation::validate_positive_duration(self.http.timeout, "http.timeout")?;
72        validation::validate_range(self.http.max_retries, 0, 10, "http.max_retries")?;
73
74        Ok(())
75    }
76}
77
78impl OpenAIConfig {
79    /// Create a new OpenAI configuration with the given API key and model.
80    pub fn new(api_key: impl Into<SecretString>, model: impl Into<String>) -> Self {
81        Self {
82            api_key: api_key.into(),
83            model: model.into(),
84            ..Default::default()
85        }
86    }
87
88    /// Create a configuration builder.
89    pub fn builder() -> OpenAIConfigBuilder {
90        OpenAIConfigBuilder::new()
91    }
92
93    /// Get the base URL for API requests.
94    pub fn base_url(&self) -> &str {
95        self.base_url
96            .as_ref()
97            .map(|u| u.as_str())
98            .unwrap_or("https://api.openai.com/v1")
99    }
100
101    /// Get the chat completions endpoint URL.
102    pub fn chat_url(&self) -> String {
103        format!("{}/chat/completions", self.base_url())
104    }
105
106    /// Get the completions endpoint URL.
107    pub fn completions_url(&self) -> String {
108        format!("{}/completions", self.base_url())
109    }
110
111    /// Get the embeddings endpoint URL.
112    pub fn embeddings_url(&self) -> String {
113        format!("{}/embeddings", self.base_url())
114    }
115
116    /// Get the images endpoint URL.
117    pub fn images_url(&self) -> String {
118        format!("{}/images/generations", self.base_url())
119    }
120
121    /// Get the audio transcriptions endpoint URL.
122    pub fn transcriptions_url(&self) -> String {
123        format!("{}/audio/transcriptions", self.base_url())
124    }
125
126    /// Get the audio speech endpoint URL.
127    pub fn speech_url(&self) -> String {
128        format!("{}/audio/speech", self.base_url())
129    }
130
131    /// Load configuration from environment variables.
132    pub fn from_env() -> Result<Self, ConfigError> {
133        use ferrous_llm_core::env;
134
135        let api_key = env::required_secret("OPENAI_API_KEY")?;
136        let model = env::with_default("OPENAI_MODEL", "gpt-3.5-turbo");
137        let organization = env::optional("OPENAI_ORGANIZATION");
138        let project = env::optional("OPENAI_PROJECT");
139
140        let base_url = if let Some(url_str) = env::optional("OPENAI_BASE_URL") {
141            Some(validation::validate_url(&url_str, "OPENAI_BASE_URL")?)
142        } else {
143            None
144        };
145
146        Ok(Self {
147            api_key,
148            model,
149            base_url,
150            organization,
151            project,
152            http: HttpConfig::default(),
153            embedding_model: None,
154        })
155    }
156}
157
158/// Builder for OpenAI configuration.
159pub struct OpenAIConfigBuilder {
160    config: OpenAIConfig,
161}
162
163impl OpenAIConfigBuilder {
164    /// Create a new builder.
165    pub fn new() -> Self {
166        Self {
167            config: OpenAIConfig::default(),
168        }
169    }
170
171    /// Set the API key.
172    pub fn api_key(mut self, api_key: impl Into<SecretString>) -> Self {
173        self.config.api_key = api_key.into();
174        self
175    }
176
177    /// Set the model.
178    pub fn model(mut self, model: impl Into<String>) -> Self {
179        self.config.model = model.into();
180        self
181    }
182
183    /// Set the base URL.
184    pub fn base_url(mut self, base_url: impl Into<String>) -> Result<Self, ConfigError> {
185        let url = validation::validate_url(&base_url.into(), "base_url")?;
186        self.config.base_url = Some(url);
187        Ok(self)
188    }
189
190    /// Set the organization.
191    pub fn organization(mut self, organization: impl Into<String>) -> Self {
192        self.config.organization = Some(organization.into());
193        self
194    }
195
196    /// Set the project.
197    pub fn project(mut self, project: impl Into<String>) -> Self {
198        self.config.project = Some(project.into());
199        self
200    }
201
202    /// Set the request timeout.
203    pub fn timeout(mut self, timeout: Duration) -> Self {
204        self.config.http.timeout = timeout;
205        self
206    }
207
208    /// Set the maximum number of retries.
209    pub fn max_retries(mut self, max_retries: u32) -> Self {
210        self.config.http.max_retries = max_retries;
211        self
212    }
213
214    /// Set a custom HTTP header.
215    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
216        self.config.http.headers.insert(key.into(), value.into());
217        self
218    }
219
220    /// Build the configuration.
221    pub fn build(self) -> OpenAIConfig {
222        self.config
223    }
224}
225
226impl Default for OpenAIConfigBuilder {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_config_validation() {
238        let config = OpenAIConfig::new("sk-test123456789", "gpt-4");
239        assert!(config.validate().is_ok());
240    }
241
242    #[test]
243    fn test_config_validation_empty_api_key() {
244        let config = OpenAIConfig::new("", "gpt-4");
245        assert!(config.validate().is_err());
246    }
247
248    #[test]
249    fn test_config_builder() {
250        let config = OpenAIConfig::builder()
251            .api_key("sk-test123456789")
252            .model("gpt-4")
253            .organization("org-123")
254            .timeout(Duration::from_secs(60))
255            .build();
256
257        assert_eq!(config.model, "gpt-4");
258        assert_eq!(config.organization, Some("org-123".to_string()));
259        assert_eq!(config.http.timeout, Duration::from_secs(60));
260    }
261
262    #[test]
263    fn test_urls() {
264        let config = OpenAIConfig::new("sk-test", "gpt-4");
265        assert_eq!(
266            config.chat_url(),
267            "https://api.openai.com/v1/chat/completions"
268        );
269        assert_eq!(
270            config.embeddings_url(),
271            "https://api.openai.com/v1/embeddings"
272        );
273    }
274
275    #[test]
276    fn test_custom_base_url() {
277        let mut config = OpenAIConfig::new("sk-test", "gpt-4");
278        config.base_url = Some("https://custom.openai.com/v1".parse().unwrap());
279        assert_eq!(
280            config.chat_url(),
281            "https://custom.openai.com/v1/chat/completions"
282        );
283    }
284
285    #[test]
286    fn test_api_key_serialization_redaction() {
287        // Build a realistic-looking key at runtime to avoid secret scanners.
288        let api_key = format!("sk-{}", "a".repeat(32));
289        let config = OpenAIConfig::new(SecretString::from(api_key.clone()), "gpt-4");
290        // Test JSON serialization
291        let serialized = serde_json::to_string(&config).expect("Failed to serialize config");
292        // Parse the JSON to check the api_key field
293        let json_value: serde_json::Value =
294            serde_json::from_str(&serialized).expect("Failed to parse JSON");
295        let api_key_value = json_value.get("api_key").expect("api_key field not found");
296        // Verify that the api_key is redacted in serialization
297        assert_eq!(api_key_value, "[REDACTED]");
298        assert_ne!(api_key_value.as_str().unwrap(), api_key);
299        // Ensure no accidental leakage anywhere in the serialized payload.
300        assert!(!serialized.contains(&api_key));
301        // Verify that we can still access the actual key via expose_secret
302        assert_eq!(config.api_key.expose_secret(), &api_key);
303    }
304
305    #[test]
306    fn test_config_debug_redaction() {
307        let config = OpenAIConfig::new("sk-supersecrettestkey123", "gpt-4");
308
309        // Test debug formatting
310        let debug_output = format!("{config:?}");
311
312        // Verify that the debug output contains [REDACTED] for the API key
313        assert!(debug_output.contains("[REDACTED]"));
314        assert!(!debug_output.contains("sk-supersecrettestkey123"));
315
316        // Verify that other fields are still visible
317        assert!(debug_output.contains("gpt-4"));
318        assert!(debug_output.contains("api_key"));
319    }
320
321    #[test]
322    fn test_secret_string_serialization_various_keys() {
323        // Test with different API key formats (built at runtime to avoid scanners)
324        let suffix = "c".repeat(32);
325        let test_keys = vec![
326            format!("sk-{suffix}"),
327            format!("sk-proj-{suffix}"),
328            format!("sk-org-{suffix}"),
329            "sk-short".to_string(),
330        ];
331        for test_key in test_keys {
332            let config = OpenAIConfig::new(test_key.clone(), "gpt-4");
333            let serialized = serde_json::to_string(&config).expect("Failed to serialize config");
334            let json_value: serde_json::Value =
335                serde_json::from_str(&serialized).expect("Failed to parse JSON");
336            let api_key_value = json_value.get("api_key").expect("api_key field not found");
337            // All should be redacted
338            assert_eq!(api_key_value, "[REDACTED]");
339            assert_ne!(api_key_value.as_str().unwrap(), test_key);
340            assert!(!serialized.contains(&test_key));
341            // But we can still access the original
342            assert_eq!(config.api_key.expose_secret(), test_key);
343        }
344    }
345}