oxibonsai-model 0.1.0

Qwen3-8B Transformer implementation for OxiBonsai 1-bit inference
Documentation
// Weight-loading helper functions extracted from model.rs.
// These are `pub(super)` so only `model.rs` can call them.

use oxibonsai_core::config::Qwen3Config;
use oxibonsai_core::gguf::reader::GgufFile;
use oxibonsai_core::gguf::tensor_info::tensor_names;
use oxibonsai_core::gguf::types::GgufTensorType;
use oxibonsai_core::tensor::{BlockQ1_0G128, QK1_0_G128};

use crate::block::TransformerBlock;
use crate::error::{ModelError, ModelResult};
use crate::layers::linear::Linear1Bit;
use crate::layers::rms_norm::RmsNorm;

use super::OutputWeight;

/// Load an FP32 tensor from GGUF by name.
pub(super) fn load_f32_tensor(gguf: &GgufFile<'_>, name: &str) -> ModelResult<Vec<f32>> {
    let info = gguf.tensors.require(name).map_err(ModelError::Core)?;

    let data = gguf.tensor_data(name).map_err(ModelError::Core)?;

    match info.tensor_type {
        GgufTensorType::F32 => {
            let count = data.len() / 4;
            let mut out = vec![0.0f32; count];
            for (i, chunk) in data.chunks_exact(4).enumerate() {
                out[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
            }
            Ok(out)
        }
        GgufTensorType::F16 => {
            let count = data.len() / 2;
            let mut out = vec![0.0f32; count];
            for (i, chunk) in data.chunks_exact(2).enumerate() {
                let raw = u16::from_le_bytes([chunk[0], chunk[1]]);
                out[i] = half::f16::from_bits(raw).to_f32();
            }
            Ok(out)
        }
        GgufTensorType::Q1_0_g128 => {
            let blocks = BlockQ1_0G128::slice_from_bytes(data).map_err(ModelError::Core)?;
            let n = blocks.len() * QK1_0_G128;
            let mut out = vec![0.0f32; n];
            for (i, block) in blocks.iter().enumerate() {
                let d = block.d.to_f32();
                let base = i * QK1_0_G128;
                for j in 0..QK1_0_G128 {
                    let byte_index = j / 8;
                    let bit_offset = j % 8;
                    let bit = (block.qs[byte_index] >> bit_offset) & 1;
                    out[base + j] = if bit != 0 { d } else { -d };
                }
            }
            Ok(out)
        }
        other => Err(ModelError::MissingTensor {
            name: format!("{name}: expected F32, F16, or Q1_0_g128, got {other}"),
        }),
    }
}

/// Load Q1_0_g128 weight blocks from GGUF (zero-copy).
pub(super) fn load_1bit_blocks<'a>(
    gguf: &'a GgufFile<'a>,
    name: &str,
) -> ModelResult<&'a [BlockQ1_0G128]> {
    let data = gguf.tensor_data(name).map_err(ModelError::Core)?;
    BlockQ1_0G128::slice_from_bytes(data).map_err(ModelError::Core)
}

/// Load a single Transformer block's weights from GGUF.
pub(super) fn load_transformer_block<'a>(
    gguf: &'a GgufFile<'a>,
    config: &Qwen3Config,
    layer_idx: usize,
) -> ModelResult<TransformerBlock<'a>> {
    let h = config.hidden_size;
    let nq = config.num_attention_heads;
    let nkv = config.num_kv_heads;
    let hd = config.head_dim;
    let inter = config.intermediate_size;

    let blk = |suffix: &str| -> String { tensor_names::block_tensor(layer_idx, suffix) };

    // RMSNorm weights (FP32)
    let attn_norm_w = load_f32_tensor(gguf, &blk(tensor_names::ATTN_NORM))?;
    let ffn_norm_w = load_f32_tensor(gguf, &blk(tensor_names::FFN_NORM))?;
    let q_norm_w = load_f32_tensor(gguf, &blk(tensor_names::ATTN_Q_NORM))?;
    let k_norm_w = load_f32_tensor(gguf, &blk(tensor_names::ATTN_K_NORM))?;

    // Attention projection weights (Q1_0_g128)
    let q_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_Q))?;

    let k_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_K))?;
    let v_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_V))?;
    let o_blocks = load_1bit_blocks(gguf, &blk(tensor_names::ATTN_OUTPUT))?;

    // FFN weights (Q1_0_g128)
    let gate_blocks = load_1bit_blocks(gguf, &blk(tensor_names::FFN_GATE))?;
    let up_blocks = load_1bit_blocks(gguf, &blk(tensor_names::FFN_UP))?;
    let down_blocks = load_1bit_blocks(gguf, &blk(tensor_names::FFN_DOWN))?;

    let block = TransformerBlock::new(
        layer_idx,
        RmsNorm::new(attn_norm_w, config.rms_norm_eps),
        Linear1Bit::new(q_blocks, nq * hd, h),
        Linear1Bit::new(k_blocks, nkv * hd, h),
        Linear1Bit::new(v_blocks, nkv * hd, h),
        Linear1Bit::new(o_blocks, h, nq * hd),
        RmsNorm::new(q_norm_w, config.rms_norm_eps),
        RmsNorm::new(k_norm_w, config.rms_norm_eps),
        RmsNorm::new(ffn_norm_w, config.rms_norm_eps),
        Linear1Bit::new(gate_blocks, inter, h),
        Linear1Bit::new(up_blocks, inter, h),
        Linear1Bit::new(down_blocks, h, inter),
        nq,
        nkv,
        hd,
        h,
    );

    tracing::trace!(layer = layer_idx, "loaded transformer block");
    Ok(block)
}

/// Load the output (LM head) weight — may be Q1_0_g128 or FP32.
pub(super) fn load_output_weight<'a>(
    gguf: &'a GgufFile<'a>,
    config: &Qwen3Config,
) -> ModelResult<OutputWeight<'a>> {
    let info = gguf
        .tensors
        .require(tensor_names::OUTPUT)
        .map_err(ModelError::Core)?;

    // Derive actual output dimensions from the tensor shape rather than
    // config.vocab_size, which may reflect the tokenizer vocabulary rather
    // than the model's actual output projection size.
    let out_features = if info.shape.len() >= 2 {
        info.shape[1] as usize
    } else {
        config.vocab_size
    };
    let in_features = if !info.shape.is_empty() {
        info.shape[0] as usize
    } else {
        config.hidden_size
    };

    match info.tensor_type {
        GgufTensorType::Q1_0_g128 => {
            let blocks = load_1bit_blocks(gguf, tensor_names::OUTPUT)?;
            Ok(OutputWeight::OneBit(Linear1Bit::new(
                blocks,
                out_features,
                in_features,
            )))
        }
        GgufTensorType::F32 | GgufTensorType::F16 => {
            let weights = load_f32_tensor(gguf, tensor_names::OUTPUT)?;
            Ok(OutputWeight::Fp32 {
                weights,
                out_features,
                in_features,
            })
        }
        other => Err(ModelError::MissingTensor {
            name: format!("output.weight: unsupported type {other}"),
        }),
    }
}