use std::collections::HashMap;
use crate::tensor::Tensor;
use crate::model::tribe::TribeV2;
use crate::model::encoder::{LayerBlock, XTransformerEncoder};
use crate::model::projector::Projector;
pub struct WeightMap {
pub tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
}
impl WeightMap {
pub fn from_safetensors(path: &str) -> anyhow::Result<Self> {
let bytes = std::fs::read(path)?;
let st = safetensors::SafeTensors::deserialize(&bytes)?;
let mut tensors = HashMap::with_capacity(st.len());
for (key, view) in st.tensors() {
let key = key.strip_prefix("model.").unwrap_or(&key).to_string();
let shape: Vec<usize> = view.shape().to_vec();
let data = view.data();
let f32s: Vec<f32> = match view.dtype() {
safetensors::Dtype::BF16 => {
data.chunks_exact(2)
.map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
.collect()
}
safetensors::Dtype::F16 => {
data.chunks_exact(2)
.map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
.collect()
}
safetensors::Dtype::F32 => {
data.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect()
}
other => anyhow::bail!("unsupported dtype {:?}", other),
};
tensors.insert(key, (f32s, shape));
}
Ok(Self { tensors })
}
pub fn take(&mut self, key: &str) -> anyhow::Result<Tensor> {
let (data, shape) = self.tensors.remove(key)
.ok_or_else(|| anyhow::anyhow!("weight key not found: {key}"))?;
Ok(Tensor::from_vec(data, shape))
}
pub fn try_take(&mut self, key: &str) -> Option<Tensor> {
self.tensors.remove(key).map(|(data, shape)| Tensor::from_vec(data, shape))
}
pub fn remaining_keys(&self) -> Vec<String> {
let mut keys: Vec<_> = self.tensors.keys().cloned().collect();
keys.sort();
keys
}
}
pub fn load_checkpoint(path: &str) -> anyhow::Result<WeightMap> {
WeightMap::from_safetensors(path)
}
pub fn load_weights(wm: &mut WeightMap, model: &mut TribeV2) -> anyhow::Result<()> {
for np in &mut model.projectors {
let name = &np.name;
let projector = &mut np.projector;
match projector {
Projector::SubjectLayers(ref mut sl) => {
if let Ok(w) = wm.take(&format!("projectors.{name}.weights")) {
sl.weights = w;
}
if let Ok(b) = wm.take(&format!("projectors.{name}.bias")) {
sl.bias = Some(b);
}
}
Projector::Mlp(ref mut mlp) => {
if mlp.layers.len() == 1 {
if let Ok(w) = wm.take(&format!("projectors.{name}.0.weight")) {
mlp.layers[0].weight = w.transpose_last2();
} else if let Ok(w) = wm.take(&format!("projectors.{name}.weight")) {
mlp.layers[0].weight = w.transpose_last2();
}
if let Ok(b) = wm.take(&format!("projectors.{name}.0.bias")) {
mlp.layers[0].bias = b;
} else if let Ok(b) = wm.take(&format!("projectors.{name}.bias")) {
mlp.layers[0].bias = b;
}
} else {
let n_layers = mlp.layers.len();
for (li, layer) in mlp.layers.iter_mut().enumerate() {
let pytorch_idx = if li < n_layers - 1 {
li * 4 } else {
(n_layers - 1) * 4 };
if let Ok(w) = wm.take(&format!("projectors.{name}.{pytorch_idx}.0.weight")) {
layer.weight = w.transpose_last2();
} else if let Ok(w) = wm.take(&format!("projectors.{name}.{pytorch_idx}.weight")) {
layer.weight = w.transpose_last2();
}
if let Ok(b) = wm.take(&format!("projectors.{name}.{pytorch_idx}.0.bias")) {
layer.bias = b;
} else if let Ok(b) = wm.take(&format!("projectors.{name}.{pytorch_idx}.bias")) {
layer.bias = b;
}
if let Some(ref mut ln_w) = layer.ln_weight {
if let Ok(w) = wm.take(&format!("projectors.{name}.{pytorch_idx}.1.weight")) {
*ln_w = w;
}
}
if let Some(ref mut ln_b) = layer.ln_bias {
if let Ok(b) = wm.take(&format!("projectors.{name}.{pytorch_idx}.1.bias")) {
*ln_b = b;
}
}
}
}
}
}
}
if let Some(ref mut combiner) = model.combiner {
if let Some(mlp) = combiner.as_mlp_mut() {
if let Ok(w) = wm.take("combiner.0.weight") {
mlp.layers[0].weight = w.transpose_last2();
} else if let Ok(w) = wm.take("combiner.weight") {
mlp.layers[0].weight = w.transpose_last2();
}
if let Ok(b) = wm.take("combiner.0.bias") {
mlp.layers[0].bias = b;
} else if let Ok(b) = wm.take("combiner.bias") {
mlp.layers[0].bias = b;
}
}
}
if let Some(ref mut tpe) = model.time_pos_embed {
if let Ok(t) = wm.take("time_pos_embed") {
*tpe = t;
}
}
if let Some(ref mut se) = model.subject_embed {
if let Ok(t) = wm.take("subject_embed.weight") {
*se = t;
}
}
if let Some(ref mut encoder) = model.encoder {
load_encoder_weights(wm, encoder)?;
}
if let Some(ref mut lr) = model.low_rank_head {
if let Ok(w) = wm.take("low_rank_head.weight") {
*lr = w.transpose_last2();
}
}
if let Ok(w) = wm.take("predictor.weights") {
model.predictor.weights = w;
}
if let Ok(b) = wm.take("predictor.bias") {
model.predictor.bias = Some(b);
}
if let Some(ref mut ts) = model.temporal_smoothing {
if let Ok(k) = wm.take("temporal_smoothing.weight") {
ts.kernel = k;
}
}
Ok(())
}
fn load_encoder_weights(wm: &mut WeightMap, encoder: &mut XTransformerEncoder) -> anyhow::Result<()> {
for (i, layer) in encoder.layers.iter_mut().enumerate() {
let prefix = format!("encoder.layers.{i}");
if let Ok(g) = wm.take(&format!("{prefix}.0.0.g")) {
layer.pre_norm.g = g.data[0];
}
match &mut layer.block {
LayerBlock::Attn(attn) => {
if let Ok(w) = wm.take(&format!("{prefix}.1.to_q.weight")) {
attn.w_q = w.transpose_last2();
}
if let Ok(w) = wm.take(&format!("{prefix}.1.to_k.weight")) {
attn.w_k = w.transpose_last2();
}
if let Ok(w) = wm.take(&format!("{prefix}.1.to_v.weight")) {
attn.w_v = w.transpose_last2();
}
if let Ok(w) = wm.take(&format!("{prefix}.1.to_out.weight")) {
attn.w_out = w.transpose_last2();
}
}
LayerBlock::FF(ff) => {
if let Ok(w) = wm.take(&format!("{prefix}.1.ff.0.0.weight")) {
ff.w1 = w.transpose_last2(); }
if let Ok(b) = wm.take(&format!("{prefix}.1.ff.0.0.bias")) {
ff.b1 = b;
}
if let Ok(w) = wm.take(&format!("{prefix}.1.ff.2.weight")) {
ff.w2 = w.transpose_last2(); }
if let Ok(b) = wm.take(&format!("{prefix}.1.ff.2.bias")) {
ff.b2 = b;
}
}
}
if let Some(ref mut rs) = layer.residual.residual_scale {
if let Ok(s) = wm.take(&format!("{prefix}.2.residual_scale")) {
*rs = s;
}
}
}
if let Ok(g) = wm.take("encoder.final_norm.g") {
encoder.final_norm.g = g.data[0];
}
wm.try_take("encoder.rotary_pos_emb.inv_freq");
Ok(())
}