litellm_rs/config/
builder.rs

1//! Configuration builder for type-safe configuration construction
2//!
3//! This module provides a builder pattern for creating configurations
4//! with compile-time validation and better ergonomics.
5
6#![allow(dead_code)] // Builder module - functions may be used in the future
7
8use super::{AuthConfig, Config, GatewayConfig, ProviderConfig, ServerConfig, StorageConfig};
9use crate::utils::data::type_utils::{Builder, NonEmptyString, PositiveF64};
10use crate::utils::error::{GatewayError, Result};
11use std::collections::HashMap;
12use std::time::Duration;
13
14/// Builder for creating gateway configurations
15#[derive(Debug, Clone)]
16pub struct ConfigBuilder {
17    server: Option<ServerConfig>,
18    auth: Option<AuthConfig>,
19    storage: Option<StorageConfig>,
20    providers: Vec<ProviderConfig>,
21    features: HashMap<String, bool>,
22}
23
24impl ConfigBuilder {
25    /// Create a new configuration builder
26    pub fn new() -> Self {
27        Self {
28            server: None,
29            auth: None,
30            storage: None,
31            providers: Vec::new(),
32            features: HashMap::new(),
33        }
34    }
35
36    /// Set the server configuration
37    pub fn with_server(mut self, config: ServerConfig) -> Self {
38        self.server = Some(config);
39        self
40    }
41
42    /// Set the authentication configuration
43    pub fn with_auth(mut self, config: AuthConfig) -> Self {
44        self.auth = Some(config);
45        self
46    }
47
48    /// Set the storage configuration
49    pub fn with_storage(mut self, config: StorageConfig) -> Self {
50        self.storage = Some(config);
51        self
52    }
53
54    /// Add a provider configuration
55    pub fn add_provider(mut self, config: ProviderConfig) -> Self {
56        self.providers.push(config);
57        self
58    }
59
60    /// Add multiple provider configurations
61    pub fn add_providers(mut self, configs: Vec<ProviderConfig>) -> Self {
62        self.providers.extend(configs);
63        self
64    }
65
66    /// Enable a feature
67    pub fn enable_feature(mut self, feature: impl Into<String>) -> Self {
68        self.features.insert(feature.into(), true);
69        self
70    }
71
72    /// Disable a feature
73    pub fn disable_feature(mut self, feature: impl Into<String>) -> Self {
74        self.features.insert(feature.into(), false);
75        self
76    }
77
78    /// Build the configuration with validation
79    pub fn build(self) -> Result<Config> {
80        let gateway = GatewayConfig {
81            server: self.server.unwrap_or_default(),
82            auth: self.auth.unwrap_or_default(),
83            storage: self.storage.unwrap_or_default(),
84            providers: self.providers,
85            router: super::RouterConfig::default(),
86            monitoring: super::MonitoringConfig::default(),
87            cache: super::CacheConfig::default(),
88            rate_limit: super::RateLimitConfig::default(),
89            enterprise: super::EnterpriseConfig::default(),
90        };
91
92        let config = Config { gateway };
93
94        // Validate the configuration
95        if let Err(e) = config.gateway.validate() {
96            return Err(GatewayError::Config(e));
97        }
98
99        Ok(config)
100    }
101
102    /// Build the configuration or panic with a descriptive message
103    pub fn build_or_panic(self) -> Config {
104        self.build().unwrap_or_else(|e| {
105            panic!("Failed to build configuration: {}", e);
106        })
107    }
108}
109
110impl Default for ConfigBuilder {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116impl Builder<Config> for ConfigBuilder {
117    fn build(self) -> Config {
118        self.build().expect("Configuration validation failed")
119    }
120}
121
122/// Builder for server configuration
123#[derive(Debug, Clone)]
124pub struct ServerConfigBuilder {
125    host: Option<String>,
126    port: Option<u16>,
127    workers: Option<usize>,
128    timeout: Option<Duration>,
129    max_connections: Option<usize>,
130    enable_cors: bool,
131    cors_origins: Vec<String>,
132}
133
134impl ServerConfigBuilder {
135    /// Create a new server configuration builder
136    pub fn new() -> Self {
137        Self {
138            host: None,
139            port: None,
140            workers: None,
141            timeout: None,
142            max_connections: None,
143            enable_cors: false,
144            cors_origins: Vec::new(),
145        }
146    }
147
148    /// Set the host
149    pub fn host(mut self, host: impl Into<String>) -> Self {
150        self.host = Some(host.into());
151        self
152    }
153
154    /// Set the port
155    pub fn port(mut self, port: u16) -> Self {
156        self.port = Some(port);
157        self
158    }
159
160    /// Set the number of workers
161    pub fn workers(mut self, workers: usize) -> Self {
162        self.workers = Some(workers);
163        self
164    }
165
166    /// Set the request timeout
167    pub fn timeout(mut self, timeout: Duration) -> Self {
168        self.timeout = Some(timeout);
169        self
170    }
171
172    /// Set the maximum number of connections
173    pub fn max_connections(mut self, max_connections: usize) -> Self {
174        self.max_connections = Some(max_connections);
175        self
176    }
177
178    /// Enable CORS
179    pub fn enable_cors(mut self) -> Self {
180        self.enable_cors = true;
181        self
182    }
183
184    /// Add CORS origin
185    pub fn add_cors_origin(mut self, origin: impl Into<String>) -> Self {
186        self.cors_origins.push(origin.into());
187        self
188    }
189
190    /// Build the server configuration
191    pub fn build(self) -> ServerConfig {
192        ServerConfig {
193            host: self.host.unwrap_or_else(|| "127.0.0.1".to_string()),
194            port: self.port.unwrap_or(8080),
195            workers: self.workers,
196            timeout: self.timeout.map(|d| d.as_secs()).unwrap_or(30),
197            max_body_size: 1024 * 1024, // 1MB default
198            dev_mode: false,
199            tls: None,
200            cors: super::CorsConfig {
201                enabled: self.enable_cors,
202                allowed_origins: if self.cors_origins.is_empty() {
203                    vec!["*".to_string()]
204                } else {
205                    self.cors_origins
206                },
207                allowed_methods: vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()],
208                allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
209                max_age: 3600,
210                allow_credentials: false,
211            },
212        }
213    }
214}
215
216impl Default for ServerConfigBuilder {
217    fn default() -> Self {
218        Self::new()
219    }
220}
221
222impl Builder<ServerConfig> for ServerConfigBuilder {
223    fn build(self) -> ServerConfig {
224        self.build()
225    }
226}
227
228/// Builder for provider configuration
229#[derive(Debug, Clone)]
230pub struct ProviderConfigBuilder {
231    name: Option<NonEmptyString>,
232    provider_type: Option<NonEmptyString>,
233    api_key: Option<String>,
234    base_url: Option<String>,
235    models: Vec<String>,
236    max_requests_per_minute: Option<u32>,
237    timeout: Option<Duration>,
238    enabled: bool,
239    weight: Option<PositiveF64>,
240}
241
242impl ProviderConfigBuilder {
243    /// Create a new provider configuration builder
244    pub fn new() -> Self {
245        Self {
246            name: None,
247            provider_type: None,
248            api_key: None,
249            base_url: None,
250            models: Vec::new(),
251            max_requests_per_minute: None,
252            timeout: None,
253            enabled: true,
254            weight: None,
255        }
256    }
257
258    /// Set the provider name
259    pub fn name(mut self, name: impl TryInto<NonEmptyString>) -> Result<Self> {
260        self.name = Some(
261            name.try_into()
262                .map_err(|_| GatewayError::Config("Provider name cannot be empty".to_string()))?,
263        );
264        Ok(self)
265    }
266
267    /// Set the provider type
268    pub fn provider_type(mut self, provider_type: impl TryInto<NonEmptyString>) -> Result<Self> {
269        self.provider_type = Some(
270            provider_type
271                .try_into()
272                .map_err(|_| GatewayError::Config("Provider type cannot be empty".to_string()))?,
273        );
274        Ok(self)
275    }
276
277    /// Set the API key
278    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
279        self.api_key = Some(api_key.into());
280        self
281    }
282
283    /// Set the base URL
284    pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
285        self.base_url = Some(base_url.into());
286        self
287    }
288
289    /// Add a supported model
290    pub fn add_model(mut self, model: impl Into<String>) -> Self {
291        self.models.push(model.into());
292        self
293    }
294
295    /// Set the rate limit
296    pub fn rate_limit(mut self, requests_per_minute: u32) -> Self {
297        self.max_requests_per_minute = Some(requests_per_minute);
298        self
299    }
300
301    /// Set the timeout
302    pub fn timeout(mut self, timeout: Duration) -> Self {
303        self.timeout = Some(timeout);
304        self
305    }
306
307    /// Enable the provider
308    pub fn enable(mut self) -> Self {
309        self.enabled = true;
310        self
311    }
312
313    /// Disable the provider
314    pub fn disable(mut self) -> Self {
315        self.enabled = false;
316        self
317    }
318
319    /// Set the provider weight for load balancing
320    pub fn weight(mut self, weight: f64) -> Result<Self> {
321        self.weight =
322            Some(PositiveF64::new(weight).map_err(|_| {
323                GatewayError::Config("Provider weight must be positive".to_string())
324            })?);
325        Ok(self)
326    }
327
328    /// Build the provider configuration
329    pub fn build(self) -> Result<ProviderConfig> {
330        let name = self
331            .name
332            .ok_or_else(|| GatewayError::Config("Provider name is required".to_string()))?;
333
334        let provider_type = self
335            .provider_type
336            .ok_or_else(|| GatewayError::Config("Provider type is required".to_string()))?;
337
338        Ok(ProviderConfig {
339            name: name.into_string(),
340            provider_type: provider_type.into_string(),
341            api_key: self.api_key.unwrap_or_default(),
342            base_url: self.base_url,
343            api_version: None,
344            organization: None,
345            project: None,
346            weight: self.weight.map(|w| w.get() as f32).unwrap_or(1.0),
347            rpm: self.max_requests_per_minute.unwrap_or(1000),
348            tpm: 100000, // Default TPM
349            max_concurrent_requests: 10,
350            timeout: self.timeout.map(|d| d.as_secs()).unwrap_or(30),
351            max_retries: 3,
352            retry: super::RetryConfig::default(),
353            health_check: super::HealthCheckConfig::default(),
354            settings: std::collections::HashMap::new(),
355            models: self.models,
356            enabled: self.enabled,
357            tags: Vec::new(),
358        })
359    }
360}
361
362impl Default for ProviderConfigBuilder {
363    fn default() -> Self {
364        Self::new()
365    }
366}
367
368/// Convenience functions for common configurations
369pub mod presets {
370    use super::*;
371
372    /// Create a development server configuration
373    pub fn dev_server() -> ServerConfigBuilder {
374        ServerConfigBuilder::new()
375            .host("127.0.0.1")
376            .port(8080)
377            .workers(1)
378            .enable_cors()
379            .add_cors_origin("*")
380    }
381
382    /// Create a production server configuration
383    pub fn prod_server() -> ServerConfigBuilder {
384        ServerConfigBuilder::new()
385            .host("0.0.0.0")
386            .port(8080)
387            .workers(num_cpus::get())
388            .max_connections(10000)
389            .timeout(Duration::from_secs(60))
390    }
391
392    /// Create an OpenAI provider configuration
393    pub fn openai_provider(name: &str, api_key: &str) -> Result<ProviderConfigBuilder> {
394        Ok(ProviderConfigBuilder::new()
395            .name(name)?
396            .provider_type("openai")?
397            .api_key(api_key)
398            .add_model("gpt-3.5-turbo")
399            .add_model("gpt-4")
400            .rate_limit(3000))
401    }
402
403    /// Create an Anthropic provider configuration
404    pub fn anthropic_provider(name: &str, api_key: &str) -> Result<ProviderConfigBuilder> {
405        Ok(ProviderConfigBuilder::new()
406            .name(name)?
407            .provider_type("anthropic")?
408            .api_key(api_key)
409            .add_model("claude-3-sonnet")
410            .add_model("claude-3-haiku")
411            .rate_limit(1000))
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_config_builder() {
421        let config = ConfigBuilder::new()
422            .with_server(presets::dev_server().build())
423            .add_provider(
424                presets::openai_provider("openai", "test-key")
425                    .unwrap()
426                    .build()
427                    .unwrap(),
428            )
429            .enable_feature("metrics")
430            .build();
431
432        assert!(config.is_ok());
433        let config = config.unwrap();
434        assert_eq!(config.gateway.server.port, 8080);
435        assert_eq!(config.gateway.providers.len(), 1);
436    }
437
438    #[test]
439    fn test_provider_builder() {
440        let provider = ProviderConfigBuilder::new()
441            .name("test-provider")
442            .unwrap()
443            .provider_type("openai")
444            .unwrap()
445            .api_key("test-key")
446            .add_model("gpt-4")
447            .weight(2.0)
448            .unwrap()
449            .build();
450
451        assert!(provider.is_ok());
452        let provider = provider.unwrap();
453        assert_eq!(provider.name, "test-provider");
454        assert_eq!(provider.weight, 2.0);
455    }
456}