#![allow(clippy::doc_markdown)]
#![allow(clippy::missing_docs_in_private_items)]
#![allow(clippy::unnecessary_wraps)]
#![allow(clippy::cast_precision_loss)]
use candle_mi::{
GenerationResult, HookSpec, MIModel, MITokenizer, SUPPORTED_MODEL_TYPES, sample_token,
};
#[cfg(feature = "memory")]
use candle_mi::{MemoryReport, MemorySnapshot};
use std::path::{Path, PathBuf};
use std::time::Instant;
fn main() {
if let Err(e) = run() {
eprintln!("Error: {e}");
std::process::exit(1);
}
}
fn run() -> candle_mi::Result<()> {
let prompt = "The capital of France is";
let max_new_tokens: usize = 20;
let args: Vec<String> = std::env::args().collect();
if args.len() > 1 {
#[allow(clippy::indexing_slicing)]
let model_id = &args[1];
return run_single_model(model_id, prompt, max_new_tokens);
}
let cached = discover_cached_models();
if cached.is_empty() {
println!("No cached transformer models found in the HuggingFace Hub cache.");
println!("Download one first, e.g.:");
println!(" cargo run --example fast_download -- meta-llama/Llama-3.2-1B");
return Ok(());
}
println!(
"Found {} supported transformer(s) in HF cache:\n",
cached.len()
);
for (model_id, model_type, snapshot) in &cached {
println!("=== {model_id} (model_type: {model_type}) ===");
if let Err(e) = run_model(model_id, snapshot, prompt, max_new_tokens) {
println!(" Skipped: {e}\n");
}
}
Ok(())
}
fn run_single_model(model_id: &str, prompt: &str, max_new_tokens: usize) -> candle_mi::Result<()> {
println!("=== {model_id} ===");
#[cfg(feature = "memory")]
let mem_before = MemorySnapshot::now(
&candle_core::Device::cuda_if_available(0).unwrap_or(candle_core::Device::Cpu),
)?;
let t0 = Instant::now();
let model = MIModel::from_pretrained(model_id)?;
let load_time = t0.elapsed();
let n_layers = model.num_layers();
let n_heads = model.num_heads();
let hidden = model.hidden_size();
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let weight_mb = estimate_weight_mb(n_layers, hidden);
println!(
" Layers: {n_layers}, heads: {n_heads}, hidden: {hidden}, device: {:?}",
model.device()
);
println!(" Estimated F32 weight size: {weight_mb:.0} MB");
println!(" Load time: {load_time:.2?}");
#[cfg(feature = "memory")]
{
let mem_after = MemorySnapshot::now(model.device())?;
MemoryReport::new(mem_before, mem_after).print_before_after("Model load");
}
let tokenizer = model.tokenizer().ok_or(candle_mi::MIError::Tokenizer(
"model has no embedded tokenizer".into(),
))?;
let prompt_tokens = tokenizer.encode(prompt)?;
let t1 = Instant::now();
let result = generate(&model, tokenizer, &prompt_tokens, max_new_tokens)?;
let gen_time = t1.elapsed();
println!("\n --- Generation Result ---");
println!(" Prompt tokens : {}", result.prompt_tokens.len());
println!(" Generated : {}", result.generated_tokens.len());
println!(" Generation time: {gen_time:.2?}");
println!(" Full text : \"{}\"", result.full_text);
Ok(())
}
fn hf_cache_dir() -> Option<PathBuf> {
if let Ok(cache) = std::env::var("HF_HOME") {
return Some(PathBuf::from(cache).join("hub"));
}
if let Ok(home) = std::env::var("USERPROFILE") {
let p = PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("hub");
if p.is_dir() {
return Some(p);
}
}
if let Ok(home) = std::env::var("HOME") {
let p = PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("hub");
if p.is_dir() {
return Some(p);
}
}
None
}
fn find_snapshot(cache_dir: &Path, model_id: &str) -> Option<PathBuf> {
let dir_name = format!("models--{}", model_id.replace('/', "--"));
let snapshots = cache_dir.join(dir_name).join("snapshots");
let entry = std::fs::read_dir(snapshots).ok()?.next()?.ok()?;
Some(entry.path())
}
fn read_model_type(snapshot: &Path) -> Option<String> {
let config_path = snapshot.join("config.json");
let text = std::fs::read_to_string(config_path).ok()?;
let json: serde_json::Value = serde_json::from_str(&text).ok()?;
json.get("model_type")?.as_str().map(String::from)
}
fn discover_cached_models() -> Vec<(String, String, PathBuf)> {
let Some(cache_dir) = hf_cache_dir() else {
return Vec::new();
};
let Ok(entries) = std::fs::read_dir(&cache_dir) else {
return Vec::new();
};
let mut models = Vec::new();
for entry in entries.flatten() {
let name = entry.file_name();
let Some(dir_name) = name.to_str() else {
continue;
};
let Some(rest) = dir_name.strip_prefix("models--") else {
continue;
};
let model_id = rest.replacen("--", "/", 1);
let Some(snapshot) = find_snapshot(&cache_dir, &model_id) else {
continue;
};
let Some(model_type) = read_model_type(&snapshot) else {
continue;
};
if SUPPORTED_MODEL_TYPES.contains(&model_type.as_str()) {
models.push((model_id, model_type, snapshot));
}
}
models.sort_by(|a, b| a.0.cmp(&b.0));
models
}
fn run_model(
model_id: &str,
snapshot: &Path,
prompt: &str,
max_new_tokens: usize,
) -> candle_mi::Result<()> {
#[cfg(feature = "memory")]
let mem_before = MemorySnapshot::now(
&candle_core::Device::cuda_if_available(0).unwrap_or(candle_core::Device::Cpu),
)?;
let t0 = Instant::now();
let model = MIModel::from_pretrained(model_id)?;
let load_time = t0.elapsed();
let n_layers = model.num_layers();
let n_heads = model.num_heads();
let hidden = model.hidden_size();
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let weight_mb = estimate_weight_mb(n_layers, hidden);
println!(
" {} layers, {} heads, {} hidden, device: {:?}",
n_layers,
n_heads,
hidden,
model.device()
);
println!(" Estimated F32 weight size: {weight_mb:.0} MB | Load: {load_time:.2?}");
#[cfg(feature = "memory")]
{
let mem_after = MemorySnapshot::now(model.device())?;
MemoryReport::new(mem_before, mem_after).print_before_after("Memory");
}
let tokenizer_path = snapshot.join("tokenizer.json");
if !tokenizer_path.exists() {
return Err(candle_mi::MIError::Tokenizer(
"tokenizer.json not found in snapshot".into(),
));
}
let tokenizer = MITokenizer::from_hf_path(tokenizer_path)?;
let prompt_tokens = tokenizer.encode(prompt)?;
let t1 = Instant::now();
let result = generate(&model, &tokenizer, &prompt_tokens, max_new_tokens)?;
let gen_time = t1.elapsed();
println!(" --- Result (generation: {gen_time:.2?}) ---");
println!(
" {} prompt + {} generated = {} total tokens",
result.prompt_tokens.len(),
result.generated_tokens.len(),
result.total_tokens
);
println!(" Full text: \"{}\"", result.full_text);
println!();
Ok(())
}
fn generate(
model: &MIModel,
tokenizer: &MITokenizer,
prompt_tokens: &[u32],
max_new_tokens: usize,
) -> candle_mi::Result<GenerationResult> {
let mut tokens = prompt_tokens.to_vec();
let hooks = HookSpec::new();
let prompt_text = tokenizer.decode(prompt_tokens)?;
print!(" {prompt_text}");
for _ in 0..max_new_tokens {
let input = candle_core::Tensor::new(tokens.as_slice(), model.device())?.unsqueeze(0)?;
let cache = model.forward(&input, &hooks)?;
let logits = cache.output();
let seq_len = tokens.len();
let last_logits = logits.get(0)?.get(seq_len - 1)?;
let next_token = sample_token(&last_logits, 0.0)?;
let token_text = tokenizer.decode(&[next_token])?;
print!("{token_text}");
tokens.push(next_token);
}
println!();
let prompt_len = prompt_tokens.len();
let generated_tokens = tokens.get(prompt_len..).unwrap_or_default().to_vec();
let full_text = tokenizer.decode(&tokens)?;
let generated_text = tokenizer.decode(&generated_tokens)?;
Ok(GenerationResult {
prompt: prompt_text,
full_text,
generated_text,
prompt_tokens: prompt_tokens.to_vec(),
generated_tokens,
total_tokens: tokens.len(),
})
}
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
fn estimate_weight_mb(n_layers: usize, hidden: usize) -> f64 {
let params_per_layer = 12.0 * (hidden as f64) * (hidden as f64);
let total_params = (n_layers as f64) * params_per_layer;
total_params * 4.0 / 1_000_000.0
}