1use crate::{Provider, RsllmError, RsllmResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10use url::Url;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ClientConfig {
15 pub provider: ProviderConfig,
17
18 pub model: ModelConfig,
20
21 pub http: HttpConfig,
23
24 pub retry: RetryConfig,
26
27 pub headers: HashMap<String, String>,
29}
30
31impl Default for ClientConfig {
32 fn default() -> Self {
33 Self {
34 provider: ProviderConfig::default(),
35 model: ModelConfig::default(),
36 http: HttpConfig::default(),
37 retry: RetryConfig::default(),
38 headers: HashMap::new(),
39 }
40 }
41}
42
43impl ClientConfig {
44 pub fn builder() -> ClientConfigBuilder {
46 ClientConfigBuilder::new()
47 }
48
49 pub fn from_env() -> RsllmResult<Self> {
61 dotenv::dotenv().ok(); let mut config = Self::default();
64
65 if let Ok(provider_str) = std::env::var("RSLLM_PROVIDER") {
67 config.provider.provider = provider_str.parse()?;
68 }
69
70 if let Ok(api_key) = std::env::var("RSLLM_API_KEY") {
71 config.provider.api_key = Some(api_key);
72 }
73
74 let provider_name = config.provider.provider.to_string().to_uppercase();
76 let provider_specific_url_key = format!("RSLLM_{}_BASE_URL", provider_name);
77
78 if let Ok(base_url) = std::env::var(&provider_specific_url_key) {
79 config.provider.base_url = Some(base_url.parse()?);
80 } else if let Ok(base_url) = std::env::var("RSLLM_BASE_URL") {
81 config.provider.base_url = Some(base_url.parse()?);
82 }
83
84 let provider_specific_model_key = format!("RSLLM_{}_MODEL", provider_name);
86
87 if let Ok(model) = std::env::var(&provider_specific_model_key) {
88 config.model.model = model;
89 } else if let Ok(model) = std::env::var("RSLLM_MODEL") {
90 config.model.model = model;
91 }
92
93 if let Ok(temp_str) = std::env::var("RSLLM_TEMPERATURE") {
94 config.model.temperature = Some(
95 temp_str
96 .parse()
97 .map_err(|_| RsllmError::configuration("Invalid temperature value"))?,
98 );
99 }
100
101 if let Ok(max_tokens_str) = std::env::var("RSLLM_MAX_TOKENS") {
102 config.model.max_tokens = Some(
103 max_tokens_str
104 .parse()
105 .map_err(|_| RsllmError::configuration("Invalid max_tokens value"))?,
106 );
107 }
108
109 if let Ok(timeout_str) = std::env::var("RSLLM_TIMEOUT") {
111 let timeout_secs: u64 = timeout_str
112 .parse()
113 .map_err(|_| RsllmError::configuration("Invalid timeout value"))?;
114 config.http.timeout = Duration::from_secs(timeout_secs);
115 }
116
117 Ok(config)
118 }
119
120 pub fn validate(&self) -> RsllmResult<()> {
122 self.provider.validate()?;
124
125 self.model.validate()?;
127
128 self.http.validate()?;
130
131 Ok(())
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ProviderConfig {
138 pub provider: Provider,
140
141 pub api_key: Option<String>,
143
144 pub base_url: Option<Url>,
146
147 pub organization_id: Option<String>,
149
150 pub custom_settings: HashMap<String, serde_json::Value>,
152}
153
154impl Default for ProviderConfig {
155 fn default() -> Self {
156 Self {
157 provider: Provider::OpenAI,
158 api_key: None,
159 base_url: None,
160 organization_id: None,
161 custom_settings: HashMap::new(),
162 }
163 }
164}
165
166impl ProviderConfig {
167 pub fn validate(&self) -> RsllmResult<()> {
169 match self.provider {
173 Provider::OpenAI | Provider::Claude => {
174 if self.api_key.is_none() && self.base_url.is_none() {
176 return Err(RsllmError::configuration(format!(
177 "API key required for provider: {:?} (or provide a custom base_url)",
178 self.provider
179 )));
180 }
181 }
182 Provider::Ollama => {
183 }
185 }
186
187 if let Some(url) = &self.base_url {
189 if url.scheme() != "http" && url.scheme() != "https" {
190 return Err(RsllmError::configuration(
191 "Base URL must use HTTP or HTTPS scheme",
192 ));
193 }
194 }
195
196 Ok(())
197 }
198
199 pub fn effective_base_url(&self) -> RsllmResult<Url> {
201 if let Some(url) = &self.base_url {
202 Ok(url.clone())
203 } else {
204 Ok(self.provider.default_base_url())
205 }
206 }
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ModelConfig {
212 pub model: String,
214
215 pub temperature: Option<f32>,
217
218 pub max_tokens: Option<u32>,
220
221 pub top_p: Option<f32>,
223
224 pub frequency_penalty: Option<f32>,
226
227 pub presence_penalty: Option<f32>,
229
230 pub stop: Option<Vec<String>>,
232
233 pub stream: bool,
235}
236
237impl Default for ModelConfig {
238 fn default() -> Self {
239 Self {
240 model: "gpt-3.5-turbo".to_string(),
241 temperature: Some(0.7),
242 max_tokens: None,
243 top_p: None,
244 frequency_penalty: None,
245 presence_penalty: None,
246 stop: None,
247 stream: false,
248 }
249 }
250}
251
252impl ModelConfig {
253 pub fn validate(&self) -> RsllmResult<()> {
262 if self.model.is_empty() {
263 return Err(RsllmError::validation(
264 "model",
265 "Model name cannot be empty",
266 ));
267 }
268
269 if let Some(temp) = self.temperature {
270 if !(0.0..=2.0).contains(&temp) {
271 return Err(RsllmError::validation(
272 "temperature",
273 "Temperature must be between 0.0 and 2.0",
274 ));
275 }
276 }
277
278 if let Some(top_p) = self.top_p {
279 if !(0.0..=1.0).contains(&top_p) {
280 return Err(RsllmError::validation(
281 "top_p",
282 "Top-p must be between 0.0 and 1.0",
283 ));
284 }
285 }
286
287 if let Some(freq_penalty) = self.frequency_penalty {
288 if !(-2.0..=2.0).contains(&freq_penalty) {
289 return Err(RsllmError::validation(
290 "frequency_penalty",
291 "Frequency penalty must be between -2.0 and 2.0",
292 ));
293 }
294 }
295
296 if let Some(pres_penalty) = self.presence_penalty {
297 if !(-2.0..=2.0).contains(&pres_penalty) {
298 return Err(RsllmError::validation(
299 "presence_penalty",
300 "Presence penalty must be between -2.0 and 2.0",
301 ));
302 }
303 }
304
305 Ok(())
306 }
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct HttpConfig {
312 pub timeout: Duration,
314
315 pub connect_timeout: Duration,
317
318 pub max_redirects: u32,
320
321 pub user_agent: String,
323
324 pub verify_tls: bool,
326}
327
328impl Default for HttpConfig {
329 fn default() -> Self {
330 Self {
331 timeout: Duration::from_secs(30),
332 connect_timeout: Duration::from_secs(10),
333 max_redirects: 5,
334 user_agent: format!("rsllm/{}", crate::VERSION),
335 verify_tls: true,
336 }
337 }
338}
339
340impl HttpConfig {
341 pub fn validate(&self) -> RsllmResult<()> {
343 if self.timeout.as_secs() == 0 {
344 return Err(RsllmError::validation(
345 "timeout",
346 "Timeout must be greater than 0",
347 ));
348 }
349
350 if self.connect_timeout.as_secs() == 0 {
351 return Err(RsllmError::validation(
352 "connect_timeout",
353 "Connect timeout must be greater than 0",
354 ));
355 }
356
357 Ok(())
358 }
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
363pub struct RetryConfig {
364 pub max_retries: u32,
366
367 pub base_delay: Duration,
369
370 pub max_delay: Duration,
372
373 pub backoff_multiplier: f32,
375
376 pub jitter: bool,
378}
379
380impl Default for RetryConfig {
381 fn default() -> Self {
382 Self {
383 max_retries: 3,
384 base_delay: Duration::from_millis(500),
385 max_delay: Duration::from_secs(30),
386 backoff_multiplier: 2.0,
387 jitter: true,
388 }
389 }
390}
391
392pub struct ClientConfigBuilder {
394 config: ClientConfig,
395}
396
397impl ClientConfigBuilder {
398 pub fn new() -> Self {
400 Self {
401 config: ClientConfig::default(),
402 }
403 }
404
405 pub fn provider(mut self, provider: Provider) -> Self {
407 self.config.provider.provider = provider;
408 self
409 }
410
411 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
413 self.config.provider.api_key = Some(api_key.into());
414 self
415 }
416
417 pub fn base_url(mut self, base_url: impl AsRef<str>) -> RsllmResult<Self> {
419 self.config.provider.base_url = Some(base_url.as_ref().parse()?);
420 Ok(self)
421 }
422
423 pub fn model(mut self, model: impl Into<String>) -> Self {
425 self.config.model.model = model.into();
426 self
427 }
428
429 pub fn temperature(mut self, temperature: f32) -> Self {
431 self.config.model.temperature = Some(temperature);
432 self
433 }
434
435 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
437 self.config.model.max_tokens = Some(max_tokens);
438 self
439 }
440
441 pub fn stream(mut self, stream: bool) -> Self {
443 self.config.model.stream = stream;
444 self
445 }
446
447 pub fn timeout(mut self, timeout: Duration) -> Self {
449 self.config.http.timeout = timeout;
450 self
451 }
452
453 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
455 self.config.headers.insert(key.into(), value.into());
456 self
457 }
458
459 pub fn build(self) -> RsllmResult<ClientConfig> {
461 self.config.validate()?;
462 Ok(self.config)
463 }
464}
465
466impl Default for ClientConfigBuilder {
467 fn default() -> Self {
468 Self::new()
469 }
470}