use std::path::{Path, PathBuf};
pub struct ModelTier {
pub name: &'static str,
pub display_name: &'static str,
pub url: &'static str,
pub archive_sha256: &'static str,
pub safetensors_sha256: &'static str,
pub tokenizer_sha256: &'static str,
#[allow(dead_code)]
pub dim: usize,
pub dir_name: &'static str,
}
pub static STANDARD: ModelTier = ModelTier {
name: "standard",
display_name: "CodeSage-M2V-256",
url: "https://github.com/civitas-io/prx/releases/download/models-v1/codesage-m2v-256.tar.gz",
archive_sha256: "0000000000000000000000000000000000000000000000000000000000000000",
safetensors_sha256: "0000000000000000000000000000000000000000000000000000000000000000",
tokenizer_sha256: "0000000000000000000000000000000000000000000000000000000000000000",
dim: 256,
dir_name: "codesage-m2v-256",
};
pub static LARGE: ModelTier = ModelTier {
name: "large",
display_name: "Jina-Code-M2V-512",
url: "https://github.com/civitas-io/prx/releases/download/models-v1/jina-code-m2v-512.tar.gz",
archive_sha256: "0000000000000000000000000000000000000000000000000000000000000000",
safetensors_sha256: "0000000000000000000000000000000000000000000000000000000000000000",
tokenizer_sha256: "0000000000000000000000000000000000000000000000000000000000000000",
dim: 512,
dir_name: "jina-code-m2v-512",
};
pub fn get_tier(name: &str) -> Option<&'static ModelTier> {
match name {
"standard" => Some(&STANDARD),
"large" => Some(&LARGE),
_ => None,
}
}
pub fn model_dir(tier: &ModelTier) -> PathBuf {
let base = if let Ok(dir) = std::env::var("PRX_MODELS_DIR") {
PathBuf::from(dir)
} else {
dirs_next::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".prx")
.join("models")
};
base.join(tier.dir_name)
}
pub fn is_model_ready(tier: &ModelTier) -> bool {
let dir = model_dir(tier);
let safetensors = dir.join("model.safetensors");
let tokenizer = dir.join("tokenizer.json");
if !safetensors.exists() || !tokenizer.exists() {
return false;
}
if !is_placeholder_hash(tier.safetensors_sha256)
&& !verify_sha256(&safetensors, tier.safetensors_sha256)
{
return false;
}
if !is_placeholder_hash(tier.tokenizer_sha256)
&& !verify_sha256(&tokenizer, tier.tokenizer_sha256)
{
return false;
}
true
}
pub fn download_model(tier: &ModelTier) -> Result<PathBuf, String> {
let dir = model_dir(tier);
std::fs::create_dir_all(&dir).map_err(|e| format!("Failed to create model dir: {e}"))?;
let archive_path = dir.join("model.tar.gz");
eprintln!("Downloading {} model ({})...", tier.display_name, tier.name);
eprintln!(" From: {}", tier.url);
let status = std::process::Command::new("curl")
.args(["-fSL", "--progress-bar", "-o"])
.arg(&archive_path)
.arg(tier.url)
.status()
.map_err(|e| format!("curl not found: {e}"))?;
if !status.success() {
let _ = std::fs::remove_file(&archive_path);
return Err("Download failed. Check your network connection.".to_string());
}
if !is_placeholder_hash(tier.archive_sha256) {
let data =
std::fs::read(&archive_path).map_err(|e| format!("Failed to read archive: {e}"))?;
let hash = sha256_hex(&data);
if hash != tier.archive_sha256 {
let _ = std::fs::remove_file(&archive_path);
return Err(format!(
"Archive hash mismatch: expected {}, got {hash}",
tier.archive_sha256
));
}
}
let status = std::process::Command::new("tar")
.args(["xzf"])
.arg(&archive_path)
.arg("-C")
.arg(&dir)
.status()
.map_err(|e| format!("tar not found: {e}"))?;
if !status.success() {
return Err("Extraction failed.".to_string());
}
let _ = std::fs::remove_file(&archive_path);
eprintln!(" Saved to {}", dir.display());
Ok(dir)
}
pub fn ensure_model(tier: &ModelTier) -> Result<PathBuf, String> {
if is_model_ready(tier) {
return Ok(model_dir(tier));
}
download_model(tier)
}
fn is_placeholder_hash(hex: &str) -> bool {
hex.chars().all(|c| c == '0')
}
fn verify_sha256(path: &Path, expected: &str) -> bool {
let Ok(data) = std::fs::read(path) else {
return false;
};
sha256_hex(&data) == expected
}
fn sha256_hex(data: &[u8]) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(data);
format!("{:x}", hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn get_tier_known_names() {
assert_eq!(get_tier("standard").map(|t| t.name), Some("standard"));
assert_eq!(get_tier("large").map(|t| t.name), Some("large"));
}
#[test]
fn get_tier_unknown_returns_none() {
assert!(get_tier("nope").is_none());
assert!(get_tier("builtin").is_none());
assert!(get_tier("").is_none());
}
#[test]
fn placeholder_detection() {
assert!(is_placeholder_hash(
"0000000000000000000000000000000000000000000000000000000000000000"
));
assert!(!is_placeholder_hash(
"ca6159081a6e96cebe4ad878e5e8437bfccc761e8db16223370149cd2faa6c0b"
));
}
#[test]
fn sha256_hex_matches_known_value() {
assert_eq!(
sha256_hex(b""),
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
);
}
#[test]
fn model_dir_respects_env_override() {
let tier = &STANDARD;
unsafe {
std::env::set_var("PRX_MODELS_DIR", "/tmp/prx-test-models");
}
let dir = model_dir(tier);
unsafe {
std::env::remove_var("PRX_MODELS_DIR");
}
assert_eq!(
dir,
PathBuf::from("/tmp/prx-test-models").join(tier.dir_name)
);
}
#[test]
fn is_model_ready_false_for_missing_dir() {
let tier = &STANDARD;
unsafe {
std::env::set_var(
"PRX_MODELS_DIR",
"/tmp/prx-test-models-definitely-does-not-exist-xyz123",
);
}
let ready = is_model_ready(tier);
unsafe {
std::env::remove_var("PRX_MODELS_DIR");
}
assert!(!ready);
}
}