1use crate::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13#[serde(rename_all = "lowercase")]
14pub enum LlmProvider {
15 OpenAI,
17 Anthropic,
19 OpenAICompatible,
21 Ollama,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
27#[serde(rename_all = "lowercase")]
28pub enum EmbeddingProvider {
29 OpenAI,
31 OpenAICompatible,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct RagConfig {
38 pub provider: LlmProvider,
40 pub api_endpoint: String,
42 pub api_key: Option<String>,
44 pub model: String,
46 pub max_tokens: usize,
48 pub temperature: f32,
50 pub top_p: f32,
52 pub timeout_secs: u64,
54 pub max_retries: u32,
56 pub embedding_provider: EmbeddingProvider,
58 pub embedding_model: String,
60 pub embedding_dimensions: usize,
62 pub chunk_size: usize,
64 pub chunk_overlap: usize,
66 pub top_k: usize,
68 pub similarity_threshold: f32,
70 pub hybrid_search: bool,
72 pub semantic_weight: f32,
74 pub keyword_weight: f32,
76 pub query_expansion: bool,
78 pub response_filtering: bool,
80 pub caching: bool,
82 pub cache_ttl_secs: u64,
84 pub rate_limiting: RateLimitConfig,
86 pub retry_config: RetryConfig,
88 pub custom_headers: HashMap<String, String>,
90 pub debug_mode: bool,
92 pub max_context_length: usize,
94 pub response_format: ResponseFormat,
96 pub logging: LoggingConfig,
98 pub monitoring: MonitoringConfig,
100}
101
102impl Default for RagConfig {
103 fn default() -> Self {
104 Self {
105 provider: LlmProvider::OpenAI,
106 api_endpoint: "https://api.openai.com/v1".to_string(),
107 api_key: None,
108 model: "gpt-3.5-turbo".to_string(),
109 max_tokens: 1024,
110 temperature: 0.7,
111 top_p: 0.9,
112 timeout_secs: 30,
113 max_retries: 3,
114 embedding_provider: EmbeddingProvider::OpenAI,
115 embedding_model: "text-embedding-ada-002".to_string(),
116 embedding_dimensions: 1536,
117 chunk_size: 1000,
118 chunk_overlap: 200,
119 top_k: 5,
120 similarity_threshold: 0.7,
121 hybrid_search: true,
122 semantic_weight: 0.7,
123 keyword_weight: 0.3,
124 query_expansion: false,
125 response_filtering: true,
126 caching: true,
127 cache_ttl_secs: 3600,
128 rate_limiting: RateLimitConfig::default(),
129 retry_config: RetryConfig::default(),
130 custom_headers: HashMap::new(),
131 debug_mode: false,
132 max_context_length: 4096,
133 response_format: ResponseFormat::Json,
134 logging: LoggingConfig::default(),
135 monitoring: MonitoringConfig::default(),
136 }
137 }
138}
139
140impl RagConfig {
141 pub fn new(provider: LlmProvider, model: String) -> Self {
143 Self {
144 provider,
145 model,
146 ..Default::default()
147 }
148 }
149
150 pub fn with_api_key(mut self, api_key: String) -> Self {
152 self.api_key = Some(api_key);
153 self
154 }
155
156 pub fn with_endpoint(mut self, endpoint: String) -> Self {
158 self.api_endpoint = endpoint;
159 self
160 }
161
162 pub fn with_model_params(mut self, max_tokens: usize, temperature: f32, top_p: f32) -> Self {
164 self.max_tokens = max_tokens;
165 self.temperature = temperature;
166 self.top_p = top_p;
167 self
168 }
169
170 pub fn with_embedding(
172 mut self,
173 provider: EmbeddingProvider,
174 model: String,
175 dimensions: usize,
176 ) -> Self {
177 self.embedding_provider = provider;
178 self.embedding_model = model;
179 self.embedding_dimensions = dimensions;
180 self
181 }
182
183 pub fn with_chunking(mut self, chunk_size: usize, chunk_overlap: usize) -> Self {
185 self.chunk_size = chunk_size;
186 self.chunk_overlap = chunk_overlap;
187 self
188 }
189
190 pub fn with_retrieval(mut self, top_k: usize, similarity_threshold: f32) -> Self {
192 self.top_k = top_k;
193 self.similarity_threshold = similarity_threshold;
194 self
195 }
196
197 pub fn with_hybrid_search(mut self, semantic_weight: f32, keyword_weight: f32) -> Self {
199 self.hybrid_search = true;
200 self.semantic_weight = semantic_weight;
201 self.keyword_weight = keyword_weight;
202 self
203 }
204
205 pub fn with_caching(mut self, enabled: bool, ttl_secs: u64) -> Self {
207 self.caching = enabled;
208 self.cache_ttl_secs = ttl_secs;
209 self
210 }
211
212 pub fn with_rate_limit(mut self, requests_per_minute: u32, burst_size: u32) -> Self {
214 self.rate_limiting = RateLimitConfig {
215 requests_per_minute,
216 burst_size,
217 enabled: true,
218 };
219 self
220 }
221
222 pub fn with_retry(mut self, max_attempts: u32, backoff_secs: u64) -> Self {
224 self.retry_config = RetryConfig {
225 max_attempts,
226 backoff_secs,
227 exponential_backoff: true,
228 };
229 self
230 }
231
232 pub fn with_header(mut self, key: String, value: String) -> Self {
234 self.custom_headers.insert(key, value);
235 self
236 }
237
238 pub fn with_debug_mode(mut self, debug: bool) -> Self {
240 self.debug_mode = debug;
241 self
242 }
243
244 pub fn validate(&self) -> Result<()> {
246 if self.api_endpoint.is_empty() {
247 return Err(mockforge_core::Error::generic("API endpoint cannot be empty"));
248 }
249
250 if self.model.is_empty() {
251 return Err(mockforge_core::Error::generic("Model name cannot be empty"));
252 }
253
254 if !(0.0..=2.0).contains(&self.temperature) {
255 return Err(mockforge_core::Error::generic("Temperature must be between 0.0 and 2.0"));
256 }
257
258 if !(0.0..=1.0).contains(&self.top_p) {
259 return Err(mockforge_core::Error::generic("Top-p must be between 0.0 and 1.0"));
260 }
261
262 if self.chunk_size == 0 {
263 return Err(mockforge_core::Error::generic("Chunk size must be greater than 0"));
264 }
265
266 if self.chunk_overlap >= self.chunk_size {
267 return Err(mockforge_core::Error::generic(
268 "Chunk overlap must be less than chunk size",
269 ));
270 }
271
272 if !(0.0..=1.0).contains(&self.similarity_threshold) {
273 return Err(mockforge_core::Error::generic(
274 "Similarity threshold must be between 0.0 and 1.0",
275 ));
276 }
277
278 if self.hybrid_search {
279 let total_weight = self.semantic_weight + self.keyword_weight;
280 if (total_weight - 1.0).abs() > f32::EPSILON {
281 return Err(mockforge_core::Error::generic(
282 "Hybrid search weights must sum to 1.0",
283 ));
284 }
285 }
286
287 Ok(())
288 }
289
290 pub fn timeout_duration(&self) -> Duration {
292 Duration::from_secs(self.timeout_secs)
293 }
294
295 pub fn cache_ttl_duration(&self) -> Duration {
297 Duration::from_secs(self.cache_ttl_secs)
298 }
299
300 pub fn is_caching_enabled(&self) -> bool {
302 self.caching
303 }
304
305 pub fn is_rate_limited(&self) -> bool {
307 self.rate_limiting.enabled
308 }
309
310 pub fn requests_per_minute(&self) -> u32 {
312 self.rate_limiting.requests_per_minute
313 }
314
315 pub fn burst_size(&self) -> u32 {
317 self.rate_limiting.burst_size
318 }
319
320 pub fn max_retry_attempts(&self) -> u32 {
322 self.retry_config.max_attempts
323 }
324
325 pub fn backoff_duration(&self) -> Duration {
327 Duration::from_secs(self.retry_config.backoff_secs)
328 }
329
330 pub fn is_exponential_backoff(&self) -> bool {
332 self.retry_config.exponential_backoff
333 }
334
335 pub fn response_format(&self) -> &ResponseFormat {
337 &self.response_format
338 }
339
340 pub fn logging_config(&self) -> &LoggingConfig {
342 &self.logging
343 }
344
345 pub fn monitoring_config(&self) -> &MonitoringConfig {
347 &self.monitoring
348 }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
353pub struct RateLimitConfig {
354 pub requests_per_minute: u32,
356 pub burst_size: u32,
358 pub enabled: bool,
360}
361
362impl Default for RateLimitConfig {
363 fn default() -> Self {
364 Self {
365 requests_per_minute: 60,
366 burst_size: 10,
367 enabled: true,
368 }
369 }
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
374pub struct RetryConfig {
375 pub max_attempts: u32,
377 pub backoff_secs: u64,
379 pub exponential_backoff: bool,
381}
382
383impl Default for RetryConfig {
384 fn default() -> Self {
385 Self {
386 max_attempts: 3,
387 backoff_secs: 1,
388 exponential_backoff: true,
389 }
390 }
391}
392
393#[derive(Debug, Clone, Serialize, Deserialize)]
395pub enum ResponseFormat {
396 Text,
398 Json,
400 Markdown,
402 Custom(String),
404}
405
406impl Default for ResponseFormat {
407 fn default() -> Self {
408 Self::Json
409 }
410}
411
412#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct LoggingConfig {
415 pub log_level: String,
417 pub log_requests: bool,
419 pub log_performance: bool,
421 pub log_file: Option<String>,
423 pub max_log_size_mb: u64,
425}
426
427impl Default for LoggingConfig {
428 fn default() -> Self {
429 Self {
430 log_level: "info".to_string(),
431 log_requests: false,
432 log_performance: true,
433 log_file: None,
434 max_log_size_mb: 100,
435 }
436 }
437}
438
439#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct MonitoringConfig {
442 pub enable_metrics: bool,
444 pub metrics_interval_secs: u64,
446 pub enable_tracing: bool,
448 pub trace_sample_rate: f32,
450 pub thresholds: PerformanceThresholds,
452}
453
454impl Default for MonitoringConfig {
455 fn default() -> Self {
456 Self {
457 enable_metrics: true,
458 metrics_interval_secs: 60,
459 enable_tracing: false,
460 trace_sample_rate: 0.1,
461 thresholds: PerformanceThresholds::default(),
462 }
463 }
464}
465
466#[derive(Debug, Clone, Serialize, Deserialize)]
468pub struct PerformanceThresholds {
469 pub max_response_time_secs: f64,
471 pub min_similarity_score: f32,
473 pub max_memory_usage_mb: u64,
475 pub max_cpu_usage_percent: f32,
477}
478
479impl Default for PerformanceThresholds {
480 fn default() -> Self {
481 Self {
482 max_response_time_secs: 30.0,
483 min_similarity_score: 0.7,
484 max_memory_usage_mb: 1024,
485 max_cpu_usage_percent: 80.0,
486 }
487 }
488}
489
490#[derive(Debug)]
492pub struct RagConfigBuilder {
493 config: RagConfig,
494}
495
496impl RagConfigBuilder {
497 pub fn new() -> Self {
499 Self {
500 config: RagConfig::default(),
501 }
502 }
503
504 pub fn build(self) -> Result<RagConfig> {
506 self.config.validate()?;
507 Ok(self.config)
508 }
509
510 pub fn provider(mut self, provider: LlmProvider) -> Self {
512 self.config.provider = provider;
513 self
514 }
515
516 pub fn model(mut self, model: String) -> Self {
518 self.config.model = model;
519 self
520 }
521
522 pub fn api_key(mut self, api_key: String) -> Self {
524 self.config.api_key = Some(api_key);
525 self
526 }
527
528 pub fn endpoint(mut self, endpoint: String) -> Self {
530 self.config.api_endpoint = endpoint;
531 self
532 }
533
534 pub fn model_params(mut self, max_tokens: usize, temperature: f32) -> Self {
536 self.config.max_tokens = max_tokens;
537 self.config.temperature = temperature;
538 self
539 }
540
541 pub fn embedding(mut self, model: String, dimensions: usize) -> Self {
543 self.config.embedding_model = model;
544 self.config.embedding_dimensions = dimensions;
545 self
546 }
547
548 pub fn chunking(mut self, size: usize, overlap: usize) -> Self {
550 self.config.chunk_size = size;
551 self.config.chunk_overlap = overlap;
552 self
553 }
554
555 pub fn retrieval(mut self, top_k: usize, threshold: f32) -> Self {
557 self.config.top_k = top_k;
558 self.config.similarity_threshold = threshold;
559 self
560 }
561
562 pub fn hybrid_search(mut self, semantic_weight: f32) -> Self {
564 self.config.hybrid_search = true;
565 self.config.semantic_weight = semantic_weight;
566 self.config.keyword_weight = 1.0 - semantic_weight;
567 self
568 }
569
570 pub fn caching(mut self, enabled: bool) -> Self {
572 self.config.caching = enabled;
573 self
574 }
575
576 pub fn rate_limit(mut self, requests_per_minute: u32) -> Self {
578 self.config.rate_limiting = RateLimitConfig {
579 requests_per_minute,
580 burst_size: requests_per_minute / 6, enabled: true,
582 };
583 self
584 }
585
586 pub fn debug(mut self, debug: bool) -> Self {
588 self.config.debug_mode = debug;
589 self
590 }
591}
592
593impl Default for RagConfigBuilder {
594 fn default() -> Self {
595 Self::new()
596 }
597}
598
599#[cfg(test)]
600mod tests {
601
602 #[test]
603 fn test_module_compiles() {
604 }
606}