use oxibonsai_core::config::Qwen3Config;
use oxibonsai_core::gguf::reader::GgufFile;
use oxibonsai_core::gguf::tensor_info::tensor_names;
use oxibonsai_core::gguf::types::GgufTensorType;
use oxibonsai_core::tensor::{BlockQ1_0G128, QK1_0_G128};
use oxibonsai_core::{
BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0,
};
use crate::block::TransformerBlock;
use crate::error::{ModelError, ModelResult};
use crate::layers::linear::{Linear1Bit, LinearFP8E4M3, LinearFP8E5M2, LinearTernary};
use crate::layers::linear_kquant_ext::{LinearQ5K, LinearQ6K};
use crate::layers::linear_kquant_full::{LinearQ2K, LinearQ3K, LinearQ4K, LinearQ8K};
use crate::layers::linear_standard::{LinearQ4_0, LinearQ8_0};
use crate::layers::rms_norm::RmsNorm;
use super::types::OutputWeight;
pub(super) fn load_f32_tensor(gguf: &GgufFile<'_>, name: &str) -> ModelResult<Vec<f32>> {
let info = gguf.tensors.require(name).map_err(ModelError::Core)?;
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
match info.tensor_type {
GgufTensorType::F32 => {
let count = data.len() / 4;
let mut out = vec![0.0f32; count];
for (i, chunk) in data.chunks_exact(4).enumerate() {
out[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
Ok(out)
}
GgufTensorType::F16 => {
let count = data.len() / 2;
let mut out = vec![0.0f32; count];
for (i, chunk) in data.chunks_exact(2).enumerate() {
let raw = u16::from_le_bytes([chunk[0], chunk[1]]);
out[i] = half::f16::from_bits(raw).to_f32();
}
Ok(out)
}
GgufTensorType::Q1_0_g128 => {
let blocks = BlockQ1_0G128::slice_from_bytes(data).map_err(ModelError::Core)?;
let n = blocks.len() * QK1_0_G128;
let mut out = vec![0.0f32; n];
for (i, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = i * QK1_0_G128;
for j in 0..QK1_0_G128 {
let byte_index = j / 8;
let bit_offset = j % 8;
let bit = (block.qs[byte_index] >> bit_offset) & 1;
out[base + j] = if bit != 0 { d } else { -d };
}
}
Ok(out)
}
GgufTensorType::TQ2_0_g128 => {
let blocks = oxibonsai_core::BlockTQ2_0_g128::slice_from_bytes(data)
.map_err(ModelError::Core)?;
let n = blocks.len() * oxibonsai_core::QK_TQ2_0_G128;
let mut out = vec![0.0f32; n];
oxibonsai_core::BlockTQ2_0_g128::dequant(blocks, &mut out).map_err(ModelError::Core)?;
Ok(out)
}
GgufTensorType::Q4_0 => {
let blocks = BlockQ4_0::slice_from_bytes(data).map_err(ModelError::Core)?;
let n = blocks.len() * oxibonsai_core::QK_Q4_0;
let mut out = vec![0.0f32; n];
BlockQ4_0::dequant(blocks, &mut out).map_err(ModelError::Core)?;
Ok(out)
}
GgufTensorType::Q8_0 => {
let blocks = BlockQ8_0::slice_from_bytes(data).map_err(ModelError::Core)?;
let n = blocks.len() * oxibonsai_core::QK_Q8_0;
let mut out = vec![0.0f32; n];
BlockQ8_0::dequant(blocks, &mut out).map_err(ModelError::Core)?;
Ok(out)
}
GgufTensorType::Q5_K => {
let blocks = BlockQ5K::slice_from_bytes(data).map_err(ModelError::Core)?;
let n = blocks.len() * 256;
let mut out = vec![0.0f32; n];
BlockQ5K::dequant(blocks, &mut out).map_err(ModelError::Core)?;
Ok(out)
}
GgufTensorType::Q6_K => {
let blocks = BlockQ6K::slice_from_bytes(data).map_err(ModelError::Core)?;
let n = blocks.len() * 256;
let mut out = vec![0.0f32; n];
BlockQ6K::dequant(blocks, &mut out).map_err(ModelError::Core)?;
Ok(out)
}
other => Err(ModelError::MissingTensor {
name: format!(
"{name}: expected F32, F16, Q1_0_g128, TQ2_0_g128, Q4_0, Q8_0, Q5_K, or Q6_K, got {other}"
),
}),
}
}
pub(super) fn load_1bit_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ1_0G128]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
BlockQ1_0G128::slice_from_bytes(data).map_err(ModelError::Core)
}
pub(super) fn load_ternary_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [oxibonsai_core::BlockTQ2_0_g128]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
oxibonsai_core::BlockTQ2_0_g128::slice_from_bytes(data).map_err(ModelError::Core)
}
#[allow(dead_code)]
pub(super) fn load_ternary_embedding(gguf: &GgufFile<'_>, name: &str) -> ModelResult<Vec<f32>> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
let blocks =
oxibonsai_core::BlockTQ2_0_g128::slice_from_bytes(data).map_err(ModelError::Core)?;
let n = blocks.len() * oxibonsai_core::QK_TQ2_0_G128;
let mut out = vec![0.0f32; n];
oxibonsai_core::BlockTQ2_0_g128::dequant(blocks, &mut out).map_err(ModelError::Core)?;
Ok(out)
}
pub(super) fn load_fp8_e4m3_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [oxibonsai_core::BlockFP8E4M3]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
oxibonsai_core::BlockFP8E4M3::slice_from_bytes(data).map_err(ModelError::Core)
}
pub(super) fn load_fp8_e5m2_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [oxibonsai_core::BlockFP8E5M2]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
oxibonsai_core::BlockFP8E5M2::slice_from_bytes(data).map_err(ModelError::Core)
}
pub(super) fn load_q4_0_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ4_0]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
BlockQ4_0::slice_from_bytes(data).map_err(ModelError::Core)
}
pub(super) fn load_q8_0_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ8_0]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
BlockQ8_0::slice_from_bytes(data).map_err(ModelError::Core)
}
pub(super) fn load_q5k_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ5K]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
BlockQ5K::slice_from_bytes(data).map_err(ModelError::Core)
}
pub(super) fn load_q6k_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ6K]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
BlockQ6K::slice_from_bytes(data).map_err(ModelError::Core)
}
pub(super) fn load_q2k_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ2K]> {
BlockQ2K::slice_from_bytes(gguf.tensor_data(name).map_err(ModelError::Core)?)
.map_err(ModelError::Core)
}
pub(super) fn load_q3k_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ3K]> {
BlockQ3K::slice_from_bytes(gguf.tensor_data(name).map_err(ModelError::Core)?)
.map_err(ModelError::Core)
}
pub(super) fn load_q4k_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ4K]> {
BlockQ4K::slice_from_bytes(gguf.tensor_data(name).map_err(ModelError::Core)?)
.map_err(ModelError::Core)
}
pub(super) fn load_q8k_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ8K]> {
BlockQ8K::slice_from_bytes(gguf.tensor_data(name).map_err(ModelError::Core)?)
.map_err(ModelError::Core)
}
pub(super) fn load_transformer_block<'a>(
gguf: &'a GgufFile<'a>,
config: &Qwen3Config,
layer_idx: usize,
kernel: &std::sync::Arc<oxibonsai_kernels::KernelDispatcher>,
) -> ModelResult<TransformerBlock<'a>> {
let h = config.hidden_size;
let nq = config.num_attention_heads;
let nkv = config.num_kv_heads;
let hd = config.head_dim;
let inter = config.intermediate_size;
let blk = |suffix: &str| -> String { tensor_names::block_tensor(layer_idx, suffix) };
let sample_name = blk(tensor_names::ATTN_Q);
let sample_info = gguf
.tensors
.require(&sample_name)
.map_err(ModelError::Core)?;
let is_ternary = sample_info.tensor_type == GgufTensorType::TQ2_0_g128;
let is_fp8_e4m3 = sample_info.tensor_type == GgufTensorType::F8_E4M3;
let is_fp8_e5m2 = sample_info.tensor_type == GgufTensorType::F8_E5M2;
let is_q4_0 = sample_info.tensor_type == GgufTensorType::Q4_0;
let is_q8_0 = sample_info.tensor_type == GgufTensorType::Q8_0;
let is_q5k = sample_info.tensor_type == GgufTensorType::Q5_K;
let is_q6k = sample_info.tensor_type == GgufTensorType::Q6_K;
let is_q2k = sample_info.tensor_type == GgufTensorType::Q2_K;
let is_q3k = sample_info.tensor_type == GgufTensorType::Q3_K;
let is_q4k = sample_info.tensor_type == GgufTensorType::Q4_K;
let is_q8k = sample_info.tensor_type == GgufTensorType::Q8_K;
let attn_norm_w = load_f32_tensor(gguf, &blk(tensor_names::ATTN_NORM))?;
let ffn_norm_w = load_f32_tensor(gguf, &blk(tensor_names::FFN_NORM))?;
let q_norm_w = load_f32_tensor(gguf, &blk(tensor_names::ATTN_Q_NORM))?;
let k_norm_w = load_f32_tensor(gguf, &blk(tensor_names::ATTN_K_NORM))?;
if is_ternary {
let q_blocks = load_ternary_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_ternary_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_ternary_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_ternary_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_ternary_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_ternary_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_ternary_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearTernary::new(q_blocks, nq * hd, h, kernel.clone())?.into(),
LinearTernary::new(k_blocks, nkv * hd, h, kernel.clone())?.into(),
LinearTernary::new(v_blocks, nkv * hd, h, kernel.clone())?.into(),
LinearTernary::new(o_blocks, h, nq * hd, kernel.clone())?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearTernary::new(gate_blocks, inter, h, kernel.clone())?.into(),
LinearTernary::new(up_blocks, inter, h, kernel.clone())?.into(),
LinearTernary::new(down_blocks, h, inter, kernel.clone())?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded ternary transformer block");
Ok(block)
} else if is_fp8_e4m3 {
let q_blocks = load_fp8_e4m3_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_fp8_e4m3_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_fp8_e4m3_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_fp8_e4m3_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_fp8_e4m3_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_fp8_e4m3_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_fp8_e4m3_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearFP8E4M3::new(q_blocks, nq * hd, h, kernel.clone())?.into(),
LinearFP8E4M3::new(k_blocks, nkv * hd, h, kernel.clone())?.into(),
LinearFP8E4M3::new(v_blocks, nkv * hd, h, kernel.clone())?.into(),
LinearFP8E4M3::new(o_blocks, h, nq * hd, kernel.clone())?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearFP8E4M3::new(gate_blocks, inter, h, kernel.clone())?.into(),
LinearFP8E4M3::new(up_blocks, inter, h, kernel.clone())?.into(),
LinearFP8E4M3::new(down_blocks, h, inter, kernel.clone())?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded FP8 E4M3FN transformer block");
Ok(block)
} else if is_fp8_e5m2 {
let q_blocks = load_fp8_e5m2_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_fp8_e5m2_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_fp8_e5m2_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_fp8_e5m2_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_fp8_e5m2_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_fp8_e5m2_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_fp8_e5m2_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearFP8E5M2::new(q_blocks, nq * hd, h, kernel.clone())?.into(),
LinearFP8E5M2::new(k_blocks, nkv * hd, h, kernel.clone())?.into(),
LinearFP8E5M2::new(v_blocks, nkv * hd, h, kernel.clone())?.into(),
LinearFP8E5M2::new(o_blocks, h, nq * hd, kernel.clone())?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearFP8E5M2::new(gate_blocks, inter, h, kernel.clone())?.into(),
LinearFP8E5M2::new(up_blocks, inter, h, kernel.clone())?.into(),
LinearFP8E5M2::new(down_blocks, h, inter, kernel.clone())?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded FP8 E5M2 transformer block");
Ok(block)
} else if is_q4_0 {
let q_blocks = load_q4_0_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_q4_0_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_q4_0_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_q4_0_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_q4_0_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_q4_0_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_q4_0_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearQ4_0::new(q_blocks, nq * hd, h)?.into(),
LinearQ4_0::new(k_blocks, nkv * hd, h)?.into(),
LinearQ4_0::new(v_blocks, nkv * hd, h)?.into(),
LinearQ4_0::new(o_blocks, h, nq * hd)?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearQ4_0::new(gate_blocks, inter, h)?.into(),
LinearQ4_0::new(up_blocks, inter, h)?.into(),
LinearQ4_0::new(down_blocks, h, inter)?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded Q4_0 transformer block");
Ok(block)
} else if is_q8_0 {
let q_blocks = load_q8_0_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_q8_0_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_q8_0_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_q8_0_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_q8_0_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_q8_0_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_q8_0_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearQ8_0::new(q_blocks, nq * hd, h)?.into(),
LinearQ8_0::new(k_blocks, nkv * hd, h)?.into(),
LinearQ8_0::new(v_blocks, nkv * hd, h)?.into(),
LinearQ8_0::new(o_blocks, h, nq * hd)?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearQ8_0::new(gate_blocks, inter, h)?.into(),
LinearQ8_0::new(up_blocks, inter, h)?.into(),
LinearQ8_0::new(down_blocks, h, inter)?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded Q8_0 transformer block");
Ok(block)
} else if is_q5k {
let q_blocks = load_q5k_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_q5k_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_q5k_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_q5k_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_q5k_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_q5k_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_q5k_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearQ5K::new(q_blocks, nq * hd, h)?.into(),
LinearQ5K::new(k_blocks, nkv * hd, h)?.into(),
LinearQ5K::new(v_blocks, nkv * hd, h)?.into(),
LinearQ5K::new(o_blocks, h, nq * hd)?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearQ5K::new(gate_blocks, inter, h)?.into(),
LinearQ5K::new(up_blocks, inter, h)?.into(),
LinearQ5K::new(down_blocks, h, inter)?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded Q5_K transformer block");
Ok(block)
} else if is_q6k {
let q_blocks = load_q6k_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_q6k_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_q6k_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_q6k_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_q6k_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_q6k_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_q6k_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearQ6K::new(q_blocks, nq * hd, h)?.into(),
LinearQ6K::new(k_blocks, nkv * hd, h)?.into(),
LinearQ6K::new(v_blocks, nkv * hd, h)?.into(),
LinearQ6K::new(o_blocks, h, nq * hd)?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearQ6K::new(gate_blocks, inter, h)?.into(),
LinearQ6K::new(up_blocks, inter, h)?.into(),
LinearQ6K::new(down_blocks, h, inter)?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded Q6_K transformer block");
Ok(block)
} else if is_q2k {
let q_b = load_q2k_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_b = load_q2k_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_b = load_q2k_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_b = load_q2k_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_b = load_q2k_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_b = load_q2k_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_b = load_q2k_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearQ2K::new(q_b, nq * hd, h)?.into(),
LinearQ2K::new(k_b, nkv * hd, h)?.into(),
LinearQ2K::new(v_b, nkv * hd, h)?.into(),
LinearQ2K::new(o_b, h, nq * hd)?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearQ2K::new(gate_b, inter, h)?.into(),
LinearQ2K::new(up_b, inter, h)?.into(),
LinearQ2K::new(down_b, h, inter)?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded Q2_K transformer block");
Ok(block)
} else if is_q3k {
let q_b = load_q3k_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_b = load_q3k_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_b = load_q3k_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_b = load_q3k_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_b = load_q3k_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_b = load_q3k_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_b = load_q3k_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearQ3K::new(q_b, nq * hd, h)?.into(),
LinearQ3K::new(k_b, nkv * hd, h)?.into(),
LinearQ3K::new(v_b, nkv * hd, h)?.into(),
LinearQ3K::new(o_b, h, nq * hd)?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearQ3K::new(gate_b, inter, h)?.into(),
LinearQ3K::new(up_b, inter, h)?.into(),
LinearQ3K::new(down_b, h, inter)?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded Q3_K transformer block");
Ok(block)
} else if is_q4k {
let q_b = load_q4k_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_b = load_q4k_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_b = load_q4k_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_b = load_q4k_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_b = load_q4k_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_b = load_q4k_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_b = load_q4k_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearQ4K::new(q_b, nq * hd, h)?.into(),
LinearQ4K::new(k_b, nkv * hd, h)?.into(),
LinearQ4K::new(v_b, nkv * hd, h)?.into(),
LinearQ4K::new(o_b, h, nq * hd)?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearQ4K::new(gate_b, inter, h)?.into(),
LinearQ4K::new(up_b, inter, h)?.into(),
LinearQ4K::new(down_b, h, inter)?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded Q4_K transformer block");
Ok(block)
} else if is_q8k {
let q_b = load_q8k_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_b = load_q8k_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_b = load_q8k_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_b = load_q8k_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_b = load_q8k_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_b = load_q8k_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_b = load_q8k_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearQ8K::new(q_b, nq * hd, h)?.into(),
LinearQ8K::new(k_b, nkv * hd, h)?.into(),
LinearQ8K::new(v_b, nkv * hd, h)?.into(),
LinearQ8K::new(o_b, h, nq * hd)?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearQ8K::new(gate_b, inter, h)?.into(),
LinearQ8K::new(up_b, inter, h)?.into(),
LinearQ8K::new(down_b, h, inter)?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded Q8_K transformer block");
Ok(block)
} else {
let q_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_1bit_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_1bit_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_1bit_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
Linear1Bit::new(q_blocks, nq * hd, h, kernel.clone())?.into(),
Linear1Bit::new(k_blocks, nkv * hd, h, kernel.clone())?.into(),
Linear1Bit::new(v_blocks, nkv * hd, h, kernel.clone())?.into(),
Linear1Bit::new(o_blocks, h, nq * hd, kernel.clone())?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
Linear1Bit::new(gate_blocks, inter, h, kernel.clone())?.into(),
Linear1Bit::new(up_blocks, inter, h, kernel.clone())?.into(),
Linear1Bit::new(down_blocks, h, inter, kernel.clone())?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded transformer block");
Ok(block)
}
}
pub(super) fn load_output_weight<'a>(
gguf: &'a GgufFile<'a>,
config: &Qwen3Config,
kernel: &std::sync::Arc<oxibonsai_kernels::KernelDispatcher>,
) -> ModelResult<OutputWeight<'a>> {
let info = gguf
.tensors
.require(tensor_names::OUTPUT)
.map_err(ModelError::Core)?;
let out_features = if info.shape.len() >= 2 {
info.shape[1] as usize
} else {
config.vocab_size
};
let in_features = if !info.shape.is_empty() {
info.shape[0] as usize
} else {
config.hidden_size
};
match info.tensor_type {
GgufTensorType::Q1_0_g128 => {
let blocks = load_1bit_blocks(gguf, tensor_names::OUTPUT)?;
let linear = Linear1Bit::new(blocks, out_features, in_features, kernel.clone())?;
Ok(OutputWeight::OneBit(linear))
}
GgufTensorType::TQ2_0_g128 => {
let blocks = load_ternary_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearTernary::new(blocks, out_features, in_features, kernel.clone())?;
Ok(OutputWeight::Ternary(linear))
}
GgufTensorType::F8_E4M3 => {
let blocks = load_fp8_e4m3_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearFP8E4M3::new(blocks, out_features, in_features, kernel.clone())?;
Ok(OutputWeight::FP8E4M3(linear))
}
GgufTensorType::F8_E5M2 => {
let blocks = load_fp8_e5m2_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearFP8E5M2::new(blocks, out_features, in_features, kernel.clone())?;
Ok(OutputWeight::FP8E5M2(linear))
}
GgufTensorType::Q4_0 => {
let blocks = load_q4_0_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearQ4_0::new(blocks, out_features, in_features)?;
Ok(OutputWeight::Q4_0(linear))
}
GgufTensorType::Q8_0 => {
let blocks = load_q8_0_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearQ8_0::new(blocks, out_features, in_features)?;
Ok(OutputWeight::Q8_0(linear))
}
GgufTensorType::Q5_K => {
let blocks = load_q5k_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearQ5K::new(blocks, out_features, in_features)?;
Ok(OutputWeight::Q5K(linear))
}
GgufTensorType::Q6_K => {
let blocks = load_q6k_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearQ6K::new(blocks, out_features, in_features)?;
Ok(OutputWeight::Q6K(linear))
}
GgufTensorType::Q2_K => {
let blocks = load_q2k_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearQ2K::new(blocks, out_features, in_features)?;
Ok(OutputWeight::Q2K(linear))
}
GgufTensorType::Q3_K => {
let blocks = load_q3k_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearQ3K::new(blocks, out_features, in_features)?;
Ok(OutputWeight::Q3K(linear))
}
GgufTensorType::Q4_K => {
let blocks = load_q4k_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearQ4K::new(blocks, out_features, in_features)?;
Ok(OutputWeight::Q4K(linear))
}
GgufTensorType::Q8_K => {
let blocks = load_q8k_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearQ8K::new(blocks, out_features, in_features)?;
Ok(OutputWeight::Q8K(linear))
}
GgufTensorType::F32 | GgufTensorType::F16 => {
let weights = load_f32_tensor(gguf, tensor_names::OUTPUT)?;
Ok(OutputWeight::Fp32 {
weights,
out_features,
in_features,
})
}
other => Err(ModelError::MissingTensor {
name: format!("output.weight: unsupported type {other}"),
}),
}
}