use crate::aliases::{AliasEntry, AliasRegistry, ResolvedAlias};
use crate::cache::{
format_bytes, CacheConfig, CacheEntry, CacheManager, CacheStats, DownloadProgress,
EvictionPolicy,
};
use crate::error::{PachaError, Result};
use crate::format::{detect_format, ModelFormat, QuantType};
use crate::resolver::ModelResolver;
use crate::uri::ModelUri;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FetchConfig {
pub cache: CacheConfig,
pub default_quant: Option<QuantType>,
pub auto_pull: bool,
pub max_concurrent: usize,
pub verify_integrity: bool,
pub eviction_policy: EvictionPolicy,
}
impl Default for FetchConfig {
fn default() -> Self {
Self {
cache: CacheConfig::default(),
default_quant: Some(QuantType::Q4_K_M),
auto_pull: true,
max_concurrent: 2,
verify_integrity: true,
eviction_policy: EvictionPolicy::LRU,
}
}
}
impl FetchConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_cache(mut self, cache: CacheConfig) -> Self {
self.cache = cache;
self
}
#[must_use]
pub fn with_default_quant(mut self, quant: QuantType) -> Self {
self.default_quant = Some(quant);
self
}
#[must_use]
pub fn with_auto_pull(mut self, enabled: bool) -> Self {
self.auto_pull = enabled;
self
}
#[must_use]
pub fn with_eviction_policy(mut self, policy: EvictionPolicy) -> Self {
self.eviction_policy = policy;
self
}
}
#[derive(Debug)]
pub struct FetchResult {
pub path: PathBuf,
pub format: ModelFormat,
pub size_bytes: u64,
pub cache_hit: bool,
pub reference: String,
pub resolved_uri: String,
pub hash: String,
}
impl FetchResult {
#[must_use]
pub fn size_human(&self) -> String {
format_bytes(self.size_bytes)
}
#[must_use]
pub fn is_quantized(&self) -> bool {
match &self.format {
ModelFormat::Gguf(info) => info.quantization.is_some(),
_ => false,
}
}
#[must_use]
pub fn quant_type(&self) -> Option<QuantType> {
match &self.format {
ModelFormat::Gguf(info) => info
.quantization
.as_ref()
.and_then(|q| QuantType::from_str(q)),
_ => None,
}
}
}
pub struct ModelFetcher {
config: FetchConfig,
aliases: AliasRegistry,
cache: CacheManager,
resolver: Option<ModelResolver>,
cache_dir: PathBuf,
}
impl ModelFetcher {
pub fn new() -> Result<Self> {
Self::with_config(FetchConfig::default())
}
pub fn with_config(config: FetchConfig) -> Result<Self> {
let cache_dir = get_default_cache_dir();
std::fs::create_dir_all(&cache_dir).map_err(|e| {
PachaError::Io(std::io::Error::new(
e.kind(),
format!("Failed to create cache dir: {}", cache_dir.display()),
))
})?;
let mut cache = CacheManager::new(config.cache.clone()).with_policy(config.eviction_policy);
Self::load_manifest(&cache_dir, &mut cache);
let resolver = ModelResolver::new_default().ok();
Ok(Self {
config,
aliases: AliasRegistry::with_defaults(),
cache,
resolver,
cache_dir,
})
}
fn load_manifest(cache_dir: &Path, cache: &mut CacheManager) {
let manifest_path = cache_dir.join("manifest.json");
if let Ok(data) = std::fs::read_to_string(&manifest_path) {
if let Ok(entries) = serde_json::from_str::<Vec<CacheEntry>>(&data) {
for entry in entries {
if entry.path.exists() {
cache.add(entry);
}
}
}
}
}
fn save_manifest(&self) {
let manifest_path = self.cache_dir.join("manifest.json");
let entries: Vec<&CacheEntry> = self.cache.list();
if let Ok(data) = serde_json::to_string_pretty(&entries) {
let _ = std::fs::write(manifest_path, data);
}
}
pub fn with_cache_dir(cache_dir: PathBuf, config: FetchConfig) -> Result<Self> {
std::fs::create_dir_all(&cache_dir).map_err(|e| {
PachaError::Io(std::io::Error::new(
e.kind(),
format!("Failed to create cache dir: {}", cache_dir.display()),
))
})?;
let cache = CacheManager::new(config.cache.clone()).with_policy(config.eviction_policy);
let resolver = ModelResolver::new_default().ok();
Ok(Self {
config,
aliases: AliasRegistry::with_defaults(),
cache,
resolver,
cache_dir,
})
}
#[must_use]
pub fn config(&self) -> &FetchConfig {
&self.config
}
#[must_use]
pub fn aliases(&self) -> &AliasRegistry {
&self.aliases
}
pub fn add_alias(&mut self, alias: &str, uri: &str) -> Result<()> {
self.aliases.add(AliasEntry::new(alias, uri));
Ok(())
}
pub fn resolve_ref(&self, model_ref: &str) -> Result<ResolvedAlias> {
let resolved = self.aliases.resolve(model_ref);
if resolved.is_alias || model_ref.contains("://") {
Ok(resolved)
} else {
Err(PachaError::NotFound {
kind: "alias".to_string(),
name: model_ref.to_string(),
version: "N/A".to_string(),
})
}
}
pub fn pull<F>(&mut self, model_ref: &str, progress_fn: F) -> Result<FetchResult>
where
F: Fn(&DownloadProgress),
{
let resolved = self.aliases.resolve(model_ref);
let uri_str = resolved.uri;
let cache_key = Self::cache_key(&uri_str);
if let Some(entry) = self.cache.get(&cache_key, "1.0") {
let format = format_from_path(&entry.path);
return Ok(FetchResult {
path: entry.path.clone(),
format,
size_bytes: entry.size_bytes,
cache_hit: true,
reference: model_ref.to_string(),
resolved_uri: uri_str,
hash: entry.hash.clone(),
});
}
let uri = ModelUri::parse(&uri_str)?;
let resolver = self
.resolver
.as_ref()
.ok_or_else(|| PachaError::NotInitialized(PathBuf::from("~/.pacha")))?;
let mut progress = DownloadProgress::new(0); progress_fn(&progress);
let resolved_model = resolver.resolve(&uri)?;
progress = DownloadProgress::new(resolved_model.data.len() as u64);
progress.update(resolved_model.data.len() as u64);
progress_fn(&progress);
let format = detect_format(&resolved_model.data);
let hash = blake3::hash(&resolved_model.data).to_hex().to_string();
let extension = match &format {
ModelFormat::Gguf(_) => "gguf",
ModelFormat::SafeTensors(_) => "safetensors",
ModelFormat::Apr(_) => "apr",
ModelFormat::Onnx(_) => "onnx",
ModelFormat::PyTorch => "pt",
ModelFormat::Unknown => "bin",
};
let filename = format!("{}.{}", &hash[..16], extension);
let cache_path = self.cache_dir.join(&filename);
std::fs::write(&cache_path, &resolved_model.data).map_err(|e| {
PachaError::Io(std::io::Error::new(
e.kind(),
format!("Failed to write to cache: {}", cache_path.display()),
))
})?;
let entry = CacheEntry::new(
&cache_key,
"1.0",
resolved_model.data.len() as u64,
&hash,
cache_path.clone(),
);
self.cache.add(entry);
self.save_manifest();
Ok(FetchResult {
path: cache_path,
format,
size_bytes: resolved_model.data.len() as u64,
cache_hit: false,
reference: model_ref.to_string(),
resolved_uri: uri_str,
hash,
})
}
pub fn pull_quiet(&mut self, model_ref: &str) -> Result<FetchResult> {
self.pull(model_ref, |_| {})
}
#[must_use]
pub fn is_cached(&self, model_ref: &str) -> bool {
let resolved = self.aliases.resolve(model_ref);
let key = Self::cache_key(&resolved.uri);
self.cache.contains(&key, "1.0")
}
pub fn remove(&mut self, model_ref: &str) -> Result<bool> {
let resolved = self.aliases.resolve(model_ref);
let uri = resolved.uri;
let key = Self::cache_key(&uri);
if let Some(entry) = self.cache.remove(&key, "1.0") {
if entry.path.exists() {
std::fs::remove_file(&entry.path).ok();
}
Ok(true)
} else {
Ok(false)
}
}
#[must_use]
pub fn list(&self) -> Vec<CachedModel> {
self.cache
.list()
.iter()
.map(|e| {
let format = format_from_path(&e.path);
CachedModel {
name: e.name.clone(),
version: e.version.clone(),
size_bytes: e.size_bytes,
format,
path: e.path.clone(),
last_accessed: e.last_accessed,
access_count: e.access_count,
pinned: e.pinned,
}
})
.collect()
}
#[must_use]
pub fn stats(&self) -> CacheStats {
self.cache.stats()
}
pub fn cleanup(&mut self) -> u64 {
self.cache.cleanup_to_target()
}
pub fn cleanup_old(&mut self) -> u64 {
self.cache.cleanup_old_entries()
}
pub fn clear(&mut self) -> u64 {
for entry in self.cache.list() {
if entry.path.exists() {
std::fs::remove_file(&entry.path).ok();
}
}
self.cache.clear()
}
pub fn pin(&mut self, model_ref: &str) -> bool {
let key = Self::cache_key(model_ref);
self.cache.pin(&key, "1.0")
}
pub fn unpin(&mut self, model_ref: &str) -> bool {
let key = Self::cache_key(model_ref);
self.cache.unpin(&key, "1.0")
}
#[must_use]
pub fn cache_dir(&self) -> &PathBuf {
&self.cache_dir
}
fn cache_key(uri: &str) -> String {
uri.replace("://", "_").replace('/', "_").replace(':', "_")
}
}
#[derive(Debug, Clone)]
pub struct CachedModel {
pub name: String,
pub version: String,
pub size_bytes: u64,
pub format: ModelFormat,
pub path: PathBuf,
pub last_accessed: std::time::SystemTime,
pub access_count: u64,
pub pinned: bool,
}
impl CachedModel {
#[must_use]
pub fn size_human(&self) -> String {
format_bytes(self.size_bytes)
}
#[must_use]
pub fn quant_type(&self) -> Option<QuantType> {
match &self.format {
ModelFormat::Gguf(info) => info
.quantization
.as_ref()
.and_then(|q| QuantType::from_str(q)),
_ => None,
}
}
}
fn format_from_path(path: &Path) -> ModelFormat {
let ext = path
.extension()
.and_then(|e| e.to_str())
.map(|e| e.to_lowercase());
match ext.as_deref() {
Some("gguf") => ModelFormat::Gguf(Default::default()),
Some("safetensors") => ModelFormat::SafeTensors(Default::default()),
Some("apr") => ModelFormat::Apr(Default::default()),
Some("onnx") => ModelFormat::Onnx(Default::default()),
Some("pt") | Some("pth") => ModelFormat::PyTorch,
_ => ModelFormat::Unknown,
}
}
fn get_default_cache_dir() -> PathBuf {
if let Ok(cache_home) = std::env::var("XDG_CACHE_HOME") {
return PathBuf::from(cache_home).join("pacha").join("models");
}
if let Ok(home) = std::env::var("HOME") {
return PathBuf::from(home)
.join(".cache")
.join("pacha")
.join("models");
}
if let Ok(local_app_data) = std::env::var("LOCALAPPDATA") {
return PathBuf::from(local_app_data)
.join("pacha")
.join("cache")
.join("models");
}
PathBuf::from(".cache").join("pacha").join("models")
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_fetch_config_default() {
let config = FetchConfig::default();
assert!(config.auto_pull);
assert_eq!(config.max_concurrent, 2);
assert!(config.verify_integrity);
}
#[test]
fn test_fetch_config_builder() {
let config = FetchConfig::new()
.with_default_quant(QuantType::Q8_0)
.with_auto_pull(false)
.with_eviction_policy(EvictionPolicy::LFU);
assert_eq!(config.default_quant, Some(QuantType::Q8_0));
assert!(!config.auto_pull);
assert_eq!(config.eviction_policy, EvictionPolicy::LFU);
}
#[test]
fn test_fetch_config_with_cache() {
let cache_config = CacheConfig::new().with_max_size_gb(100.0);
let config = FetchConfig::new().with_cache(cache_config.clone());
assert_eq!(config.cache.max_size_bytes, cache_config.max_size_bytes);
}
#[test]
fn test_fetcher_with_cache_dir() {
let dir = TempDir::new().unwrap();
let result = ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default());
assert!(result.is_ok());
}
#[test]
fn test_fetcher_cache_dir_created() {
let dir = TempDir::new().unwrap();
let cache_dir = dir.path().join("models");
let _ = ModelFetcher::with_cache_dir(cache_dir.clone(), FetchConfig::default()).unwrap();
assert!(cache_dir.exists());
}
#[test]
fn test_fetcher_config_access() {
let dir = TempDir::new().unwrap();
let config = FetchConfig::new().with_auto_pull(false);
let fetcher = ModelFetcher::with_cache_dir(dir.path().to_path_buf(), config).unwrap();
assert!(!fetcher.config().auto_pull);
}
#[test]
fn test_fetcher_has_default_aliases() {
let dir = TempDir::new().unwrap();
let fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
let aliases = fetcher.aliases();
assert!(aliases.get("llama3").is_some());
assert!(aliases.get("mistral").is_some());
}
#[test]
fn test_fetcher_add_alias() {
let dir = TempDir::new().unwrap();
let mut fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
fetcher
.add_alias("mymodel", "hf://my-org/my-model")
.unwrap();
assert!(fetcher.aliases().get("mymodel").is_some());
}
#[test]
fn test_fetcher_resolve_ref() {
let dir = TempDir::new().unwrap();
let fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
let resolved = fetcher.resolve_ref("llama3");
assert!(resolved.is_ok());
let uri = resolved.unwrap().uri;
assert!(uri.starts_with("hf://"), "Expected hf:// URI, got: {}", uri);
}
#[test]
fn test_fetcher_resolve_ref_not_found() {
let dir = TempDir::new().unwrap();
let fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
let resolved = fetcher.resolve_ref("nonexistent-model-xyz");
assert!(resolved.is_err());
}
#[test]
fn test_fetcher_is_cached_empty() {
let dir = TempDir::new().unwrap();
let fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
assert!(!fetcher.is_cached("llama3"));
}
#[test]
fn test_fetcher_stats_empty() {
let dir = TempDir::new().unwrap();
let fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
let stats = fetcher.stats();
assert_eq!(stats.model_count, 0);
assert_eq!(stats.total_size_bytes, 0);
}
#[test]
fn test_fetcher_list_empty() {
let dir = TempDir::new().unwrap();
let fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
assert!(fetcher.list().is_empty());
}
#[test]
fn test_fetcher_clear() {
let dir = TempDir::new().unwrap();
let mut fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
let freed = fetcher.clear();
assert_eq!(freed, 0); }
#[test]
fn test_fetcher_cleanup() {
let dir = TempDir::new().unwrap();
let mut fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
let freed = fetcher.cleanup();
assert_eq!(freed, 0);
}
#[test]
fn test_cache_key_generation() {
let key1 = ModelFetcher::cache_key("hf://meta-llama/Llama-3-8B");
let key2 = ModelFetcher::cache_key("pacha://model:1.0.0");
assert!(!key1.contains("://"));
assert!(!key2.contains("://"));
}
#[test]
fn test_cache_key_unique() {
let key1 = ModelFetcher::cache_key("hf://org/model1");
let key2 = ModelFetcher::cache_key("hf://org/model2");
assert_ne!(key1, key2);
}
#[test]
fn test_fetch_result_size_human() {
let result = FetchResult {
path: PathBuf::from("/cache/model.gguf"),
format: ModelFormat::Unknown,
size_bytes: 4 * 1024 * 1024 * 1024, cache_hit: true,
reference: "llama3".to_string(),
resolved_uri: "hf://meta-llama/Llama-3-8B".to_string(),
hash: "abc123".to_string(),
};
assert!(result.size_human().contains("GB"));
}
#[test]
fn test_fetch_result_not_quantized() {
let result = FetchResult {
path: PathBuf::from("/cache/model.safetensors"),
format: ModelFormat::SafeTensors(Default::default()),
size_bytes: 1000,
cache_hit: false,
reference: "test".to_string(),
resolved_uri: "test".to_string(),
hash: "hash".to_string(),
};
assert!(!result.is_quantized());
assert!(result.quant_type().is_none());
}
#[test]
fn test_fetch_result_quantized_gguf() {
use crate::format::GgufInfo;
let result = FetchResult {
path: PathBuf::from("/cache/model.gguf"),
format: ModelFormat::Gguf(GgufInfo {
version: 3,
tensor_count: 100,
metadata_count: 10,
quantization: Some("Q4_K_M".to_string()),
..Default::default()
}),
size_bytes: 4_000_000_000,
cache_hit: true,
reference: "llama3:8b-q4_k_m".to_string(),
resolved_uri: "hf://...".to_string(),
hash: "hash".to_string(),
};
assert!(result.is_quantized());
assert_eq!(result.quant_type(), Some(QuantType::Q4_K_M));
}
#[test]
fn test_cached_model_size_human() {
let model = CachedModel {
name: "llama3".to_string(),
version: "8b".to_string(),
size_bytes: 4 * 1024 * 1024 * 1024,
format: ModelFormat::Unknown,
path: PathBuf::from("/cache"),
last_accessed: std::time::SystemTime::now(),
access_count: 5,
pinned: false,
};
assert!(model.size_human().contains("GB"));
}
#[test]
fn test_cached_model_quant_type() {
use crate::format::GgufInfo;
let model = CachedModel {
name: "llama3".to_string(),
version: "8b".to_string(),
size_bytes: 4_000_000_000,
format: ModelFormat::Gguf(GgufInfo {
version: 3,
tensor_count: 100,
metadata_count: 10,
quantization: Some("Q8_0".to_string()),
..Default::default()
}),
path: PathBuf::from("/cache/model.gguf"),
last_accessed: std::time::SystemTime::now(),
access_count: 1,
pinned: true,
};
assert_eq!(model.quant_type(), Some(QuantType::Q8_0));
}
#[test]
fn test_fetcher_pin_nonexistent() {
let dir = TempDir::new().unwrap();
let mut fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
assert!(!fetcher.pin("nonexistent"));
}
#[test]
fn test_fetcher_unpin_nonexistent() {
let dir = TempDir::new().unwrap();
let mut fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
assert!(!fetcher.unpin("nonexistent"));
}
#[test]
fn test_fetcher_remove_nonexistent() {
let dir = TempDir::new().unwrap();
let mut fetcher =
ModelFetcher::with_cache_dir(dir.path().to_path_buf(), FetchConfig::default()).unwrap();
let result = fetcher.remove("nonexistent");
assert!(result.is_ok());
assert!(!result.unwrap());
}
#[test]
fn test_fetch_config_serialization() {
let config = FetchConfig::new()
.with_default_quant(QuantType::Q4_K_M)
.with_auto_pull(false);
let json = serde_json::to_string(&config).unwrap();
let parsed: FetchConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.default_quant, config.default_quant);
assert_eq!(parsed.auto_pull, config.auto_pull);
}
}