use sha2::{Digest, Sha256};
use std::path::PathBuf;
use crate::error::Error;
const DEFAULT_ONNX_URL: &str =
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx";
const ONNX_FILENAME: &str = "all-MiniLM-L6-v2.onnx";
const ONNX_SHA256: &str = "6fd5d72fe4589f189f8ebc006442dbb529bb7ce38f8082112682524616046452";
const EXPECTED_SIZE: u64 = 90_000_000;
fn model_url() -> String {
std::env::var("JAILGUARD_MODEL_URL").unwrap_or_else(|_| DEFAULT_ONNX_URL.to_string())
}
fn cache_dir() -> Result<PathBuf, Error> {
if let Ok(dir) = std::env::var("JAILGUARD_MODEL_DIR") {
return Ok(PathBuf::from(dir));
}
#[cfg(windows)]
{
if let Ok(profile) = std::env::var("USERPROFILE") {
return Ok(PathBuf::from(profile).join(".cache").join("jailguard"));
}
if let Ok(appdata) = std::env::var("LOCALAPPDATA") {
return Ok(PathBuf::from(appdata).join("jailguard"));
}
}
if let Ok(home) = std::env::var("HOME") {
return Ok(PathBuf::from(home).join(".cache").join("jailguard"));
}
if let Ok(profile) = std::env::var("USERPROFILE") {
return Ok(PathBuf::from(profile).join(".cache").join("jailguard"));
}
if let Ok(appdata) = std::env::var("LOCALAPPDATA") {
return Ok(PathBuf::from(appdata).join("jailguard"));
}
Err(Error::Config(
"no cache directory: set JAILGUARD_MODEL_DIR, HOME, USERPROFILE, or LOCALAPPDATA".into(),
))
}
pub fn cache_dir_string() -> Result<String, Error> {
let p = cache_dir()?;
p.into_os_string()
.into_string()
.map_err(|_| Error::Config("cache directory contains non-UTF-8 bytes".into()))
}
#[allow(clippy::print_stderr)]
pub fn download_model() -> Result<PathBuf, Error> {
let dir = cache_dir()?;
let model_path = dir.join(ONNX_FILENAME);
if model_path.exists() {
return Ok(model_path);
}
std::fs::create_dir_all(&dir).map_err(|e| {
Error::Io(std::io::Error::new(
e.kind(),
format!("Failed to create cache dir {}: {}", dir.display(), e),
))
})?;
let url = model_url();
eprintln!(
"jailguard: downloading ONNX model (~90 MB) to {} ...",
model_path.display()
);
if url != DEFAULT_ONNX_URL {
eprintln!("jailguard: using custom URL: {url}");
}
let resp = ureq::get(&url)
.call()
.map_err(|e| Error::Model(format!("Failed to download ONNX model: {e}")))?;
let body = resp.into_body();
let content_length = body.content_length().unwrap_or(EXPECTED_SIZE);
let tmp_path = dir.join(format!("{ONNX_FILENAME}.download"));
let mut file = std::fs::File::create(&tmp_path)?;
let mut reader = body.into_reader();
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
let mut downloaded: u64 = 0;
let mut last_pct: u64 = 0;
loop {
let n = std::io::Read::read(&mut reader, &mut buf)?;
if n == 0 {
break;
}
std::io::Write::write_all(&mut file, &buf[..n])?;
hasher.update(&buf[..n]);
downloaded += n as u64;
let pct = (downloaded * 100) / content_length;
if pct >= last_pct + 10 {
eprintln!("jailguard: downloaded {pct}%");
last_pct = pct;
}
}
drop(file);
let actual: String = hasher
.finalize()
.iter()
.map(|b| format!("{b:02x}"))
.collect();
if actual != ONNX_SHA256 {
let _ = std::fs::remove_file(&tmp_path);
return Err(Error::Model(format!(
"ONNX model checksum mismatch — expected {ONNX_SHA256}, got {actual}. \
The download may be incomplete or the file at the URL has changed. \
Delete {} and retry, or set JAILGUARD_MODEL_URL to a known-good mirror.",
tmp_path.display()
)));
}
std::fs::rename(&tmp_path, &model_path)?;
eprintln!(
"jailguard: ONNX model ready ({:.1} MB)",
downloaded as f64 / 1_000_000.0
);
Ok(model_path)
}