1use crate::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "lowercase")]
14pub enum LlmProvider {
15 OpenAI,
17 Anthropic,
19 OpenAICompatible,
21 Ollama,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum EmbeddingProvider {
29 OpenAI,
31 OpenAICompatible,
33 Ollama,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct RagConfig {
40 pub provider: LlmProvider,
42 pub api_endpoint: String,
44 pub api_key: Option<String>,
46 pub model: String,
48 pub max_tokens: usize,
50 pub temperature: f32,
52 pub top_p: f32,
54 pub timeout_secs: u64,
56 pub max_retries: u32,
58 pub embedding_provider: EmbeddingProvider,
60 pub embedding_model: String,
62 pub embedding_dimensions: usize,
64 pub chunk_size: usize,
66 pub chunk_overlap: usize,
68 pub top_k: usize,
70 pub similarity_threshold: f32,
72 pub hybrid_search: bool,
74 pub semantic_weight: f32,
76 pub keyword_weight: f32,
78 pub query_expansion: bool,
80 pub response_filtering: bool,
82 pub caching: bool,
84 pub cache_ttl_secs: u64,
86 pub rate_limiting: RateLimitConfig,
88 pub retry_config: RetryConfig,
90 pub custom_headers: HashMap<String, String>,
92 pub debug_mode: bool,
94 pub max_context_length: usize,
96 pub response_format: ResponseFormat,
98 pub logging: LoggingConfig,
100 pub monitoring: MonitoringConfig,
102}
103
104impl Default for RagConfig {
105 fn default() -> Self {
106 Self {
107 provider: LlmProvider::OpenAI,
108 api_endpoint: "https://api.openai.com/v1".to_string(),
109 api_key: None,
110 model: "gpt-3.5-turbo".to_string(),
111 max_tokens: 1024,
112 temperature: 0.7,
113 top_p: 0.9,
114 timeout_secs: 30,
115 max_retries: 3,
116 embedding_provider: EmbeddingProvider::OpenAI,
117 embedding_model: "text-embedding-ada-002".to_string(),
118 embedding_dimensions: 1536,
119 chunk_size: 1000,
120 chunk_overlap: 200,
121 top_k: 5,
122 similarity_threshold: 0.7,
123 hybrid_search: true,
124 semantic_weight: 0.7,
125 keyword_weight: 0.3,
126 query_expansion: false,
127 response_filtering: true,
128 caching: true,
129 cache_ttl_secs: 3600,
130 rate_limiting: RateLimitConfig::default(),
131 retry_config: RetryConfig::default(),
132 custom_headers: HashMap::new(),
133 debug_mode: false,
134 max_context_length: 4096,
135 response_format: ResponseFormat::Json,
136 logging: LoggingConfig::default(),
137 monitoring: MonitoringConfig::default(),
138 }
139 }
140}
141
142impl RagConfig {
143 pub fn new(provider: LlmProvider, model: String) -> Self {
145 Self {
146 provider,
147 model,
148 ..Default::default()
149 }
150 }
151
152 pub fn with_api_key(mut self, api_key: String) -> Self {
154 self.api_key = Some(api_key);
155 self
156 }
157
158 pub fn with_endpoint(mut self, endpoint: String) -> Self {
160 self.api_endpoint = endpoint;
161 self
162 }
163
164 pub fn with_model_params(mut self, max_tokens: usize, temperature: f32, top_p: f32) -> Self {
166 self.max_tokens = max_tokens;
167 self.temperature = temperature;
168 self.top_p = top_p;
169 self
170 }
171
172 pub fn with_embedding(
174 mut self,
175 provider: EmbeddingProvider,
176 model: String,
177 dimensions: usize,
178 ) -> Self {
179 self.embedding_provider = provider;
180 self.embedding_model = model;
181 self.embedding_dimensions = dimensions;
182 self
183 }
184
185 pub fn with_chunking(mut self, chunk_size: usize, chunk_overlap: usize) -> Self {
187 self.chunk_size = chunk_size;
188 self.chunk_overlap = chunk_overlap;
189 self
190 }
191
192 pub fn with_retrieval(mut self, top_k: usize, similarity_threshold: f32) -> Self {
194 self.top_k = top_k;
195 self.similarity_threshold = similarity_threshold;
196 self
197 }
198
199 pub fn with_hybrid_search(mut self, semantic_weight: f32, keyword_weight: f32) -> Self {
201 self.hybrid_search = true;
202 self.semantic_weight = semantic_weight;
203 self.keyword_weight = keyword_weight;
204 self
205 }
206
207 pub fn with_caching(mut self, enabled: bool, ttl_secs: u64) -> Self {
209 self.caching = enabled;
210 self.cache_ttl_secs = ttl_secs;
211 self
212 }
213
214 pub fn with_rate_limit(mut self, requests_per_minute: u32, burst_size: u32) -> Self {
216 self.rate_limiting = RateLimitConfig {
217 requests_per_minute,
218 burst_size,
219 enabled: true,
220 };
221 self
222 }
223
224 pub fn with_retry(mut self, max_attempts: u32, backoff_secs: u64) -> Self {
226 self.retry_config = RetryConfig {
227 max_attempts,
228 backoff_secs,
229 exponential_backoff: true,
230 };
231 self
232 }
233
234 pub fn with_header(mut self, key: String, value: String) -> Self {
236 self.custom_headers.insert(key, value);
237 self
238 }
239
240 pub fn with_debug_mode(mut self, debug: bool) -> Self {
242 self.debug_mode = debug;
243 self
244 }
245
246 pub fn validate(&self) -> Result<()> {
248 if self.api_endpoint.is_empty() {
249 return Err(crate::Error::generic("API endpoint cannot be empty"));
250 }
251
252 if self.model.is_empty() {
253 return Err(crate::Error::generic("Model name cannot be empty"));
254 }
255
256 if !(0.0..=2.0).contains(&self.temperature) {
257 return Err(crate::Error::generic("Temperature must be between 0.0 and 2.0"));
258 }
259
260 if !(0.0..=1.0).contains(&self.top_p) {
261 return Err(crate::Error::generic("Top-p must be between 0.0 and 1.0"));
262 }
263
264 if self.chunk_size == 0 {
265 return Err(crate::Error::generic("Chunk size must be greater than 0"));
266 }
267
268 if self.chunk_overlap >= self.chunk_size {
269 return Err(crate::Error::generic("Chunk overlap must be less than chunk size"));
270 }
271
272 if !(0.0..=1.0).contains(&self.similarity_threshold) {
273 return Err(crate::Error::generic("Similarity threshold must be between 0.0 and 1.0"));
274 }
275
276 if self.hybrid_search {
277 let total_weight = self.semantic_weight + self.keyword_weight;
278 if (total_weight - 1.0).abs() > f32::EPSILON {
279 return Err(crate::Error::generic("Hybrid search weights must sum to 1.0"));
280 }
281 }
282
283 Ok(())
284 }
285
286 pub fn timeout_duration(&self) -> Duration {
288 Duration::from_secs(self.timeout_secs)
289 }
290
291 pub fn cache_ttl_duration(&self) -> Duration {
293 Duration::from_secs(self.cache_ttl_secs)
294 }
295
296 pub fn is_caching_enabled(&self) -> bool {
298 self.caching
299 }
300
301 pub fn is_rate_limited(&self) -> bool {
303 self.rate_limiting.enabled
304 }
305
306 pub fn requests_per_minute(&self) -> u32 {
308 self.rate_limiting.requests_per_minute
309 }
310
311 pub fn burst_size(&self) -> u32 {
313 self.rate_limiting.burst_size
314 }
315
316 pub fn max_retry_attempts(&self) -> u32 {
318 self.retry_config.max_attempts
319 }
320
321 pub fn backoff_duration(&self) -> Duration {
323 Duration::from_secs(self.retry_config.backoff_secs)
324 }
325
326 pub fn is_exponential_backoff(&self) -> bool {
328 self.retry_config.exponential_backoff
329 }
330
331 pub fn response_format(&self) -> &ResponseFormat {
333 &self.response_format
334 }
335
336 pub fn logging_config(&self) -> &LoggingConfig {
338 &self.logging
339 }
340
341 pub fn monitoring_config(&self) -> &MonitoringConfig {
343 &self.monitoring
344 }
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize)]
349pub struct RateLimitConfig {
350 pub requests_per_minute: u32,
352 pub burst_size: u32,
354 pub enabled: bool,
356}
357
358impl Default for RateLimitConfig {
359 fn default() -> Self {
360 Self {
361 requests_per_minute: 60,
362 burst_size: 10,
363 enabled: true,
364 }
365 }
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct RetryConfig {
371 pub max_attempts: u32,
373 pub backoff_secs: u64,
375 pub exponential_backoff: bool,
377}
378
379impl Default for RetryConfig {
380 fn default() -> Self {
381 Self {
382 max_attempts: 3,
383 backoff_secs: 1,
384 exponential_backoff: true,
385 }
386 }
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
391pub enum ResponseFormat {
392 Text,
394 Json,
396 Markdown,
398 Custom(String),
400}
401
402impl Default for ResponseFormat {
403 fn default() -> Self {
404 Self::Json
405 }
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct LoggingConfig {
411 pub log_level: String,
413 pub log_requests: bool,
415 pub log_performance: bool,
417 pub log_file: Option<String>,
419 pub max_log_size_mb: u64,
421}
422
423impl Default for LoggingConfig {
424 fn default() -> Self {
425 Self {
426 log_level: "info".to_string(),
427 log_requests: false,
428 log_performance: true,
429 log_file: None,
430 max_log_size_mb: 100,
431 }
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct MonitoringConfig {
438 pub enable_metrics: bool,
440 pub metrics_interval_secs: u64,
442 pub enable_tracing: bool,
444 pub trace_sample_rate: f32,
446 pub thresholds: PerformanceThresholds,
448}
449
450impl Default for MonitoringConfig {
451 fn default() -> Self {
452 Self {
453 enable_metrics: true,
454 metrics_interval_secs: 60,
455 enable_tracing: false,
456 trace_sample_rate: 0.1,
457 thresholds: PerformanceThresholds::default(),
458 }
459 }
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize)]
464pub struct PerformanceThresholds {
465 pub max_response_time_secs: f64,
467 pub min_similarity_score: f32,
469 pub max_memory_usage_mb: u64,
471 pub max_cpu_usage_percent: f32,
473}
474
475impl Default for PerformanceThresholds {
476 fn default() -> Self {
477 Self {
478 max_response_time_secs: 30.0,
479 min_similarity_score: 0.7,
480 max_memory_usage_mb: 1024,
481 max_cpu_usage_percent: 80.0,
482 }
483 }
484}
485
486#[derive(Debug)]
488pub struct RagConfigBuilder {
489 config: RagConfig,
490}
491
492impl RagConfigBuilder {
493 pub fn new() -> Self {
495 Self {
496 config: RagConfig::default(),
497 }
498 }
499
500 pub fn build(self) -> Result<RagConfig> {
502 self.config.validate()?;
503 Ok(self.config)
504 }
505
506 pub fn provider(mut self, provider: LlmProvider) -> Self {
508 self.config.provider = provider;
509 self
510 }
511
512 pub fn model(mut self, model: String) -> Self {
514 self.config.model = model;
515 self
516 }
517
518 pub fn api_key(mut self, api_key: String) -> Self {
520 self.config.api_key = Some(api_key);
521 self
522 }
523
524 pub fn endpoint(mut self, endpoint: String) -> Self {
526 self.config.api_endpoint = endpoint;
527 self
528 }
529
530 pub fn model_params(mut self, max_tokens: usize, temperature: f32) -> Self {
532 self.config.max_tokens = max_tokens;
533 self.config.temperature = temperature;
534 self
535 }
536
537 pub fn embedding(mut self, model: String, dimensions: usize) -> Self {
539 self.config.embedding_model = model;
540 self.config.embedding_dimensions = dimensions;
541 self
542 }
543
544 pub fn chunking(mut self, size: usize, overlap: usize) -> Self {
546 self.config.chunk_size = size;
547 self.config.chunk_overlap = overlap;
548 self
549 }
550
551 pub fn retrieval(mut self, top_k: usize, threshold: f32) -> Self {
553 self.config.top_k = top_k;
554 self.config.similarity_threshold = threshold;
555 self
556 }
557
558 pub fn hybrid_search(mut self, semantic_weight: f32) -> Self {
560 self.config.hybrid_search = true;
561 self.config.semantic_weight = semantic_weight;
562 self.config.keyword_weight = 1.0 - semantic_weight;
563 self
564 }
565
566 pub fn caching(mut self, enabled: bool) -> Self {
568 self.config.caching = enabled;
569 self
570 }
571
572 pub fn rate_limit(mut self, requests_per_minute: u32) -> Self {
574 self.config.rate_limiting = RateLimitConfig {
575 requests_per_minute,
576 burst_size: requests_per_minute / 6, enabled: true,
578 };
579 self
580 }
581
582 pub fn debug(mut self, debug: bool) -> Self {
584 self.config.debug_mode = debug;
585 self
586 }
587}
588
589impl Default for RagConfigBuilder {
590 fn default() -> Self {
591 Self::new()
592 }
593}
594
595#[cfg(test)]
596mod tests {
597
598 #[test]
599 fn test_module_compiles() {
600 }
602}