use candle_core::{DType, Device, IndexOp};
use candle_mi::sae::SparseAutoencoder;
use candle_mi::{
GenericTransformer, HookPoint, HookSpec, MIBackend, MITokenizer, TransformerConfig,
};
use std::path::PathBuf;
fn main() {
if let Err(e) = run() {
eprintln!("Error: {e}");
std::process::exit(1);
}
}
fn run() -> candle_mi::Result<()> {
let device = Device::cuda_if_available(0)?;
println!("Device: {device:?}");
let snapshot =
find_snapshot("google/gemma-2-2b").expect("google/gemma-2-2b not found in HF cache");
let config_str =
std::fs::read_to_string(snapshot.join("config.json")).expect("cannot read config.json");
let json: serde_json::Value =
serde_json::from_str(&config_str).expect("cannot parse config.json");
let config = TransformerConfig::from_hf_config(&json)?;
let dtype = DType::F32;
let paths = safetensors_paths(&snapshot);
let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&paths, dtype, &device)? };
let model = GenericTransformer::load(config, &device, dtype, vb)?;
let tokenizer = MITokenizer::from_hf_path(snapshot.join("tokenizer.json"))?;
println!(
"Model: {} layers, {} hidden",
model.num_layers(),
model.hidden_size()
);
let sae = SparseAutoencoder::from_pretrained_npz(
"google/gemma-scope-2b-pt-res",
"layer_0/width_16k/average_l0_105/params.npz",
0, &device,
)?;
println!(
"SAE: d_in={}, d_sae={}, arch={:?}, hook={}",
sae.d_in(),
sae.d_sae(),
sae.config().architecture,
sae.config().hook_name,
);
let prompt = "The capital of France is";
let token_ids = tokenizer.encode(prompt)?;
let seq_len = token_ids.len();
println!("\nPrompt: \"{prompt}\" ({seq_len} tokens)");
let input = candle_core::Tensor::new(&token_ids[..], &device)?.unsqueeze(0)?;
let mut hooks = HookSpec::new();
hooks.capture(HookPoint::ResidPost(0));
let result = model.forward(&input, &hooks)?;
let resid_post = result.require(&HookPoint::ResidPost(0))?;
let resid_last = resid_post.i((0, seq_len - 1))?; let sparse = sae.encode_sparse(&resid_last)?;
println!("\nActive features: {}", sparse.len());
println!("Top-10 features:");
for (rank, (fid, val)) in sparse.features.iter().take(10).enumerate() {
println!(" #{}: feature {} = {val:.4}", rank + 1, fid.index);
}
let mse = sae.reconstruction_error(resid_post)?;
println!("\nReconstruction MSE: {mse:.6}");
if let Some((top_fid, _)) = sparse.features.first() {
let dec_vec = sae.decoder_vector(top_fid.index)?;
let norm: f32 = dec_vec.sqr()?.sum_all()?.to_scalar::<f32>()?.sqrt();
println!("Top feature {} decoder norm: {norm:.4}", top_fid.index);
}
Ok(())
}
fn hf_cache_dir() -> Option<PathBuf> {
if let Ok(cache) = std::env::var("HF_HOME") {
return Some(PathBuf::from(cache).join("hub"));
}
if let Ok(home) = std::env::var("USERPROFILE") {
let p = PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("hub");
if p.is_dir() {
return Some(p);
}
}
if let Ok(home) = std::env::var("HOME") {
let p = PathBuf::from(home)
.join(".cache")
.join("huggingface")
.join("hub");
if p.is_dir() {
return Some(p);
}
}
None
}
fn find_snapshot(model_id: &str) -> Option<PathBuf> {
let cache_dir = hf_cache_dir()?;
let dir_name = format!("models--{}", model_id.replace('/', "--"));
let snapshots = cache_dir.join(dir_name).join("snapshots");
let entry = std::fs::read_dir(snapshots).ok()?.next()?.ok()?;
Some(entry.path())
}
fn safetensors_paths(snapshot: &std::path::Path) -> Vec<PathBuf> {
let single = snapshot.join("model.safetensors");
if single.exists() {
return vec![single];
}
let index_path = snapshot.join("model.safetensors.index.json");
let index_str = std::fs::read_to_string(&index_path).unwrap_or_else(|_| {
panic!(
"no model.safetensors or index.json in {}",
snapshot.display()
)
});
let index: serde_json::Value = serde_json::from_str(&index_str).unwrap();
let weight_map = index["weight_map"].as_object().unwrap();
let mut shard_names: Vec<String> = weight_map
.values()
.map(|v| v.as_str().unwrap().to_string())
.collect();
shard_names.sort();
shard_names.dedup();
shard_names.iter().map(|name| snapshot.join(name)).collect()
}