use std::collections::HashMap;
use burn::prelude::*;
use half::bf16;
use safetensors::SafeTensors;
use crate::model::reve::Reve;
use crate::config::ModelConfig;
pub struct WeightMap {
pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
}
impl WeightMap {
pub fn from_file(path: &str) -> anyhow::Result<Self> {
let bytes = std::fs::read(path)?;
let st = SafeTensors::deserialize(&bytes)?;
let mut tensors = HashMap::with_capacity(st.len());
for (raw_key, view) in st.tensors() {
let key = raw_key
.strip_prefix("model.")
.unwrap_or(raw_key.as_str())
.to_string();
let shape: Vec<usize> = view.shape().to_vec();
let data = view.data();
let f32s: Vec<f32> = match view.dtype() {
safetensors::Dtype::BF16 => data
.chunks_exact(2)
.map(|b| bf16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
safetensors::Dtype::F32 => data
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
safetensors::Dtype::F16 => data
.chunks_exact(2)
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
other => anyhow::bail!("unsupported dtype {:?} for key {key}", other),
};
tensors.insert(key, (f32s, shape));
}
Ok(Self { tensors })
}
pub fn take<B: Backend, const N: usize>(
&mut self,
key: &str,
device: &B::Device,
) -> anyhow::Result<Tensor<B, N>> {
let (data, shape) = self.tensors.remove(key)
.ok_or_else(|| anyhow::anyhow!("weight key not found: {key}"))?;
if shape.len() != N {
anyhow::bail!("rank mismatch for {key}: expected {N}, got {}", shape.len());
}
Ok(Tensor::<B, N>::from_data(TensorData::new(data, shape), device))
}
pub fn has(&self, key: &str) -> bool {
self.tensors.contains_key(key)
}
pub fn print_keys(&self) {
let mut keys: Vec<&str> = self.tensors.keys().map(String::as_str).collect();
keys.sort();
for k in keys {
let (_, s) = &self.tensors[k];
println!(" {k:80} {s:?}");
}
}
}
fn set_linear_w<B: Backend>(linear: &mut burn::nn::Linear<B>, w: Tensor<B, 2>) {
linear.weight = linear.weight.clone().map(|_| w.transpose());
}
fn set_linear_wb<B: Backend>(linear: &mut burn::nn::Linear<B>, w: Tensor<B, 2>, b: Tensor<B, 1>) {
linear.weight = linear.weight.clone().map(|_| w.transpose());
if let Some(ref bias) = linear.bias {
linear.bias = Some(bias.clone().map(|_| b));
}
}
fn set_layernorm<B: Backend>(norm: &mut burn::nn::LayerNorm<B>, w: Tensor<B, 1>, b: Tensor<B, 1>) {
norm.gamma = norm.gamma.clone().map(|_| w);
if let Some(ref beta) = norm.beta {
norm.beta = Some(beta.clone().map(|_| b));
}
}
fn set_rmsnorm<B: Backend>(norm: &mut crate::model::rms_norm::RmsNorm<B>, w: Tensor<B, 1>) {
norm.weight = norm.weight.clone().map(|_| w);
}
pub fn load_model<B: Backend>(
cfg: &ModelConfig,
weights_path: &str,
device: &B::Device,
) -> anyhow::Result<Reve<B>> {
let mut wm = WeightMap::from_file(weights_path)?;
eprintln!("Loading {} weight tensors...", wm.tensors.len());
load_model_from_wm(cfg, &mut wm, device)
}
pub fn load_model_from_wm<B: Backend>(
cfg: &ModelConfig,
wm: &mut WeightMap,
device: &B::Device,
) -> anyhow::Result<Reve<B>> {
let mut model = Reve::new(
cfg.n_outputs,
cfg.n_chans,
cfg.n_times,
cfg.embed_dim,
cfg.depth,
cfg.heads,
cfg.head_dim,
cfg.mlp_dim_ratio,
cfg.use_geglu,
cfg.freqs,
cfg.patch_size,
cfg.patch_overlap,
cfg.attention_pooling,
device,
);
load_reve_weights(wm, &mut model, cfg, device)?;
Ok(model)
}
fn load_reve_weights<B: Backend>(
wm: &mut WeightMap,
model: &mut Reve<B>,
cfg: &ModelConfig,
device: &B::Device,
) -> anyhow::Result<()> {
if let (Ok(w), Ok(b)) = (
wm.take::<B, 2>("to_patch_embedding.0.weight", device),
wm.take::<B, 1>("to_patch_embedding.0.bias", device),
) {
set_linear_wb(&mut model.patch_embed, w, b);
}
if let Ok(w) = wm.take::<B, 2>("mlp4d.0.weight", device) {
set_linear_w(&mut model.mlp4d_linear, w);
}
if let (Ok(w), Ok(b)) = (
wm.take::<B, 1>("mlp4d.2.weight", device),
wm.take::<B, 1>("mlp4d.2.bias", device),
) {
set_layernorm(&mut model.mlp4d_ln, w, b);
}
if let (Ok(w), Ok(b)) = (
wm.take::<B, 1>("ln.weight", device),
wm.take::<B, 1>("ln.bias", device),
) {
set_layernorm(&mut model.pos_ln, w, b);
}
for i in 0..cfg.depth {
let block = &mut model.transformer.layers[i];
if let Ok(w) = wm.take::<B, 1>(
&format!("transformer.layers.{i}.0.norm.weight"), device,
) {
set_rmsnorm(&mut block.attn.norm, w);
}
if let Ok(w) = wm.take::<B, 2>(
&format!("transformer.layers.{i}.0.to_qkv.weight"), device,
) {
set_linear_w(&mut block.attn.to_qkv, w);
}
if let Ok(w) = wm.take::<B, 2>(
&format!("transformer.layers.{i}.0.to_out.weight"), device,
) {
set_linear_w(&mut block.attn.to_out, w);
}
if let Ok(w) = wm.take::<B, 1>(
&format!("transformer.layers.{i}.1.net.0.weight"), device,
) {
set_rmsnorm(&mut block.ff.norm, w);
}
if let Ok(w) = wm.take::<B, 2>(
&format!("transformer.layers.{i}.1.net.1.weight"), device,
) {
set_linear_w(&mut block.ff.linear1, w);
}
if let Ok(w) = wm.take::<B, 2>(
&format!("transformer.layers.{i}.1.net.3.weight"), device,
) {
set_linear_w(&mut block.ff.linear2, w);
}
}
if cfg.attention_pooling {
if let Ok(t) = wm.take::<B, 3>("cls_query_token", device) {
if let Some(ref mut q) = model.cls_query_token {
*q = q.clone().map(|_| t);
}
}
if let (Ok(w), Ok(b)) = (
wm.take::<B, 1>("final_layer.0.weight", device),
wm.take::<B, 1>("final_layer.0.bias", device),
) {
set_layernorm(&mut model.final_ln, w, b);
}
if let (Ok(w), Ok(b)) = (
wm.take::<B, 2>("final_layer.1.weight", device),
wm.take::<B, 1>("final_layer.1.bias", device),
) {
set_linear_wb(&mut model.final_linear, w, b);
}
} else {
if let (Ok(w), Ok(b)) = (
wm.take::<B, 1>("final_layer.1.weight", device),
wm.take::<B, 1>("final_layer.1.bias", device),
) {
set_layernorm(&mut model.final_ln, w, b);
}
if let (Ok(w), Ok(b)) = (
wm.take::<B, 2>("final_layer.2.weight", device),
wm.take::<B, 1>("final_layer.2.bias", device),
) {
set_linear_wb(&mut model.final_linear, w, b);
}
}
Ok(())
}