use std::path::{Path, PathBuf};
pub const DEFAULT_REPO: &str = "eugenehp/BrainHarmony";
const WEIGHTS_FILE: &str = "brainharmony.safetensors";
const GRADIENT_FILE: &str = "gradient_mapping_400.csv";
const GEOH_FILE: &str = "schaefer400_roi_eigenmodes.csv";
pub struct ResolvedWeights {
pub weights_path: PathBuf,
pub gradient_path: PathBuf,
pub geoh_path: PathBuf,
}
pub fn scan_cache(
repo: &str,
hf_cache: Option<&Path>,
) -> Option<ResolvedWeights> {
let cache_root = hf_cache
.map(|p| p.to_path_buf())
.or_else(|| {
dirs_fallback().map(|home| home.join(".cache/huggingface/hub"))
})?;
let repo_dir_name = format!("models--{}", repo.replace('/', "--"));
let snapshots_dir = cache_root.join(&repo_dir_name).join("snapshots");
if !snapshots_dir.exists() {
return None;
}
let mut entries: Vec<_> = std::fs::read_dir(&snapshots_dir)
.ok()?
.filter_map(|e| e.ok())
.filter(|e| e.path().is_dir())
.collect();
entries.sort_by_key(|e| std::cmp::Reverse(e.metadata().ok().and_then(|m| m.modified().ok())));
for entry in entries {
let dir = entry.path();
let weights = dir.join(WEIGHTS_FILE);
let gradient = dir.join(GRADIENT_FILE);
let geoh = dir.join(GEOH_FILE);
if weights.exists() && gradient.exists() && geoh.exists() {
return Some(ResolvedWeights {
weights_path: weights,
gradient_path: gradient,
geoh_path: geoh,
});
}
}
None
}
#[cfg(feature = "hf-download")]
pub fn download(
repo: &str,
hf_cache: Option<&Path>,
) -> anyhow::Result<ResolvedWeights> {
use hf_hub::api::sync::ApiBuilder;
let mut builder = ApiBuilder::new();
if let Some(cache) = hf_cache {
builder = builder.with_cache_dir(cache.to_path_buf());
}
let api = builder.build()?;
let repo = api.model(repo.to_string());
println!("Downloading {WEIGHTS_FILE} from {repo:?} ...");
let weights_path = repo.get(WEIGHTS_FILE)?;
println!("Downloading {GRADIENT_FILE} from {repo:?} ...");
let gradient_path = repo.get(GRADIENT_FILE)?;
println!("Downloading {GEOH_FILE} from {repo:?} ...");
let geoh_path = repo.get(GEOH_FILE)?;
Ok(ResolvedWeights {
weights_path,
gradient_path,
geoh_path,
})
}
#[cfg(not(feature = "hf-download"))]
pub fn download(
_repo: &str,
_hf_cache: Option<&Path>,
) -> anyhow::Result<ResolvedWeights> {
anyhow::bail!(
"HuggingFace download requires --features hf-download.\n\
Alternatively, download manually from https://huggingface.co/{DEFAULT_REPO}"
)
}
pub fn resolve(
repo: &str,
weights: Option<&str>,
gradient: Option<&str>,
geoh: Option<&str>,
hf_cache: Option<&Path>,
) -> anyhow::Result<ResolvedWeights> {
if let (Some(w), Some(g), Some(gh)) = (weights, gradient, geoh) {
return Ok(ResolvedWeights {
weights_path: PathBuf::from(w),
gradient_path: PathBuf::from(g),
geoh_path: PathBuf::from(gh),
});
}
if let Some(resolved) = scan_cache(repo, hf_cache) {
println!(
"Found cached weights: {}",
resolved.weights_path.display()
);
return Ok(resolved);
}
download(repo, hf_cache)
}
fn dirs_fallback() -> Option<PathBuf> {
std::env::var_os("HOME").map(PathBuf::from)
}