mod download;
mod internal;
pub use download::{download_and_verify, download_file, sha256_file};
pub use internal::{LockedModel, ModelLock};
use crate::paths;
use anyhow::Result;
use std::path::{Path, PathBuf};
pub fn cached_model_path(name: &str) -> Option<PathBuf> {
let model_dir = paths::models::model_dir(name);
let tokenizer = paths::models::model_tokenizer(name);
if !tokenizer.exists() {
return None;
}
let onnx = model_dir.join("model.onnx");
let onnx_quantized = model_dir.join("model_quantized.onnx");
if onnx.exists() || onnx_quantized.exists() {
Some(model_dir)
} else {
None
}
}
pub fn is_tracked(name: &str) -> Result<bool> {
let lock = ModelLock::load()?;
Ok(lock.get(name).is_some())
}
fn has_valid_model_files(dir: &Path) -> bool {
let tokenizer = dir.join("tokenizer.json");
if !tokenizer.exists() {
return false;
}
let onnx = dir.join("model.onnx");
let onnx_quantized = dir.join("model_quantized.onnx");
onnx.exists() || onnx_quantized.exists()
}
pub fn resolve_model_path(name: &str) -> Result<PathBuf> {
if let Some(path) = cached_model_path(name) {
return Ok(path);
}
let local_path = PathBuf::from(format!("resources/models/{}", name));
if has_valid_model_files(&local_path) {
return Ok(local_path);
}
anyhow::bail!(
"Model '{}' not found. Run `patina model add {}` to download it.",
name,
name
)
}
#[derive(Debug)]
pub struct ModelStatus {
pub name: String,
pub in_cache: bool,
pub in_local: bool,
pub provenance: Option<LockedModel>,
}
pub fn model_status(name: &str) -> Result<ModelStatus> {
let lock = ModelLock::load()?;
let provenance = lock.get(name).cloned();
let cache_dir = paths::models::model_dir(name);
let in_cache = has_valid_model_files(&cache_dir);
let local_path = PathBuf::from(format!("resources/models/{}", name));
let in_local = has_valid_model_files(&local_path);
Ok(ModelStatus {
name: name.to_string(),
in_cache,
in_local,
provenance,
})
}
pub fn add_model(name: &str) -> Result<()> {
use crate::embeddings::models::ModelRegistry;
let registry = ModelRegistry::load()?;
let model_def = registry.get_model(name)?;
let model_url = model_def
.download_quantized
.as_ref()
.ok_or_else(|| anyhow::anyhow!("No download URL for model '{}'", name))?;
let tokenizer_url = model_def
.download_tokenizer
.as_ref()
.ok_or_else(|| anyhow::anyhow!("No tokenizer URL for model '{}'", name))?;
println!("Downloading {}...", name);
let model_path = paths::models::model_onnx(name);
let model_path = model_path.with_file_name("model_quantized.onnx");
println!(" Model:");
let sha256_model = download_and_verify(model_url, &model_path, None)?;
let tokenizer_path = paths::models::model_tokenizer(name);
println!(" Tokenizer:");
let sha256_tokenizer = download_and_verify(tokenizer_url, &tokenizer_path, None)?;
let model_size = std::fs::metadata(&model_path)?.len();
let tokenizer_size = std::fs::metadata(&tokenizer_path)?.len();
let mut lock = ModelLock::load()?;
lock.insert(
name,
LockedModel {
downloaded: chrono::Utc::now().to_rfc3339(),
source_model: model_url.clone(),
source_tokenizer: tokenizer_url.clone(),
sha256_model,
sha256_tokenizer,
size_bytes: model_size + tokenizer_size,
dimensions: model_def.dimensions,
},
);
lock.save()?;
println!("\n✓ Model '{}' added to cache", name);
println!(" Location: {:?}", paths::models::model_dir(name));
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cached_model_path_missing() {
let path = cached_model_path("nonexistent-model-xyz");
assert!(path.is_none());
}
#[test]
fn test_resolve_model_path_local_fallback() {
let result = resolve_model_path("all-minilm-l6-v2");
match result {
Ok(path) => {
assert!(path.to_string_lossy().contains("all-minilm-l6-v2"));
}
Err(e) => {
assert!(e.to_string().contains("not found"));
}
}
}
#[test]
#[ignore] fn test_add_model_e2e() {
let result = add_model("bge-small-en-v1-5");
assert!(result.is_ok(), "Failed: {:?}", result);
let model_dir = paths::models::model_dir("bge-small-en-v1-5");
assert!(model_dir.join("model_quantized.onnx").exists());
assert!(model_dir.join("tokenizer.json").exists());
let lock = ModelLock::load().unwrap();
let entry = lock.get("bge-small-en-v1-5");
assert!(entry.is_some());
assert_eq!(entry.unwrap().dimensions, 384);
}
}