use std::collections::HashMap;
use burn::prelude::*;
use burn::module::{Param, ParamId};
use half::bf16;
use safetensors::SafeTensors;
use crate::config::ModelConfig;
use crate::model::privacy_filter::PrivacyFilterModel;
pub struct WeightMap {
tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
}
impl WeightMap {
pub fn from_file(path: &str) -> anyhow::Result<Self> {
let bytes = std::fs::read(path)?;
let st = SafeTensors::deserialize(&bytes)?;
let mut tensors = HashMap::with_capacity(st.len());
for (raw_key, view) in st.tensors() {
let key = raw_key.to_string();
let shape: Vec<usize> = view.shape().to_vec();
let data = view.data();
let f32s: Vec<f32> = match view.dtype() {
safetensors::Dtype::BF16 => data
.chunks_exact(2)
.map(|b| bf16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
safetensors::Dtype::F32 => data
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
other => anyhow::bail!("unsupported dtype {:?} for key {key}", other),
};
tensors.insert(key, (f32s, shape));
}
Ok(Self { tensors })
}
pub fn take<B: Backend, const N: usize>(
&mut self,
key: &str,
device: &B::Device,
) -> anyhow::Result<Tensor<B, N>> {
let (data, shape) = self.tensors.remove(key)
.ok_or_else(|| anyhow::anyhow!("weight key not found: {key}"))?;
if shape.len() != N {
anyhow::bail!("rank mismatch for {key}: expected {N}, got {}", shape.len());
}
Ok(Tensor::<B, N>::from_data(
TensorData::new(data, shape),
device,
))
}
pub fn print_keys(&self) {
let mut keys: Vec<&str> = self.tensors.keys().map(String::as_str).collect();
keys.sort();
for k in keys {
let (_, s) = &self.tensors[k];
println!(" {k:70} {s:?}");
}
}
}
fn set_linear_wb<B: Backend>(
linear: &mut burn::nn::Linear<B>,
w: Tensor<B, 2>,
b: Tensor<B, 1>,
) {
linear.weight = Param::initialized(ParamId::new(), w.transpose());
linear.bias = Some(Param::initialized(ParamId::new(), b));
}
#[allow(dead_code)]
fn set_linear_w<B: Backend>(
linear: &mut burn::nn::Linear<B>,
w: Tensor<B, 2>,
) {
linear.weight = Param::initialized(ParamId::new(), w.transpose());
}
pub fn load_model<B: Backend>(
config: &ModelConfig,
weights_path: &str,
device: &B::Device,
) -> anyhow::Result<PrivacyFilterModel<B>> {
eprintln!("Loading weights from {weights_path} ...");
let mut wm = WeightMap::from_file(weights_path)?;
eprintln!("Loaded {} tensors from safetensors file.", wm.tensors.len());
let mut model = PrivacyFilterModel::new(config, device);
let emb_w: Tensor<B, 2> = wm.take("model.embed_tokens.weight", device)?;
model.embed_tokens = Param::initialized(ParamId::new(), emb_w);
for (i, layer) in model.layers.iter_mut().enumerate() {
let p = format!("model.layers.{i}");
let ln_w: Tensor<B, 1> = wm.take(&format!("{p}.input_layernorm.weight"), device)?;
layer.input_layernorm.weight = Param::initialized(ParamId::new(), ln_w);
set_linear_wb(
&mut layer.self_attn.q_proj,
wm.take(&format!("{p}.self_attn.q_proj.weight"), device)?,
wm.take(&format!("{p}.self_attn.q_proj.bias"), device)?,
);
set_linear_wb(
&mut layer.self_attn.k_proj,
wm.take(&format!("{p}.self_attn.k_proj.weight"), device)?,
wm.take(&format!("{p}.self_attn.k_proj.bias"), device)?,
);
set_linear_wb(
&mut layer.self_attn.v_proj,
wm.take(&format!("{p}.self_attn.v_proj.weight"), device)?,
wm.take(&format!("{p}.self_attn.v_proj.bias"), device)?,
);
set_linear_wb(
&mut layer.self_attn.o_proj,
wm.take(&format!("{p}.self_attn.o_proj.weight"), device)?,
wm.take(&format!("{p}.self_attn.o_proj.bias"), device)?,
);
let sinks: Tensor<B, 1> = wm.take(&format!("{p}.self_attn.sinks"), device)?;
layer.self_attn.sinks = Param::initialized(ParamId::new(), sinks);
let pln_w: Tensor<B, 1> = wm.take(&format!("{p}.post_attention_layernorm.weight"), device)?;
layer.post_attention_layernorm.weight = Param::initialized(ParamId::new(), pln_w);
let router_w: Tensor<B, 2> = wm.take(&format!("{p}.mlp.router.weight"), device)?;
layer.mlp.router_weight = Param::initialized(ParamId::new(), router_w.transpose());
let router_b: Tensor<B, 1> = wm.take(&format!("{p}.mlp.router.bias"), device)?;
layer.mlp.router_bias = Param::initialized(ParamId::new(), router_b);
let gu: Tensor<B, 3> = wm.take(&format!("{p}.mlp.experts.gate_up_proj"), device)?;
layer.mlp.gate_up_proj = Param::initialized(ParamId::new(), gu);
let gu_b: Tensor<B, 2> = wm.take(&format!("{p}.mlp.experts.gate_up_proj_bias"), device)?;
layer.mlp.gate_up_proj_bias = Param::initialized(ParamId::new(), gu_b);
let dp: Tensor<B, 3> = wm.take(&format!("{p}.mlp.experts.down_proj"), device)?;
layer.mlp.down_proj = Param::initialized(ParamId::new(), dp);
let dp_b: Tensor<B, 2> = wm.take(&format!("{p}.mlp.experts.down_proj_bias"), device)?;
layer.mlp.down_proj_bias = Param::initialized(ParamId::new(), dp_b);
layer.mlp.cache_weights();
}
let norm_w: Tensor<B, 1> = wm.take("model.norm.weight", device)?;
model.norm.weight = Param::initialized(ParamId::new(), norm_w);
let score_w: Tensor<B, 2> = wm.take("score.weight", device)?;
model.score_weight = Param::initialized(ParamId::new(), score_w.transpose());
let score_b: Tensor<B, 1> = wm.take("score.bias", device)?;
model.score_bias = Param::initialized(ParamId::new(), score_b);
if !wm.tensors.is_empty() {
eprintln!("Warning: {} unused weight keys remain:", wm.tensors.len());
wm.print_keys();
}
eprintln!("Model loaded successfully.");
Ok(model)
}