mockforge_data/rag/
config.rs

1//! RAG configuration and settings management
2//!
3//! This module handles all configuration aspects of the RAG system,
4//! including provider settings, model configurations, and operational parameters.
5
6use crate::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10
11/// Supported LLM providers
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "lowercase")]
14pub enum LlmProvider {
15    /// OpenAI GPT models
16    OpenAI,
17    /// Anthropic Claude models
18    Anthropic,
19    /// Generic OpenAI-compatible API
20    OpenAICompatible,
21    /// Local Ollama instance
22    Ollama,
23}
24
25/// Supported embedding providers
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum EmbeddingProvider {
29    /// OpenAI text-embedding-ada-002
30    OpenAI,
31    /// Generic OpenAI-compatible embeddings API
32    OpenAICompatible,
33    /// Local Ollama instance
34    Ollama,
35}
36
37/// RAG configuration
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct RagConfig {
40    /// LLM provider
41    pub provider: LlmProvider,
42    /// LLM API endpoint
43    pub api_endpoint: String,
44    /// API key for authentication
45    pub api_key: Option<String>,
46    /// Model name to use
47    pub model: String,
48    /// Maximum tokens for generation
49    pub max_tokens: usize,
50    /// Temperature for generation (0.0 to 2.0)
51    pub temperature: f32,
52    /// Top-p sampling (0.0 to 1.0)
53    pub top_p: f32,
54    /// Request timeout in seconds
55    pub timeout_secs: u64,
56    /// Maximum retry attempts
57    pub max_retries: u32,
58    /// Embedding provider
59    pub embedding_provider: EmbeddingProvider,
60    /// Embedding model name
61    pub embedding_model: String,
62    /// Embedding dimensions
63    pub embedding_dimensions: usize,
64    /// Chunk size for document splitting
65    pub chunk_size: usize,
66    /// Chunk overlap for document splitting
67    pub chunk_overlap: usize,
68    /// Top-k similar chunks to retrieve
69    pub top_k: usize,
70    /// Similarity threshold (0.0 to 1.0)
71    pub similarity_threshold: f32,
72    /// Enable hybrid search (combines semantic and keyword search)
73    pub hybrid_search: bool,
74    /// Weight for semantic search in hybrid mode (0.0 to 1.0)
75    pub semantic_weight: f32,
76    /// Weight for keyword search in hybrid mode (0.0 to 1.0)
77    pub keyword_weight: f32,
78    /// Enable query expansion
79    pub query_expansion: bool,
80    /// Enable response filtering
81    pub response_filtering: bool,
82    /// Enable caching
83    pub caching: bool,
84    /// Cache TTL in seconds
85    pub cache_ttl_secs: u64,
86    /// Rate limiting configuration
87    pub rate_limiting: RateLimitConfig,
88    /// Retry configuration
89    pub retry_config: RetryConfig,
90    /// Custom headers for API requests
91    pub custom_headers: HashMap<String, String>,
92    /// Enable debug mode
93    pub debug_mode: bool,
94    /// Maximum context length
95    pub max_context_length: usize,
96    /// Response format preferences
97    pub response_format: ResponseFormat,
98    /// Logging configuration
99    pub logging: LoggingConfig,
100    /// Performance monitoring
101    pub monitoring: MonitoringConfig,
102}
103
104impl Default for RagConfig {
105    fn default() -> Self {
106        Self {
107            provider: LlmProvider::OpenAI,
108            api_endpoint: "https://api.openai.com/v1".to_string(),
109            api_key: None,
110            model: "gpt-3.5-turbo".to_string(),
111            max_tokens: 1024,
112            temperature: 0.7,
113            top_p: 0.9,
114            timeout_secs: 30,
115            max_retries: 3,
116            embedding_provider: EmbeddingProvider::OpenAI,
117            embedding_model: "text-embedding-ada-002".to_string(),
118            embedding_dimensions: 1536,
119            chunk_size: 1000,
120            chunk_overlap: 200,
121            top_k: 5,
122            similarity_threshold: 0.7,
123            hybrid_search: true,
124            semantic_weight: 0.7,
125            keyword_weight: 0.3,
126            query_expansion: false,
127            response_filtering: true,
128            caching: true,
129            cache_ttl_secs: 3600,
130            rate_limiting: RateLimitConfig::default(),
131            retry_config: RetryConfig::default(),
132            custom_headers: HashMap::new(),
133            debug_mode: false,
134            max_context_length: 4096,
135            response_format: ResponseFormat::Json,
136            logging: LoggingConfig::default(),
137            monitoring: MonitoringConfig::default(),
138        }
139    }
140}
141
142impl RagConfig {
143    /// Create a new RAG configuration
144    pub fn new(provider: LlmProvider, model: String) -> Self {
145        Self {
146            provider,
147            model,
148            ..Default::default()
149        }
150    }
151
152    /// Set API key
153    pub fn with_api_key(mut self, api_key: String) -> Self {
154        self.api_key = Some(api_key);
155        self
156    }
157
158    /// Set API endpoint
159    pub fn with_endpoint(mut self, endpoint: String) -> Self {
160        self.api_endpoint = endpoint;
161        self
162    }
163
164    /// Set model parameters
165    pub fn with_model_params(mut self, max_tokens: usize, temperature: f32, top_p: f32) -> Self {
166        self.max_tokens = max_tokens;
167        self.temperature = temperature;
168        self.top_p = top_p;
169        self
170    }
171
172    /// Set embedding configuration
173    pub fn with_embedding(
174        mut self,
175        provider: EmbeddingProvider,
176        model: String,
177        dimensions: usize,
178    ) -> Self {
179        self.embedding_provider = provider;
180        self.embedding_model = model;
181        self.embedding_dimensions = dimensions;
182        self
183    }
184
185    /// Set chunking parameters
186    pub fn with_chunking(mut self, chunk_size: usize, chunk_overlap: usize) -> Self {
187        self.chunk_size = chunk_size;
188        self.chunk_overlap = chunk_overlap;
189        self
190    }
191
192    /// Set retrieval parameters
193    pub fn with_retrieval(mut self, top_k: usize, similarity_threshold: f32) -> Self {
194        self.top_k = top_k;
195        self.similarity_threshold = similarity_threshold;
196        self
197    }
198
199    /// Enable hybrid search
200    pub fn with_hybrid_search(mut self, semantic_weight: f32, keyword_weight: f32) -> Self {
201        self.hybrid_search = true;
202        self.semantic_weight = semantic_weight;
203        self.keyword_weight = keyword_weight;
204        self
205    }
206
207    /// Set caching configuration
208    pub fn with_caching(mut self, enabled: bool, ttl_secs: u64) -> Self {
209        self.caching = enabled;
210        self.cache_ttl_secs = ttl_secs;
211        self
212    }
213
214    /// Set rate limiting
215    pub fn with_rate_limit(mut self, requests_per_minute: u32, burst_size: u32) -> Self {
216        self.rate_limiting = RateLimitConfig {
217            requests_per_minute,
218            burst_size,
219            enabled: true,
220        };
221        self
222    }
223
224    /// Set retry configuration
225    pub fn with_retry(mut self, max_attempts: u32, backoff_secs: u64) -> Self {
226        self.retry_config = RetryConfig {
227            max_attempts,
228            backoff_secs,
229            exponential_backoff: true,
230        };
231        self
232    }
233
234    /// Add custom header
235    pub fn with_header(mut self, key: String, value: String) -> Self {
236        self.custom_headers.insert(key, value);
237        self
238    }
239
240    /// Enable debug mode
241    pub fn with_debug_mode(mut self, debug: bool) -> Self {
242        self.debug_mode = debug;
243        self
244    }
245
246    /// Validate configuration
247    pub fn validate(&self) -> Result<()> {
248        if self.api_endpoint.is_empty() {
249            return Err(crate::Error::generic("API endpoint cannot be empty"));
250        }
251
252        if self.model.is_empty() {
253            return Err(crate::Error::generic("Model name cannot be empty"));
254        }
255
256        if !(0.0..=2.0).contains(&self.temperature) {
257            return Err(crate::Error::generic("Temperature must be between 0.0 and 2.0"));
258        }
259
260        if !(0.0..=1.0).contains(&self.top_p) {
261            return Err(crate::Error::generic("Top-p must be between 0.0 and 1.0"));
262        }
263
264        if self.chunk_size == 0 {
265            return Err(crate::Error::generic("Chunk size must be greater than 0"));
266        }
267
268        if self.chunk_overlap >= self.chunk_size {
269            return Err(crate::Error::generic("Chunk overlap must be less than chunk size"));
270        }
271
272        if !(0.0..=1.0).contains(&self.similarity_threshold) {
273            return Err(crate::Error::generic("Similarity threshold must be between 0.0 and 1.0"));
274        }
275
276        if self.hybrid_search {
277            let total_weight = self.semantic_weight + self.keyword_weight;
278            if (total_weight - 1.0).abs() > f32::EPSILON {
279                return Err(crate::Error::generic("Hybrid search weights must sum to 1.0"));
280            }
281        }
282
283        Ok(())
284    }
285
286    /// Get timeout duration
287    pub fn timeout_duration(&self) -> Duration {
288        Duration::from_secs(self.timeout_secs)
289    }
290
291    /// Get cache TTL duration
292    pub fn cache_ttl_duration(&self) -> Duration {
293        Duration::from_secs(self.cache_ttl_secs)
294    }
295
296    /// Check if caching is enabled
297    pub fn is_caching_enabled(&self) -> bool {
298        self.caching
299    }
300
301    /// Check if rate limiting is enabled
302    pub fn is_rate_limited(&self) -> bool {
303        self.rate_limiting.enabled
304    }
305
306    /// Get requests per minute limit
307    pub fn requests_per_minute(&self) -> u32 {
308        self.rate_limiting.requests_per_minute
309    }
310
311    /// Get burst size for rate limiting
312    pub fn burst_size(&self) -> u32 {
313        self.rate_limiting.burst_size
314    }
315
316    /// Get maximum retry attempts
317    pub fn max_retry_attempts(&self) -> u32 {
318        self.retry_config.max_attempts
319    }
320
321    /// Get backoff duration for retries
322    pub fn backoff_duration(&self) -> Duration {
323        Duration::from_secs(self.retry_config.backoff_secs)
324    }
325
326    /// Check if exponential backoff is enabled
327    pub fn is_exponential_backoff(&self) -> bool {
328        self.retry_config.exponential_backoff
329    }
330
331    /// Get response format
332    pub fn response_format(&self) -> &ResponseFormat {
333        &self.response_format
334    }
335
336    /// Get logging configuration
337    pub fn logging_config(&self) -> &LoggingConfig {
338        &self.logging
339    }
340
341    /// Get monitoring configuration
342    pub fn monitoring_config(&self) -> &MonitoringConfig {
343        &self.monitoring
344    }
345}
346
347/// Rate limiting configuration
348#[derive(Debug, Clone, Serialize, Deserialize)]
349pub struct RateLimitConfig {
350    /// Number of requests allowed per minute
351    pub requests_per_minute: u32,
352    /// Burst size for rate limiting
353    pub burst_size: u32,
354    /// Whether rate limiting is enabled
355    pub enabled: bool,
356}
357
358impl Default for RateLimitConfig {
359    fn default() -> Self {
360        Self {
361            requests_per_minute: 60,
362            burst_size: 10,
363            enabled: true,
364        }
365    }
366}
367
368/// Retry configuration
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct RetryConfig {
371    /// Maximum number of retry attempts
372    pub max_attempts: u32,
373    /// Base backoff time in seconds
374    pub backoff_secs: u64,
375    /// Whether to use exponential backoff
376    pub exponential_backoff: bool,
377}
378
379impl Default for RetryConfig {
380    fn default() -> Self {
381        Self {
382            max_attempts: 3,
383            backoff_secs: 1,
384            exponential_backoff: true,
385        }
386    }
387}
388
389/// Response format preferences
390#[derive(Debug, Clone, Serialize, Deserialize)]
391pub enum ResponseFormat {
392    /// Plain text response
393    Text,
394    /// JSON structured response
395    Json,
396    /// Markdown formatted response
397    Markdown,
398    /// Custom format with template
399    Custom(String),
400}
401
402impl Default for ResponseFormat {
403    fn default() -> Self {
404        Self::Json
405    }
406}
407
408/// Logging configuration
409#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct LoggingConfig {
411    /// Log level for RAG operations
412    pub log_level: String,
413    /// Enable request/response logging
414    pub log_requests: bool,
415    /// Enable performance logging
416    pub log_performance: bool,
417    /// Log file path (if any)
418    pub log_file: Option<String>,
419    /// Maximum log file size in MB
420    pub max_log_size_mb: u64,
421}
422
423impl Default for LoggingConfig {
424    fn default() -> Self {
425        Self {
426            log_level: "info".to_string(),
427            log_requests: false,
428            log_performance: true,
429            log_file: None,
430            max_log_size_mb: 100,
431        }
432    }
433}
434
435/// Performance monitoring configuration
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct MonitoringConfig {
438    /// Enable metrics collection
439    pub enable_metrics: bool,
440    /// Metrics collection interval in seconds
441    pub metrics_interval_secs: u64,
442    /// Enable tracing
443    pub enable_tracing: bool,
444    /// Tracing sample rate (0.0 to 1.0)
445    pub trace_sample_rate: f32,
446    /// Performance thresholds
447    pub thresholds: PerformanceThresholds,
448}
449
450impl Default for MonitoringConfig {
451    fn default() -> Self {
452        Self {
453            enable_metrics: true,
454            metrics_interval_secs: 60,
455            enable_tracing: false,
456            trace_sample_rate: 0.1,
457            thresholds: PerformanceThresholds::default(),
458        }
459    }
460}
461
462/// Performance thresholds for monitoring
463#[derive(Debug, Clone, Serialize, Deserialize)]
464pub struct PerformanceThresholds {
465    /// Maximum response time in seconds
466    pub max_response_time_secs: f64,
467    /// Minimum similarity score threshold
468    pub min_similarity_score: f32,
469    /// Maximum memory usage in MB
470    pub max_memory_usage_mb: u64,
471    /// Maximum CPU usage percentage
472    pub max_cpu_usage_percent: f32,
473}
474
475impl Default for PerformanceThresholds {
476    fn default() -> Self {
477        Self {
478            max_response_time_secs: 30.0,
479            min_similarity_score: 0.7,
480            max_memory_usage_mb: 1024,
481            max_cpu_usage_percent: 80.0,
482        }
483    }
484}
485
486/// Configuration builder for RAG
487#[derive(Debug)]
488pub struct RagConfigBuilder {
489    config: RagConfig,
490}
491
492impl RagConfigBuilder {
493    /// Create a new configuration builder
494    pub fn new() -> Self {
495        Self {
496            config: RagConfig::default(),
497        }
498    }
499
500    /// Build the configuration
501    pub fn build(self) -> Result<RagConfig> {
502        self.config.validate()?;
503        Ok(self.config)
504    }
505
506    /// Set the LLM provider
507    pub fn provider(mut self, provider: LlmProvider) -> Self {
508        self.config.provider = provider;
509        self
510    }
511
512    /// Set the model name
513    pub fn model(mut self, model: String) -> Self {
514        self.config.model = model;
515        self
516    }
517
518    /// Set the API key
519    pub fn api_key(mut self, api_key: String) -> Self {
520        self.config.api_key = Some(api_key);
521        self
522    }
523
524    /// Set the API endpoint
525    pub fn endpoint(mut self, endpoint: String) -> Self {
526        self.config.api_endpoint = endpoint;
527        self
528    }
529
530    /// Set model parameters
531    pub fn model_params(mut self, max_tokens: usize, temperature: f32) -> Self {
532        self.config.max_tokens = max_tokens;
533        self.config.temperature = temperature;
534        self
535    }
536
537    /// Set embedding configuration
538    pub fn embedding(mut self, model: String, dimensions: usize) -> Self {
539        self.config.embedding_model = model;
540        self.config.embedding_dimensions = dimensions;
541        self
542    }
543
544    /// Set chunking parameters
545    pub fn chunking(mut self, size: usize, overlap: usize) -> Self {
546        self.config.chunk_size = size;
547        self.config.chunk_overlap = overlap;
548        self
549    }
550
551    /// Set retrieval parameters
552    pub fn retrieval(mut self, top_k: usize, threshold: f32) -> Self {
553        self.config.top_k = top_k;
554        self.config.similarity_threshold = threshold;
555        self
556    }
557
558    /// Enable hybrid search
559    pub fn hybrid_search(mut self, semantic_weight: f32) -> Self {
560        self.config.hybrid_search = true;
561        self.config.semantic_weight = semantic_weight;
562        self.config.keyword_weight = 1.0 - semantic_weight;
563        self
564    }
565
566    /// Enable caching
567    pub fn caching(mut self, enabled: bool) -> Self {
568        self.config.caching = enabled;
569        self
570    }
571
572    /// Set rate limiting
573    pub fn rate_limit(mut self, requests_per_minute: u32) -> Self {
574        self.config.rate_limiting = RateLimitConfig {
575            requests_per_minute,
576            burst_size: requests_per_minute / 6, // 10-second burst
577            enabled: true,
578        };
579        self
580    }
581
582    /// Enable debug mode
583    pub fn debug(mut self, debug: bool) -> Self {
584        self.config.debug_mode = debug;
585        self
586    }
587}
588
589impl Default for RagConfigBuilder {
590    fn default() -> Self {
591        Self::new()
592    }
593}
594
595#[cfg(test)]
596mod tests {
597
598    #[test]
599    fn test_module_compiles() {
600        // Basic compilation test
601    }
602}