use crate::gguf::ArchConstraints;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WeightRole {
AttnKBias,
AttnKNorm,
AttnNorm,
AttnQBias,
AttnQNorm,
AttnVBias,
FfnDown,
FfnNorm,
FfnGate,
KProj,
OProj,
QProj,
FfnUp,
VProj,
}
impl WeightRole {
#[must_use]
pub const fn field_name(&self) -> &'static str {
match self {
Self::AttnKBias => "attn_k_bias",
Self::AttnKNorm => "attn_k_norm",
Self::AttnNorm => "attn_norm",
Self::AttnQBias => "attn_q_bias",
Self::AttnQNorm => "attn_q_norm",
Self::AttnVBias => "attn_v_bias",
Self::FfnDown => "ffn_down",
Self::FfnNorm => "ffn_norm",
Self::FfnGate => "ffn_gate",
Self::KProj => "attn_k (k_proj)",
Self::OProj => "attn_output (o_proj)",
Self::QProj => "attn_q (q_proj)",
Self::FfnUp => "ffn_up",
Self::VProj => "attn_v (v_proj)",
}
}
}
const ROLES_NO_QK_NORM_NO_BIAS: &[WeightRole] = &[
WeightRole::AttnNorm,
WeightRole::FfnNorm,
WeightRole::QProj,
WeightRole::KProj,
WeightRole::VProj,
WeightRole::OProj,
WeightRole::FfnGate,
WeightRole::FfnUp,
WeightRole::FfnDown,
];
const _: () = assert!(ROLES_NO_QK_NORM_NO_BIAS.len() == 9, "YAML declares 9 roles");
const ROLES_NO_QK_NORM_BIAS: &[WeightRole] = &[
WeightRole::AttnNorm,
WeightRole::FfnNorm,
WeightRole::QProj,
WeightRole::KProj,
WeightRole::VProj,
WeightRole::OProj,
WeightRole::FfnGate,
WeightRole::FfnUp,
WeightRole::FfnDown,
WeightRole::AttnQBias,
WeightRole::AttnKBias,
WeightRole::AttnVBias,
];
const _: () = assert!(ROLES_NO_QK_NORM_BIAS.len() == 12, "YAML declares 12 roles");
const ROLES_QK_NORM_NO_BIAS: &[WeightRole] = &[
WeightRole::AttnNorm,
WeightRole::FfnNorm,
WeightRole::QProj,
WeightRole::KProj,
WeightRole::VProj,
WeightRole::OProj,
WeightRole::FfnGate,
WeightRole::FfnUp,
WeightRole::FfnDown,
WeightRole::AttnQNorm,
WeightRole::AttnKNorm,
];
const _: () = assert!(ROLES_QK_NORM_NO_BIAS.len() == 11, "YAML declares 11 roles");
const ROLES_QK_NORM_AND_BIAS: &[WeightRole] = &[
WeightRole::AttnNorm,
WeightRole::FfnNorm,
WeightRole::QProj,
WeightRole::KProj,
WeightRole::VProj,
WeightRole::OProj,
WeightRole::FfnGate,
WeightRole::FfnUp,
WeightRole::FfnDown,
WeightRole::AttnQNorm,
WeightRole::AttnKNorm,
WeightRole::AttnQBias,
WeightRole::AttnKBias,
WeightRole::AttnVBias,
];
const _: () = assert!(ROLES_QK_NORM_AND_BIAS.len() == 14, "YAML declares 14 roles");
#[must_use]
pub fn required_roles(arch: &ArchConstraints) -> &'static [WeightRole] {
match (arch.has_qk_norm, arch.has_bias) {
(false, false) => ROLES_NO_QK_NORM_NO_BIAS,
(false, true) => ROLES_NO_QK_NORM_BIAS,
(true, false) => ROLES_QK_NORM_NO_BIAS,
(true, true) => ROLES_QK_NORM_AND_BIAS,
}
}