use trueno_gpu::driver::GpuBuffer;
#[derive(Debug, Clone)]
pub struct IndexedLayerWeights {
pub attn_q_ptr: u64,
pub attn_q_len: usize,
pub attn_q_qtype: WeightQuantType,
pub attn_k_ptr: u64,
pub attn_k_len: usize,
pub attn_k_qtype: WeightQuantType,
pub attn_v_ptr: u64,
pub attn_v_len: usize,
pub attn_v_qtype: WeightQuantType,
pub attn_output_ptr: u64,
pub attn_output_len: usize,
pub attn_output_qtype: WeightQuantType,
pub ffn_gate_ptr: u64,
pub ffn_gate_len: usize,
pub ffn_gate_qtype: WeightQuantType,
pub ffn_up_ptr: u64,
pub ffn_up_len: usize,
pub ffn_up_qtype: WeightQuantType,
pub ffn_down_ptr: u64,
pub ffn_down_len: usize,
pub ffn_down_qtype: WeightQuantType,
pub attn_norm_ptr: u64,
pub attn_norm_len: usize,
pub ffn_norm_ptr: u64,
pub ffn_norm_len: usize,
pub attn_q_bias_ptr: u64,
pub attn_q_bias_len: usize,
pub attn_k_bias_ptr: u64,
pub attn_k_bias_len: usize,
pub attn_v_bias_ptr: u64,
pub attn_v_bias_len: usize,
pub attn_q_norm_ptr: u64,
pub attn_q_norm_len: usize,
pub attn_k_norm_ptr: u64,
pub attn_k_norm_len: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightQuantType {
Q4K,
Q5K,
Q6K,
Q8_0,
Q5_0,
Q4_0,
Q4_1,
F32,
}
impl WeightQuantType {
pub const fn bytes_per_superblock(&self) -> usize {
match self {
Self::Q4K => 144,
Self::Q5K => 176,
Self::Q6K => 210,
Self::Q8_0 => 34 * 8, Self::Q5_0 => 22 * 8, Self::Q4_0 => 18 * 8, Self::Q4_1 => 20 * 8, Self::F32 => 256 * 4, }
}
pub const fn bytes_per_block(&self) -> usize {
match self {
Self::Q4K => 18, Self::Q5K => 22, Self::Q6K => 26, Self::Q8_0 => 34,
Self::Q5_0 => 22,
Self::Q4_0 => 18,
Self::Q4_1 => 20,
Self::F32 => 128, }
}
pub fn from_ggml_type(type_id: u32) -> Option<Self> {
match type_id {
0 => Some(Self::F32), 2 => Some(Self::Q4_0),
3 => Some(Self::Q4_1), 6 => Some(Self::Q5_0),
8 => Some(Self::Q8_0),
12 => Some(Self::Q4K),
13 => Some(Self::Q5K),
14 => Some(Self::Q6K),
_ => None,
}
}
pub fn matches_size(&self, size_bytes: usize, n_rows: usize, n_cols: usize) -> bool {
match self {
Self::F32 => size_bytes == n_rows * n_cols * 4,
Self::Q4K | Self::Q5K | Self::Q6K => {
let n_superblocks = n_rows * ((n_cols + 255) / 256);
size_bytes == n_superblocks * self.bytes_per_superblock()
},
Self::Q4_0 | Self::Q4_1 | Self::Q5_0 | Self::Q8_0 => {
let n_blocks = n_rows * ((n_cols + 31) / 32);
size_bytes == n_blocks * self.bytes_per_block()
},
}
}
pub fn from_size(size_bytes: usize, n_rows: usize, n_cols: usize) -> Option<Self> {
if size_bytes == n_rows * n_cols * 4 {
return Some(Self::F32);
}
let n_superblocks = n_rows * ((n_cols + 255) / 256);
let superblock_formats = [(Self::Q6K, 210), (Self::Q5K, 176), (Self::Q4K, 144)];
for (fmt, bytes_per_sb) in superblock_formats {
if size_bytes == n_superblocks * bytes_per_sb {
return Some(fmt);
}
}
let n_blocks = n_rows * ((n_cols + 31) / 32);
let formats = [
(Self::Q4_0, 18),
(Self::Q4_1, 20),
(Self::Q5_0, 22),
(Self::Q8_0, 34),
];
for (fmt, bytes_per_block) in formats {
if size_bytes == n_blocks * bytes_per_block {
return Some(fmt);
}
}
None
}
}
#[derive(Debug, Clone)]
pub struct BoundWeight {
pub ptr: u64,
pub len: usize,
pub out_dim: u32,
pub in_dim: u32,
kernel: GemvKernel,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GemvKernel {
Q4K,
Q5K,
Q6K,
Q8_0,
Q4_0,
Q5_0,
Q4_1,
F32,
}
impl BoundWeight {
pub fn bind(ptr: u64, len: usize, qtype: WeightQuantType, out_dim: u32, in_dim: u32) -> Self {
let kernel = match qtype {
WeightQuantType::Q4K => GemvKernel::Q4K,
WeightQuantType::Q5K => GemvKernel::Q5K,
WeightQuantType::Q6K => GemvKernel::Q6K,
WeightQuantType::Q8_0 => GemvKernel::Q8_0,
WeightQuantType::Q4_0 => GemvKernel::Q4_0,
WeightQuantType::Q5_0 => GemvKernel::Q5_0,
WeightQuantType::Q4_1 => GemvKernel::Q4_1,
WeightQuantType::F32 => GemvKernel::F32,
};
Self {
ptr,
len,
out_dim,
in_dim,
kernel,
}
}
pub fn kernel(&self) -> GemvKernel {
self.kernel
}
}
#[derive(Debug, Clone)]
pub struct BoundLayerWeights {
pub q_proj: BoundWeight,
pub k_proj: BoundWeight,
pub v_proj: BoundWeight,
pub o_proj: BoundWeight,
pub ffn_gate: BoundWeight,
pub ffn_up: BoundWeight,
pub ffn_down: BoundWeight,
pub attn_norm_ptr: u64,
pub attn_norm_len: usize,
pub ffn_norm_ptr: u64,
pub ffn_norm_len: usize,
pub attn_q_bias_ptr: u64,
pub attn_q_bias_len: usize,
pub attn_k_bias_ptr: u64,
pub attn_k_bias_len: usize,
pub attn_v_bias_ptr: u64,
pub attn_v_bias_len: usize,
pub attn_q_norm_ptr: u64,
pub attn_q_norm_len: usize,
pub attn_k_norm_ptr: u64,
pub attn_k_norm_len: usize,
}
impl BoundLayerWeights {
pub fn bind(
src: &ValidatedLayerWeights,
hidden_dim: u32,
q_dim: u32,
kv_dim: u32,
intermediate_dim: u32,
) -> Self {
Self {
q_proj: BoundWeight::bind(
src.attn_q_ptr,
src.attn_q_len,
src.attn_q_qtype,
q_dim,
hidden_dim,
),
k_proj: BoundWeight::bind(
src.attn_k_ptr,
src.attn_k_len,
src.attn_k_qtype,
kv_dim,
hidden_dim,
),
v_proj: BoundWeight::bind(
src.attn_v_ptr,
src.attn_v_len,
src.attn_v_qtype,
kv_dim,
hidden_dim,
),
o_proj: BoundWeight::bind(
src.attn_output_ptr,
src.attn_output_len,
src.attn_output_qtype,
hidden_dim,
q_dim,
),
ffn_gate: BoundWeight::bind(
src.ffn_gate_ptr,
src.ffn_gate_len,
src.ffn_gate_qtype,
intermediate_dim,
hidden_dim,
),
ffn_up: BoundWeight::bind(
src.ffn_up_ptr,
src.ffn_up_len,
src.ffn_up_qtype,
intermediate_dim,
hidden_dim,
),
ffn_down: BoundWeight::bind(
src.ffn_down_ptr,
src.ffn_down_len,
src.ffn_down_qtype,
hidden_dim,
intermediate_dim,
),
attn_norm_ptr: src.attn_norm_ptr,
attn_norm_len: src.attn_norm_len,
ffn_norm_ptr: src.ffn_norm_ptr,
ffn_norm_len: src.ffn_norm_len,
attn_q_bias_ptr: src.attn_q_bias_ptr,
attn_q_bias_len: src.attn_q_bias_len,
attn_k_bias_ptr: src.attn_k_bias_ptr,
attn_k_bias_len: src.attn_k_bias_len,
attn_v_bias_ptr: src.attn_v_bias_ptr,
attn_v_bias_len: src.attn_v_bias_len,
attn_q_norm_ptr: src.attn_q_norm_ptr,
attn_q_norm_len: src.attn_q_norm_len,
attn_k_norm_ptr: src.attn_k_norm_ptr,
attn_k_norm_len: src.attn_k_norm_len,
}
}
}
use crate::arch_requirements::{required_roles, WeightRole};
use crate::gguf::ArchConstraints;
use std::fmt;
#[derive(Debug, Clone)]
pub struct WeightValidationError {
pub role: WeightRole,
pub field: &'static str,
pub arch_name: String,
pub layer_idx: usize,
}
impl fmt::Display for WeightValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GH-279: Missing required weight '{}' for architecture '{}' at layer {} \
(ValidatedLayerWeights Poka-Yoke: (0, 0) is not allowed for required roles)",
self.field, self.arch_name, self.layer_idx
)
}
}
impl std::error::Error for WeightValidationError {}
#[derive(Debug, Clone)]
pub struct ValidatedLayerWeights {
inner: IndexedLayerWeights,
}
impl std::ops::Deref for ValidatedLayerWeights {
type Target = IndexedLayerWeights;
fn deref(&self) -> &IndexedLayerWeights {
&self.inner
}
}
impl ValidatedLayerWeights {
pub fn validate(
raw: IndexedLayerWeights,
arch: &ArchConstraints,
layer_idx: usize,
) -> Result<Self, WeightValidationError> {
let roles = required_roles(arch);
for &role in roles {
let (ptr, len) = Self::get_field(&raw, role);
if ptr == 0 && len == 0 {
return Err(WeightValidationError {
role,
field: role.field_name(),
arch_name: Self::arch_display_name(arch),
layer_idx,
});
}
}
Ok(Self { inner: raw })
}
#[must_use]
pub fn inner(&self) -> &IndexedLayerWeights {
&self.inner
}
#[cfg(test)]
#[must_use]
pub fn new_unchecked(raw: IndexedLayerWeights) -> Self {
Self { inner: raw }
}
#[cfg(test)]
#[must_use]
pub fn inner_mut(&mut self) -> &mut IndexedLayerWeights {
&mut self.inner
}
fn get_field(raw: &IndexedLayerWeights, role: WeightRole) -> (u64, usize) {
match role {
WeightRole::AttnNorm => (raw.attn_norm_ptr, raw.attn_norm_len),
WeightRole::FfnNorm => (raw.ffn_norm_ptr, raw.ffn_norm_len),
WeightRole::AttnQNorm => (raw.attn_q_norm_ptr, raw.attn_q_norm_len),
WeightRole::AttnKNorm => (raw.attn_k_norm_ptr, raw.attn_k_norm_len),
WeightRole::AttnQBias => (raw.attn_q_bias_ptr, raw.attn_q_bias_len),
WeightRole::AttnKBias => (raw.attn_k_bias_ptr, raw.attn_k_bias_len),
WeightRole::AttnVBias => (raw.attn_v_bias_ptr, raw.attn_v_bias_len),
WeightRole::QProj => (raw.attn_q_ptr, raw.attn_q_len),
WeightRole::KProj => (raw.attn_k_ptr, raw.attn_k_len),
WeightRole::VProj => (raw.attn_v_ptr, raw.attn_v_len),
WeightRole::OProj => (raw.attn_output_ptr, raw.attn_output_len),
WeightRole::FfnGate => (raw.ffn_gate_ptr, raw.ffn_gate_len),
WeightRole::FfnUp => (raw.ffn_up_ptr, raw.ffn_up_len),
WeightRole::FfnDown => (raw.ffn_down_ptr, raw.ffn_down_len),
}
}
fn arch_display_name(arch: &ArchConstraints) -> String {
if arch.has_qk_norm && !arch.has_bias {
"qwen3".to_string()
} else if !arch.has_qk_norm && arch.has_bias {
"qwen2/phi (has_bias)".to_string()
} else if arch.has_qk_norm && arch.has_bias {
"unknown (has_qk_norm + has_bias)".to_string()
} else {
"llama/mistral/gemma (base)".to_string()
}
}
}
include!("transformer_workspace.rs");