brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// HuggingFace Hub weight resolution and download.
///
/// Resolves model weights in priority order:
/// 1. Explicit paths (--weights + --gradient + --geoh)
/// 2. Scan ~/.cache/huggingface/hub/ for existing snapshots
/// 3. Download via hf-hub (requires feature `hf-download`)
use std::path::{Path, PathBuf};

/// Default HuggingFace repo for Brain-Harmony weights.
pub const DEFAULT_REPO: &str = "eugenehp/BrainHarmony";

const WEIGHTS_FILE: &str = "brainharmony.safetensors";
const GRADIENT_FILE: &str = "gradient_mapping_400.csv";
const GEOH_FILE: &str = "schaefer400_roi_eigenmodes.csv";

/// Resolved weight paths.
pub struct ResolvedWeights {
    pub weights_path: PathBuf,
    pub gradient_path: PathBuf,
    pub geoh_path: PathBuf,
}

/// Scan the HuggingFace cache directory for an existing snapshot.
pub fn scan_cache(
    repo: &str,
    hf_cache: Option<&Path>,
) -> Option<ResolvedWeights> {
    let cache_root = hf_cache
        .map(|p| p.to_path_buf())
        .or_else(|| {
            dirs_fallback().map(|home| home.join(".cache/huggingface/hub"))
        })?;

    let repo_dir_name = format!("models--{}", repo.replace('/', "--"));
    let snapshots_dir = cache_root.join(&repo_dir_name).join("snapshots");

    if !snapshots_dir.exists() {
        return None;
    }

    let mut entries: Vec<_> = std::fs::read_dir(&snapshots_dir)
        .ok()?
        .filter_map(|e| e.ok())
        .filter(|e| e.path().is_dir())
        .collect();
    entries.sort_by_key(|e| std::cmp::Reverse(e.metadata().ok().and_then(|m| m.modified().ok())));

    for entry in entries {
        let dir = entry.path();
        let weights = dir.join(WEIGHTS_FILE);
        let gradient = dir.join(GRADIENT_FILE);
        let geoh = dir.join(GEOH_FILE);
        if weights.exists() && gradient.exists() && geoh.exists() {
            return Some(ResolvedWeights {
                weights_path: weights,
                gradient_path: gradient,
                geoh_path: geoh,
            });
        }
    }

    None
}

/// Download weights from HuggingFace Hub.
#[cfg(feature = "hf-download")]
pub fn download(
    repo: &str,
    hf_cache: Option<&Path>,
) -> anyhow::Result<ResolvedWeights> {
    use hf_hub::api::sync::ApiBuilder;

    let mut builder = ApiBuilder::new();
    if let Some(cache) = hf_cache {
        builder = builder.with_cache_dir(cache.to_path_buf());
    }
    let api = builder.build()?;
    let repo = api.model(repo.to_string());

    println!("Downloading {WEIGHTS_FILE} from {repo:?} ...");
    let weights_path = repo.get(WEIGHTS_FILE)?;

    println!("Downloading {GRADIENT_FILE} from {repo:?} ...");
    let gradient_path = repo.get(GRADIENT_FILE)?;

    println!("Downloading {GEOH_FILE} from {repo:?} ...");
    let geoh_path = repo.get(GEOH_FILE)?;

    Ok(ResolvedWeights {
        weights_path,
        gradient_path,
        geoh_path,
    })
}

#[cfg(not(feature = "hf-download"))]
pub fn download(
    _repo: &str,
    _hf_cache: Option<&Path>,
) -> anyhow::Result<ResolvedWeights> {
    anyhow::bail!(
        "HuggingFace download requires --features hf-download.\n\
         Alternatively, download manually from https://huggingface.co/{DEFAULT_REPO}"
    )
}

/// Resolve weights: try explicit paths, then cache, then download.
pub fn resolve(
    repo: &str,
    weights: Option<&str>,
    gradient: Option<&str>,
    geoh: Option<&str>,
    hf_cache: Option<&Path>,
) -> anyhow::Result<ResolvedWeights> {
    if let (Some(w), Some(g), Some(gh)) = (weights, gradient, geoh) {
        return Ok(ResolvedWeights {
            weights_path: PathBuf::from(w),
            gradient_path: PathBuf::from(g),
            geoh_path: PathBuf::from(gh),
        });
    }

    if let Some(resolved) = scan_cache(repo, hf_cache) {
        println!(
            "Found cached weights: {}",
            resolved.weights_path.display()
        );
        return Ok(resolved);
    }

    download(repo, hf_cache)
}

fn dirs_fallback() -> Option<PathBuf> {
    std::env::var_os("HOME").map(PathBuf::from)
}