use crate::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum LlmProvider {
OpenAI,
Anthropic,
OpenAICompatible,
Ollama,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingProvider {
OpenAI,
OpenAICompatible,
Ollama,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagConfig {
pub provider: LlmProvider,
pub api_endpoint: String,
pub api_key: Option<String>,
pub model: String,
pub max_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub timeout_secs: u64,
pub max_retries: u32,
pub embedding_provider: EmbeddingProvider,
pub embedding_model: String,
pub embedding_dimensions: usize,
pub chunk_size: usize,
pub chunk_overlap: usize,
pub top_k: usize,
pub similarity_threshold: f32,
pub hybrid_search: bool,
pub semantic_weight: f32,
pub keyword_weight: f32,
pub query_expansion: bool,
pub response_filtering: bool,
pub caching: bool,
pub cache_ttl_secs: u64,
pub rate_limiting: RateLimitConfig,
pub retry_config: RetryConfig,
pub custom_headers: HashMap<String, String>,
pub debug_mode: bool,
pub max_context_length: usize,
pub response_format: ResponseFormat,
pub logging: LoggingConfig,
pub monitoring: MonitoringConfig,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
provider: LlmProvider::OpenAI,
api_endpoint: "https://api.openai.com/v1".to_string(),
api_key: None,
model: "gpt-3.5-turbo".to_string(),
max_tokens: 1024,
temperature: 0.7,
top_p: 0.9,
timeout_secs: 30,
max_retries: 3,
embedding_provider: EmbeddingProvider::OpenAI,
embedding_model: "text-embedding-ada-002".to_string(),
embedding_dimensions: 1536,
chunk_size: 1000,
chunk_overlap: 200,
top_k: 5,
similarity_threshold: 0.7,
hybrid_search: true,
semantic_weight: 0.7,
keyword_weight: 0.3,
query_expansion: false,
response_filtering: true,
caching: true,
cache_ttl_secs: 3600,
rate_limiting: RateLimitConfig::default(),
retry_config: RetryConfig::default(),
custom_headers: HashMap::new(),
debug_mode: false,
max_context_length: 4096,
response_format: ResponseFormat::Json,
logging: LoggingConfig::default(),
monitoring: MonitoringConfig::default(),
}
}
}
impl RagConfig {
pub fn new(provider: LlmProvider, model: String) -> Self {
Self {
provider,
model,
..Default::default()
}
}
pub fn with_api_key(mut self, api_key: String) -> Self {
self.api_key = Some(api_key);
self
}
pub fn with_endpoint(mut self, endpoint: String) -> Self {
self.api_endpoint = endpoint;
self
}
pub fn with_model_params(mut self, max_tokens: usize, temperature: f32, top_p: f32) -> Self {
self.max_tokens = max_tokens;
self.temperature = temperature;
self.top_p = top_p;
self
}
pub fn with_embedding(
mut self,
provider: EmbeddingProvider,
model: String,
dimensions: usize,
) -> Self {
self.embedding_provider = provider;
self.embedding_model = model;
self.embedding_dimensions = dimensions;
self
}
pub fn with_chunking(mut self, chunk_size: usize, chunk_overlap: usize) -> Self {
self.chunk_size = chunk_size;
self.chunk_overlap = chunk_overlap;
self
}
pub fn with_retrieval(mut self, top_k: usize, similarity_threshold: f32) -> Self {
self.top_k = top_k;
self.similarity_threshold = similarity_threshold;
self
}
pub fn with_hybrid_search(mut self, semantic_weight: f32, keyword_weight: f32) -> Self {
self.hybrid_search = true;
self.semantic_weight = semantic_weight;
self.keyword_weight = keyword_weight;
self
}
pub fn with_caching(mut self, enabled: bool, ttl_secs: u64) -> Self {
self.caching = enabled;
self.cache_ttl_secs = ttl_secs;
self
}
pub fn with_rate_limit(mut self, requests_per_minute: u32, burst_size: u32) -> Self {
self.rate_limiting = RateLimitConfig {
requests_per_minute,
burst_size,
enabled: true,
};
self
}
pub fn with_retry(mut self, max_attempts: u32, backoff_secs: u64) -> Self {
self.retry_config = RetryConfig {
max_attempts,
backoff_secs,
exponential_backoff: true,
};
self
}
pub fn with_header(mut self, key: String, value: String) -> Self {
self.custom_headers.insert(key, value);
self
}
pub fn with_debug_mode(mut self, debug: bool) -> Self {
self.debug_mode = debug;
self
}
pub fn validate(&self) -> Result<()> {
if self.api_endpoint.is_empty() {
return Err(crate::Error::generic("API endpoint cannot be empty"));
}
if self.model.is_empty() {
return Err(crate::Error::generic("Model name cannot be empty"));
}
if !(0.0..=2.0).contains(&self.temperature) {
return Err(crate::Error::generic("Temperature must be between 0.0 and 2.0"));
}
if !(0.0..=1.0).contains(&self.top_p) {
return Err(crate::Error::generic("Top-p must be between 0.0 and 1.0"));
}
if self.chunk_size == 0 {
return Err(crate::Error::generic("Chunk size must be greater than 0"));
}
if self.chunk_overlap >= self.chunk_size {
return Err(crate::Error::generic("Chunk overlap must be less than chunk size"));
}
if !(0.0..=1.0).contains(&self.similarity_threshold) {
return Err(crate::Error::generic("Similarity threshold must be between 0.0 and 1.0"));
}
if self.hybrid_search {
let total_weight = self.semantic_weight + self.keyword_weight;
if (total_weight - 1.0).abs() > f32::EPSILON {
return Err(crate::Error::generic("Hybrid search weights must sum to 1.0"));
}
}
Ok(())
}
pub fn timeout_duration(&self) -> Duration {
Duration::from_secs(self.timeout_secs)
}
pub fn cache_ttl_duration(&self) -> Duration {
Duration::from_secs(self.cache_ttl_secs)
}
pub fn is_caching_enabled(&self) -> bool {
self.caching
}
pub fn is_rate_limited(&self) -> bool {
self.rate_limiting.enabled
}
pub fn requests_per_minute(&self) -> u32 {
self.rate_limiting.requests_per_minute
}
pub fn burst_size(&self) -> u32 {
self.rate_limiting.burst_size
}
pub fn max_retry_attempts(&self) -> u32 {
self.retry_config.max_attempts
}
pub fn backoff_duration(&self) -> Duration {
Duration::from_secs(self.retry_config.backoff_secs)
}
pub fn is_exponential_backoff(&self) -> bool {
self.retry_config.exponential_backoff
}
pub fn response_format(&self) -> &ResponseFormat {
&self.response_format
}
pub fn logging_config(&self) -> &LoggingConfig {
&self.logging
}
pub fn monitoring_config(&self) -> &MonitoringConfig {
&self.monitoring
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub burst_size: u32,
pub enabled: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: 60,
burst_size: 10,
enabled: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_attempts: u32,
pub backoff_secs: u64,
pub exponential_backoff: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
backoff_secs: 1,
exponential_backoff: true,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub enum ResponseFormat {
Text,
#[default]
Json,
Markdown,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoggingConfig {
pub log_level: String,
pub log_requests: bool,
pub log_performance: bool,
pub log_file: Option<String>,
pub max_log_size_mb: u64,
}
impl Default for LoggingConfig {
fn default() -> Self {
Self {
log_level: "info".to_string(),
log_requests: false,
log_performance: true,
log_file: None,
max_log_size_mb: 100,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringConfig {
pub enable_metrics: bool,
pub metrics_interval_secs: u64,
pub enable_tracing: bool,
pub trace_sample_rate: f32,
pub thresholds: PerformanceThresholds,
}
impl Default for MonitoringConfig {
fn default() -> Self {
Self {
enable_metrics: true,
metrics_interval_secs: 60,
enable_tracing: false,
trace_sample_rate: 0.1,
thresholds: PerformanceThresholds::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceThresholds {
pub max_response_time_secs: f64,
pub min_similarity_score: f32,
pub max_memory_usage_mb: u64,
pub max_cpu_usage_percent: f32,
}
impl Default for PerformanceThresholds {
fn default() -> Self {
Self {
max_response_time_secs: 30.0,
min_similarity_score: 0.7,
max_memory_usage_mb: 1024,
max_cpu_usage_percent: 80.0,
}
}
}
#[derive(Debug)]
pub struct RagConfigBuilder {
config: RagConfig,
}
impl RagConfigBuilder {
pub fn new() -> Self {
Self {
config: RagConfig::default(),
}
}
pub fn build(self) -> Result<RagConfig> {
self.config.validate()?;
Ok(self.config)
}
pub fn provider(mut self, provider: LlmProvider) -> Self {
self.config.provider = provider;
self
}
pub fn model(mut self, model: String) -> Self {
self.config.model = model;
self
}
pub fn api_key(mut self, api_key: String) -> Self {
self.config.api_key = Some(api_key);
self
}
pub fn endpoint(mut self, endpoint: String) -> Self {
self.config.api_endpoint = endpoint;
self
}
pub fn model_params(mut self, max_tokens: usize, temperature: f32) -> Self {
self.config.max_tokens = max_tokens;
self.config.temperature = temperature;
self
}
pub fn embedding(mut self, model: String, dimensions: usize) -> Self {
self.config.embedding_model = model;
self.config.embedding_dimensions = dimensions;
self
}
pub fn chunking(mut self, size: usize, overlap: usize) -> Self {
self.config.chunk_size = size;
self.config.chunk_overlap = overlap;
self
}
pub fn retrieval(mut self, top_k: usize, threshold: f32) -> Self {
self.config.top_k = top_k;
self.config.similarity_threshold = threshold;
self
}
pub fn hybrid_search(mut self, semantic_weight: f32) -> Self {
self.config.hybrid_search = true;
self.config.semantic_weight = semantic_weight;
self.config.keyword_weight = 1.0 - semantic_weight;
self
}
pub fn caching(mut self, enabled: bool) -> Self {
self.config.caching = enabled;
self
}
pub fn rate_limit(mut self, requests_per_minute: u32) -> Self {
self.config.rate_limiting = RateLimitConfig {
requests_per_minute,
burst_size: requests_per_minute / 6, enabled: true,
};
self
}
pub fn debug(mut self, debug: bool) -> Self {
self.config.debug_mode = debug;
self
}
}
impl Default for RagConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_module_compiles() {
}
}