use oxibonsai_core::config::Qwen3Config;
use oxibonsai_core::gguf::reader::GgufFile;
use oxibonsai_core::gguf::tensor_info::tensor_names;
use oxibonsai_core::gguf::types::GgufTensorType;
use oxibonsai_core::tensor::{BlockQ1_0G128, QK1_0_G128};
use crate::block::TransformerBlock;
use crate::error::{ModelError, ModelResult};
use crate::layers::linear::{Linear1Bit, LinearTernary};
use crate::layers::rms_norm::RmsNorm;
use super::types::OutputWeight;
pub(super) fn load_f32_tensor(gguf: &GgufFile<'_>, name: &str) -> ModelResult<Vec<f32>> {
let info = gguf.tensors.require(name).map_err(ModelError::Core)?;
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
match info.tensor_type {
GgufTensorType::F32 => {
let count = data.len() / 4;
let mut out = vec![0.0f32; count];
for (i, chunk) in data.chunks_exact(4).enumerate() {
out[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
Ok(out)
}
GgufTensorType::F16 => {
let count = data.len() / 2;
let mut out = vec![0.0f32; count];
for (i, chunk) in data.chunks_exact(2).enumerate() {
let raw = u16::from_le_bytes([chunk[0], chunk[1]]);
out[i] = half::f16::from_bits(raw).to_f32();
}
Ok(out)
}
GgufTensorType::Q1_0_g128 => {
let blocks = BlockQ1_0G128::slice_from_bytes(data).map_err(ModelError::Core)?;
let n = blocks.len() * QK1_0_G128;
let mut out = vec![0.0f32; n];
for (i, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = i * QK1_0_G128;
for j in 0..QK1_0_G128 {
let byte_index = j / 8;
let bit_offset = j % 8;
let bit = (block.qs[byte_index] >> bit_offset) & 1;
out[base + j] = if bit != 0 { d } else { -d };
}
}
Ok(out)
}
GgufTensorType::TQ2_0_g128 => {
let blocks = oxibonsai_core::BlockTQ2_0_g128::slice_from_bytes(data)
.map_err(ModelError::Core)?;
let n = blocks.len() * oxibonsai_core::QK_TQ2_0_G128;
let mut out = vec![0.0f32; n];
oxibonsai_core::BlockTQ2_0_g128::dequant(blocks, &mut out).map_err(ModelError::Core)?;
Ok(out)
}
other => Err(ModelError::MissingTensor {
name: format!("{name}: expected F32, F16, Q1_0_g128, or TQ2_0_g128, got {other}"),
}),
}
}
pub(super) fn load_1bit_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [BlockQ1_0G128]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
BlockQ1_0G128::slice_from_bytes(data).map_err(ModelError::Core)
}
pub(super) fn load_ternary_blocks<'a>(
gguf: &'a GgufFile<'a>,
name: &str,
) -> ModelResult<&'a [oxibonsai_core::BlockTQ2_0_g128]> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
oxibonsai_core::BlockTQ2_0_g128::slice_from_bytes(data).map_err(ModelError::Core)
}
#[allow(dead_code)]
pub(super) fn load_ternary_embedding(gguf: &GgufFile<'_>, name: &str) -> ModelResult<Vec<f32>> {
let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
let blocks =
oxibonsai_core::BlockTQ2_0_g128::slice_from_bytes(data).map_err(ModelError::Core)?;
let n = blocks.len() * oxibonsai_core::QK_TQ2_0_G128;
let mut out = vec![0.0f32; n];
oxibonsai_core::BlockTQ2_0_g128::dequant(blocks, &mut out).map_err(ModelError::Core)?;
Ok(out)
}
pub(super) fn load_transformer_block<'a>(
gguf: &'a GgufFile<'a>,
config: &Qwen3Config,
layer_idx: usize,
kernel: &std::sync::Arc<oxibonsai_kernels::KernelDispatcher>,
) -> ModelResult<TransformerBlock<'a>> {
let h = config.hidden_size;
let nq = config.num_attention_heads;
let nkv = config.num_kv_heads;
let hd = config.head_dim;
let inter = config.intermediate_size;
let blk = |suffix: &str| -> String { tensor_names::block_tensor(layer_idx, suffix) };
let sample_name = blk(tensor_names::ATTN_Q);
let sample_info = gguf
.tensors
.require(&sample_name)
.map_err(ModelError::Core)?;
let is_ternary = sample_info.tensor_type == GgufTensorType::TQ2_0_g128;
let attn_norm_w = load_f32_tensor(gguf, &blk(tensor_names::ATTN_NORM))?;
let ffn_norm_w = load_f32_tensor(gguf, &blk(tensor_names::FFN_NORM))?;
let q_norm_w = load_f32_tensor(gguf, &blk(tensor_names::ATTN_Q_NORM))?;
let k_norm_w = load_f32_tensor(gguf, &blk(tensor_names::ATTN_K_NORM))?;
if is_ternary {
let q_blocks = load_ternary_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_ternary_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_ternary_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_ternary_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_ternary_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_ternary_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_ternary_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
LinearTernary::new(q_blocks, nq * hd, h, kernel.clone())?.into(),
LinearTernary::new(k_blocks, nkv * hd, h, kernel.clone())?.into(),
LinearTernary::new(v_blocks, nkv * hd, h, kernel.clone())?.into(),
LinearTernary::new(o_blocks, h, nq * hd, kernel.clone())?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
LinearTernary::new(gate_blocks, inter, h, kernel.clone())?.into(),
LinearTernary::new(up_blocks, inter, h, kernel.clone())?.into(),
LinearTernary::new(down_blocks, h, inter, kernel.clone())?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded ternary transformer block");
Ok(block)
} else {
let q_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_Q))?;
let k_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_K))?;
let v_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_V))?;
let o_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;
let gate_blocks = load_1bit_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
let up_blocks = load_1bit_blocks(gguf, &blk(tensor_names::FFN_UP))?;
let down_blocks = load_1bit_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(attn_norm_w, config.rms_norm_eps),
Linear1Bit::new(q_blocks, nq * hd, h, kernel.clone())?.into(),
Linear1Bit::new(k_blocks, nkv * hd, h, kernel.clone())?.into(),
Linear1Bit::new(v_blocks, nkv * hd, h, kernel.clone())?.into(),
Linear1Bit::new(o_blocks, h, nq * hd, kernel.clone())?.into(),
RmsNorm::new(q_norm_w, config.rms_norm_eps),
RmsNorm::new(k_norm_w, config.rms_norm_eps),
RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
Linear1Bit::new(gate_blocks, inter, h, kernel.clone())?.into(),
Linear1Bit::new(up_blocks, inter, h, kernel.clone())?.into(),
Linear1Bit::new(down_blocks, h, inter, kernel.clone())?.into(),
nq,
nkv,
hd,
h,
);
tracing::trace!(layer = layer_idx, "loaded transformer block");
Ok(block)
}
}
pub(super) fn load_output_weight<'a>(
gguf: &'a GgufFile<'a>,
config: &Qwen3Config,
kernel: &std::sync::Arc<oxibonsai_kernels::KernelDispatcher>,
) -> ModelResult<OutputWeight<'a>> {
let info = gguf
.tensors
.require(tensor_names::OUTPUT)
.map_err(ModelError::Core)?;
let out_features = if info.shape.len() >= 2 {
info.shape[1] as usize
} else {
config.vocab_size
};
let in_features = if !info.shape.is_empty() {
info.shape[0] as usize
} else {
config.hidden_size
};
match info.tensor_type {
GgufTensorType::Q1_0_g128 => {
let blocks = load_1bit_blocks(gguf, tensor_names::OUTPUT)?;
let linear = Linear1Bit::new(blocks, out_features, in_features, kernel.clone())?;
Ok(OutputWeight::OneBit(linear))
}
GgufTensorType::TQ2_0_g128 => {
let blocks = load_ternary_blocks(gguf, tensor_names::OUTPUT)?;
let linear = LinearTernary::new(blocks, out_features, in_features, kernel.clone())?;
Ok(OutputWeight::Ternary(linear))
}
GgufTensorType::F32 | GgufTensorType::F16 => {
let weights = load_f32_tensor(gguf, tensor_names::OUTPUT)?;
Ok(OutputWeight::Fp32 {
weights,
out_features,
in_features,
})
}
other => Err(ModelError::MissingTensor {
name: format!("output.weight: unsupported type {other}"),
}),
}
}