mockforge_data/rag/
config.rs

1//! RAG configuration and settings management
2//!
3//! This module handles all configuration aspects of the RAG system,
4//! including provider settings, model configurations, and operational parameters.
5
6use crate::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10
11/// Supported LLM providers
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "lowercase")]
14pub enum LlmProvider {
15    /// OpenAI GPT models
16    OpenAI,
17    /// Anthropic Claude models
18    Anthropic,
19    /// Generic OpenAI-compatible API
20    OpenAICompatible,
21    /// Local Ollama instance
22    Ollama,
23}
24
25/// Supported embedding providers
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum EmbeddingProvider {
29    /// OpenAI text-embedding-ada-002
30    OpenAI,
31    /// Generic OpenAI-compatible embeddings API
32    OpenAICompatible,
33}
34
35/// RAG configuration
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct RagConfig {
38    /// LLM provider
39    pub provider: LlmProvider,
40    /// LLM API endpoint
41    pub api_endpoint: String,
42    /// API key for authentication
43    pub api_key: Option<String>,
44    /// Model name to use
45    pub model: String,
46    /// Maximum tokens for generation
47    pub max_tokens: usize,
48    /// Temperature for generation (0.0 to 2.0)
49    pub temperature: f32,
50    /// Top-p sampling (0.0 to 1.0)
51    pub top_p: f32,
52    /// Request timeout in seconds
53    pub timeout_secs: u64,
54    /// Maximum retry attempts
55    pub max_retries: u32,
56    /// Embedding provider
57    pub embedding_provider: EmbeddingProvider,
58    /// Embedding model name
59    pub embedding_model: String,
60    /// Embedding dimensions
61    pub embedding_dimensions: usize,
62    /// Chunk size for document splitting
63    pub chunk_size: usize,
64    /// Chunk overlap for document splitting
65    pub chunk_overlap: usize,
66    /// Top-k similar chunks to retrieve
67    pub top_k: usize,
68    /// Similarity threshold (0.0 to 1.0)
69    pub similarity_threshold: f32,
70    /// Enable hybrid search (combines semantic and keyword search)
71    pub hybrid_search: bool,
72    /// Weight for semantic search in hybrid mode (0.0 to 1.0)
73    pub semantic_weight: f32,
74    /// Weight for keyword search in hybrid mode (0.0 to 1.0)
75    pub keyword_weight: f32,
76    /// Enable query expansion
77    pub query_expansion: bool,
78    /// Enable response filtering
79    pub response_filtering: bool,
80    /// Enable caching
81    pub caching: bool,
82    /// Cache TTL in seconds
83    pub cache_ttl_secs: u64,
84    /// Rate limiting configuration
85    pub rate_limiting: RateLimitConfig,
86    /// Retry configuration
87    pub retry_config: RetryConfig,
88    /// Custom headers for API requests
89    pub custom_headers: HashMap<String, String>,
90    /// Enable debug mode
91    pub debug_mode: bool,
92    /// Maximum context length
93    pub max_context_length: usize,
94    /// Response format preferences
95    pub response_format: ResponseFormat,
96    /// Logging configuration
97    pub logging: LoggingConfig,
98    /// Performance monitoring
99    pub monitoring: MonitoringConfig,
100}
101
102impl Default for RagConfig {
103    fn default() -> Self {
104        Self {
105            provider: LlmProvider::OpenAI,
106            api_endpoint: "https://api.openai.com/v1".to_string(),
107            api_key: None,
108            model: "gpt-3.5-turbo".to_string(),
109            max_tokens: 1024,
110            temperature: 0.7,
111            top_p: 0.9,
112            timeout_secs: 30,
113            max_retries: 3,
114            embedding_provider: EmbeddingProvider::OpenAI,
115            embedding_model: "text-embedding-ada-002".to_string(),
116            embedding_dimensions: 1536,
117            chunk_size: 1000,
118            chunk_overlap: 200,
119            top_k: 5,
120            similarity_threshold: 0.7,
121            hybrid_search: true,
122            semantic_weight: 0.7,
123            keyword_weight: 0.3,
124            query_expansion: false,
125            response_filtering: true,
126            caching: true,
127            cache_ttl_secs: 3600,
128            rate_limiting: RateLimitConfig::default(),
129            retry_config: RetryConfig::default(),
130            custom_headers: HashMap::new(),
131            debug_mode: false,
132            max_context_length: 4096,
133            response_format: ResponseFormat::Json,
134            logging: LoggingConfig::default(),
135            monitoring: MonitoringConfig::default(),
136        }
137    }
138}
139
140impl RagConfig {
141    /// Create a new RAG configuration
142    pub fn new(provider: LlmProvider, model: String) -> Self {
143        Self {
144            provider,
145            model,
146            ..Default::default()
147        }
148    }
149
150    /// Set API key
151    pub fn with_api_key(mut self, api_key: String) -> Self {
152        self.api_key = Some(api_key);
153        self
154    }
155
156    /// Set API endpoint
157    pub fn with_endpoint(mut self, endpoint: String) -> Self {
158        self.api_endpoint = endpoint;
159        self
160    }
161
162    /// Set model parameters
163    pub fn with_model_params(mut self, max_tokens: usize, temperature: f32, top_p: f32) -> Self {
164        self.max_tokens = max_tokens;
165        self.temperature = temperature;
166        self.top_p = top_p;
167        self
168    }
169
170    /// Set embedding configuration
171    pub fn with_embedding(
172        mut self,
173        provider: EmbeddingProvider,
174        model: String,
175        dimensions: usize,
176    ) -> Self {
177        self.embedding_provider = provider;
178        self.embedding_model = model;
179        self.embedding_dimensions = dimensions;
180        self
181    }
182
183    /// Set chunking parameters
184    pub fn with_chunking(mut self, chunk_size: usize, chunk_overlap: usize) -> Self {
185        self.chunk_size = chunk_size;
186        self.chunk_overlap = chunk_overlap;
187        self
188    }
189
190    /// Set retrieval parameters
191    pub fn with_retrieval(mut self, top_k: usize, similarity_threshold: f32) -> Self {
192        self.top_k = top_k;
193        self.similarity_threshold = similarity_threshold;
194        self
195    }
196
197    /// Enable hybrid search
198    pub fn with_hybrid_search(mut self, semantic_weight: f32, keyword_weight: f32) -> Self {
199        self.hybrid_search = true;
200        self.semantic_weight = semantic_weight;
201        self.keyword_weight = keyword_weight;
202        self
203    }
204
205    /// Set caching configuration
206    pub fn with_caching(mut self, enabled: bool, ttl_secs: u64) -> Self {
207        self.caching = enabled;
208        self.cache_ttl_secs = ttl_secs;
209        self
210    }
211
212    /// Set rate limiting
213    pub fn with_rate_limit(mut self, requests_per_minute: u32, burst_size: u32) -> Self {
214        self.rate_limiting = RateLimitConfig {
215            requests_per_minute,
216            burst_size,
217            enabled: true,
218        };
219        self
220    }
221
222    /// Set retry configuration
223    pub fn with_retry(mut self, max_attempts: u32, backoff_secs: u64) -> Self {
224        self.retry_config = RetryConfig {
225            max_attempts,
226            backoff_secs,
227            exponential_backoff: true,
228        };
229        self
230    }
231
232    /// Add custom header
233    pub fn with_header(mut self, key: String, value: String) -> Self {
234        self.custom_headers.insert(key, value);
235        self
236    }
237
238    /// Enable debug mode
239    pub fn with_debug_mode(mut self, debug: bool) -> Self {
240        self.debug_mode = debug;
241        self
242    }
243
244    /// Validate configuration
245    pub fn validate(&self) -> Result<()> {
246        if self.api_endpoint.is_empty() {
247            return Err(mockforge_core::Error::generic("API endpoint cannot be empty"));
248        }
249
250        if self.model.is_empty() {
251            return Err(mockforge_core::Error::generic("Model name cannot be empty"));
252        }
253
254        if !(0.0..=2.0).contains(&self.temperature) {
255            return Err(mockforge_core::Error::generic("Temperature must be between 0.0 and 2.0"));
256        }
257
258        if !(0.0..=1.0).contains(&self.top_p) {
259            return Err(mockforge_core::Error::generic("Top-p must be between 0.0 and 1.0"));
260        }
261
262        if self.chunk_size == 0 {
263            return Err(mockforge_core::Error::generic("Chunk size must be greater than 0"));
264        }
265
266        if self.chunk_overlap >= self.chunk_size {
267            return Err(mockforge_core::Error::generic(
268                "Chunk overlap must be less than chunk size",
269            ));
270        }
271
272        if !(0.0..=1.0).contains(&self.similarity_threshold) {
273            return Err(mockforge_core::Error::generic(
274                "Similarity threshold must be between 0.0 and 1.0",
275            ));
276        }
277
278        if self.hybrid_search {
279            let total_weight = self.semantic_weight + self.keyword_weight;
280            if (total_weight - 1.0).abs() > f32::EPSILON {
281                return Err(mockforge_core::Error::generic(
282                    "Hybrid search weights must sum to 1.0",
283                ));
284            }
285        }
286
287        Ok(())
288    }
289
290    /// Get timeout duration
291    pub fn timeout_duration(&self) -> Duration {
292        Duration::from_secs(self.timeout_secs)
293    }
294
295    /// Get cache TTL duration
296    pub fn cache_ttl_duration(&self) -> Duration {
297        Duration::from_secs(self.cache_ttl_secs)
298    }
299
300    /// Check if caching is enabled
301    pub fn is_caching_enabled(&self) -> bool {
302        self.caching
303    }
304
305    /// Check if rate limiting is enabled
306    pub fn is_rate_limited(&self) -> bool {
307        self.rate_limiting.enabled
308    }
309
310    /// Get requests per minute limit
311    pub fn requests_per_minute(&self) -> u32 {
312        self.rate_limiting.requests_per_minute
313    }
314
315    /// Get burst size for rate limiting
316    pub fn burst_size(&self) -> u32 {
317        self.rate_limiting.burst_size
318    }
319
320    /// Get maximum retry attempts
321    pub fn max_retry_attempts(&self) -> u32 {
322        self.retry_config.max_attempts
323    }
324
325    /// Get backoff duration for retries
326    pub fn backoff_duration(&self) -> Duration {
327        Duration::from_secs(self.retry_config.backoff_secs)
328    }
329
330    /// Check if exponential backoff is enabled
331    pub fn is_exponential_backoff(&self) -> bool {
332        self.retry_config.exponential_backoff
333    }
334
335    /// Get response format
336    pub fn response_format(&self) -> &ResponseFormat {
337        &self.response_format
338    }
339
340    /// Get logging configuration
341    pub fn logging_config(&self) -> &LoggingConfig {
342        &self.logging
343    }
344
345    /// Get monitoring configuration
346    pub fn monitoring_config(&self) -> &MonitoringConfig {
347        &self.monitoring
348    }
349}
350
351/// Rate limiting configuration
352#[derive(Debug, Clone, Serialize, Deserialize)]
353pub struct RateLimitConfig {
354    /// Number of requests allowed per minute
355    pub requests_per_minute: u32,
356    /// Burst size for rate limiting
357    pub burst_size: u32,
358    /// Whether rate limiting is enabled
359    pub enabled: bool,
360}
361
362impl Default for RateLimitConfig {
363    fn default() -> Self {
364        Self {
365            requests_per_minute: 60,
366            burst_size: 10,
367            enabled: true,
368        }
369    }
370}
371
372/// Retry configuration
373#[derive(Debug, Clone, Serialize, Deserialize)]
374pub struct RetryConfig {
375    /// Maximum number of retry attempts
376    pub max_attempts: u32,
377    /// Base backoff time in seconds
378    pub backoff_secs: u64,
379    /// Whether to use exponential backoff
380    pub exponential_backoff: bool,
381}
382
383impl Default for RetryConfig {
384    fn default() -> Self {
385        Self {
386            max_attempts: 3,
387            backoff_secs: 1,
388            exponential_backoff: true,
389        }
390    }
391}
392
393/// Response format preferences
394#[derive(Debug, Clone, Serialize, Deserialize)]
395pub enum ResponseFormat {
396    /// Plain text response
397    Text,
398    /// JSON structured response
399    Json,
400    /// Markdown formatted response
401    Markdown,
402    /// Custom format with template
403    Custom(String),
404}
405
406impl Default for ResponseFormat {
407    fn default() -> Self {
408        Self::Json
409    }
410}
411
412/// Logging configuration
413#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct LoggingConfig {
415    /// Log level for RAG operations
416    pub log_level: String,
417    /// Enable request/response logging
418    pub log_requests: bool,
419    /// Enable performance logging
420    pub log_performance: bool,
421    /// Log file path (if any)
422    pub log_file: Option<String>,
423    /// Maximum log file size in MB
424    pub max_log_size_mb: u64,
425}
426
427impl Default for LoggingConfig {
428    fn default() -> Self {
429        Self {
430            log_level: "info".to_string(),
431            log_requests: false,
432            log_performance: true,
433            log_file: None,
434            max_log_size_mb: 100,
435        }
436    }
437}
438
439/// Performance monitoring configuration
440#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct MonitoringConfig {
442    /// Enable metrics collection
443    pub enable_metrics: bool,
444    /// Metrics collection interval in seconds
445    pub metrics_interval_secs: u64,
446    /// Enable tracing
447    pub enable_tracing: bool,
448    /// Tracing sample rate (0.0 to 1.0)
449    pub trace_sample_rate: f32,
450    /// Performance thresholds
451    pub thresholds: PerformanceThresholds,
452}
453
454impl Default for MonitoringConfig {
455    fn default() -> Self {
456        Self {
457            enable_metrics: true,
458            metrics_interval_secs: 60,
459            enable_tracing: false,
460            trace_sample_rate: 0.1,
461            thresholds: PerformanceThresholds::default(),
462        }
463    }
464}
465
466/// Performance thresholds for monitoring
467#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct PerformanceThresholds {
469    /// Maximum response time in seconds
470    pub max_response_time_secs: f64,
471    /// Minimum similarity score threshold
472    pub min_similarity_score: f32,
473    /// Maximum memory usage in MB
474    pub max_memory_usage_mb: u64,
475    /// Maximum CPU usage percentage
476    pub max_cpu_usage_percent: f32,
477}
478
479impl Default for PerformanceThresholds {
480    fn default() -> Self {
481        Self {
482            max_response_time_secs: 30.0,
483            min_similarity_score: 0.7,
484            max_memory_usage_mb: 1024,
485            max_cpu_usage_percent: 80.0,
486        }
487    }
488}
489
490/// Configuration builder for RAG
491#[derive(Debug)]
492pub struct RagConfigBuilder {
493    config: RagConfig,
494}
495
496impl RagConfigBuilder {
497    /// Create a new configuration builder
498    pub fn new() -> Self {
499        Self {
500            config: RagConfig::default(),
501        }
502    }
503
504    /// Build the configuration
505    pub fn build(self) -> Result<RagConfig> {
506        self.config.validate()?;
507        Ok(self.config)
508    }
509
510    /// Set the LLM provider
511    pub fn provider(mut self, provider: LlmProvider) -> Self {
512        self.config.provider = provider;
513        self
514    }
515
516    /// Set the model name
517    pub fn model(mut self, model: String) -> Self {
518        self.config.model = model;
519        self
520    }
521
522    /// Set the API key
523    pub fn api_key(mut self, api_key: String) -> Self {
524        self.config.api_key = Some(api_key);
525        self
526    }
527
528    /// Set the API endpoint
529    pub fn endpoint(mut self, endpoint: String) -> Self {
530        self.config.api_endpoint = endpoint;
531        self
532    }
533
534    /// Set model parameters
535    pub fn model_params(mut self, max_tokens: usize, temperature: f32) -> Self {
536        self.config.max_tokens = max_tokens;
537        self.config.temperature = temperature;
538        self
539    }
540
541    /// Set embedding configuration
542    pub fn embedding(mut self, model: String, dimensions: usize) -> Self {
543        self.config.embedding_model = model;
544        self.config.embedding_dimensions = dimensions;
545        self
546    }
547
548    /// Set chunking parameters
549    pub fn chunking(mut self, size: usize, overlap: usize) -> Self {
550        self.config.chunk_size = size;
551        self.config.chunk_overlap = overlap;
552        self
553    }
554
555    /// Set retrieval parameters
556    pub fn retrieval(mut self, top_k: usize, threshold: f32) -> Self {
557        self.config.top_k = top_k;
558        self.config.similarity_threshold = threshold;
559        self
560    }
561
562    /// Enable hybrid search
563    pub fn hybrid_search(mut self, semantic_weight: f32) -> Self {
564        self.config.hybrid_search = true;
565        self.config.semantic_weight = semantic_weight;
566        self.config.keyword_weight = 1.0 - semantic_weight;
567        self
568    }
569
570    /// Enable caching
571    pub fn caching(mut self, enabled: bool) -> Self {
572        self.config.caching = enabled;
573        self
574    }
575
576    /// Set rate limiting
577    pub fn rate_limit(mut self, requests_per_minute: u32) -> Self {
578        self.config.rate_limiting = RateLimitConfig {
579            requests_per_minute,
580            burst_size: requests_per_minute / 6, // 10-second burst
581            enabled: true,
582        };
583        self
584    }
585
586    /// Enable debug mode
587    pub fn debug(mut self, debug: bool) -> Self {
588        self.config.debug_mode = debug;
589        self
590    }
591}
592
593impl Default for RagConfigBuilder {
594    fn default() -> Self {
595        Self::new()
596    }
597}
598
599#[cfg(test)]
600mod tests {
601
602    #[test]
603    fn test_module_compiles() {
604        // Basic compilation test
605    }
606}