use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs;
use std::io::{self, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use std::time::Duration;
const HF_BASE: &str = "https://huggingface.co";
const HF_MIRROR_BASE: &str = "https://hf-mirror.com";
const PROBE_TIMEOUT: Duration = Duration::from_secs(3);
const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(300);
const PROBE_FILE: &str = "config.json";
const PROBE_REFERENCE_SIZE: u64 = 600;
const PROBE_SIZE_TOLERANCE: f64 = 5.0;
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub id: String,
pub files: Vec<&'static str>,
}
impl ModelInfo {
pub fn potion_code_16m() -> Self {
Self {
id: "minishlab/potion-code-16M".to_string(),
files: vec!["config.json", "tokenizer.json", "model.safetensors"],
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Source<'a> {
HuggingFace,
HfMirror,
Custom(&'a str),
}
impl<'a> Source<'a> {
fn base_url(self) -> &'a str {
match self {
Source::HuggingFace => HF_BASE,
Source::HfMirror => HF_MIRROR_BASE,
Source::Custom(url) => url,
}
}
fn label(self) -> &'static str {
match self {
Source::HuggingFace => "hf",
Source::HfMirror => "hf-mirror",
Source::Custom(_) => "custom",
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct Manifest {
sha256: HashMap<String, String>,
source: String,
}
pub fn cache_root() -> io::Result<PathBuf> {
if let Ok(custom) = std::env::var("AST_BRO_MODEL_DIR") {
return Ok(PathBuf::from(custom));
}
if let Ok(custom) = std::env::var("AST_OUTLINE_MODEL_DIR") {
return Ok(PathBuf::from(custom));
}
let base = dirs::cache_dir().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"no cache directory found (set AST_BRO_MODEL_DIR)",
)
})?;
let new_dir = base.join("ast-bro");
let old_dir = base.join("ast-outline");
static MIGRATED: OnceLock<()> = OnceLock::new();
MIGRATED.get_or_init(|| {
if old_dir.exists() && !new_dir.exists() {
if let Err(e) = std::fs::rename(&old_dir, &new_dir) {
eprintln!("warning: could not rename {} -> {}: {e}", old_dir.display(), new_dir.display());
} else {
eprintln!("info: auto-renamed {} -> {}", old_dir.display(), new_dir.display());
}
}
});
Ok(new_dir.join("models"))
}
pub fn model_dir(info: &ModelInfo) -> io::Result<PathBuf> {
let leaf = info.id.split('/').next_back().unwrap_or(&info.id);
Ok(cache_root()?.join(leaf))
}
pub fn ensure_model(info: &ModelInfo) -> io::Result<PathBuf> {
let dir = model_dir(info)?;
fs::create_dir_all(&dir)?;
if cache_is_valid(&dir, info)? {
return Ok(dir);
}
warn_about_tls_policy();
let source = select_source(info);
eprintln!(
"ast-bro: downloading model {} via {} ({} files)",
info.id,
source.label(),
info.files.len()
);
let client = build_client(DOWNLOAD_TIMEOUT)?;
let mut sha256: HashMap<String, String> = HashMap::new();
for file in &info.files {
let url = format!("{}/{}/resolve/main/{}", source.base_url(), info.id, file);
let dest = dir.join(file);
let hash = download_to(&client, &url, &dest)?;
sha256.insert(file.to_string(), hash);
}
let manifest = Manifest {
sha256,
source: source.label().to_string(),
};
write_manifest(&dir, &manifest)?;
Ok(dir)
}
fn select_source<'a>(_info: &'a ModelInfo) -> Source<'a> {
let forced = std::env::var("AST_BRO_MODEL_SOURCE")
.or_else(|_| std::env::var("AST_OUTLINE_MODEL_SOURCE"));
if let Ok(forced) = forced {
return match forced.as_str() {
"hf" => Source::HuggingFace,
"hf-mirror" => Source::HfMirror,
url if url.starts_with("http://") || url.starts_with("https://") => {
Source::Custom(Box::leak(url.to_string().into_boxed_str()))
}
other => {
eprintln!(
"ast-bro: ignoring AST_BRO_MODEL_SOURCE={other:?} (use hf, hf-mirror, or a URL)"
);
Source::HuggingFace
}
};
}
if probe_huggingface(&_info.id) {
Source::HuggingFace
} else {
eprintln!("ast-bro: HuggingFace unreachable, falling back to hf-mirror.com");
Source::HfMirror
}
}
fn probe_huggingface(model_id: &str) -> bool {
let Ok(client) = build_client(PROBE_TIMEOUT) else {
return false;
};
let url = format!("{HF_BASE}/{model_id}/resolve/main/{PROBE_FILE}");
let Ok(resp) = client.head(&url).send() else {
return false;
};
if !resp.status().is_success() {
return false;
}
if let Some(len) = resp.content_length() {
let hi = (PROBE_REFERENCE_SIZE as f64 * (1.0 + PROBE_SIZE_TOLERANCE)) as u64;
if len > hi {
eprintln!(
"ast-bro: HF probe returned implausibly large content-length {len} (expected ≤{hi}); likely a captive portal, falling back"
);
return false;
}
}
true
}
fn build_client(timeout: Duration) -> io::Result<reqwest::blocking::Client> {
let mut builder = reqwest::blocking::Client::builder()
.connect_timeout(timeout)
.timeout(timeout)
.user_agent(concat!("ast-bro/", env!("CARGO_PKG_VERSION")));
let ca_bundle = std::env::var("AST_BRO_CA_BUNDLE")
.or_else(|_| std::env::var("AST_OUTLINE_CA_BUNDLE"));
if let Ok(bundle) = ca_bundle {
let pem = fs::read(&bundle).map_err(|e| {
io::Error::new(
io::ErrorKind::Other,
format!("AST_BRO_CA_BUNDLE={bundle}: {e}"),
)
})?;
for cert in reqwest::Certificate::from_pem_bundle(&pem)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("invalid CA bundle: {e}")))?
{
builder = builder.add_root_certificate(cert);
}
}
let strict = tls_strict();
if !strict {
builder = builder.danger_accept_invalid_certs(true);
}
builder
.build()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn tls_strict() -> bool {
std::env::var("AST_BRO_TLS_STRICT")
.or_else(|_| std::env::var("AST_OUTLINE_TLS_STRICT"))
.ok()
.filter(|v| !v.is_empty() && v != "0" && v.to_ascii_lowercase() != "false")
.is_some()
}
fn warn_about_tls_policy() {
let strict = tls_strict();
if !strict {
eprintln!(
"ast-bro: TLS certificate verification is DISABLED for model downloads \
(works through corp MITM proxies). Set AST_BRO_TLS_STRICT=1 to enforce \
full chain verification. Integrity is checked via SHA-256 on subsequent loads."
);
}
}
fn download_to(client: &reqwest::blocking::Client, url: &str, dest: &Path) -> io::Result<String> {
let resp = client.get(url).send().map_err(|e| {
let mut msg = format!("GET {url}: {e}");
let mut src: Option<&dyn std::error::Error> = std::error::Error::source(&e);
while let Some(s) = src {
msg.push_str(&format!(" → {s}"));
src = s.source();
}
io::Error::new(io::ErrorKind::Other, msg)
})?;
if !resp.status().is_success() {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("GET {url} returned HTTP {}", resp.status()),
));
}
let tmp = dest.with_extension("tmp");
let mut file = fs::File::create(&tmp)?;
let mut hasher = Sha256::new();
let mut reader = resp;
let mut buf = [0u8; 64 * 1024];
loop {
let n = reader
.read(&mut buf)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
file.write_all(&buf[..n])?;
}
file.sync_all()?;
drop(file);
fs::rename(&tmp, dest)?;
Ok(hex_digest(hasher.finalize()))
}
fn hex_digest<T: AsRef<[u8]>>(bytes: T) -> String {
let mut s = String::with_capacity(bytes.as_ref().len() * 2);
for b in bytes.as_ref() {
use std::fmt::Write;
let _ = write!(s, "{b:02x}");
}
s
}
fn manifest_path(dir: &Path) -> PathBuf {
dir.join("manifest.json")
}
fn write_manifest(dir: &Path, manifest: &Manifest) -> io::Result<()> {
let json = serde_json::to_vec_pretty(manifest)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
fs::write(manifest_path(dir), json)
}
fn read_manifest(dir: &Path) -> io::Result<Manifest> {
let bytes = fs::read(manifest_path(dir))?;
serde_json::from_slice(&bytes).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
fn sha256_file(path: &Path) -> io::Result<String> {
let mut file = fs::File::open(path)?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = file.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
Ok(hex_digest(hasher.finalize()))
}
fn cache_is_valid(dir: &Path, info: &ModelInfo) -> io::Result<bool> {
let Ok(manifest) = read_manifest(dir) else {
return Ok(false);
};
for file in &info.files {
let path = dir.join(file);
if !path.exists() {
return Ok(false);
}
let Some(expected) = manifest.sha256.get(*file) else {
return Ok(false);
};
let actual = sha256_file(&path)?;
if &actual != expected {
eprintln!(
"ast-bro: cached {file} failed integrity check, will re-download"
);
return Ok(false);
}
}
Ok(true)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn potion_info_lists_three_files() {
let info = ModelInfo::potion_code_16m();
assert_eq!(info.id, "minishlab/potion-code-16M");
assert_eq!(info.files.len(), 3);
assert!(info.files.contains(&"model.safetensors"));
}
#[test]
fn cache_root_and_model_dir_honour_env_override() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().to_path_buf();
std::env::set_var("AST_BRO_MODEL_DIR", &path);
let resolved_root = cache_root().unwrap();
let resolved_model = model_dir(&ModelInfo::potion_code_16m()).unwrap();
std::env::remove_var("AST_BRO_MODEL_DIR");
assert_eq!(resolved_root, path);
assert!(resolved_model.starts_with(&path));
assert!(resolved_model.ends_with("potion-code-16M"));
}
#[test]
fn cache_invalid_when_manifest_missing() {
let tmp = tempfile::tempdir().unwrap();
let info = ModelInfo::potion_code_16m();
assert!(!cache_is_valid(tmp.path(), &info).unwrap());
}
#[test]
fn cache_invalid_when_file_hash_mismatches() {
let tmp = tempfile::tempdir().unwrap();
let info = ModelInfo {
id: "fake/model".to_string(),
files: vec!["a.txt"],
};
let dir = tmp.path();
fs::write(dir.join("a.txt"), b"hello").unwrap();
let manifest = Manifest {
sha256: [(
"a.txt".to_string(),
"deadbeef".repeat(8), )]
.into_iter()
.collect(),
source: "hf".to_string(),
};
write_manifest(dir, &manifest).unwrap();
assert!(!cache_is_valid(dir, &info).unwrap());
}
#[test]
fn cache_valid_when_hash_matches() {
let tmp = tempfile::tempdir().unwrap();
let info = ModelInfo {
id: "fake/model".to_string(),
files: vec!["a.txt"],
};
let dir = tmp.path();
fs::write(dir.join("a.txt"), b"hello").unwrap();
let actual = sha256_file(&dir.join("a.txt")).unwrap();
let manifest = Manifest {
sha256: [("a.txt".to_string(), actual)].into_iter().collect(),
source: "hf".to_string(),
};
write_manifest(dir, &manifest).unwrap();
assert!(cache_is_valid(dir, &info).unwrap());
}
#[test]
fn sha256_matches_known_vector() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("v.txt");
fs::write(&path, b"abc").unwrap();
assert_eq!(
sha256_file(&path).unwrap(),
"ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
);
}
#[test]
#[ignore]
fn network_real_download() {
let tmp = tempfile::tempdir().unwrap();
std::env::set_var("AST_BRO_MODEL_DIR", tmp.path());
let info = ModelInfo::potion_code_16m();
let dir = ensure_model(&info).expect("download failed");
let dir2 = ensure_model(&info).expect("revalidate failed");
assert_eq!(dir, dir2);
for f in &info.files {
assert!(dir.join(f).exists(), "missing {f}");
}
std::env::remove_var("AST_BRO_MODEL_DIR");
}
}