use candle_mi::{HookSpec, MIModel, MITokenizer, SUPPORTED_MODEL_TYPES};
use std::path::{Path, PathBuf};
fn main() {
let prompt = "The capital of France is";
let cached = discover_cached_models();
if cached.is_empty() {
println!("No cached transformer models found in the HuggingFace Hub cache.");
println!("Download one first, e.g.:");
println!(
" python -c \"from huggingface_hub import snapshot_download; \
snapshot_download('meta-llama/Llama-3.2-1B')\""
);
println!();
println!("Or with Rust:");
println!(" cargo run --example fast_download -- meta-llama/Llama-3.2-1B");
return;
}
println!(
"Found {} supported transformer(s) in HF cache:\n",
cached.len()
);
for (model_id, model_type, snapshot) in &cached {
println!("--- {model_id} (model_type: {model_type}) ---");
if let Err(e) = run_model(model_id, snapshot, prompt) {
println!(" Skipped: {e}\n");
}
}
}
fn hf_cache_dir() -> Option<PathBuf> {
if let Ok(cache) = std::env::var("HF_HOME") {
return Some(PathBuf::from(cache).join("hub"));
}
if let Ok(home) = std::env::var("USERPROFILE") {
let p = PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("hub");
if p.is_dir() {
return Some(p);
}
}
if let Ok(home) = std::env::var("HOME") {
let p = PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("hub");
if p.is_dir() {
return Some(p);
}
}
None
}
fn find_snapshot(cache_dir: &Path, model_id: &str) -> Option<PathBuf> {
let dir_name = format!("models--{}", model_id.replace('/', "--"));
let snapshots = cache_dir.join(dir_name).join("snapshots");
let entry = std::fs::read_dir(snapshots).ok()?.next()?.ok()?;
Some(entry.path())
}
fn read_model_type(snapshot: &Path) -> Option<String> {
let config_path = snapshot.join("config.json");
let text = std::fs::read_to_string(config_path).ok()?;
let json: serde_json::Value = serde_json::from_str(&text).ok()?;
json.get("model_type")?.as_str().map(String::from)
}
fn discover_cached_models() -> Vec<(String, String, PathBuf)> {
let Some(cache_dir) = hf_cache_dir() else {
return Vec::new();
};
let Ok(entries) = std::fs::read_dir(&cache_dir) else {
return Vec::new();
};
let mut models = Vec::new();
for entry in entries.flatten() {
let name = entry.file_name();
let Some(dir_name) = name.to_str() else {
continue;
};
let Some(rest) = dir_name.strip_prefix("models--") else {
continue;
};
let model_id = rest.replacen("--", "/", 1);
let Some(snapshot) = find_snapshot(&cache_dir, &model_id) else {
continue;
};
let Some(model_type) = read_model_type(&snapshot) else {
continue;
};
if SUPPORTED_MODEL_TYPES.contains(&model_type.as_str()) {
models.push((model_id, model_type, snapshot));
}
}
models.sort_by(|a, b| a.0.cmp(&b.0));
models
}
fn run_model(model_id: &str, snapshot: &Path, prompt: &str) -> candle_mi::Result<()> {
let model = MIModel::from_pretrained(model_id)?;
println!(
" {} layers, {} hidden, device: {:?}",
model.num_layers(),
model.hidden_size(),
model.device()
);
let tokenizer_path = snapshot.join("tokenizer.json");
if !tokenizer_path.exists() {
return Err(candle_mi::MIError::Tokenizer(
"tokenizer.json not found in snapshot".into(),
));
}
let tokenizer = MITokenizer::from_hf_path(tokenizer_path)?;
let token_ids = tokenizer.encode(prompt)?;
let input = candle_core::Tensor::new(&token_ids[..], model.device())?.unsqueeze(0)?; println!(" Prompt: \"{prompt}\" ({} tokens)", token_ids.len());
let hooks = HookSpec::new();
let cache = model.forward(&input, &hooks)?;
let logits = cache.output();
let seq_len = token_ids.len();
let last_logits = logits.get(0)?.get(seq_len - 1)?; print_top_k(&last_logits, &tokenizer, 5)?;
println!();
Ok(())
}
fn print_top_k(
logits: &candle_core::Tensor,
tokenizer: &MITokenizer,
k: usize,
) -> candle_mi::Result<()> {
let logits_f32: Vec<f32> = logits
.to_dtype(candle_core::DType::F32)?
.flatten_all()?
.to_vec1()?;
let mut indexed: Vec<(usize, f32)> = logits_f32.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
println!(" Top-{k} predictions:");
for (rank, (idx, score)) in indexed.iter().take(k).enumerate() {
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
let token_text = tokenizer.decode(&[*idx as u32])?;
println!(
" #{}: {:>8.3} \"{}\"",
rank + 1,
score,
token_text.trim()
);
}
Ok(())
}