rexis_llm/
config.rs

1//! # RSLLM Configuration
2//!
3//! Configuration types and utilities for the RSLLM client library.
4//! Supports environment variables, config files, and programmatic configuration.
5
6use crate::{Provider, RsllmError, RsllmResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10use url::Url;
11
12/// Main client configuration
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ClientConfig {
15    /// Provider configuration
16    pub provider: ProviderConfig,
17
18    /// Model configuration
19    pub model: ModelConfig,
20
21    /// HTTP configuration
22    pub http: HttpConfig,
23
24    /// Retry configuration
25    pub retry: RetryConfig,
26
27    /// Custom headers
28    pub headers: HashMap<String, String>,
29}
30
31impl Default for ClientConfig {
32    fn default() -> Self {
33        Self {
34            provider: ProviderConfig::default(),
35            model: ModelConfig::default(),
36            http: HttpConfig::default(),
37            retry: RetryConfig::default(),
38            headers: HashMap::new(),
39        }
40    }
41}
42
43impl ClientConfig {
44    /// Create a new configuration builder
45    pub fn builder() -> ClientConfigBuilder {
46        ClientConfigBuilder::new()
47    }
48
49    /// Load configuration from environment variables
50    ///
51    /// Supports both generic and provider-specific environment variables:
52    /// - RSLLM_BASE_URL: Generic base URL (overridden by provider-specific)
53    /// - RSLLM_OPENAI_BASE_URL: OpenAI-specific base URL
54    /// - RSLLM_OLLAMA_BASE_URL: Ollama-specific base URL
55    /// - RSLLM_CLAUDE_BASE_URL: Claude-specific base URL
56    /// - RSLLM_MODEL: Generic model name
57    /// - RSLLM_OPENAI_MODEL: OpenAI-specific model
58    /// - RSLLM_OLLAMA_MODEL: Ollama-specific model
59    /// - RSLLM_CLAUDE_MODEL: Claude-specific model
60    pub fn from_env() -> RsllmResult<Self> {
61        dotenv::dotenv().ok(); // Load .env file if present
62
63        let mut config = Self::default();
64
65        // Provider configuration
66        if let Ok(provider_str) = std::env::var("RSLLM_PROVIDER") {
67            config.provider.provider = provider_str.parse()?;
68        }
69
70        if let Ok(api_key) = std::env::var("RSLLM_API_KEY") {
71            config.provider.api_key = Some(api_key);
72        }
73
74        // Base URL: Try provider-specific first, then generic
75        let provider_name = config.provider.provider.to_string().to_uppercase();
76        let provider_specific_url_key = format!("RSLLM_{}_BASE_URL", provider_name);
77
78        if let Ok(base_url) = std::env::var(&provider_specific_url_key) {
79            config.provider.base_url = Some(base_url.parse()?);
80        } else if let Ok(base_url) = std::env::var("RSLLM_BASE_URL") {
81            config.provider.base_url = Some(base_url.parse()?);
82        }
83
84        // Model configuration: Try provider-specific first, then generic
85        let provider_specific_model_key = format!("RSLLM_{}_MODEL", provider_name);
86
87        if let Ok(model) = std::env::var(&provider_specific_model_key) {
88            config.model.model = model;
89        } else if let Ok(model) = std::env::var("RSLLM_MODEL") {
90            config.model.model = model;
91        }
92
93        if let Ok(temp_str) = std::env::var("RSLLM_TEMPERATURE") {
94            config.model.temperature = Some(
95                temp_str
96                    .parse()
97                    .map_err(|_| RsllmError::configuration("Invalid temperature value"))?,
98            );
99        }
100
101        if let Ok(max_tokens_str) = std::env::var("RSLLM_MAX_TOKENS") {
102            config.model.max_tokens = Some(
103                max_tokens_str
104                    .parse()
105                    .map_err(|_| RsllmError::configuration("Invalid max_tokens value"))?,
106            );
107        }
108
109        // HTTP configuration
110        if let Ok(timeout_str) = std::env::var("RSLLM_TIMEOUT") {
111            let timeout_secs: u64 = timeout_str
112                .parse()
113                .map_err(|_| RsllmError::configuration("Invalid timeout value"))?;
114            config.http.timeout = Duration::from_secs(timeout_secs);
115        }
116
117        Ok(config)
118    }
119
120    /// Validate the configuration
121    pub fn validate(&self) -> RsllmResult<()> {
122        // Validate provider
123        self.provider.validate()?;
124
125        // Validate model
126        self.model.validate()?;
127
128        // Validate HTTP config
129        self.http.validate()?;
130
131        Ok(())
132    }
133}
134
135/// Provider-specific configuration
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ProviderConfig {
138    /// LLM provider type
139    pub provider: Provider,
140
141    /// API key for the provider
142    pub api_key: Option<String>,
143
144    /// Base URL for the provider (if custom)
145    pub base_url: Option<Url>,
146
147    /// Organization ID (for providers that support it)
148    pub organization_id: Option<String>,
149
150    /// Custom provider-specific settings
151    pub custom_settings: HashMap<String, serde_json::Value>,
152}
153
154impl Default for ProviderConfig {
155    fn default() -> Self {
156        Self {
157            provider: Provider::OpenAI,
158            api_key: None,
159            base_url: None,
160            organization_id: None,
161            custom_settings: HashMap::new(),
162        }
163    }
164}
165
166impl ProviderConfig {
167    /// Validate provider configuration
168    pub fn validate(&self) -> RsllmResult<()> {
169        // Check if API key is required and present
170        // For custom base URLs, we allow flexibility - the user may have
171        // their own authentication mechanism
172        match self.provider {
173            Provider::OpenAI | Provider::Claude => {
174                // Only require API key if using default endpoints
175                if self.api_key.is_none() && self.base_url.is_none() {
176                    return Err(RsllmError::configuration(format!(
177                        "API key required for provider: {:?} (or provide a custom base_url)",
178                        self.provider
179                    )));
180                }
181            }
182            Provider::Ollama => {
183                // Ollama typically doesn't require an API key for local instances
184            }
185        }
186
187        // Validate base URL if provided
188        if let Some(url) = &self.base_url {
189            if url.scheme() != "http" && url.scheme() != "https" {
190                return Err(RsllmError::configuration(
191                    "Base URL must use HTTP or HTTPS scheme",
192                ));
193            }
194        }
195
196        Ok(())
197    }
198
199    /// Get the effective base URL for the provider
200    pub fn effective_base_url(&self) -> RsllmResult<Url> {
201        if let Some(url) = &self.base_url {
202            Ok(url.clone())
203        } else {
204            Ok(self.provider.default_base_url())
205        }
206    }
207}
208
209/// Model configuration
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ModelConfig {
212    /// Model name/identifier
213    pub model: String,
214
215    /// Temperature for sampling (0.0 to 2.0)
216    pub temperature: Option<f32>,
217
218    /// Maximum tokens to generate
219    pub max_tokens: Option<u32>,
220
221    /// Top-p sampling parameter
222    pub top_p: Option<f32>,
223
224    /// Frequency penalty
225    pub frequency_penalty: Option<f32>,
226
227    /// Presence penalty
228    pub presence_penalty: Option<f32>,
229
230    /// Stop sequences
231    pub stop: Option<Vec<String>>,
232
233    /// Whether to stream responses
234    pub stream: bool,
235}
236
237impl Default for ModelConfig {
238    fn default() -> Self {
239        Self {
240            model: "gpt-3.5-turbo".to_string(),
241            temperature: Some(0.7),
242            max_tokens: None,
243            top_p: None,
244            frequency_penalty: None,
245            presence_penalty: None,
246            stop: None,
247            stream: false,
248        }
249    }
250}
251
252impl ModelConfig {
253    /// Validate model configuration
254    ///
255    /// Note: This validation intentionally does NOT restrict model names to a predefined list.
256    /// Users can specify custom models for:
257    /// - Custom fine-tuned models
258    /// - Self-hosted LLM endpoints
259    /// - Future models not yet in the library
260    /// - Alternative model naming schemes
261    pub fn validate(&self) -> RsllmResult<()> {
262        if self.model.is_empty() {
263            return Err(RsllmError::validation(
264                "model",
265                "Model name cannot be empty",
266            ));
267        }
268
269        if let Some(temp) = self.temperature {
270            if !(0.0..=2.0).contains(&temp) {
271                return Err(RsllmError::validation(
272                    "temperature",
273                    "Temperature must be between 0.0 and 2.0",
274                ));
275            }
276        }
277
278        if let Some(top_p) = self.top_p {
279            if !(0.0..=1.0).contains(&top_p) {
280                return Err(RsllmError::validation(
281                    "top_p",
282                    "Top-p must be between 0.0 and 1.0",
283                ));
284            }
285        }
286
287        if let Some(freq_penalty) = self.frequency_penalty {
288            if !(-2.0..=2.0).contains(&freq_penalty) {
289                return Err(RsllmError::validation(
290                    "frequency_penalty",
291                    "Frequency penalty must be between -2.0 and 2.0",
292                ));
293            }
294        }
295
296        if let Some(pres_penalty) = self.presence_penalty {
297            if !(-2.0..=2.0).contains(&pres_penalty) {
298                return Err(RsllmError::validation(
299                    "presence_penalty",
300                    "Presence penalty must be between -2.0 and 2.0",
301                ));
302            }
303        }
304
305        Ok(())
306    }
307}
308
309/// HTTP configuration
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct HttpConfig {
312    /// Request timeout
313    pub timeout: Duration,
314
315    /// Connection timeout
316    pub connect_timeout: Duration,
317
318    /// Maximum number of redirects to follow
319    pub max_redirects: u32,
320
321    /// User agent string
322    pub user_agent: String,
323
324    /// Whether to use TLS verification
325    pub verify_tls: bool,
326}
327
328impl Default for HttpConfig {
329    fn default() -> Self {
330        Self {
331            timeout: Duration::from_secs(30),
332            connect_timeout: Duration::from_secs(10),
333            max_redirects: 5,
334            user_agent: format!("rsllm/{}", crate::VERSION),
335            verify_tls: true,
336        }
337    }
338}
339
340impl HttpConfig {
341    /// Validate HTTP configuration
342    pub fn validate(&self) -> RsllmResult<()> {
343        if self.timeout.as_secs() == 0 {
344            return Err(RsllmError::validation(
345                "timeout",
346                "Timeout must be greater than 0",
347            ));
348        }
349
350        if self.connect_timeout.as_secs() == 0 {
351            return Err(RsllmError::validation(
352                "connect_timeout",
353                "Connect timeout must be greater than 0",
354            ));
355        }
356
357        Ok(())
358    }
359}
360
361/// Retry configuration
362#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct RetryConfig {
364    /// Maximum number of retries
365    pub max_retries: u32,
366
367    /// Base delay between retries
368    pub base_delay: Duration,
369
370    /// Maximum delay between retries
371    pub max_delay: Duration,
372
373    /// Backoff multiplier
374    pub backoff_multiplier: f32,
375
376    /// Whether to add jitter to retry delays
377    pub jitter: bool,
378}
379
380impl Default for RetryConfig {
381    fn default() -> Self {
382        Self {
383            max_retries: 3,
384            base_delay: Duration::from_millis(500),
385            max_delay: Duration::from_secs(30),
386            backoff_multiplier: 2.0,
387            jitter: true,
388        }
389    }
390}
391
392/// Builder for client configuration
393pub struct ClientConfigBuilder {
394    config: ClientConfig,
395}
396
397impl ClientConfigBuilder {
398    /// Create a new builder
399    pub fn new() -> Self {
400        Self {
401            config: ClientConfig::default(),
402        }
403    }
404
405    /// Set the provider
406    pub fn provider(mut self, provider: Provider) -> Self {
407        self.config.provider.provider = provider;
408        self
409    }
410
411    /// Set the API key
412    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
413        self.config.provider.api_key = Some(api_key.into());
414        self
415    }
416
417    /// Set the base URL
418    pub fn base_url(mut self, base_url: impl AsRef<str>) -> RsllmResult<Self> {
419        self.config.provider.base_url = Some(base_url.as_ref().parse()?);
420        Ok(self)
421    }
422
423    /// Set the model
424    pub fn model(mut self, model: impl Into<String>) -> Self {
425        self.config.model.model = model.into();
426        self
427    }
428
429    /// Set the temperature
430    pub fn temperature(mut self, temperature: f32) -> Self {
431        self.config.model.temperature = Some(temperature);
432        self
433    }
434
435    /// Set max tokens
436    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
437        self.config.model.max_tokens = Some(max_tokens);
438        self
439    }
440
441    /// Enable streaming
442    pub fn stream(mut self, stream: bool) -> Self {
443        self.config.model.stream = stream;
444        self
445    }
446
447    /// Set timeout
448    pub fn timeout(mut self, timeout: Duration) -> Self {
449        self.config.http.timeout = timeout;
450        self
451    }
452
453    /// Add a custom header
454    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
455        self.config.headers.insert(key.into(), value.into());
456        self
457    }
458
459    /// Build the configuration
460    pub fn build(self) -> RsllmResult<ClientConfig> {
461        self.config.validate()?;
462        Ok(self.config)
463    }
464}
465
466impl Default for ClientConfigBuilder {
467    fn default() -> Self {
468        Self::new()
469    }
470}