project_rag/config/
mod.rs

1/// Configuration system for project-rag
2///
3/// Supports loading from multiple sources with priority:
4/// CLI args > Environment variables > Config file > Defaults
5use crate::error::{ConfigError, RagError};
6use serde::{Deserialize, Serialize};
7use std::path::{Path, PathBuf};
8
9/// Main configuration structure
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct Config {
12    /// Vector database configuration
13    pub vector_db: VectorDbConfig,
14
15    /// Embedding model configuration
16    pub embedding: EmbeddingConfig,
17
18    /// Indexing configuration
19    pub indexing: IndexingConfig,
20
21    /// Search configuration
22    pub search: SearchConfig,
23
24    /// Cache configuration
25    pub cache: CacheConfig,
26}
27
28/// Vector database configuration
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct VectorDbConfig {
31    /// Database backend: "lancedb" or "qdrant"
32    #[serde(default = "default_db_backend")]
33    pub backend: String,
34
35    /// LanceDB data directory path
36    #[serde(default = "default_lancedb_path")]
37    pub lancedb_path: PathBuf,
38
39    /// Qdrant server URL
40    #[serde(default = "default_qdrant_url")]
41    pub qdrant_url: String,
42
43    /// Collection name for vector storage
44    #[serde(default = "default_collection_name")]
45    pub collection_name: String,
46}
47
48/// Embedding model configuration
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct EmbeddingConfig {
51    /// Model name (e.g., "all-MiniLM-L6-v2", "BAAI/bge-small-en-v1.5")
52    #[serde(default = "default_model_name")]
53    pub model_name: String,
54
55    /// Batch size for embedding generation
56    /// Smaller values allow faster cancellation response but may be less efficient
57    #[serde(default = "default_batch_size")]
58    pub batch_size: usize,
59
60    /// Timeout in seconds for embedding generation per batch
61    /// This is per-batch, not total - smaller batches mean faster timeout response
62    #[serde(default = "default_embedding_timeout")]
63    pub timeout_secs: u64,
64
65    /// Maximum number of chunks to process before checking for cancellation
66    /// This provides more granular control over cancellation responsiveness
67    /// Set to 0 to use batch_size (check once per batch)
68    #[serde(default = "default_cancellation_check_interval")]
69    pub cancellation_check_interval: usize,
70}
71
72/// Indexing configuration
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct IndexingConfig {
75    /// Default chunk size for FixedLines strategy
76    #[serde(default = "default_chunk_size")]
77    pub chunk_size: usize,
78
79    /// Maximum file size to index (in bytes)
80    #[serde(default = "default_max_file_size")]
81    pub max_file_size: usize,
82
83    /// Default include patterns
84    #[serde(default)]
85    pub include_patterns: Vec<String>,
86
87    /// Default exclude patterns
88    #[serde(default = "default_exclude_patterns")]
89    pub exclude_patterns: Vec<String>,
90}
91
92/// Search configuration
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct SearchConfig {
95    /// Default minimum similarity score (0.0 to 1.0)
96    #[serde(default = "default_min_score")]
97    pub min_score: f32,
98
99    /// Default result limit
100    #[serde(default = "default_result_limit")]
101    pub limit: usize,
102
103    /// Enable hybrid search (vector + BM25) by default
104    #[serde(default = "default_hybrid_search")]
105    pub hybrid: bool,
106}
107
108/// Cache configuration
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct CacheConfig {
111    /// Hash cache file path
112    #[serde(default = "default_hash_cache_path")]
113    pub hash_cache_path: PathBuf,
114
115    /// Git cache file path
116    #[serde(default = "default_git_cache_path")]
117    pub git_cache_path: PathBuf,
118}
119
120// Default value functions
121fn default_db_backend() -> String {
122    #[cfg(feature = "qdrant-backend")]
123    return "qdrant".to_string();
124    #[cfg(not(feature = "qdrant-backend"))]
125    return "lancedb".to_string();
126}
127
128fn default_lancedb_path() -> PathBuf {
129    crate::paths::PlatformPaths::default_lancedb_path()
130}
131
132fn default_qdrant_url() -> String {
133    "http://localhost:6334".to_string()
134}
135
136fn default_collection_name() -> String {
137    "code_embeddings".to_string()
138}
139
140fn default_model_name() -> String {
141    "all-MiniLM-L6-v2".to_string()
142}
143
144fn default_batch_size() -> usize {
145    // Reduced from 32 to 8 for faster cancellation response
146    // Each batch takes ~1-3 seconds, so cancellation can respond within 3 seconds
147    8
148}
149
150fn default_embedding_timeout() -> u64 {
151    // Reduced from 30 to 10 seconds for faster timeout detection per batch
152    10
153}
154
155fn default_cancellation_check_interval() -> usize {
156    // Check cancellation every 4 chunks (every ~0.5-1.5 seconds)
157    // Set to 0 to use batch_size instead
158    4
159}
160
161fn default_chunk_size() -> usize {
162    50
163}
164
165fn default_max_file_size() -> usize {
166    1_048_576 // 1 MB
167}
168
169fn default_exclude_patterns() -> Vec<String> {
170    vec![
171        "target".to_string(),
172        "node_modules".to_string(),
173        ".git".to_string(),
174        "dist".to_string(),
175        "build".to_string(),
176    ]
177}
178
179fn default_min_score() -> f32 {
180    0.7
181}
182
183fn default_result_limit() -> usize {
184    10
185}
186
187fn default_hybrid_search() -> bool {
188    true
189}
190
191fn default_hash_cache_path() -> PathBuf {
192    crate::paths::PlatformPaths::default_hash_cache_path()
193}
194
195fn default_git_cache_path() -> PathBuf {
196    crate::paths::PlatformPaths::default_git_cache_path()
197}
198
199impl Default for VectorDbConfig {
200    fn default() -> Self {
201        Self {
202            backend: default_db_backend(),
203            lancedb_path: default_lancedb_path(),
204            qdrant_url: default_qdrant_url(),
205            collection_name: default_collection_name(),
206        }
207    }
208}
209
210impl Default for EmbeddingConfig {
211    fn default() -> Self {
212        Self {
213            model_name: default_model_name(),
214            batch_size: default_batch_size(),
215            timeout_secs: default_embedding_timeout(),
216            cancellation_check_interval: default_cancellation_check_interval(),
217        }
218    }
219}
220
221impl Default for IndexingConfig {
222    fn default() -> Self {
223        Self {
224            chunk_size: default_chunk_size(),
225            max_file_size: default_max_file_size(),
226            include_patterns: Vec::new(),
227            exclude_patterns: default_exclude_patterns(),
228        }
229    }
230}
231
232impl Default for SearchConfig {
233    fn default() -> Self {
234        Self {
235            min_score: default_min_score(),
236            limit: default_result_limit(),
237            hybrid: default_hybrid_search(),
238        }
239    }
240}
241
242impl Default for CacheConfig {
243    fn default() -> Self {
244        Self {
245            hash_cache_path: default_hash_cache_path(),
246            git_cache_path: default_git_cache_path(),
247        }
248    }
249}
250
251impl Config {
252    /// Load configuration from file
253    pub fn from_file(path: &Path) -> Result<Self, RagError> {
254        if !path.exists() {
255            return Err(ConfigError::FileNotFound(path.display().to_string()).into());
256        }
257
258        let content = std::fs::read_to_string(path)
259            .map_err(|e| ConfigError::LoadFailed(format!("Failed to read config file: {}", e)))?;
260
261        let config: Config = toml::from_str(&content)
262            .map_err(|e| ConfigError::ParseFailed(format!("Invalid TOML: {}", e)))?;
263
264        config.validate()?;
265        Ok(config)
266    }
267
268    /// Load configuration from default location or create default
269    pub fn load_or_default() -> Result<Self, RagError> {
270        let config_path = crate::paths::PlatformPaths::default_config_path();
271
272        if config_path.exists() {
273            tracing::info!("Loading config from: {}", config_path.display());
274            Self::from_file(&config_path)
275        } else {
276            tracing::info!("No config file found, using defaults");
277            Ok(Self::default())
278        }
279    }
280
281    /// Save configuration to file
282    pub fn save(&self, path: &Path) -> Result<(), RagError> {
283        // Create parent directory if needed
284        if let Some(parent) = path.parent() {
285            std::fs::create_dir_all(parent).map_err(|e| {
286                ConfigError::SaveFailed(format!("Failed to create config directory: {}", e))
287            })?;
288        }
289
290        let content = toml::to_string_pretty(self)
291            .map_err(|e| ConfigError::SaveFailed(format!("Failed to serialize config: {}", e)))?;
292
293        std::fs::write(path, content)
294            .map_err(|e| ConfigError::SaveFailed(format!("Failed to write config file: {}", e)))?;
295
296        tracing::info!("Saved config to: {}", path.display());
297        Ok(())
298    }
299
300    /// Save to default location
301    pub fn save_default(&self) -> Result<(), RagError> {
302        let config_path = crate::paths::PlatformPaths::default_config_path();
303        self.save(&config_path)
304    }
305
306    /// Validate configuration values
307    pub fn validate(&self) -> Result<(), RagError> {
308        // Validate vector DB backend
309        if self.vector_db.backend != "lancedb" && self.vector_db.backend != "qdrant" {
310            return Err(ConfigError::InvalidValue {
311                key: "vector_db.backend".to_string(),
312                reason: format!(
313                    "must be 'lancedb' or 'qdrant', got '{}'",
314                    self.vector_db.backend
315                ),
316            }
317            .into());
318        }
319
320        // Validate batch size
321        if self.embedding.batch_size == 0 {
322            return Err(ConfigError::InvalidValue {
323                key: "embedding.batch_size".to_string(),
324                reason: "must be greater than 0".to_string(),
325            }
326            .into());
327        }
328
329        // Validate chunk size
330        if self.indexing.chunk_size == 0 {
331            return Err(ConfigError::InvalidValue {
332                key: "indexing.chunk_size".to_string(),
333                reason: "must be greater than 0".to_string(),
334            }
335            .into());
336        }
337
338        // Validate max file size
339        if self.indexing.max_file_size == 0 {
340            return Err(ConfigError::InvalidValue {
341                key: "indexing.max_file_size".to_string(),
342                reason: "must be greater than 0".to_string(),
343            }
344            .into());
345        }
346
347        // Validate min_score range
348        if !(0.0..=1.0).contains(&self.search.min_score) {
349            return Err(ConfigError::InvalidValue {
350                key: "search.min_score".to_string(),
351                reason: format!("must be between 0.0 and 1.0, got {}", self.search.min_score),
352            }
353            .into());
354        }
355
356        // Validate limit
357        if self.search.limit == 0 {
358            return Err(ConfigError::InvalidValue {
359                key: "search.limit".to_string(),
360                reason: "must be greater than 0".to_string(),
361            }
362            .into());
363        }
364
365        Ok(())
366    }
367
368    /// Apply environment variable overrides
369    pub fn apply_env_overrides(&mut self) {
370        // Vector DB backend
371        if let Ok(backend) = std::env::var("PROJECT_RAG_DB_BACKEND") {
372            self.vector_db.backend = backend;
373        }
374
375        // LanceDB path
376        if let Ok(path) = std::env::var("PROJECT_RAG_LANCEDB_PATH") {
377            self.vector_db.lancedb_path = PathBuf::from(path);
378        }
379
380        // Qdrant URL
381        if let Ok(url) = std::env::var("PROJECT_RAG_QDRANT_URL") {
382            self.vector_db.qdrant_url = url;
383        }
384
385        // Embedding model
386        if let Ok(model) = std::env::var("PROJECT_RAG_MODEL") {
387            self.embedding.model_name = model;
388        }
389
390        // Batch size
391        if let Ok(batch_size) = std::env::var("PROJECT_RAG_BATCH_SIZE")
392            && let Ok(size) = batch_size.parse()
393        {
394            self.embedding.batch_size = size;
395        }
396
397        // Min score
398        if let Ok(min_score) = std::env::var("PROJECT_RAG_MIN_SCORE")
399            && let Ok(score) = min_score.parse()
400        {
401            self.search.min_score = score;
402        }
403    }
404
405    /// Create a new Config with defaults and environment overrides
406    pub fn new() -> Result<Self, RagError> {
407        let mut config = Self::load_or_default()?;
408        config.apply_env_overrides();
409        config.validate()?;
410        Ok(config)
411    }
412}
413
414// Tests are inline in this module
415#[cfg(test)]
416mod tests {
417    #[test]
418    fn test_config_placeholder() {
419        // Placeholder for config tests
420        // TODO: Add comprehensive config tests
421    }
422}