use crate::cache::CacheConfig;
use crate::config::FeatureFlags;
use crate::search::fusion::FusionWeights;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use thiserror::Error;
use tracing::{debug, info, warn};
#[derive(Error, Debug)]
pub enum SearchConfigError {
#[error("Configuration file not found: {0}")]
FileNotFound(String),
#[error("Invalid YAML syntax: {0}")]
InvalidYaml(String),
#[error("Configuration validation failed: {0}")]
ValidationError(String),
#[error("Environment variable parsing error: {0}")]
EnvVarError(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SearchConfig {
pub embedding: EmbeddingConfig,
pub fusion: FusionConfig,
pub performance: PerformanceConfig,
pub index: IndexConfig,
pub feature_flags: FeatureFlags,
#[serde(default)]
pub cache: CacheConfig,
#[serde(default)]
pub indexing: IndexingConfig,
#[serde(default)]
pub database: DatabaseConfig,
#[serde(default)]
pub runtime: RuntimeConfig,
#[serde(default)]
pub buffers: BufferConfig,
#[serde(default)]
pub graph_importance: GraphImportanceConfig,
}
impl SearchConfig {
pub async fn load_default() -> Result<Self> {
let default_paths = vec![
PathBuf::from("config/maproom-search.yml"),
PathBuf::from("../config/maproom-search.yml"),
PathBuf::from("/etc/maproom/maproom-search.yml"),
];
for path in default_paths {
if path.exists() {
info!("Loading configuration from: {}", path.display());
return Self::load_from_file(&path).await;
}
}
warn!("No configuration file found, using defaults");
Ok(Self::default())
}
pub async fn load_from_file(path: &Path) -> Result<Self> {
if !path.exists() {
return Err(SearchConfigError::FileNotFound(path.display().to_string()).into());
}
let contents = tokio::fs::read_to_string(path)
.await
.context("Failed to read configuration file")?;
let mut config: SearchConfig = serde_yaml::from_str(&contents)
.map_err(|e| SearchConfigError::InvalidYaml(e.to_string()))?;
config.apply_env_overrides()?;
config.validate()?;
info!("Configuration loaded successfully from: {}", path.display());
debug!("Active configuration: {:#?}", config);
Ok(config)
}
fn apply_env_overrides(&mut self) -> Result<()> {
if let Ok(provider) = std::env::var("MAPROOM_SEARCH_EMBEDDING_PROVIDER") {
self.embedding.provider = provider;
debug!("Override: embedding.provider = {}", self.embedding.provider);
}
if let Ok(model) = std::env::var("MAPROOM_SEARCH_EMBEDDING_MODEL_NAME") {
self.embedding.model_name = model;
debug!(
"Override: embedding.model_name = {}",
self.embedding.model_name
);
}
if let Ok(dim) = std::env::var("MAPROOM_SEARCH_EMBEDDING_DIMENSION") {
self.embedding.dimension = dim
.parse()
.context("Failed to parse MAPROOM_SEARCH_EMBEDDING_DIMENSION")?;
debug!(
"Override: embedding.dimension = {}",
self.embedding.dimension
);
}
if let Ok(size) = std::env::var("MAPROOM_SEARCH_EMBEDDING_CACHE_SIZE") {
self.embedding.cache_size = size
.parse()
.context("Failed to parse MAPROOM_SEARCH_EMBEDDING_CACHE_SIZE")?;
debug!(
"Override: embedding.cache_size = {}",
self.embedding.cache_size
);
}
if let Ok(ttl) = std::env::var("MAPROOM_SEARCH_EMBEDDING_CACHE_TTL_SECONDS") {
self.embedding.cache_ttl_seconds = ttl
.parse()
.context("Failed to parse MAPROOM_SEARCH_EMBEDDING_CACHE_TTL_SECONDS")?;
debug!(
"Override: embedding.cache_ttl_seconds = {}",
self.embedding.cache_ttl_seconds
);
}
if let Ok(method) = std::env::var("MAPROOM_SEARCH_FUSION_METHOD") {
self.fusion.method = FusionMethod::from_str(&method)?;
debug!("Override: fusion.method = {:?}", self.fusion.method);
}
if let Ok(k) = std::env::var("MAPROOM_SEARCH_FUSION_RRF_K") {
self.fusion.rrf_k = k
.parse()
.context("Failed to parse MAPROOM_SEARCH_FUSION_RRF_K")?;
debug!("Override: fusion.rrf_k = {}", self.fusion.rrf_k);
}
if let Ok(fts) = std::env::var("MAPROOM_SEARCH_FUSION_WEIGHTS_FTS") {
self.fusion.weights.fts = fts
.parse()
.context("Failed to parse MAPROOM_SEARCH_FUSION_WEIGHTS_FTS")?;
debug!("Override: fusion.weights.fts = {}", self.fusion.weights.fts);
}
if let Ok(vector) = std::env::var("MAPROOM_SEARCH_FUSION_WEIGHTS_VECTOR") {
self.fusion.weights.vector = vector
.parse()
.context("Failed to parse MAPROOM_SEARCH_FUSION_WEIGHTS_VECTOR")?;
debug!(
"Override: fusion.weights.vector = {}",
self.fusion.weights.vector
);
}
if let Ok(graph) = std::env::var("MAPROOM_SEARCH_FUSION_WEIGHTS_GRAPH") {
self.fusion.weights.graph = graph
.parse()
.context("Failed to parse MAPROOM_SEARCH_FUSION_WEIGHTS_GRAPH")?;
debug!(
"Override: fusion.weights.graph = {}",
self.fusion.weights.graph
);
}
if let Ok(recency) = std::env::var("MAPROOM_SEARCH_FUSION_WEIGHTS_RECENCY") {
self.fusion.weights.recency = recency
.parse()
.context("Failed to parse MAPROOM_SEARCH_FUSION_WEIGHTS_RECENCY")?;
debug!(
"Override: fusion.weights.recency = {}",
self.fusion.weights.recency
);
}
if let Ok(churn) = std::env::var("MAPROOM_SEARCH_FUSION_WEIGHTS_CHURN") {
self.fusion.weights.churn = churn
.parse()
.context("Failed to parse MAPROOM_SEARCH_FUSION_WEIGHTS_CHURN")?;
debug!(
"Override: fusion.weights.churn = {}",
self.fusion.weights.churn
);
}
if let Ok(max_candidates) =
std::env::var("MAPROOM_SEARCH_PERFORMANCE_MAX_CANDIDATES_PER_METHOD")
{
self.performance.max_candidates_per_method = max_candidates
.parse()
.context("Failed to parse MAPROOM_SEARCH_PERFORMANCE_MAX_CANDIDATES_PER_METHOD")?;
debug!(
"Override: performance.max_candidates_per_method = {}",
self.performance.max_candidates_per_method
);
}
if let Ok(final_limit) = std::env::var("MAPROOM_SEARCH_PERFORMANCE_FINAL_RESULT_LIMIT") {
self.performance.final_result_limit = final_limit
.parse()
.context("Failed to parse MAPROOM_SEARCH_PERFORMANCE_FINAL_RESULT_LIMIT")?;
debug!(
"Override: performance.final_result_limit = {}",
self.performance.final_result_limit
);
}
if let Ok(timeout) = std::env::var("MAPROOM_SEARCH_PERFORMANCE_TIMEOUT_MS") {
self.performance.timeout_ms = timeout
.parse()
.context("Failed to parse MAPROOM_SEARCH_PERFORMANCE_TIMEOUT_MS")?;
debug!(
"Override: performance.timeout_ms = {}",
self.performance.timeout_ms
);
}
if let Ok(parallel) = std::env::var("MAPROOM_SEARCH_PERFORMANCE_PARALLEL_EXECUTION") {
self.performance.parallel_execution = parallel
.parse()
.context("Failed to parse MAPROOM_SEARCH_PERFORMANCE_PARALLEL_EXECUTION")?;
debug!(
"Override: performance.parallel_execution = {}",
self.performance.parallel_execution
);
}
if let Ok(lists) = std::env::var("MAPROOM_SEARCH_INDEX_IVFFLAT_LISTS") {
self.index.ivfflat_lists = lists
.parse()
.context("Failed to parse MAPROOM_SEARCH_INDEX_IVFFLAT_LISTS")?;
debug!(
"Override: index.ivfflat_lists = {}",
self.index.ivfflat_lists
);
}
if let Ok(probes) = std::env::var("MAPROOM_SEARCH_INDEX_IVFFLAT_PROBES") {
self.index.ivfflat_probes = probes
.parse()
.context("Failed to parse MAPROOM_SEARCH_INDEX_IVFFLAT_PROBES")?;
debug!(
"Override: index.ivfflat_probes = {}",
self.index.ivfflat_probes
);
}
if let Ok(refresh) = std::env::var("MAPROOM_SEARCH_INDEX_REFRESH_INTERVAL_SECONDS") {
self.index.refresh_interval_seconds = refresh
.parse()
.context("Failed to parse MAPROOM_SEARCH_INDEX_REFRESH_INTERVAL_SECONDS")?;
debug!(
"Override: index.refresh_interval_seconds = {}",
self.index.refresh_interval_seconds
);
}
if let Ok(vector) = std::env::var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_VECTOR_SEARCH") {
self.feature_flags.enable_vector_search = vector
.parse()
.context("Failed to parse MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_VECTOR_SEARCH")?;
debug!(
"Override: feature_flags.enable_vector_search = {}",
self.feature_flags.enable_vector_search
);
}
if let Ok(hybrid) = std::env::var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_HYBRID_FUSION") {
self.feature_flags.enable_hybrid_fusion = hybrid
.parse()
.context("Failed to parse MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_HYBRID_FUSION")?;
debug!(
"Override: feature_flags.enable_hybrid_fusion = {}",
self.feature_flags.enable_hybrid_fusion
);
}
if let Ok(graph) = std::env::var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_GRAPH_SIGNALS") {
self.feature_flags.enable_graph_signals = graph
.parse()
.context("Failed to parse MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_GRAPH_SIGNALS")?;
debug!(
"Override: feature_flags.enable_graph_signals = {}",
self.feature_flags.enable_graph_signals
);
}
if let Ok(temporal) = std::env::var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_TEMPORAL_SIGNALS")
{
self.feature_flags.enable_temporal_signals = temporal
.parse()
.context("Failed to parse MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_TEMPORAL_SIGNALS")?;
debug!(
"Override: feature_flags.enable_temporal_signals = {}",
self.feature_flags.enable_temporal_signals
);
}
if let Ok(cache) = std::env::var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUERY_CACHE") {
self.feature_flags.enable_query_cache = cache
.parse()
.context("Failed to parse MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUERY_CACHE")?;
debug!(
"Override: feature_flags.enable_query_cache = {}",
self.feature_flags.enable_query_cache
);
}
if let Ok(hot_reload) = std::env::var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_HOT_RELOAD") {
self.feature_flags.enable_hot_reload = hot_reload
.parse()
.context("Failed to parse MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_HOT_RELOAD")?;
debug!(
"Override: feature_flags.enable_hot_reload = {}",
self.feature_flags.enable_hot_reload
);
}
if let Ok(quality_graph) =
std::env::var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUALITY_WEIGHTED_GRAPH")
{
self.feature_flags.enable_quality_weighted_graph = quality_graph.parse().context(
"Failed to parse MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUALITY_WEIGHTED_GRAPH",
)?;
debug!(
"Override: feature_flags.enable_quality_weighted_graph = {}",
self.feature_flags.enable_quality_weighted_graph
);
}
if let Ok(enable_quality) =
std::env::var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_ENABLE_QUALITY_SCORING")
{
self.graph_importance.enable_quality_scoring = enable_quality.parse().context(
"Failed to parse MAPROOM_SEARCH_GRAPH_IMPORTANCE_ENABLE_QUALITY_SCORING",
)?;
debug!(
"Override: graph_importance.enable_quality_scoring = {}",
self.graph_importance.enable_quality_scoring
);
}
if let Ok(prod_weight) =
std::env::var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_PRODUCTION_CODE_WEIGHT")
{
self.graph_importance.edge_quality_weights.production_code =
prod_weight.parse().context(
"Failed to parse MAPROOM_SEARCH_GRAPH_IMPORTANCE_PRODUCTION_CODE_WEIGHT",
)?;
debug!(
"Override: graph_importance.edge_quality_weights.production_code = {}",
self.graph_importance.edge_quality_weights.production_code
);
}
if let Ok(test_weight) = std::env::var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_TEST_CODE_WEIGHT") {
self.graph_importance.edge_quality_weights.test_code = test_weight
.parse()
.context("Failed to parse MAPROOM_SEARCH_GRAPH_IMPORTANCE_TEST_CODE_WEIGHT")?;
debug!(
"Override: graph_importance.edge_quality_weights.test_code = {}",
self.graph_importance.edge_quality_weights.test_code
);
}
if let Ok(calls_weight) = std::env::var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_CALLS_WEIGHT") {
self.graph_importance.edge_quality_weights.calls = calls_weight
.parse()
.context("Failed to parse MAPROOM_SEARCH_GRAPH_IMPORTANCE_CALLS_WEIGHT")?;
debug!(
"Override: graph_importance.edge_quality_weights.calls = {}",
self.graph_importance.edge_quality_weights.calls
);
}
if let Ok(fusion_override) =
std::env::var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_FUSION_WEIGHT_OVERRIDE")
{
self.graph_importance.fusion_weight_override = Some(fusion_override.parse().context(
"Failed to parse MAPROOM_SEARCH_GRAPH_IMPORTANCE_FUSION_WEIGHT_OVERRIDE",
)?);
debug!(
"Override: graph_importance.fusion_weight_override = {:?}",
self.graph_importance.fusion_weight_override
);
}
Ok(())
}
pub fn validate(&self) -> Result<()> {
self.embedding.validate()?;
self.fusion.validate()?;
self.performance.validate()?;
self.index.validate()?;
self.indexing.validate()?;
self.database.validate()?;
self.runtime.validate()?;
self.buffers.validate()?;
self.graph_importance.validate()?;
Ok(())
}
pub fn get_env_overrides() -> Vec<(String, String)> {
std::env::vars()
.filter(|(k, _)| k.starts_with("MAPROOM_SEARCH_"))
.collect()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub provider: String,
pub model_name: String,
pub dimension: usize,
pub cache_size: usize,
pub cache_ttl_seconds: u64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
provider: "openai".to_string(),
model_name: "text-embedding-3-small".to_string(),
dimension: 1536,
cache_size: 10000,
cache_ttl_seconds: 3600,
}
}
}
impl EmbeddingConfig {
pub fn validate(&self) -> Result<()> {
if self.provider.is_empty() {
return Err(SearchConfigError::ValidationError(
"Embedding provider cannot be empty".to_string(),
)
.into());
}
if self.model_name.is_empty() {
return Err(SearchConfigError::ValidationError(
"Embedding model name cannot be empty".to_string(),
)
.into());
}
if self.dimension == 0 {
return Err(SearchConfigError::ValidationError(
"Embedding dimension must be greater than 0".to_string(),
)
.into());
}
if self.cache_size == 0 {
warn!("Embedding cache size is 0, caching is disabled");
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FusionConfig {
pub method: FusionMethod,
pub rrf_k: u32,
pub weights: FusionWeights,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
method: FusionMethod::RRF,
rrf_k: 60,
weights: FusionWeights::default(),
}
}
}
impl FusionConfig {
pub fn validate(&self) -> Result<()> {
self.weights.validate().context("Invalid fusion weights")?;
if !self.weights.is_normalized() {
warn!(
"Fusion weights are not normalized (sum = {}), consider normalizing for predictable behavior",
self.weights.sum()
);
}
if self.rrf_k == 0 {
return Err(SearchConfigError::ValidationError(
"RRF k parameter must be greater than 0".to_string(),
)
.into());
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum FusionMethod {
RRF,
Weighted,
Learned,
}
impl FusionMethod {
#[allow(clippy::should_implement_trait)] pub fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"rrf" => Ok(Self::RRF),
"weighted" => Ok(Self::Weighted),
"learned" => Ok(Self::Learned),
_ => Err(SearchConfigError::ValidationError(format!(
"Invalid fusion method: {}. Valid options: rrf, weighted, learned",
s
))
.into()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceConfig {
pub max_candidates_per_method: usize,
pub final_result_limit: usize,
pub timeout_ms: u64,
pub parallel_execution: bool,
}
impl Default for PerformanceConfig {
fn default() -> Self {
Self {
max_candidates_per_method: 100,
final_result_limit: 20,
timeout_ms: 1000,
parallel_execution: true,
}
}
}
impl PerformanceConfig {
pub fn validate(&self) -> Result<()> {
if self.max_candidates_per_method == 0 {
return Err(SearchConfigError::ValidationError(
"max_candidates_per_method must be greater than 0".to_string(),
)
.into());
}
if self.final_result_limit == 0 {
return Err(SearchConfigError::ValidationError(
"final_result_limit must be greater than 0".to_string(),
)
.into());
}
if self.timeout_ms == 0 {
warn!("Query timeout is 0, queries will not timeout");
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexConfig {
pub ivfflat_lists: u32,
pub ivfflat_probes: u32,
pub refresh_interval_seconds: u64,
}
impl Default for IndexConfig {
fn default() -> Self {
Self {
ivfflat_lists: 100,
ivfflat_probes: 10,
refresh_interval_seconds: 3600,
}
}
}
impl IndexConfig {
pub fn validate(&self) -> Result<()> {
if self.ivfflat_lists == 0 {
return Err(SearchConfigError::ValidationError(
"ivfflat_lists must be greater than 0".to_string(),
)
.into());
}
if self.ivfflat_probes == 0 {
return Err(SearchConfigError::ValidationError(
"ivfflat_probes must be greater than 0".to_string(),
)
.into());
}
if self.ivfflat_probes > self.ivfflat_lists {
warn!(
"ivfflat_probes ({}) is greater than ivfflat_lists ({}), this is inefficient",
self.ivfflat_probes, self.ivfflat_lists
);
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexingConfig {
pub parallel_workers: usize,
pub batch_size: usize,
pub max_file_size: usize,
pub chunk_insert_batch_size: usize,
pub edge_insert_batch_size: usize,
}
impl Default for IndexingConfig {
fn default() -> Self {
Self {
parallel_workers: 8, batch_size: 50, max_file_size: 10 * 1024 * 1024, chunk_insert_batch_size: 100, edge_insert_batch_size: 500, }
}
}
impl IndexingConfig {
pub fn validate(&self) -> Result<()> {
if self.parallel_workers == 0 {
return Err(SearchConfigError::ValidationError(
"parallel_workers must be greater than 0".to_string(),
)
.into());
}
if self.batch_size == 0 {
return Err(SearchConfigError::ValidationError(
"batch_size must be greater than 0".to_string(),
)
.into());
}
if self.max_file_size == 0 {
warn!("max_file_size is 0, no files will be indexed");
}
if self.chunk_insert_batch_size == 0 {
return Err(SearchConfigError::ValidationError(
"chunk_insert_batch_size must be greater than 0".to_string(),
)
.into());
}
if self.edge_insert_batch_size == 0 {
return Err(SearchConfigError::ValidationError(
"edge_insert_batch_size must be greater than 0".to_string(),
)
.into());
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub pool_size: usize,
pub connection_timeout_ms: u64,
pub statement_timeout_ms: u64,
pub lock_timeout_ms: u64,
pub idle_in_transaction_timeout_ms: u64,
pub work_mem: String,
pub max_connection_lifetime_secs: u64,
pub idle_connection_timeout_secs: u64,
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
pool_size: 20, connection_timeout_ms: 5000, statement_timeout_ms: 5000, lock_timeout_ms: 1000, idle_in_transaction_timeout_ms: 30000, work_mem: "256MB".to_string(), max_connection_lifetime_secs: 1800, idle_connection_timeout_secs: 600, }
}
}
impl DatabaseConfig {
pub fn validate(&self) -> Result<()> {
if self.pool_size == 0 {
return Err(SearchConfigError::ValidationError(
"pool_size must be greater than 0".to_string(),
)
.into());
}
if self.pool_size > 100 {
warn!(
"pool_size ({}) is very large, this may cause PostgreSQL overhead",
self.pool_size
);
}
if self.statement_timeout_ms == 0 {
warn!("statement_timeout_ms is 0, queries will not timeout");
}
if !self.work_mem.ends_with("MB") && !self.work_mem.ends_with("GB") {
return Err(SearchConfigError::ValidationError(
"work_mem must end with 'MB' or 'GB' (e.g., '256MB')".to_string(),
)
.into());
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuntimeConfig {
pub worker_threads: usize,
pub max_blocking_threads: usize,
pub thread_stack_size: usize,
pub enable_thread_names: bool,
}
impl Default for RuntimeConfig {
fn default() -> Self {
Self {
worker_threads: 8, max_blocking_threads: 16, thread_stack_size: 2 * 1024 * 1024, enable_thread_names: true,
}
}
}
impl RuntimeConfig {
pub fn validate(&self) -> Result<()> {
if self.worker_threads == 0 {
return Err(SearchConfigError::ValidationError(
"worker_threads must be greater than 0".to_string(),
)
.into());
}
if self.max_blocking_threads == 0 {
return Err(SearchConfigError::ValidationError(
"max_blocking_threads must be greater than 0".to_string(),
)
.into());
}
if self.thread_stack_size < 256 * 1024 {
warn!(
"thread_stack_size ({} bytes) is very small, this may cause stack overflows",
self.thread_stack_size
);
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BufferConfig {
pub file_read_buffer: usize,
pub db_buffer: usize,
pub parse_buffer: usize,
pub buffer_pool_size: usize,
}
impl Default for BufferConfig {
fn default() -> Self {
Self {
file_read_buffer: 64 * 1024, db_buffer: 32 * 1024, parse_buffer: 1024 * 1024, buffer_pool_size: 100, }
}
}
impl BufferConfig {
pub fn validate(&self) -> Result<()> {
if self.file_read_buffer == 0 {
return Err(SearchConfigError::ValidationError(
"file_read_buffer must be greater than 0".to_string(),
)
.into());
}
if self.db_buffer == 0 {
return Err(SearchConfigError::ValidationError(
"db_buffer must be greater than 0".to_string(),
)
.into());
}
if self.parse_buffer == 0 {
return Err(SearchConfigError::ValidationError(
"parse_buffer must be greater than 0".to_string(),
)
.into());
}
if self.buffer_pool_size == 0 {
warn!("buffer_pool_size is 0, buffer pooling is disabled");
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GraphImportanceConfig {
#[serde(default)]
pub enable_quality_scoring: bool,
#[serde(default)]
pub edge_quality_weights: EdgeQualityWeights,
#[serde(default)]
pub fusion_weight_override: Option<f32>,
}
impl GraphImportanceConfig {
pub fn validate(&self) -> Result<()> {
self.edge_quality_weights.validate()?;
if let Some(weight) = self.fusion_weight_override {
if !(0.0..=1.0).contains(&weight) {
return Err(SearchConfigError::ValidationError(
"fusion_weight_override must be between 0.0 and 1.0".to_string(),
)
.into());
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeQualityWeights {
#[serde(default = "default_production_code_weight")]
pub production_code: f32,
#[serde(default = "default_test_code_weight")]
pub test_code: f32,
#[serde(default = "default_calls_weight")]
pub calls: f32,
}
fn default_production_code_weight() -> f32 {
1.0
}
fn default_test_code_weight() -> f32 {
0.5
}
fn default_calls_weight() -> f32 {
1.0
}
impl Default for EdgeQualityWeights {
fn default() -> Self {
Self {
production_code: 1.0,
test_code: 0.5,
calls: 1.0,
}
}
}
impl EdgeQualityWeights {
pub fn validate(&self) -> Result<()> {
if self.production_code < 0.0 || self.production_code > 10.0 {
return Err(SearchConfigError::ValidationError(
"production_code weight must be between 0.0 and 10.0".to_string(),
)
.into());
}
if self.test_code < 0.0 || self.test_code > 10.0 {
return Err(SearchConfigError::ValidationError(
"test_code weight must be between 0.0 and 10.0".to_string(),
)
.into());
}
if self.calls < 0.0 || self.calls > 10.0 {
return Err(SearchConfigError::ValidationError(
"calls weight must be between 0.0 and 10.0".to_string(),
)
.into());
}
Ok(())
}
pub fn is_default(&self) -> bool {
(self.production_code - 1.0).abs() < f32::EPSILON
&& (self.test_code - 0.5).abs() < f32::EPSILON
&& (self.calls - 1.0).abs() < f32::EPSILON
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = SearchConfig::default();
assert!(config.validate().is_ok());
assert_eq!(config.embedding.provider, "openai");
assert_eq!(config.fusion.method, FusionMethod::RRF);
assert!(config.feature_flags.enable_vector_search);
}
#[test]
fn test_fusion_method_parsing() {
assert_eq!(FusionMethod::from_str("rrf").unwrap(), FusionMethod::RRF);
assert_eq!(
FusionMethod::from_str("weighted").unwrap(),
FusionMethod::Weighted
);
assert_eq!(
FusionMethod::from_str("learned").unwrap(),
FusionMethod::Learned
);
assert_eq!(FusionMethod::from_str("RRF").unwrap(), FusionMethod::RRF);
assert!(FusionMethod::from_str("invalid").is_err());
}
#[test]
fn test_embedding_config_validation() {
let mut config = EmbeddingConfig::default();
assert!(config.validate().is_ok());
config.provider = "".to_string();
assert!(config.validate().is_err());
config = EmbeddingConfig::default();
config.dimension = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_fusion_config_validation() {
let mut config = FusionConfig::default();
assert!(config.validate().is_ok());
config.rrf_k = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_performance_config_validation() {
let mut config = PerformanceConfig::default();
assert!(config.validate().is_ok());
config.max_candidates_per_method = 0;
assert!(config.validate().is_err());
config = PerformanceConfig::default();
config.final_result_limit = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_index_config_validation() {
let mut config = IndexConfig::default();
assert!(config.validate().is_ok());
config.ivfflat_lists = 0;
assert!(config.validate().is_err());
config = IndexConfig::default();
config.ivfflat_probes = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_quality_weighted_graph_env_override() {
std::env::set_var(
"MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUALITY_WEIGHTED_GRAPH",
"true",
);
let mut config = SearchConfig::default();
assert!(!config.feature_flags.enable_quality_weighted_graph);
config.apply_env_overrides().unwrap();
assert!(config.feature_flags.enable_quality_weighted_graph);
std::env::remove_var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUALITY_WEIGHTED_GRAPH");
}
#[test]
fn test_quality_weighted_graph_env_override_false() {
std::env::set_var(
"MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUALITY_WEIGHTED_GRAPH",
"false",
);
let mut config = SearchConfig::default();
assert!(!config.feature_flags.enable_quality_weighted_graph);
config.apply_env_overrides().unwrap();
assert!(!config.feature_flags.enable_quality_weighted_graph);
std::env::remove_var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUALITY_WEIGHTED_GRAPH");
}
#[test]
fn test_quality_weighted_graph_invalid_env_value() {
std::env::set_var(
"MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUALITY_WEIGHTED_GRAPH",
"invalid",
);
let mut config = SearchConfig::default();
let result = config.apply_env_overrides();
assert!(result.is_err());
std::env::remove_var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUALITY_WEIGHTED_GRAPH");
}
#[test]
fn test_quality_weighted_graph_no_env_uses_default() {
std::env::remove_var("MAPROOM_SEARCH_FEATURE_FLAGS_ENABLE_QUALITY_WEIGHTED_GRAPH");
let mut config = SearchConfig::default();
config.apply_env_overrides().unwrap();
assert!(!config.feature_flags.enable_quality_weighted_graph);
}
#[test]
fn test_graph_importance_default_values() {
let config = GraphImportanceConfig::default();
assert!(!config.enable_quality_scoring);
assert!(config.fusion_weight_override.is_none());
assert!((config.edge_quality_weights.production_code - 1.0).abs() < f32::EPSILON);
assert!((config.edge_quality_weights.test_code - 0.5).abs() < f32::EPSILON);
assert!((config.edge_quality_weights.calls - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_graph_importance_validation_success() {
let config = GraphImportanceConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_graph_importance_fusion_weight_override_validation() {
let mut config = GraphImportanceConfig::default();
config.fusion_weight_override = Some(0.15);
assert!(config.validate().is_ok());
config.fusion_weight_override = Some(0.0);
assert!(config.validate().is_ok());
config.fusion_weight_override = Some(1.0);
assert!(config.validate().is_ok());
config.fusion_weight_override = Some(-0.1);
assert!(config.validate().is_err());
config.fusion_weight_override = Some(1.5);
assert!(config.validate().is_err());
}
#[test]
fn test_edge_quality_weights_default_values() {
let weights = EdgeQualityWeights::default();
assert!((weights.production_code - 1.0).abs() < f32::EPSILON);
assert!((weights.test_code - 0.5).abs() < f32::EPSILON);
assert!((weights.calls - 1.0).abs() < f32::EPSILON);
assert!(weights.is_default());
}
#[test]
fn test_edge_quality_weights_validation() {
let mut weights = EdgeQualityWeights::default();
assert!(weights.validate().is_ok());
weights.production_code = 5.0;
weights.test_code = 0.3;
weights.calls = 2.0;
assert!(weights.validate().is_ok());
weights.production_code = 0.0;
assert!(weights.validate().is_ok());
weights.production_code = 10.0;
assert!(weights.validate().is_ok());
weights.production_code = -0.1;
assert!(weights.validate().is_err());
weights = EdgeQualityWeights::default();
weights.test_code = -0.1;
assert!(weights.validate().is_err());
weights.test_code = 10.1;
assert!(weights.validate().is_err());
weights = EdgeQualityWeights::default();
weights.calls = -0.1;
assert!(weights.validate().is_err());
weights.calls = 10.1;
assert!(weights.validate().is_err());
}
#[test]
fn test_edge_quality_weights_is_default() {
let default = EdgeQualityWeights::default();
assert!(default.is_default());
let mut modified = EdgeQualityWeights::default();
modified.test_code = 0.6;
assert!(!modified.is_default());
modified = EdgeQualityWeights::default();
modified.production_code = 1.1;
assert!(!modified.is_default());
}
#[test]
fn test_graph_importance_yaml_deserialization() {
let yaml = r#"
enable_quality_scoring: true
edge_quality_weights:
production_code: 1.0
test_code: 0.5
calls: 1.0
fusion_weight_override: 0.15
"#;
let config: GraphImportanceConfig = serde_yaml::from_str(yaml).unwrap();
assert!(config.enable_quality_scoring);
assert!((config.edge_quality_weights.production_code - 1.0).abs() < f32::EPSILON);
assert!((config.edge_quality_weights.test_code - 0.5).abs() < f32::EPSILON);
assert_eq!(config.fusion_weight_override, Some(0.15));
}
#[test]
fn test_graph_importance_yaml_backward_compat() {
let yaml = r#"
embedding:
provider: openai
model_name: text-embedding-3-small
dimension: 1536
cache_size: 10000
cache_ttl_seconds: 3600
fusion:
method: rrf
rrf_k: 60
weights:
fts: 0.4
vector: 0.3
graph: 0.1
recency: 0.1
churn: 0.1
performance:
max_candidates_per_method: 100
final_result_limit: 20
timeout_ms: 1000
parallel_execution: true
index:
ivfflat_lists: 100
ivfflat_probes: 10
refresh_interval_seconds: 3600
feature_flags:
enable_vector_search: true
enable_hybrid_fusion: true
enable_graph_signals: true
enable_temporal_signals: true
enable_query_cache: true
enable_hot_reload: true
# Note: No graph_importance section - should use defaults
"#;
let config: SearchConfig = serde_yaml::from_str(yaml).unwrap();
assert!(!config.graph_importance.enable_quality_scoring);
assert!(config.graph_importance.fusion_weight_override.is_none());
assert!(config.graph_importance.edge_quality_weights.is_default());
}
#[test]
fn test_graph_importance_partial_yaml() {
let yaml = r#"
enable_quality_scoring: true
"#;
let config: GraphImportanceConfig = serde_yaml::from_str(yaml).unwrap();
assert!(config.enable_quality_scoring);
assert!(config.fusion_weight_override.is_none());
assert!(config.edge_quality_weights.is_default());
}
#[test]
fn test_graph_importance_env_overrides() {
std::env::set_var(
"MAPROOM_SEARCH_GRAPH_IMPORTANCE_ENABLE_QUALITY_SCORING",
"true",
);
std::env::set_var(
"MAPROOM_SEARCH_GRAPH_IMPORTANCE_PRODUCTION_CODE_WEIGHT",
"1.5",
);
std::env::set_var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_TEST_CODE_WEIGHT", "0.3");
std::env::set_var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_CALLS_WEIGHT", "2.0");
std::env::set_var(
"MAPROOM_SEARCH_GRAPH_IMPORTANCE_FUSION_WEIGHT_OVERRIDE",
"0.2",
);
let mut config = SearchConfig::default();
config.apply_env_overrides().unwrap();
assert!(config.graph_importance.enable_quality_scoring);
assert!(
(config.graph_importance.edge_quality_weights.production_code - 1.5).abs()
< f32::EPSILON
);
assert!(
(config.graph_importance.edge_quality_weights.test_code - 0.3).abs() < f32::EPSILON
);
assert!((config.graph_importance.edge_quality_weights.calls - 2.0).abs() < f32::EPSILON);
assert_eq!(config.graph_importance.fusion_weight_override, Some(0.2));
std::env::remove_var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_ENABLE_QUALITY_SCORING");
std::env::remove_var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_PRODUCTION_CODE_WEIGHT");
std::env::remove_var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_TEST_CODE_WEIGHT");
std::env::remove_var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_CALLS_WEIGHT");
std::env::remove_var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_FUSION_WEIGHT_OVERRIDE");
}
#[test]
fn test_search_config_includes_graph_importance() {
let config = SearchConfig::default();
assert!(config.validate().is_ok());
assert!(!config.graph_importance.enable_quality_scoring);
assert!(config.graph_importance.edge_quality_weights.is_default());
}
#[test]
fn test_graph_importance_invalid_env_override() {
std::env::set_var(
"MAPROOM_SEARCH_GRAPH_IMPORTANCE_PRODUCTION_CODE_WEIGHT",
"not_a_number",
);
let mut config = SearchConfig::default();
let result = config.apply_env_overrides();
assert!(result.is_err());
std::env::remove_var("MAPROOM_SEARCH_GRAPH_IMPORTANCE_PRODUCTION_CODE_WEIGHT");
}
#[test]
fn test_search_config_invalid_graph_weights_rejected() {
let mut config = SearchConfig::default();
config.graph_importance.edge_quality_weights.production_code = -1.0;
let result = config.validate();
assert!(result.is_err());
}
}