shaha 0.2.0

Hash database builder and reverse lookup tool
Documentation
use std::path::PathBuf;

use anyhow::Result;
use serde::Deserialize;

use crate::storage::R2Config;

#[derive(Debug, Default, Deserialize)]
pub struct Config {
    #[serde(default)]
    pub storage: StorageSection,
    #[serde(default)]
    pub defaults: DefaultsSection,
}

#[derive(Debug, Default, Deserialize)]
pub struct StorageSection {
    #[serde(default)]
    pub r2: R2Section,
}

#[derive(Debug, Default, Deserialize)]
pub struct R2Section {
    pub endpoint: Option<String>,
    pub bucket: Option<String>,
    pub access_key_id: Option<String>,
    pub secret_access_key: Option<String>,
    pub region: Option<String>,
    pub path: Option<String>,
}

#[derive(Debug, Default, Deserialize)]
pub struct DefaultsSection {
    pub algorithms: Option<Vec<String>>,
    pub output: Option<String>,
}

#[derive(Default)]
pub struct R2Overrides<'a> {
    pub endpoint: Option<&'a str>,
    pub bucket: Option<&'a str>,
    pub access_key_id: Option<&'a str>,
    pub secret_access_key: Option<&'a str>,
    pub path: Option<&'a str>,
    pub region: &'a str,
    pub default_path: &'a str,
}

impl<'a> R2Overrides<'a> {
    pub fn new(region: &'a str, default_path: &'a str) -> Self {
        Self {
            region,
            default_path,
            ..Default::default()
        }
    }
}

impl Config {
    pub fn load() -> Result<Self> {
        let paths = config_paths();
        
        for path in paths {
            if path.exists() {
                let content = std::fs::read_to_string(&path)?;
                let config: Config = toml::from_str(&content)?;
                return Ok(config);
            }
        }
        
        Ok(Config::default())
    }

    pub fn to_r2_config(&self) -> Option<R2Config> {
        let r2 = &self.storage.r2;
        
        let endpoint = r2.endpoint.clone()
            .or_else(|| std::env::var("SHAHA_R2_ENDPOINT").ok())?;
        let bucket = r2.bucket.clone()
            .or_else(|| std::env::var("SHAHA_R2_BUCKET").ok())?;
        let access_key_id = r2.access_key_id.clone()
            .or_else(|| std::env::var("SHAHA_R2_ACCESS_KEY_ID").ok())
            .or_else(|| std::env::var("AWS_ACCESS_KEY_ID").ok())?;
        let secret_access_key = r2.secret_access_key.clone()
            .or_else(|| std::env::var("SHAHA_R2_SECRET_ACCESS_KEY").ok())
            .or_else(|| std::env::var("AWS_SECRET_ACCESS_KEY").ok())?;
        let path = r2.path.clone().unwrap_or_else(|| "hashes.parquet".to_string());
        
        let mut config = R2Config::new(endpoint, access_key_id, secret_access_key, bucket, path);
        if let Some(ref region) = r2.region {
            config.region = region.clone();
        }
        
        Some(config)
    }

    pub fn build_r2_config(&self, overrides: R2Overrides) -> Result<R2Config> {
        let r2 = &self.storage.r2;

        let endpoint = overrides.endpoint.map(String::from)
            .or_else(|| std::env::var("SHAHA_R2_ENDPOINT").ok())
            .or_else(|| r2.endpoint.clone())
            .ok_or_else(|| anyhow::anyhow!(
                "R2 endpoint required: use --endpoint, SHAHA_R2_ENDPOINT env var, or config file"
            ))?;

        let bucket = overrides.bucket.map(String::from)
            .or_else(|| std::env::var("SHAHA_R2_BUCKET").ok())
            .or_else(|| r2.bucket.clone())
            .ok_or_else(|| anyhow::anyhow!(
                "R2 bucket required: use --bucket, SHAHA_R2_BUCKET env var, or config file"
            ))?;

        let access_key_id = overrides.access_key_id.map(String::from)
            .or_else(|| std::env::var("SHAHA_R2_ACCESS_KEY_ID").ok())
            .or_else(|| std::env::var("AWS_ACCESS_KEY_ID").ok())
            .or_else(|| r2.access_key_id.clone())
            .ok_or_else(|| anyhow::anyhow!(
                "R2 access key required: use --access-key-id, env var, or config file"
            ))?;

        let secret_access_key = overrides.secret_access_key.map(String::from)
            .or_else(|| std::env::var("SHAHA_R2_SECRET_ACCESS_KEY").ok())
            .or_else(|| std::env::var("AWS_SECRET_ACCESS_KEY").ok())
            .or_else(|| r2.secret_access_key.clone())
            .ok_or_else(|| anyhow::anyhow!(
                "R2 secret key required: use --secret-access-key, env var, or config file"
            ))?;

        let path = overrides.path.map(String::from)
            .or_else(|| r2.path.clone())
            .unwrap_or_else(|| overrides.default_path.to_string());

        let region = if overrides.region != "auto" {
            overrides.region.to_string()
        } else {
            r2.region.clone().unwrap_or_else(|| "auto".to_string())
        };

        let mut config = R2Config::new(endpoint, access_key_id, secret_access_key, bucket, path);
        config.region = region;

        Ok(config)
    }
}

fn config_paths() -> Vec<PathBuf> {
    let mut paths = Vec::new();
    
    if let Ok(cwd) = std::env::current_dir() {
        paths.push(cwd.join(".shaha.toml"));
    }
    
    if let Some(config_dir) = dirs::config_dir() {
        paths.push(config_dir.join("shaha").join("config.toml"));
    }
    
    paths
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_empty_config() {
        let config: Config = toml::from_str("").unwrap();
        assert!(config.storage.r2.endpoint.is_none());
        assert!(config.defaults.algorithms.is_none());
    }

    #[test]
    fn test_parse_r2_config() {
        let toml = r#"
[storage.r2]
endpoint = "https://example.r2.cloudflarestorage.com"
bucket = "my-bucket"
access_key_id = "key123"
secret_access_key = "secret456"
region = "auto"
path = "hashes.parquet"

[defaults]
algorithms = ["sha256", "md5"]
output = "custom.parquet"
"#;
        let config: Config = toml::from_str(toml).unwrap();
        
        assert_eq!(config.storage.r2.endpoint, Some("https://example.r2.cloudflarestorage.com".to_string()));
        assert_eq!(config.storage.r2.bucket, Some("my-bucket".to_string()));
        assert_eq!(config.defaults.algorithms, Some(vec!["sha256".to_string(), "md5".to_string()]));
    }

    #[test]
    fn test_to_r2_config_complete() {
        let toml = r#"
[storage.r2]
endpoint = "https://example.r2.cloudflarestorage.com"
bucket = "my-bucket"
access_key_id = "key123"
secret_access_key = "secret456"
"#;
        let config: Config = toml::from_str(toml).unwrap();
        let r2_config = config.to_r2_config().unwrap();
        
        assert_eq!(r2_config.s3_url(), "s3://my-bucket/hashes.parquet");
    }

    #[test]
    fn test_to_r2_config_incomplete() {
        let toml = r#"
[storage.r2]
endpoint = "https://example.r2.cloudflarestorage.com"
"#;
        let config: Config = toml::from_str(toml).unwrap();
        assert!(config.to_r2_config().is_none());
    }
}