ruvector-cli 2.0.4

CLI and MCP server for Ruvector
//! Configuration management for Ruvector CLI

use anyhow::{Context, Result};
use ruvector_core::types::{DbOptions, DistanceMetric, HnswConfig, QuantizationConfig};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};

/// Ruvector CLI configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
    /// Database options
    #[serde(default)]
    pub database: DatabaseConfig,

    /// CLI options
    #[serde(default)]
    pub cli: CliConfig,

    /// MCP server options
    #[serde(default)]
    pub mcp: McpConfig,
}

/// Database configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
    /// Default storage path
    #[serde(default = "default_storage_path")]
    pub storage_path: String,

    /// Default dimensions
    #[serde(default = "default_dimensions")]
    pub dimensions: usize,

    /// Distance metric
    #[serde(default = "default_distance_metric")]
    pub distance_metric: DistanceMetric,

    /// HNSW configuration
    #[serde(default)]
    pub hnsw: Option<HnswConfig>,

    /// Quantization configuration
    #[serde(default)]
    pub quantization: Option<QuantizationConfig>,
}

/// CLI configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CliConfig {
    /// Show progress bars
    #[serde(default = "default_true")]
    pub progress: bool,

    /// Use colors in output
    #[serde(default = "default_true")]
    pub colors: bool,

    /// Default batch size for operations
    #[serde(default = "default_batch_size")]
    pub batch_size: usize,
}

/// MCP server configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpConfig {
    /// Server host for SSE transport
    #[serde(default = "default_host")]
    pub host: String,

    /// Server port for SSE transport
    #[serde(default = "default_port")]
    pub port: u16,

    /// Enable CORS
    #[serde(default = "default_true")]
    pub cors: bool,

    /// Allowed data directory for MCP file operations (path confinement)
    /// All db_path and backup_path values must resolve within this directory.
    /// Defaults to the current working directory.
    #[serde(default = "default_data_dir")]
    pub data_dir: String,
}

// Default value functions
fn default_storage_path() -> String {
    "./ruvector.db".to_string()
}

fn default_dimensions() -> usize {
    384
}

fn default_distance_metric() -> DistanceMetric {
    DistanceMetric::Cosine
}

fn default_true() -> bool {
    true
}

fn default_batch_size() -> usize {
    1000
}

fn default_data_dir() -> String {
    std::env::current_dir()
        .map(|p| p.to_string_lossy().to_string())
        .unwrap_or_else(|_| ".".to_string())
}

fn default_host() -> String {
    "127.0.0.1".to_string()
}

fn default_port() -> u16 {
    3000
}

impl Default for Config {
    fn default() -> Self {
        Self {
            database: DatabaseConfig::default(),
            cli: CliConfig::default(),
            mcp: McpConfig::default(),
        }
    }
}

impl Default for DatabaseConfig {
    fn default() -> Self {
        Self {
            storage_path: default_storage_path(),
            dimensions: default_dimensions(),
            distance_metric: DistanceMetric::Cosine,
            hnsw: Some(HnswConfig::default()),
            quantization: Some(QuantizationConfig::Scalar),
        }
    }
}

impl Default for CliConfig {
    fn default() -> Self {
        Self {
            progress: true,
            colors: true,
            batch_size: default_batch_size(),
        }
    }
}

impl Default for McpConfig {
    fn default() -> Self {
        Self {
            host: default_host(),
            port: default_port(),
            cors: true,
            data_dir: default_data_dir(),
        }
    }
}

impl Config {
    /// Load configuration from file
    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
        let content =
            std::fs::read_to_string(path.as_ref()).context("Failed to read config file")?;
        let config: Config = toml::from_str(&content).context("Failed to parse config file")?;
        Ok(config)
    }

    /// Load configuration with precedence: CLI args > env vars > config file > defaults
    pub fn load(config_path: Option<PathBuf>) -> Result<Self> {
        let mut config = if let Some(path) = config_path {
            Self::from_file(&path).unwrap_or_default()
        } else {
            // Try default locations
            Self::try_default_locations().unwrap_or_default()
        };

        // Override with environment variables
        config.apply_env_vars()?;

        Ok(config)
    }

    /// Try loading from default locations
    fn try_default_locations() -> Option<Self> {
        let paths = vec![
            "ruvector.toml",
            ".ruvector.toml",
            "~/.config/ruvector/config.toml",
            "/etc/ruvector/config.toml",
        ];

        for path in paths {
            let expanded = shellexpand::tilde(path).to_string();
            if let Ok(config) = Self::from_file(&expanded) {
                return Some(config);
            }
        }

        None
    }

    /// Apply environment variable overrides
    fn apply_env_vars(&mut self) -> Result<()> {
        if let Ok(path) = std::env::var("RUVECTOR_STORAGE_PATH") {
            self.database.storage_path = path;
        }

        if let Ok(dims) = std::env::var("RUVECTOR_DIMENSIONS") {
            self.database.dimensions = dims.parse().context("Invalid RUVECTOR_DIMENSIONS")?;
        }

        if let Ok(metric) = std::env::var("RUVECTOR_DISTANCE_METRIC") {
            self.database.distance_metric = match metric.to_lowercase().as_str() {
                "euclidean" => DistanceMetric::Euclidean,
                "cosine" => DistanceMetric::Cosine,
                "dotproduct" => DistanceMetric::DotProduct,
                "manhattan" => DistanceMetric::Manhattan,
                _ => return Err(anyhow::anyhow!("Invalid distance metric: {}", metric)),
            };
        }

        if let Ok(host) = std::env::var("RUVECTOR_MCP_HOST") {
            self.mcp.host = host;
        }

        if let Ok(port) = std::env::var("RUVECTOR_MCP_PORT") {
            self.mcp.port = port.parse().context("Invalid RUVECTOR_MCP_PORT")?;
        }

        if let Ok(data_dir) = std::env::var("RUVECTOR_MCP_DATA_DIR") {
            self.mcp.data_dir = data_dir;
        }

        Ok(())
    }

    /// Convert to DbOptions
    pub fn to_db_options(&self) -> DbOptions {
        DbOptions {
            dimensions: self.database.dimensions,
            distance_metric: self.database.distance_metric,
            storage_path: self.database.storage_path.clone(),
            hnsw_config: self.database.hnsw.clone(),
            quantization: self.database.quantization.clone(),
        }
    }

    /// Save configuration to file
    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
        let content = toml::to_string_pretty(self).context("Failed to serialize config")?;
        std::fs::write(path, content).context("Failed to write config file")?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_default_config() {
        let config = Config::default();
        assert_eq!(config.database.dimensions, 384);
        assert_eq!(config.cli.batch_size, 1000);
        assert_eq!(config.mcp.port, 3000);
    }

    #[test]
    fn test_config_serialization() {
        let config = Config::default();
        let toml_str = toml::to_string(&config).unwrap();
        let parsed: Config = toml::from_str(&toml_str).unwrap();
        assert_eq!(config.database.dimensions, parsed.database.dimensions);
    }
}