fn download_sharded_model(cache_dir: &Path, index_path: &Path, base_url: &str) -> Result<PathBuf> {
let index_content = std::fs::read_to_string(index_path)
.map_err(|e| CliError::ValidationFailed(format!("Failed to read index file: {e}")))?;
let shard_files: HashSet<String> = extract_shard_files(&index_content);
if shard_files.is_empty() {
return Err(CliError::ValidationFailed(
"Sharded model index contains no shard files".to_string(),
));
}
let total_shards = shard_files.len();
eprintln!(" Found {} shard files to download", total_shards);
for (i, shard_file) in shard_files.iter().enumerate() {
let shard_url = format!("{base_url}/{shard_file}");
let shard_path = cache_dir.join(shard_file);
if shard_path.exists() {
eprintln!(" [{}/{}] {} (cached)", i + 1, total_shards, shard_file);
continue;
}
eprintln!(
" [{}/{}] Downloading {}...",
i + 1,
total_shards,
shard_file
);
download_file(&shard_url, &shard_path)?;
}
Ok(index_path.to_path_buf())
}
fn find_brace_content(text: &str) -> Option<&str> {
let start = text.find('{')?;
let content = &text[start + 1..];
let mut depth = 1usize;
for (i, c) in content.char_indices() {
match c {
'{' => depth += 1,
'}' if depth == 1 => return Some(&content[..i]),
'}' => depth -= 1,
_ => {}
}
}
None
}
fn extract_shard_filename(kv_pair: &str) -> Option<String> {
let colon_pos = kv_pair.rfind(':')?;
let value = kv_pair[colon_pos + 1..].trim();
let filename = value.trim_matches(|c: char| c == '"' || c.is_whitespace());
if filename.ends_with(".safetensors") && !filename.is_empty() {
Some(filename.to_string())
} else {
None
}
}
fn extract_shard_files(json: &str) -> HashSet<String> {
let Some(weight_map_start) = json.find("\"weight_map\"") else {
return HashSet::new();
};
let Some(entries) = find_brace_content(&json[weight_map_start..]) else {
return HashSet::new();
};
entries
.split(',')
.filter_map(extract_shard_filename)
.collect()
}
fn download_url_model(url: &str) -> Result<PathBuf> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
url.hash(&mut hasher);
let url_hash = format!("{:016x}", hasher.finish());
let filename = url
.rsplit('/')
.next()
.filter(|s| !s.is_empty() && s.contains('.'))
.unwrap_or("model.safetensors");
let cache_dir = dirs::home_dir()
.ok_or_else(|| CliError::ValidationFailed("Cannot find home directory".to_string()))?
.join(".apr")
.join("cache")
.join("url")
.join(&url_hash);
std::fs::create_dir_all(&cache_dir)?;
let model_path = cache_dir.join(filename);
eprintln!(" Downloading {}...", filename);
download_file(url, &model_path)?;
eprintln!("{}", " Download complete!".green());
Ok(model_path)
}
fn download_file(url: &str, path: &Path) -> Result<()> {
use std::io::Write;
let response = ureq::get(url)
.call()
.map_err(|e| CliError::ValidationFailed(format!("Download failed: {e}")))?;
if response.status() != 200 {
return Err(CliError::ValidationFailed(format!(
"Download failed with status {}: {}",
response.status(),
url
)));
}
let mut file = std::fs::File::create(path)?;
let mut reader = response.into_reader();
std::io::copy(&mut reader, &mut file)?;
Ok(())
}
#[allow(clippy::unnecessary_wraps)] fn find_model_in_dir(dir: &Path) -> Result<PathBuf> {
for ext in &["apr", "safetensors", "gguf"] {
let pattern = dir.join(format!("*.{ext}"));
if let Some(path) = glob_first(&pattern) {
return Ok(path);
}
}
Ok(dir.to_path_buf())
}
fn glob_first(pattern: &Path) -> Option<PathBuf> {
glob::glob(pattern.to_str()?).ok()?.next()?.ok()
}
struct InferenceOutput {
text: String,
tokens_generated: Option<usize>,
inference_ms: Option<f64>,
tok_per_sec: Option<f64>,
used_gpu: Option<bool>,
generated_tokens: Option<Vec<u32>>,
}
fn execute_inference(
model_path: &Path,
input_path: Option<&PathBuf>,
options: &RunOptions,
) -> Result<InferenceOutput> {
let metadata = std::fs::metadata(model_path)?;
let use_mmap = metadata.len() > 50 * 1024 * 1024;
if use_mmap && options.verbose {
eprintln!(
"{}",
format!("Using mmap for {}MB model", metadata.len() / 1024 / 1024).dimmed()
);
}
#[cfg(feature = "whisper")]
if let Some(input) = input_path {
let ext = input.extension().and_then(|e| e.to_str()).unwrap_or("");
let is_audio = matches!(ext, "wav" | "mp3" | "flac" | "ogg" | "m4a");
if is_audio {
return execute_with_whisper(model_path, input, options);
}
}
#[cfg(feature = "inference")]
{
return execute_with_realizar(model_path, input_path, options, use_mmap);
}
#[cfg(not(feature = "inference"))]
{
let input_desc =
input_path.map_or_else(|| "stdin".to_string(), |p| p.display().to_string());
Ok(InferenceOutput {
text: format!(
"[Inference requires --features inference]\nModel: {}\nInput: {}\nFormat: {}\nGPU: {}",
model_path.display(),
input_desc,
options.output_format,
if options.no_gpu { "disabled" } else { "auto" }
),
tokens_generated: None,
inference_ms: None,
tok_per_sec: None,
used_gpu: None,
generated_tokens: None,
})
}
}
#[cfg(feature = "inference")]
fn execute_with_realizar(
model_path: &Path,
input_path: Option<&PathBuf>,
options: &RunOptions,
_use_mmap: bool,
) -> Result<InferenceOutput> {
use realizar::{run_inference, InferenceConfig};
let prompt = if let Some(ref p) = options.prompt {
Some(p.clone())
} else if let Some(path) = input_path {
Some(std::fs::read_to_string(path)?)
} else {
None
};
let mut config = InferenceConfig::new(model_path);
if let Some(ref p) = prompt {
config = config.with_prompt(p);
}
config = config
.with_max_tokens(options.max_tokens)
.with_verbose(options.verbose);
if options.no_gpu {
config = config.without_gpu();
}
if options.trace {
config = config.with_trace(true);
}
if let Some(ref trace_path) = options.trace_output {
config = config.with_trace_output(trace_path);
}
let result = run_inference(&config)
.map_err(|e| CliError::InferenceFailed(format!("Inference failed: {e}")))?;
if options.benchmark {
eprintln!(
"{}",
format!(
"Generated {} tokens in {:.1}ms ({:.1} tok/s)",
result.generated_token_count, result.inference_ms, result.tok_per_sec
)
.green()
);
}
let generated_tokens = if result.tokens.len() > result.input_token_count {
Some(result.tokens[result.input_token_count..].to_vec())
} else {
Some(Vec::new())
};
Ok(InferenceOutput {
text: result.text,
tokens_generated: Some(result.generated_token_count),
inference_ms: Some(result.inference_ms),
tok_per_sec: Some(result.tok_per_sec),
used_gpu: Some(result.used_gpu),
generated_tokens,
})
}
#[cfg(feature = "whisper")]
fn execute_with_whisper(
model_path: &Path,
audio_path: &Path,
options: &RunOptions,
) -> Result<InferenceOutput> {
use whisper_apr::audio::decode::load_audio_file;
let start = std::time::Instant::now();
if options.verbose {
eprintln!("[WHISPER] Loading model: {}", model_path.display());
eprintln!("[WHISPER] Audio input: {}", audio_path.display());
}
let audio = load_audio_file(audio_path)
.map_err(|e| CliError::InferenceFailed(format!("Audio load failed: {e}")))?;
if options.verbose {
eprintln!("[WHISPER] Audio: {} samples ({:.1}s at 16kHz)", audio.len(), audio.len() as f64 / 16000.0);
}
let model_data = std::fs::read(model_path)?;
let whisper = whisper_apr::WhisperApr::load_from_apr(&model_data)
.map_err(|e| CliError::InferenceFailed(format!("Whisper model load failed: {e}")))?;
if options.verbose {
eprintln!("[WHISPER] Model loaded, transcribing...");
}
let result = whisper.transcribe(&audio, Default::default())
.map_err(|e| CliError::InferenceFailed(format!("Transcription failed: {e}")))?;
let duration = start.elapsed();
let word_count = result.text.split_whitespace().count();
Ok(InferenceOutput {
text: result.text,
tokens_generated: Some(word_count),
inference_ms: Some(duration.as_secs_f64() * 1000.0),
tok_per_sec: Some(word_count as f64 / duration.as_secs_f64()),
used_gpu: Some(false),
generated_tokens: None,
})
}