use crate::rag::error::{ConfigError, RagError};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Config {
pub vector_db: VectorDbConfig,
pub embedding: EmbeddingConfig,
pub indexing: IndexingConfig,
pub search: SearchConfig,
pub cache: CacheConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorDbConfig {
#[serde(default = "default_db_backend")]
pub backend: String,
#[serde(default = "default_lancedb_path")]
pub lancedb_path: PathBuf,
#[serde(default = "default_qdrant_url")]
pub qdrant_url: String,
#[serde(default = "default_collection_name")]
pub collection_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
#[serde(default = "default_model_name")]
pub model_name: String,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
#[serde(default = "default_embedding_timeout")]
pub timeout_secs: u64,
#[serde(default = "default_cancellation_check_interval")]
pub cancellation_check_interval: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexingConfig {
#[serde(default = "default_chunk_size")]
pub chunk_size: usize,
#[serde(default = "default_max_file_size")]
pub max_file_size: usize,
#[serde(default)]
pub include_patterns: Vec<String>,
#[serde(default = "default_exclude_patterns")]
pub exclude_patterns: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchConfig {
#[serde(default = "default_min_score")]
pub min_score: f32,
#[serde(default = "default_result_limit")]
pub limit: usize,
#[serde(default = "default_hybrid_search")]
pub hybrid: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
#[serde(default = "default_hash_cache_path")]
pub hash_cache_path: PathBuf,
#[serde(default = "default_git_cache_path")]
pub git_cache_path: PathBuf,
}
fn default_db_backend() -> String {
#[cfg(feature = "qdrant-backend")]
return "qdrant".to_string();
#[cfg(not(feature = "qdrant-backend"))]
return "lancedb".to_string();
}
fn default_lancedb_path() -> PathBuf {
brainwires_storage::paths::PlatformPaths::default_lancedb_path()
}
fn default_qdrant_url() -> String {
"http://localhost:6334".to_string()
}
fn default_collection_name() -> String {
"code_embeddings".to_string()
}
fn default_model_name() -> String {
"all-MiniLM-L6-v2".to_string()
}
fn default_batch_size() -> usize {
8
}
fn default_embedding_timeout() -> u64 {
10
}
fn default_cancellation_check_interval() -> usize {
4
}
fn default_chunk_size() -> usize {
50
}
fn default_max_file_size() -> usize {
1_048_576 }
fn default_exclude_patterns() -> Vec<String> {
vec![
"target".to_string(),
"node_modules".to_string(),
".git".to_string(),
"dist".to_string(),
"build".to_string(),
]
}
fn default_min_score() -> f32 {
0.7
}
fn default_result_limit() -> usize {
10
}
fn default_hybrid_search() -> bool {
true
}
fn default_hash_cache_path() -> PathBuf {
brainwires_storage::paths::PlatformPaths::default_hash_cache_path()
}
fn default_git_cache_path() -> PathBuf {
brainwires_storage::paths::PlatformPaths::default_git_cache_path()
}
impl Default for VectorDbConfig {
fn default() -> Self {
Self {
backend: default_db_backend(),
lancedb_path: default_lancedb_path(),
qdrant_url: default_qdrant_url(),
collection_name: default_collection_name(),
}
}
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model_name: default_model_name(),
batch_size: default_batch_size(),
timeout_secs: default_embedding_timeout(),
cancellation_check_interval: default_cancellation_check_interval(),
}
}
}
impl Default for IndexingConfig {
fn default() -> Self {
Self {
chunk_size: default_chunk_size(),
max_file_size: default_max_file_size(),
include_patterns: Vec::new(),
exclude_patterns: default_exclude_patterns(),
}
}
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
min_score: default_min_score(),
limit: default_result_limit(),
hybrid: default_hybrid_search(),
}
}
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
hash_cache_path: default_hash_cache_path(),
git_cache_path: default_git_cache_path(),
}
}
}
impl Config {
pub fn from_file(path: &Path) -> Result<Self, RagError> {
if !path.exists() {
return Err(ConfigError::FileNotFound(path.display().to_string()).into());
}
let content = std::fs::read_to_string(path)
.map_err(|e| ConfigError::LoadFailed(format!("Failed to read config file: {}", e)))?;
let config: Config = toml::from_str(&content)
.map_err(|e| ConfigError::ParseFailed(format!("Invalid TOML: {}", e)))?;
config.validate()?;
Ok(config)
}
pub fn load_or_default() -> Result<Self, RagError> {
let config_path = brainwires_storage::paths::PlatformPaths::default_config_path();
if config_path.exists() {
tracing::info!("Loading config from: {}", config_path.display());
Self::from_file(&config_path)
} else {
tracing::info!("No config file found, using defaults");
Ok(Self::default())
}
}
pub fn save(&self, path: &Path) -> Result<(), RagError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
ConfigError::SaveFailed(format!("Failed to create config directory: {}", e))
})?;
}
let content = toml::to_string_pretty(self)
.map_err(|e| ConfigError::SaveFailed(format!("Failed to serialize config: {}", e)))?;
std::fs::write(path, content)
.map_err(|e| ConfigError::SaveFailed(format!("Failed to write config file: {}", e)))?;
tracing::info!("Saved config to: {}", path.display());
Ok(())
}
pub fn save_default(&self) -> Result<(), RagError> {
let config_path = brainwires_storage::paths::PlatformPaths::default_config_path();
self.save(&config_path)
}
pub fn validate(&self) -> Result<(), RagError> {
if self.vector_db.backend != "lancedb" && self.vector_db.backend != "qdrant" {
return Err(ConfigError::InvalidValue {
key: "vector_db.backend".to_string(),
reason: format!(
"must be 'lancedb' or 'qdrant', got '{}'",
self.vector_db.backend
),
}
.into());
}
if self.embedding.batch_size == 0 {
return Err(ConfigError::InvalidValue {
key: "embedding.batch_size".to_string(),
reason: "must be greater than 0".to_string(),
}
.into());
}
if self.indexing.chunk_size == 0 {
return Err(ConfigError::InvalidValue {
key: "indexing.chunk_size".to_string(),
reason: "must be greater than 0".to_string(),
}
.into());
}
if self.indexing.max_file_size == 0 {
return Err(ConfigError::InvalidValue {
key: "indexing.max_file_size".to_string(),
reason: "must be greater than 0".to_string(),
}
.into());
}
if !(0.0..=1.0).contains(&self.search.min_score) {
return Err(ConfigError::InvalidValue {
key: "search.min_score".to_string(),
reason: format!("must be between 0.0 and 1.0, got {}", self.search.min_score),
}
.into());
}
if self.search.limit == 0 {
return Err(ConfigError::InvalidValue {
key: "search.limit".to_string(),
reason: "must be greater than 0".to_string(),
}
.into());
}
Ok(())
}
pub fn apply_env_overrides(&mut self) {
if let Ok(backend) = std::env::var("PROJECT_RAG_DB_BACKEND") {
self.vector_db.backend = backend;
}
if let Ok(path) = std::env::var("PROJECT_RAG_LANCEDB_PATH") {
self.vector_db.lancedb_path = PathBuf::from(path);
}
if let Ok(url) = std::env::var("PROJECT_RAG_QDRANT_URL") {
self.vector_db.qdrant_url = url;
}
if let Ok(model) = std::env::var("PROJECT_RAG_MODEL") {
self.embedding.model_name = model;
}
if let Ok(batch_size) = std::env::var("PROJECT_RAG_BATCH_SIZE")
&& let Ok(size) = batch_size.parse()
{
self.embedding.batch_size = size;
}
if let Ok(min_score) = std::env::var("PROJECT_RAG_MIN_SCORE")
&& let Ok(score) = min_score.parse()
{
self.search.min_score = score;
}
}
pub fn new() -> Result<Self, RagError> {
let mut config = Self::load_or_default()?;
config.apply_env_overrides();
config.validate()?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_is_valid() {
let config = Config::default();
assert!(config.indexing.chunk_size > 0);
assert!(config.indexing.max_file_size > 0);
assert!(config.search.limit > 0);
assert!(config.search.min_score >= 0.0 && config.search.min_score <= 1.0);
}
}