pub fn remove(model_ref: &str) -> Result<()> {
println!("{}", "=== APR Remove ===".cyan().bold());
println!();
println!("Model: {}", model_ref.cyan());
let mut fetcher = ModelFetcher::new().map_err(|e| {
CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
})?;
let removed = fetcher
.remove(model_ref)
.map_err(|e| CliError::ValidationFailed(format!("Failed to remove model: {e}")))?;
if removed {
println!("{} Model removed from cache", "✓".green());
Ok(())
} else {
println!("{} Model not found in cache", "⚠".yellow());
Err(CliError::FileNotFound(std::path::PathBuf::from(model_ref)))
}
}
#[allow(dead_code)]
pub fn resolve_model_path(model_ref: &str) -> Result<std::path::PathBuf> {
contract_pre_model_path_resolution!();
let path = std::path::Path::new(model_ref);
if path.exists() && path.is_file() {
return Ok(path.to_path_buf());
}
let mut fetcher = ModelFetcher::with_config(FetchConfig::default()).map_err(|e| {
CliError::ValidationFailed(format!("Failed to initialize model fetcher: {e}"))
})?;
let result = fetcher
.pull(model_ref, |progress| {
if progress.total_bytes > 0 {
let pct = progress.percent();
eprint!(
"\rPulling model... [{:30}] {:5.1}%",
"=".repeat((pct / 3.33) as usize),
pct
);
io::stderr().flush().ok();
}
})
.map_err(|e| {
CliError::ValidationFailed(format!(
"Model '{}' not found. Not a local file and could not resolve via registry: {}",
model_ref, e
))
})?;
if !result.cache_hit {
eprintln!(); }
contract_post_model_path_resolution!(&());
Ok(result.path)
}
fn format_bytes(bytes: u64) -> String {
batuta_common::fmt::format_bytes(bytes)
}
fn fetch_safetensors_companions(model_path: &Path, resolved_uri: &str) -> Result<()> {
let Some(repo_id) = extract_hf_repo(resolved_uri) else {
return Ok(());
};
let model_stem = model_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("model");
let companions = [
"tokenizer.json",
"config.json",
"tokenizer_config.json",
"tokenizer.model",
];
let cache_dir = model_path
.parent()
.ok_or_else(|| CliError::ValidationFailed("Model path has no parent directory".into()))?;
for filename in &companions {
let prefixed_filename = format!("{model_stem}.{filename}");
let sibling_path = cache_dir.join(&prefixed_filename);
if sibling_path.exists() {
println!(
" {} {} (already exists)",
"✓".green(),
prefixed_filename.dimmed()
);
continue;
}
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
repo_id, filename
);
match hf_get(&url).call() {
Ok(response) => {
let mut body = Vec::new();
response.into_reader().read_to_end(&mut body).map_err(|e| {
CliError::NetworkError(format!("Failed to read {filename}: {e}"))
})?;
std::fs::write(&sibling_path, &body).map_err(|e| {
CliError::ValidationFailed(format!(
"Failed to write {}: {e}",
sibling_path.display()
))
})?;
println!(
" {} {} ({})",
"✓".green(),
prefixed_filename,
format_bytes(body.len() as u64).dimmed()
);
}
Err(ureq::Error::Status(404, _)) => {
println!(
" {} {} (not found in repo)",
"⚠".yellow(),
prefixed_filename.dimmed()
);
}
Err(ureq::Error::Status(401, _)) => {
eprintln!(
" {} {} (access denied — set HF_TOKEN for gated models)",
"⚠".yellow(),
prefixed_filename,
);
}
Err(e) => {
eprintln!(
" {} Failed to download {}: {}",
"⚠".yellow(),
prefixed_filename,
e
);
}
}
}
let tokenizer_prefixes = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"];
let has_tokenizer = tokenizer_prefixes
.iter()
.any(|f| cache_dir.join(format!("{model_stem}.{f}")).exists());
if !has_tokenizer {
return Err(CliError::ValidationFailed(format!(
"No tokenizer found for this model. Tried: {}.\n\
The model may require a custom tokenizer not hosted in the repository.",
tokenizer_prefixes.join(", ")
)));
}
Ok(())
}
fn convert_safetensors_formats(safetensors_path: &Path) -> Result<()> {
let apr_path = safetensors_path.with_extension("apr");
let gguf_path = safetensors_path.with_extension("gguf");
if apr_path.exists() && gguf_path.exists() {
println!();
println!(
" {} APR and GGUF formats available",
"✓".green(),
);
return Ok(());
}
println!();
println!(
" {} To convert formats, run:",
"ℹ".cyan(),
);
if !apr_path.exists() {
println!(
" apr convert {} --format apr",
safetensors_path.display()
);
}
if !gguf_path.exists() {
println!(
" apr convert {} --format gguf",
safetensors_path.display()
);
}
Ok(())
}
fn extract_hf_repo(uri: &str) -> Option<String> {
let path = uri.strip_prefix("hf://")?;
let parts: Vec<&str> = path.split('/').collect();
if parts.len() >= 2 {
Some(format!("{}/{}", parts[0], parts[1]))
} else {
None
}
}
fn normalize_hf_uri(uri: &str) -> String {
if !uri.contains("://") && !uri.starts_with('/') && !uri.starts_with('.') {
let parts: Vec<&str> = uri.split('/').collect();
if parts.len() >= 2 && !parts[0].is_empty() && !parts[1].is_empty() {
return format!("hf://{uri}");
}
}
uri.to_string()
}
fn select_best_gguf(gguf_files: &[&str], org: &str, repo: &str) -> ResolvedModel {
let quantization_priority = ["q4_k_m", "q4_k_s", "q4_0", "q8_0"];
for quant in quantization_priority {
if let Some(file) = gguf_files.iter().find(|f| f.to_lowercase().contains(quant)) {
return ResolvedModel::SingleFile(format!("hf://{org}/{repo}/{file}"));
}
}
ResolvedModel::SingleFile(format!("hf://{org}/{repo}/{}", gguf_files[0]))
}
fn resolve_sharded_safetensors(org: &str, repo: &str) -> Result<ResolvedModel> {
let index_url =
format!("https://huggingface.co/{org}/{repo}/resolve/main/model.safetensors.index.json");
let index_response = hf_get(&index_url)
.call()
.map_err(|e| CliError::NetworkError(format!("Failed to download model index: {e}")))?;
let mut index_body = Vec::new();
index_response
.into_reader()
.read_to_end(&mut index_body)
.map_err(|e| CliError::NetworkError(format!("Failed to read model index: {e}")))?;
let index_json = String::from_utf8_lossy(&index_body);
let shard_files = extract_shard_files_from_index(&index_json);
if shard_files.is_empty() {
return Err(CliError::ValidationFailed(format!(
"Sharded model index for {org}/{repo} contains no shard files"
)));
}
Ok(ResolvedModel::Sharded {
org: org.to_string(),
repo: repo.to_string(),
shard_files,
})
}
fn find_safetensors_file(filenames: &[&str], org: &str, repo: &str) -> Option<ResolvedModel> {
if filenames
.iter()
.any(|f| f.to_lowercase() == "model.safetensors")
{
return Some(ResolvedModel::SingleFile(format!(
"hf://{org}/{repo}/model.safetensors"
)));
}
filenames
.iter()
.find(|f| f.to_lowercase().ends_with(".safetensors"))
.map(|file| ResolvedModel::SingleFile(format!("hf://{org}/{repo}/{file}")))
}
fn has_known_model_extension(uri: &str) -> bool {
std::path::Path::new(uri).extension().is_some_and(|ext| {
ext.eq_ignore_ascii_case("gguf")
|| ext.eq_ignore_ascii_case("safetensors")
|| ext.eq_ignore_ascii_case("apr")
|| ext.eq_ignore_ascii_case("pt")
})
}
fn resolve_hf_model(uri: &str) -> Result<ResolvedModel> {
let uri = normalize_hf_uri(uri);
let uri = uri.as_str();
if !uri.starts_with("hf://") {
return Ok(ResolvedModel::SingleFile(uri.to_string()));
}
if has_known_model_extension(uri) {
return Ok(ResolvedModel::SingleFile(uri.to_string()));
}
let path = uri.strip_prefix("hf://").unwrap_or(uri);
let parts: Vec<&str> = path.split('/').collect();
if parts.len() < 2 {
return Err(CliError::ValidationFailed(format!(
"Invalid HuggingFace URI: {uri}. Expected hf://org/repo or hf://org/repo/file.gguf"
)));
}
let org = parts[0];
let repo = parts[1];
let api_url = format!("https://huggingface.co/api/models/{org}/{repo}");
let response = hf_get(&api_url).call().map_err(|e| match &e {
ureq::Error::Status(401, _) => {
CliError::NetworkError(format_gated_model_error(&api_url))
}
_ => CliError::NetworkError(format!("Failed to query HuggingFace API: {e}")),
})?;
let body: serde_json::Value = {
let text = response.into_string().map_err(|e| {
CliError::ValidationFailed(format!("Failed to read HuggingFace response: {e}"))
})?;
serde_json::from_str(&text).map_err(|e| {
CliError::ValidationFailed(format!("Failed to parse HuggingFace response: {e}"))
})?
};
let siblings = body["siblings"]
.as_array()
.ok_or_else(|| CliError::ValidationFailed("No files found in repository".to_string()))?;
let filenames: Vec<&str> = siblings
.iter()
.filter_map(|s| s["rfilename"].as_str())
.collect();
let gguf_files: Vec<&str> = filenames
.iter()
.copied()
.filter(|f| f.to_lowercase().ends_with(".gguf"))
.collect();
if !gguf_files.is_empty() {
return Ok(select_best_gguf(&gguf_files, org, repo));
}
if filenames.contains(&"model.safetensors.index.json") {
return resolve_sharded_safetensors(org, repo);
}
if let Some(model) = find_safetensors_file(&filenames, org, repo) {
return Ok(model);
}
resolve_hf_model_fallback(&filenames, org, repo)
}
fn resolve_hf_model_fallback(filenames: &[&str], org: &str, repo: &str) -> Result<ResolvedModel> {
let has_bin_files = filenames
.iter()
.any(|f| f.to_lowercase().ends_with(".bin"));
if has_bin_files {
return Err(CliError::ValidationFailed(format!(
"{org}/{repo} only has PyTorch .bin weights (no SafeTensors or GGUF).\n\
Convert first with:\n \
python -c \"from transformers import AutoModelForCausalLM; \
m = AutoModelForCausalLM.from_pretrained('{org}/{repo}'); \
m.save_pretrained('{repo}-st', safe_serialization=True)\"\n\
Or request SafeTensors on the model page."
)));
}
Err(CliError::ValidationFailed(format!(
"No .gguf or .safetensors files found in {org}/{repo}"
)))
}
fn find_brace_content(text: &str) -> Option<&str> {
let start = text.find('{')?;
let content = &text[start + 1..];
let mut depth = 1usize;
for (i, c) in content.char_indices() {
match c {
'{' => depth += 1,
'}' if depth == 1 => return Some(&content[..i]),
'}' => depth -= 1,
_ => {}
}
}
None
}