use flodl::nn::Module;
use flodl::{DType, Variable};
use flodl_hf::models::bert::build_extended_attention_mask;
use flodl_hf::models::distilbert::DistilBertModel;
use flodl_hf::tokenizer::HfTokenizer;
fn main() -> flodl::Result<()> {
let repo = "distilbert/distilbert-base-uncased";
let tok = HfTokenizer::from_pretrained(repo)?;
let graph = DistilBertModel::from_pretrained(repo)?;
graph.eval();
let texts = &["hello world", "flodl brings libtorch to Rust"];
let enc = tok.encode(texts)?;
let mask_f32 = enc.attention_mask.data().to_dtype(DType::Float32)?;
let mask = Variable::new(build_extended_attention_mask(&mask_f32)?, false);
let hidden = graph.forward_multi(&[enc.input_ids, mask])?;
let shape = hidden.shape();
let seq = shape[1] as usize;
let dim = shape[2] as usize;
let flat = hidden.data().to_f32_vec()?;
for (i, text) in texts.iter().enumerate() {
let base = i * seq * dim; let cls = &flat[base..base + dim]; let l2 = cls.iter().map(|x| x * x).sum::<f32>().sqrt();
println!(
"{text:?} | dim={dim} L2={l2:.3} head=[{:.3}, {:.3}, {:.3}, {:.3}, {:.3}]",
cls[0], cls[1], cls[2], cls[3], cls[4],
);
}
Ok(())
}