#![allow(dead_code)]
#![allow(unused_imports)]
#![allow(unused_variables)]
#![allow(clippy::needless_return)]
#![allow(clippy::format_push_string)]
#![allow(clippy::map_unwrap_or)]
#![allow(clippy::disallowed_methods)]
use crate::error::{CliError, Result};
use colored::Colorize;
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::time::Instant;
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum ModelSource {
Local(PathBuf),
HuggingFace {
org: String,
repo: String,
file: Option<String>,
},
Url(String),
}
impl ModelSource {
pub(crate) fn parse(source: &str) -> Result<Self> {
if source.starts_with("hf://") {
let path = source.strip_prefix("hf://").unwrap_or(source);
let parts: Vec<&str> = path.split('/').collect();
if parts.len() >= 2 {
let file = if parts.len() >= 3 && parts[2].contains('.') {
Some(parts[2..].join("/"))
} else {
None
};
Ok(Self::HuggingFace {
org: parts[0].to_string(),
repo: parts[1].to_string(),
file,
})
} else {
Err(CliError::InvalidFormat(format!(
"Invalid HuggingFace source: {source}. Expected hf://org/repo"
)))
}
} else if source.starts_with("http://") || source.starts_with("https://") {
Ok(Self::Url(source.to_string()))
} else {
Ok(Self::Local(PathBuf::from(source)))
}
}
pub(crate) fn cache_path(&self) -> PathBuf {
let cache_dir = dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".apr")
.join("cache");
match self {
Self::Local(path) => path.clone(),
Self::HuggingFace { org, repo, .. } => cache_dir.join("hf").join(org).join(repo),
Self::Url(url) => {
let hash = format!("{:x}", md5_hash(url.as_bytes()));
cache_dir.join("urls").join(&hash[..16])
}
}
}
}
fn md5_hash(data: &[u8]) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for byte in data {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(0x100000001b3);
}
hash
}
#[derive(Debug, Clone)]
pub(crate) struct RunOptions {
pub input: Option<PathBuf>,
pub prompt: Option<String>,
pub max_tokens: usize,
pub output_format: String,
pub force: bool,
pub no_gpu: bool,
pub offline: bool,
pub benchmark: bool,
pub verbose: bool,
pub trace: bool,
pub trace_steps: Option<Vec<String>>,
pub trace_verbose: bool,
pub trace_output: Option<PathBuf>,
pub trace_level: String,
pub profile: bool,
pub temperature: f32,
pub top_k: usize,
pub top_p: Option<f32>,
pub seed: u64,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
pub split_prompt: bool,
}
impl Default for RunOptions {
fn default() -> Self {
Self {
input: None,
prompt: None,
max_tokens: 32,
output_format: "text".to_string(),
force: false,
no_gpu: false,
offline: false,
benchmark: false,
verbose: false,
trace: false,
trace_steps: None,
trace_verbose: false,
trace_output: None,
trace_level: "basic".to_string(),
profile: false,
temperature: 0.0,
top_k: 1,
top_p: None,
seed: 299_792_458,
repeat_penalty: 1.0,
repeat_last_n: 64,
split_prompt: false,
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct RunResult {
pub text: String,
pub duration_secs: f64,
pub cached: bool,
pub tokens_generated: Option<usize>,
pub tok_per_sec: Option<f64>,
pub used_gpu: Option<bool>,
pub generated_tokens: Option<Vec<u32>>,
}
#[provable_contracts_macros::contract("apr-cli-operations-v1", equation = "long_running_graceful")]
pub(crate) fn run_model(source: &str, options: &RunOptions) -> Result<RunResult> {
let start = Instant::now();
let resolved_source =
crate::commands::aliases::resolve_short_name(source).unwrap_or_else(|| source.to_string());
let hf_uri = if !resolved_source.contains("://")
&& resolved_source.contains('/')
&& !std::path::Path::new(&resolved_source).exists()
&& !resolved_source.starts_with('/')
{
format!("hf://{resolved_source}")
} else {
resolved_source
};
let fully_resolved_source = match crate::commands::pull::resolve_hf_model(&hf_uri) {
Ok(crate::commands::pull::ResolvedModel::SingleFile(uri)) => uri,
_ => hf_uri,
};
let model_source = ModelSource::parse(&fully_resolved_source)?;
let model_path = resolve_model(&model_source, options.force, options.offline)?;
if !model_path.exists() {
return Err(CliError::FileNotFound(model_path));
}
let input_path = options.input.as_ref();
let output = execute_inference(&model_path, input_path, options)?;
let duration = start.elapsed();
let tokens_generated = output
.tokens_generated
.or_else(|| Some(output.text.split_whitespace().count()));
Ok(RunResult {
text: output.text,
duration_secs: duration.as_secs_f64(),
cached: matches!(model_source, ModelSource::Local(_)) || model_source.cache_path().exists(),
tokens_generated,
tok_per_sec: output.tok_per_sec,
used_gpu: output.used_gpu,
generated_tokens: output.generated_tokens,
})
}
pub(crate) fn resolve_model(source: &ModelSource, force: bool, offline: bool) -> Result<PathBuf> {
match source {
ModelSource::Local(path) => Ok(path.clone()),
ModelSource::HuggingFace { org, repo, file } => {
if !force {
if let Some(path) = find_cached_model(org, repo, file.as_deref()) {
return Ok(path);
}
}
if offline {
return Err(CliError::ValidationFailed(format!(
"OFFLINE MODE: Model hf://{org}/{repo} not cached. \
Network access is disabled. Cache the model first with: \
apr import hf://{org}/{repo}"
)));
}
eprintln!("{}", format!("Downloading hf://{org}/{repo}...").yellow());
download_hf_model(org, repo, file.as_deref())
}
ModelSource::Url(url) => {
let cache_path = source.cache_path();
if !force && cache_path.exists() {
find_model_in_dir(&cache_path)
} else if offline {
Err(CliError::ValidationFailed(format!(
"OFFLINE MODE: URL {url} not cached. \
Network access is disabled. Download and cache the model first."
)))
} else {
eprintln!("{}", format!("Downloading {url}...").yellow());
download_url_model(url)
}
}
}
}
fn find_model_file_in_dir(dir: &Path, file: Option<&str>) -> Option<PathBuf> {
if let Some(filename) = file {
let path = dir.join(filename);
if path.exists() {
return Some(path);
}
} else {
let index_path = dir.join("model.safetensors.index.json");
if index_path.exists() {
return Some(index_path);
}
for name in &["model.safetensors", "pytorch_model.bin", "model.apr"] {
let path = dir.join(name);
if path.exists() {
return Some(path);
}
}
}
None
}
fn find_in_hf_cache(org: &str, repo: &str, file: Option<&str>) -> Option<PathBuf> {
let hf_cache = dirs::home_dir()?
.join(".cache")
.join("huggingface")
.join("hub");
let snapshots_dir = hf_cache
.join(format!("models--{org}--{repo}"))
.join("snapshots");
let entries = std::fs::read_dir(&snapshots_dir).ok()?;
for entry in entries.flatten() {
if let Some(found) = find_model_file_in_dir(&entry.path(), file) {
return Some(found);
}
}
None
}
fn find_in_apr_cache(org: &str, repo: &str, file: Option<&str>) -> Option<PathBuf> {
let apr_cache = dirs::home_dir()?
.join(".apr")
.join("cache")
.join("hf")
.join(org)
.join(repo);
if !apr_cache.exists() {
return None;
}
if let Some(filename) = file {
let path = apr_cache.join(filename);
if path.exists() {
return Some(path);
}
} else {
let index_path = apr_cache.join("model.safetensors.index.json");
if index_path.exists() {
return Some(index_path);
}
for ext in &["apr", "safetensors", "gguf"] {
let path = apr_cache.join(format!("model.{ext}"));
if path.exists() {
return Some(path);
}
}
}
None
}
fn find_cached_model(org: &str, repo: &str, file: Option<&str>) -> Option<PathBuf> {
if let Some(path) = find_in_hf_cache(org, repo, file) {
return Some(path);
}
if let Some(path) = find_in_apr_cache(org, repo, file) {
return Some(path);
}
if let Some(filename) = file {
if let Ok(pacha_dir) = crate::commands::pull::get_pacha_cache_dir() {
let model_ref = format!("hf://{org}/{repo}/{filename}");
let (_, pacha_path) =
crate::commands::pull::build_single_cache_path(&pacha_dir, &model_ref, filename);
if pacha_path.exists() {
return Some(pacha_path);
}
}
}
None
}
pub(crate) fn download_hf_model(org: &str, repo: &str, file: Option<&str>) -> Result<PathBuf> {
let cache_dir = hf_cache_dir(org, repo)?;
let base_url = format!("https://huggingface.co/{org}/{repo}/resolve/main");
if let Some(filename) = file {
return download_single_hf_file(&cache_dir, &base_url, filename);
}
let model_path = download_main_safetensors(&cache_dir, &base_url)?;
let config_path =
download_required_companion(&base_url, &cache_dir, "config.json", &[&model_path])?;
download_required_companion(
&base_url,
&cache_dir,
"tokenizer.json",
&[&model_path, &config_path],
)?;
download_optional_companion(&base_url, &cache_dir, "tokenizer_config.json");
eprintln!("{}", " Download complete!".green());
Ok(model_path)
}
fn hf_cache_dir(org: &str, repo: &str) -> Result<PathBuf> {
let dir = dirs::home_dir()
.ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
.join(".apr")
.join("cache")
.join("hf")
.join(org)
.join(repo);
std::fs::create_dir_all(&dir)?;
Ok(dir)
}
fn download_single_hf_file(cache_dir: &Path, base_url: &str, filename: &str) -> Result<PathBuf> {
let model_url = format!("{base_url}/{filename}");
let model_path = cache_dir.join(filename);
eprintln!(" Downloading {}...", filename);
download_file(&model_url, &model_path)?;
eprintln!("{}", " Download complete!".green());
Ok(model_path)
}
fn download_main_safetensors(cache_dir: &Path, base_url: &str) -> Result<PathBuf> {
let index_url = format!("{base_url}/model.safetensors.index.json");
let index_path = cache_dir.join("model.safetensors.index.json");
if download_file(&index_url, &index_path).is_ok() {
eprintln!(" Detected sharded model (multi-tensor)");
return download_sharded_model(cache_dir, &index_path, base_url);
}
let model_url = format!("{base_url}/model.safetensors");
let model_path = cache_dir.join("model.safetensors");
eprintln!(" Downloading model.safetensors...");
download_file(&model_url, &model_path)?;
Ok(model_path)
}
fn download_required_companion(
base_url: &str,
cache_dir: &Path,
filename: &str,
cleanup_paths: &[&Path],
) -> Result<PathBuf> {
let url = format!("{base_url}/{filename}");
let path = cache_dir.join(filename);
eprintln!(" Downloading {}...", filename);
download_file(&url, &path).map_err(|e| {
for p in cleanup_paths {
let _ = std::fs::remove_file(p);
}
CliError::ValidationFailed(format!(
"{filename} is required for inference but download failed: {e}\n\
Ensure the HuggingFace repo contains {filename}"
))
})?;
Ok(path)
}
fn download_optional_companion(base_url: &str, cache_dir: &Path, filename: &str) {
let url = format!("{base_url}/{filename}");
let path = cache_dir.join(filename);
eprintln!(" Downloading {}...", filename);
if let Err(e) = download_file(&url, &path) {
eprintln!(" Note: {filename} not available (optional): {e}");
}
}
include!("inference_output.rs");
include!("run_resolve_tokenizer.rs");
include!("safetensors.rs");
include!("gguf_generate_result.rs");
include!("run_entry.rs");
include!("run_07.rs");