use super::mtl_weight_buf::{MtlWeightBuf, MtlWeightBufError};
use super::variants::{LayerKind, VARIANT};
use super::weight_file::WeightFile;
#[derive(Debug, Clone)]
pub struct LinearAttnW {
pub qkv_w: u64,
pub qkv_s: u64,
pub qkv_b: u64,
pub z_w: u64,
pub z_s: u64,
pub z_b: u64,
pub alpha_w: u64,
pub alpha_s: u64,
pub alpha_b: u64,
pub beta_w: u64,
pub beta_s: u64,
pub beta_b: u64,
pub conv1d_w: u64,
pub a_log: u64,
pub dt_bias: u64,
pub gated_norm_w: u64,
pub o_proj_w: u64,
pub o_proj_s: u64,
pub o_proj_b: u64,
}
#[derive(Debug, Clone)]
pub struct FullAttnW {
pub q_proj_w: u64,
pub q_proj_s: u64,
pub q_proj_b: u64,
pub k_proj_w: u64,
pub k_proj_s: u64,
pub k_proj_b: u64,
pub v_proj_w: u64,
pub v_proj_s: u64,
pub v_proj_b: u64,
pub q_norm_w: u64,
pub k_norm_w: u64,
pub o_proj_w: u64,
pub o_proj_s: u64,
pub o_proj_b: u64,
}
#[derive(Debug, Clone)]
pub enum LayerAttnW {
LinearAttn(LinearAttnW),
FullAttn(FullAttnW),
}
impl LayerAttnW {
pub fn linear(&self) -> Option<&LinearAttnW> {
match self {
Self::LinearAttn(la) => Some(la),
Self::FullAttn(_) => None,
}
}
pub fn full(&self) -> Option<&FullAttnW> {
match self {
Self::FullAttn(fa) => Some(fa),
Self::LinearAttn(_) => None,
}
}
}
#[derive(Debug, Clone)]
pub struct GateW {
pub w: u64,
pub s: u64,
pub b: u64,
pub bias: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct SharedExpertW {
pub seg_w: u64,
pub seg_s: u64,
pub seg_b: u64,
pub gate_w: u64,
pub gate_s: u64,
pub gate_b: u64,
pub up_w: u64,
pub up_s: u64,
pub up_b: u64,
pub down_w: u64,
pub down_s: u64,
pub down_b: u64,
}
#[derive(Debug, Clone)]
pub struct LayerWeightCache {
pub input_layernorm_w: u64,
pub post_attention_layernorm_w: u64,
pub attn: LayerAttnW,
pub gate: GateW,
pub shared: SharedExpertW,
}
impl LayerWeightCache {
pub fn build(
layer_idx: usize,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
) -> Result<Self, MtlWeightBufError> {
let need = |name: String| -> Result<u64, MtlWeightBufError> {
wf_buf
.tensor_offset(wf, &name)?
.ok_or(MtlWeightBufError::MissingTensor { name })
};
let input_layernorm_w =
need(format!("model.layers.{layer_idx}.input_layernorm.weight"))?;
let post_attention_layernorm_w = need(format!(
"model.layers.{layer_idx}.post_attention_layernorm.weight"
))?;
let attn = match VARIANT.layer_kind(layer_idx) {
LayerKind::LinearAttn => {
let p = |suffix: &str| {
format!("model.layers.{layer_idx}.linear_attn.{suffix}")
};
LayerAttnW::LinearAttn(LinearAttnW {
qkv_w: need(p("in_proj_qkv.weight"))?,
qkv_s: need(p("in_proj_qkv.scales"))?,
qkv_b: need(p("in_proj_qkv.biases"))?,
z_w: need(p("in_proj_z.weight"))?,
z_s: need(p("in_proj_z.scales"))?,
z_b: need(p("in_proj_z.biases"))?,
alpha_w: need(p("in_proj_a.weight"))?,
alpha_s: need(p("in_proj_a.scales"))?,
alpha_b: need(p("in_proj_a.biases"))?,
beta_w: need(p("in_proj_b.weight"))?,
beta_s: need(p("in_proj_b.scales"))?,
beta_b: need(p("in_proj_b.biases"))?,
conv1d_w: need(p("conv1d.weight"))?,
a_log: need(p("A_log"))?,
dt_bias: need(p("dt_bias"))?,
gated_norm_w: need(p("norm.weight"))?,
o_proj_w: need(p("out_proj.weight"))?,
o_proj_s: need(p("out_proj.scales"))?,
o_proj_b: need(p("out_proj.biases"))?,
})
}
LayerKind::FullAttn => {
let s = |suffix: &str| {
format!("model.layers.{layer_idx}.self_attn.{suffix}")
};
LayerAttnW::FullAttn(FullAttnW {
q_proj_w: need(s("q_proj.weight"))?,
q_proj_s: need(s("q_proj.scales"))?,
q_proj_b: need(s("q_proj.biases"))?,
k_proj_w: need(s("k_proj.weight"))?,
k_proj_s: need(s("k_proj.scales"))?,
k_proj_b: need(s("k_proj.biases"))?,
v_proj_w: need(s("v_proj.weight"))?,
v_proj_s: need(s("v_proj.scales"))?,
v_proj_b: need(s("v_proj.biases"))?,
q_norm_w: need(s("q_norm.weight"))?,
k_norm_w: need(s("k_norm.weight"))?,
o_proj_w: need(s("o_proj.weight"))?,
o_proj_s: need(s("o_proj.scales"))?,
o_proj_b: need(s("o_proj.biases"))?,
})
}
};
let m =
|suffix: &str| format!("model.layers.{layer_idx}.mlp.{suffix}");
let gate = GateW {
w: need(m("gate.weight"))?,
s: need(m("gate.scales"))?,
b: need(m("gate.biases"))?,
bias: None,
};
let shared = SharedExpertW {
seg_w: need(m("shared_expert_gate.weight"))?,
seg_s: need(m("shared_expert_gate.scales"))?,
seg_b: need(m("shared_expert_gate.biases"))?,
gate_w: need(m("shared_expert.gate_proj.weight"))?,
gate_s: need(m("shared_expert.gate_proj.scales"))?,
gate_b: need(m("shared_expert.gate_proj.biases"))?,
up_w: need(m("shared_expert.up_proj.weight"))?,
up_s: need(m("shared_expert.up_proj.scales"))?,
up_b: need(m("shared_expert.up_proj.biases"))?,
down_w: need(m("shared_expert.down_proj.weight"))?,
down_s: need(m("shared_expert.down_proj.scales"))?,
down_b: need(m("shared_expert.down_proj.biases"))?,
};
Ok(Self {
input_layernorm_w,
post_attention_layernorm_w,
attn,
gate,
shared,
})
}
pub fn build_all(
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
) -> Result<Vec<Self>, MtlWeightBufError> {
(0..VARIANT.num_layers)
.map(|i| Self::build(i, wf, wf_buf))
.collect()
}
}