use std::rc::Rc;
use burn::module::ParamId;
use burn::tensor::backend::Backend;
use burn::tensor::{DType, Shape, TensorData};
use burn_store::{
ModuleAdapter, ModuleSnapshot, ModuleStore, SafetensorsStore, TensorSnapshot,
TensorSnapshotError,
};
use super::gemma::{GemmaConfig, GemmaModel};
const EXPECTED_PARAMS: usize = 2 + 18 * 13;
fn hf_to_burn_path(name: &str) -> Option<String> {
if name == "lm_head.weight" || name == "model.lm_head.weight" {
return None;
}
let n = name.strip_prefix("model.").unwrap_or(name);
if n == "embed_tokens.weight" {
return Some("embed.weight".to_string());
}
if n == "norm.weight" {
return Some("norm.gamma".to_string());
}
let rest = n.strip_prefix("layers.")?;
let (idx, tail) = rest.split_once('.')?;
let mapped_tail = match tail {
"input_layernorm.weight" => "input_layernorm.gamma",
"post_attention_layernorm.weight" => "post_attention_layernorm.gamma",
"pre_feedforward_layernorm.weight" => "pre_feedforward_layernorm.gamma",
"post_feedforward_layernorm.weight" => "post_feedforward_layernorm.gamma",
"self_attn.q_proj.weight" => "self_attn.q_proj.weight",
"self_attn.k_proj.weight" => "self_attn.k_proj.weight",
"self_attn.v_proj.weight" => "self_attn.v_proj.weight",
"self_attn.o_proj.weight" => "self_attn.o_proj.weight",
"self_attn.q_norm.weight" => "self_attn.q_norm.gamma",
"self_attn.k_norm.weight" => "self_attn.k_norm.gamma",
"mlp.gate_proj.weight" => "mlp.gate_proj.weight",
"mlp.up_proj.weight" => "mlp.up_proj.weight",
"mlp.down_proj.weight" => "mlp.down_proj.weight",
_ => return None,
};
Some(format!("layers.{idx}.{mapped_tail}"))
}
fn rope_perm(head_dim: usize) -> Vec<usize> {
let half = head_dim / 2;
let mut perm = vec![0usize; head_dim];
for j in 0..half {
perm[2 * j] = j;
perm[2 * j + 1] = j + half;
}
perm
}
fn permute_proj_rows(data: &[f32], out: usize, in_dim: usize, head_dim: usize, perm: &[usize]) -> Vec<f32> {
let mut result = vec![0f32; data.len()];
let num_heads = out / head_dim;
for h in 0..num_heads {
for burn_ch in 0..head_dim {
let hf_ch = perm[burn_ch];
let dst_row = h * head_dim + burn_ch;
let src_row = h * head_dim + hf_ch;
let dst = dst_row * in_dim;
let src = src_row * in_dim;
result[dst..dst + in_dim].copy_from_slice(&data[src..src + in_dim]);
}
}
result
}
fn permute_vec(data: &[f32], perm: &[usize]) -> Vec<f32> {
perm.iter().map(|&src| data[src]).collect()
}
enum Transform {
F32Only,
Transpose,
RmsAddOne,
QkProj,
QkNorm,
}
fn transform_for(path: &[String], module_type: Option<&str>) -> Transform {
let last = path.last().map(String::as_str).unwrap_or("");
let is_qk_proj =
matches!(last, "weight") && path.iter().any(|p| p == "q_proj" || p == "k_proj");
let is_qk_norm =
matches!(last, "gamma") && path.iter().any(|p| p == "q_norm" || p == "k_norm");
if is_qk_norm {
return Transform::QkNorm;
}
if is_qk_proj {
return Transform::QkProj;
}
match module_type {
Some("Struct:RmsNorm") => Transform::RmsAddOne,
Some("Struct:Linear") if last == "weight" => Transform::Transpose,
_ => Transform::F32Only,
}
}
#[derive(Clone, Copy)]
struct GemmaValueAdapter {
head_dim: usize,
}
impl ModuleAdapter for GemmaValueAdapter {
fn adapt(&self, s: &TensorSnapshot) -> TensorSnapshot {
let path = s.path_stack.clone().unwrap_or_default();
let container = s.container_stack.clone().unwrap_or_default();
let id = s.tensor_id.unwrap_or_else(ParamId::new);
let src = s.clone_data_fn();
let head_dim = self.head_dim;
let transform = transform_for(&path, s.module_type().as_deref());
let src_dims: Vec<usize> = s.shape.iter().copied().collect();
match transform {
Transform::F32Only => {
let shape = s.shape.clone();
let data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> =
Rc::new(move || {
let d = src()?.convert::<f32>();
Ok(d)
});
TensorSnapshot::from_closure(data_fn, DType::F32, shape, path, container, id)
}
Transform::RmsAddOne => {
let shape = s.shape.clone();
let dims = src_dims.clone();
let data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> =
Rc::new(move || {
let mut v = to_f32_vec(&src()?)?;
for x in &mut v {
*x += 1.0;
}
Ok(TensorData::new(v, dims.clone()))
});
TensorSnapshot::from_closure(data_fn, DType::F32, shape, path, container, id)
}
Transform::Transpose => {
let (out, in_dim) = (src_dims[0], src_dims[1]);
let new_shape: Shape = [in_dim, out].into();
let data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> =
Rc::new(move || {
let v = to_f32_vec(&src()?)?; Ok(TensorData::new(transpose(&v, out, in_dim), [in_dim, out]))
});
TensorSnapshot::from_closure(data_fn, DType::F32, new_shape, path, container, id)
}
Transform::QkProj => {
let (out, in_dim) = (src_dims[0], src_dims[1]);
let new_shape: Shape = [in_dim, out].into();
let data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> =
Rc::new(move || {
let v = to_f32_vec(&src()?)?; let perm = rope_perm(head_dim);
let permuted = permute_proj_rows(&v, out, in_dim, head_dim, &perm);
Ok(TensorData::new(transpose(&permuted, out, in_dim), [in_dim, out]))
});
TensorSnapshot::from_closure(data_fn, DType::F32, new_shape, path, container, id)
}
Transform::QkNorm => {
let shape = s.shape.clone();
let data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> =
Rc::new(move || {
let mut v = to_f32_vec(&src()?)?;
for x in &mut v {
*x += 1.0;
}
let perm = rope_perm(head_dim);
let permuted = permute_vec(&v, &perm);
let len = permuted.len();
Ok(TensorData::new(permuted, [len]))
});
TensorSnapshot::from_closure(data_fn, DType::F32, shape, path, container, id)
}
}
}
fn clone_box(&self) -> Box<dyn ModuleAdapter> {
Box::new(*self)
}
}
fn to_f32_vec(d: &TensorData) -> Result<Vec<f32>, TensorSnapshotError> {
d.clone()
.convert::<f32>()
.to_vec::<f32>()
.map_err(|e| TensorSnapshotError::DataError(format!("{e:?}")))
}
fn transpose(v: &[f32], r: usize, c: usize) -> Vec<f32> {
let mut t = vec![0f32; v.len()];
for i in 0..r {
for j in 0..c {
t[j * r + i] = v[i * c + j];
}
}
t
}
pub fn load_gemma<B: Backend>(
model: GemmaModel<B>,
safetensors_bytes: &[u8],
_device: &B::Device,
) -> Result<GemmaModel<B>, String> {
let mut store = SafetensorsStore::from_bytes(Some(safetensors_bytes.to_vec()));
let snaps = store
.get_all_snapshots()
.map_err(|e| format!("safetensors parse failed: {e}"))?;
let head_dim = GemmaConfig::gemma_3_270m().head_dim;
let mut remapped: Vec<TensorSnapshot> = Vec::with_capacity(snaps.len());
for s in snaps.values() {
let hf = s.full_path();
let Some(burn_path) = hf_to_burn_path(&hf) else {
continue; };
let parts: Vec<String> = burn_path.split('.').map(String::from).collect();
remapped.push(TensorSnapshot::from_closure(
s.clone_data_fn(),
s.dtype,
s.shape.clone(),
parts,
s.container_stack.clone().unwrap_or_default(),
s.tensor_id.unwrap_or_else(ParamId::new),
));
}
let mut model = model;
let result = model.apply(
remapped,
None,
Some(Box::new(GemmaValueAdapter { head_dim })),
false,
);
if !result.errors.is_empty() {
return Err(format!("load_gemma: apply errors: {:?}", result.errors));
}
if !result.missing.is_empty() {
let names: Vec<&String> = result.missing.iter().map(|(p, _)| p).collect();
return Err(format!("load_gemma: missing tensors: {names:?}"));
}
if result.applied.len() != EXPECTED_PARAMS {
return Err(format!(
"load_gemma: applied {} params, expected {EXPECTED_PARAMS} (a checkpoint tensor was \
likely mis-named or unmapped, leaving a parameter at random init)",
result.applied.len()
));
}
Ok(model)
}