use std::env;
use std::fs;
use std::process::ExitCode;
use std::sync::Arc;
use rullama::gguf::GgufReader;
use rullama::reference::embed::EmbedModel;
use rullama::tokenizer::SpmTokenizer;
const REF_HELLO_WORLD: [f32; 8] = [
-0.21395, 0.02636, 0.06661, -0.01639, 0.00745, 0.01082, -0.01431, -0.00245,
];
fn main() -> ExitCode {
let path = match env::args().nth(1) {
Some(p) => p,
None => {
eprintln!("usage: embed_parity <gguf-path> [text]");
return ExitCode::from(2);
}
};
let text = env::args()
.nth(2)
.unwrap_or_else(|| "hello world".to_string());
let bytes = match fs::read(&path) {
Ok(b) => b,
Err(e) => {
eprintln!("read {path}: {e}");
return ExitCode::from(1);
}
};
let reader = match GgufReader::new(bytes) {
Ok(r) => Arc::new(r),
Err(e) => {
eprintln!("parse gguf: {e}");
return ExitCode::from(1);
}
};
let tok = match SpmTokenizer::from_gguf(&reader) {
Ok(t) => t,
Err(e) => {
eprintln!("tokenizer: {e}");
return ExitCode::from(1);
}
};
let model = match EmbedModel::new(reader) {
Ok(m) => m,
Err(e) => {
eprintln!("model: {e}");
return ExitCode::from(1);
}
};
eprintln!(
"[cfg] layers={} d_model={} heads={}/{} head_dim={} ffn={} ctx={} pool={:?} causal={}",
model.cfg.n_layers,
model.cfg.d_model,
model.cfg.n_heads,
model.cfg.n_kv_heads,
model.cfg.head_dim,
model.cfg.ffn,
model.cfg.context_length,
model.cfg.pooling,
model.cfg.causal,
);
const BOS: u32 = 2;
const EOS: u32 = 1;
let mut ids = vec![BOS];
ids.extend(tok.encode(&text));
ids.push(EOS);
eprintln!("[tok] {} ids: {:?}", ids.len(), &ids[..ids.len().min(16)]);
let v = match model.embed_ids(&ids, 0) {
Ok(v) => v,
Err(e) => {
eprintln!("embed: {e}");
return ExitCode::from(1);
}
};
if env::var("RULLAMA_EMBED_GPU").is_ok() {
let gpu = pollster::block_on(async {
use std::sync::Arc;
let ctx = rullama::backend::WgpuCtx::new().await?;
let pipes = rullama::backend::Pipelines::new(&ctx.device);
let wcache = rullama::backend::WeightCache::new(
model.weights.reader_arc(),
ctx.device.clone(),
ctx.queue.clone(),
Arc::clone(&ctx.bind_cache),
);
model.embed_ids_gpu(&ctx, &pipes, &wcache, &ids, 0).await
});
match gpu {
Ok(g) => {
let dot: f32 = v.iter().zip(g.iter()).map(|(a, b)| a * b).sum();
let na: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = g.iter().map(|x| x * x).sum::<f32>().sqrt();
let cos = dot / (na * nb + 1e-9);
let maxabs = v
.iter()
.zip(g.iter())
.map(|(a, b)| (a - b).abs())
.fold(0f32, f32::max);
println!("GPU-vs-CPU: cosine={cos:.6} max_abs_diff={maxabs:.6}");
}
Err(e) => {
eprintln!("GPU embed: {e}");
return ExitCode::from(1);
}
}
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
println!("dim: {}", v.len());
println!("L2: {norm:.6}");
println!(
"first8: [{}]",
v[..8]
.iter()
.map(|x| format!("{x:.5}"))
.collect::<Vec<_>>()
.join(", ")
);
if env::var("RULLAMA_EMBED_DUMP").is_ok() {
let line = v
.iter()
.map(|x| format!("{x:.6}"))
.collect::<Vec<_>>()
.join(" ");
eprintln!("DUMP {line}");
}
if text == "hello world" {
let dot: f32 = v[..8]
.iter()
.zip(REF_HELLO_WORLD.iter())
.map(|(a, b)| a * b)
.sum();
let na: f32 = v[..8].iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = REF_HELLO_WORLD.iter().map(|x| x * x).sum::<f32>().sqrt();
let cos = dot / (na * nb + 1e-9);
println!(
"ref8: [{}]",
REF_HELLO_WORLD.map(|x| format!("{x:.5}")).join(", ")
);
println!("cos(first8 vs ref): {cos:.4}");
}
ExitCode::SUCCESS
}