use super::weight_loaders::{load_f32_tensor, load_output_weight, load_transformer_block};
use crate::block::TransformerBlock;
use crate::error::{ModelError, ModelResult};
use crate::kv_cache::KvCache;
use crate::layers::linear::{Linear1Bit, LinearFP8E4M3, LinearFP8E5M2, LinearTernary};
use crate::layers::linear_kquant_ext::{LinearQ5K, LinearQ6K};
use crate::layers::linear_kquant_full::{LinearQ2K, LinearQ3K, LinearQ4K, LinearQ8K};
use crate::layers::linear_standard::{LinearQ4_0, LinearQ8_0};
use crate::layers::rms_norm::RmsNorm;
use crate::layers::rope::RopeTable;
use crate::model_registry::ModelVariant;
use oxibonsai_core::config::Qwen3Config;
use oxibonsai_core::gguf::reader::GgufFile;
use oxibonsai_core::gguf::tensor_info::tensor_names;
use oxibonsai_kernels::traits::OneBitKernel;
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
mod forward_cuda;
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
mod forward_cuda_fp8;
#[cfg(all(feature = "metal", target_os = "macos"))]
mod forward_metal;
#[cfg(all(feature = "metal", target_os = "macos"))]
mod gpu_cache;
pub struct BonsaiModel<'a> {
config: Qwen3Config,
token_embd: Vec<f32>,
pub(crate) blocks: Vec<TransformerBlock<'a>>,
output_norm: RmsNorm,
output_weight: OutputWeight<'a>,
rope: RopeTable,
kv_cache: KvCache,
dominant_quant_type: oxibonsai_core::GgufTensorType,
#[cfg(all(feature = "metal", target_os = "macos"))]
gpu_weight_cache: std::sync::Mutex<Option<oxibonsai_kernels::CachedModelWeights>>,
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
cuda_qkv_cache: std::sync::Mutex<Option<std::sync::Arc<Vec<Vec<u8>>>>>,
}
impl<'a> BonsaiModel<'a> {
pub fn from_gguf(gguf: &'a GgufFile<'a>, max_seq_len: usize) -> ModelResult<Self> {
let mut config = Qwen3Config::from_metadata(&gguf.metadata)?;
if let Some(embd_info) = gguf.tensors.get(tensor_names::TOKEN_EMBD) {
if embd_info.shape.len() >= 2 {
let tensor_vocab = embd_info.shape[1] as usize;
if tensor_vocab != config.vocab_size {
tracing::warn!(
metadata_vocab = config.vocab_size, tensor_vocab,
"vocab_size mismatch: GGUF metadata says {} but token_embd tensor has {} rows; using tensor dimension",
config.vocab_size, tensor_vocab,
);
config.vocab_size = tensor_vocab;
}
}
}
let dominant_quant_type = {
let counts = gguf.tensors.count_by_type();
counts
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(ty, _)| ty)
.unwrap_or(oxibonsai_core::GgufTensorType::Q1_0_g128)
};
tracing::info!(
layers = config.num_layers,
hidden = config.hidden_size,
heads = config.num_attention_heads,
kv_heads = config.num_kv_heads,
vocab = config.vocab_size,
"loading BonsaiModel from GGUF"
);
let token_embd = load_f32_tensor(gguf, tensor_names::TOKEN_EMBD)?;
let output_norm_w = load_f32_tensor(gguf, tensor_names::OUTPUT_NORM)?;
let output_norm = RmsNorm::new(output_norm_w, config.rms_norm_eps);
let kernel = std::sync::Arc::new(oxibonsai_kernels::KernelDispatcher::auto_detect());
let output_weight = load_output_weight(gguf, &config, &kernel)?;
let mut blocks = Vec::with_capacity(config.num_layers);
for layer_idx in 0..config.num_layers {
let block = load_transformer_block(gguf, &config, layer_idx, &kernel)?;
blocks.push(block);
}
let rope = RopeTable::new(config.head_dim, max_seq_len, config.rope_freq_base);
let kv_cache = KvCache::new(
config.num_layers,
config.num_kv_heads,
config.head_dim,
max_seq_len,
);
tracing::info!(
blocks = blocks.len(),
embd_size = token_embd.len(),
max_seq_len,
"model loaded successfully"
);
Ok(Self {
config,
token_embd,
blocks,
output_norm,
output_weight,
rope,
kv_cache,
dominant_quant_type,
#[cfg(all(feature = "metal", target_os = "macos"))]
gpu_weight_cache: std::sync::Mutex::new(None),
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
cuda_qkv_cache: std::sync::Mutex::new(None),
})
}
pub fn new(config: Qwen3Config) -> Self {
let h = config.hidden_size;
let kv_cache = KvCache::new(
config.num_layers,
config.num_kv_heads,
config.head_dim,
4096,
);
let rope = RopeTable::new(config.head_dim, 4096, config.rope_freq_base);
Self {
token_embd: vec![0.0; config.vocab_size * h],
blocks: Vec::new(),
output_norm: RmsNorm::new(vec![1.0; h], config.rms_norm_eps),
output_weight: OutputWeight::Fp32 {
weights: vec![0.0; config.vocab_size * h],
out_features: config.vocab_size,
in_features: h,
},
rope,
kv_cache,
dominant_quant_type: oxibonsai_core::GgufTensorType::Q1_0_g128,
#[cfg(all(feature = "metal", target_os = "macos"))]
gpu_weight_cache: std::sync::Mutex::new(None),
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
cuda_qkv_cache: std::sync::Mutex::new(None),
config,
}
}
pub fn config(&self) -> &Qwen3Config {
&self.config
}
pub fn kv_cache_mut(&mut self) -> &mut KvCache {
&mut self.kv_cache
}
pub fn kv_cache(&self) -> &KvCache {
&self.kv_cache
}
pub fn reset(&mut self) {
self.kv_cache.clear();
}
pub fn reset_cache(&mut self) {
self.kv_cache.clear();
}
pub fn upload_weights_to_gpu(&mut self, kernel: &dyn OneBitKernel) {
let n_blocks = self.blocks.len();
if n_blocks == 0 {
return;
}
tracing::info!(blocks = n_blocks, "uploading model weights to GPU");
for block in &mut self.blocks {
block.upload_to_gpu(kernel);
}
match self.output_weight {
OutputWeight::OneBit(ref mut linear) => linear.upload_to_gpu(),
OutputWeight::Ternary(ref mut linear) => linear.upload_to_gpu(),
OutputWeight::FP8E4M3(_)
| OutputWeight::FP8E5M2(_)
| OutputWeight::Q4_0(_)
| OutputWeight::Q8_0(_)
| OutputWeight::Q5K(_)
| OutputWeight::Q6K(_)
| OutputWeight::Q2K(_)
| OutputWeight::Q3K(_)
| OutputWeight::Q4K(_)
| OutputWeight::Q8K(_) => {}
OutputWeight::Fp32 { .. } => {}
}
tracing::info!("GPU weight upload complete");
}
pub fn variant(&self) -> ModelVariant {
ModelVariant::from_config_and_sample_tensor_type(&self.config, self.dominant_quant_type)
}
pub fn num_parameters(&self) -> u64 {
self.variant().param_count()
}
pub fn model_size_bytes(&self) -> u64 {
self.variant().expected_model_size_bytes()
}
pub fn context_length(&self) -> usize {
self.config.max_context_length
}
pub fn num_layers(&self) -> usize {
self.config.num_layers
}
pub fn hidden_size(&self) -> usize {
self.config.hidden_size
}
pub fn kv_cache_memory_bytes(&self) -> usize {
self.kv_cache.memory_bytes()
}
pub fn from_gguf_auto(gguf: &'a GgufFile<'a>, max_seq_len: usize) -> ModelResult<Self> {
let model = Self::from_gguf(gguf, max_seq_len)?;
let variant = model.variant();
tracing::info!(
variant = variant.name(),
params = variant.param_count(),
"auto-detected model variant"
);
Ok(model)
}
pub fn forward_prefill(
&mut self,
token_ids: &[u32],
pos_start: usize,
kernel: &dyn OneBitKernel,
) -> ModelResult<Vec<f32>> {
if token_ids.is_empty() {
return Err(ModelError::MissingTensor {
name: "forward_prefill: empty token_ids".into(),
});
}
if token_ids.len() == 1 {
return self.forward(token_ids[0], pos_start, kernel);
}
let _gpu_kernel = kernel.is_gpu_accelerated();
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel && token_ids.len() <= 16 {
let mut last_logits = Vec::new();
for (i, &token_id) in token_ids.iter().enumerate() {
last_logits = self.forward(token_id, pos_start + i, kernel)?;
}
return Ok(last_logits);
}
#[cfg(all(feature = "metal", target_os = "macos"))]
if _gpu_kernel
&& !matches!(
&self.output_weight,
OutputWeight::FP8E4M3(_) | OutputWeight::FP8E5M2(_)
)
{
match self.try_metal_prefill_with_lm_head(token_ids, pos_start) {
Ok(logits) => return Ok(logits),
Err(e) => {
tracing::warn!(
error = % e,
"metal batch prefill failed, falling back to sequential"
);
}
}
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel
&& matches!(
&self.output_weight,
OutputWeight::FP8E4M3(_) | OutputWeight::FP8E5M2(_)
)
&& oxibonsai_kernels::CudaGraph::global().is_ok()
{
let is_e4m3 = matches!(&self.output_weight, OutputWeight::FP8E4M3(_));
match self.try_cuda_prefill_with_lm_head_fp8(token_ids, pos_start, is_e4m3) {
Ok(logits) => return Ok(logits),
Err(e) => {
tracing::warn!(
error = %e,
"cuda FP8 batch prefill failed, falling back to sequential"
);
}
}
let mut last_logits = Vec::new();
for (i, &token_id) in token_ids.iter().enumerate() {
last_logits = self.forward(token_id, pos_start + i, kernel)?;
}
return Ok(last_logits);
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel
&& matches!(
&self.output_weight,
OutputWeight::Q4_0(_) | OutputWeight::Q8_0(_)
)
&& oxibonsai_kernels::CudaGraph::global().is_ok()
{
let q4_0 = matches!(&self.output_weight, OutputWeight::Q4_0(_));
match self.try_cuda_prefill_with_lm_head_q_std(token_ids, pos_start, q4_0) {
Ok(logits) => return Ok(logits),
Err(e) => {
tracing::warn!(
error = %e,
"cuda Q4_0/Q8_0 batch prefill failed, falling back to sequential"
);
}
}
let mut last_logits = Vec::new();
for (i, &token_id) in token_ids.iter().enumerate() {
last_logits = self.forward(token_id, pos_start + i, kernel)?;
}
return Ok(last_logits);
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel
&& matches!(
&self.output_weight,
OutputWeight::Q2K(_)
| OutputWeight::Q3K(_)
| OutputWeight::Q4K(_)
| OutputWeight::Q5K(_)
| OutputWeight::Q6K(_)
| OutputWeight::Q8K(_)
)
&& oxibonsai_kernels::CudaGraph::global().is_ok()
{
let fmt = match &self.output_weight {
OutputWeight::Q2K(_) => oxibonsai_kernels::KQuantFormat::Q2K,
OutputWeight::Q3K(_) => oxibonsai_kernels::KQuantFormat::Q3K,
OutputWeight::Q4K(_) => oxibonsai_kernels::KQuantFormat::Q4K,
OutputWeight::Q5K(_) => oxibonsai_kernels::KQuantFormat::Q5K,
OutputWeight::Q6K(_) => oxibonsai_kernels::KQuantFormat::Q6K,
OutputWeight::Q8K(_) => oxibonsai_kernels::KQuantFormat::Q8K,
_ => unreachable!(),
};
match self.try_cuda_prefill_with_lm_head_k_quant(token_ids, pos_start, fmt) {
Ok(logits) => return Ok(logits),
Err(e) => {
tracing::warn!(error = %e,
"cuda K-quant batch prefill failed, falling back to sequential");
}
}
let mut last_logits = Vec::new();
for (i, &token_id) in token_ids.iter().enumerate() {
last_logits = self.forward(token_id, pos_start + i, kernel)?;
}
return Ok(last_logits);
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel {
match self.try_cuda_prefill_with_lm_head(token_ids, pos_start) {
Ok(logits) => return Ok(logits),
Err(e) => {
let msg = e.to_string();
if msg.contains("LM head not supported on CUDA prefill path") {
tracing::debug!(
error = % e,
"cuda batch prefill skipped (LM head dtype not supported), using sequential"
);
} else {
tracing::warn!(
error = % e,
"cuda batch prefill failed, falling back to sequential"
);
}
}
}
}
let mut last_logits = Vec::new();
for (i, &token_id) in token_ids.iter().enumerate() {
last_logits = self.forward(token_id, pos_start + i, kernel)?;
}
Ok(last_logits)
}
pub fn forward_prefill_verify(
&mut self,
token_ids: &[u32],
pos_start: usize,
kernel: &dyn OneBitKernel,
) -> ModelResult<Vec<u32>> {
if token_ids.is_empty() {
return Ok(vec![]);
}
let _gpu_kernel = kernel.is_gpu_accelerated();
#[cfg(all(feature = "metal", target_os = "macos"))]
if _gpu_kernel
&& !matches!(
&self.output_weight,
OutputWeight::FP8E4M3(_) | OutputWeight::FP8E5M2(_)
)
{
match self.try_metal_prefill_verify(token_ids, pos_start) {
Ok(ids) => return Ok(ids),
Err(e) => {
tracing::warn!(
error = % e,
"metal batch prefill verify failed, falling back to sequential"
);
}
}
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel
&& matches!(
&self.output_weight,
OutputWeight::FP8E4M3(_) | OutputWeight::FP8E5M2(_)
)
&& oxibonsai_kernels::CudaGraph::global().is_ok()
{
let is_e4m3 = matches!(&self.output_weight, OutputWeight::FP8E4M3(_));
match self.try_cuda_prefill_verify_fp8(token_ids, pos_start, is_e4m3) {
Ok(ids) => return Ok(ids),
Err(e) => {
tracing::warn!(
error = %e,
"cuda FP8 batch prefill verify failed, falling back to sequential"
);
}
}
let mut token_ids_out = Vec::with_capacity(token_ids.len());
for (i, &token_id) in token_ids.iter().enumerate() {
let logits = self.forward(token_id, pos_start + i, kernel)?;
let best_idx = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(j, _)| j as u32)
.unwrap_or(0);
token_ids_out.push(best_idx);
}
return Ok(token_ids_out);
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel
&& matches!(
&self.output_weight,
OutputWeight::Q2K(_)
| OutputWeight::Q3K(_)
| OutputWeight::Q4K(_)
| OutputWeight::Q5K(_)
| OutputWeight::Q6K(_)
| OutputWeight::Q8K(_)
)
&& oxibonsai_kernels::CudaGraph::global().is_ok()
{
let fmt = match &self.output_weight {
OutputWeight::Q2K(_) => oxibonsai_kernels::KQuantFormat::Q2K,
OutputWeight::Q3K(_) => oxibonsai_kernels::KQuantFormat::Q3K,
OutputWeight::Q4K(_) => oxibonsai_kernels::KQuantFormat::Q4K,
OutputWeight::Q5K(_) => oxibonsai_kernels::KQuantFormat::Q5K,
OutputWeight::Q6K(_) => oxibonsai_kernels::KQuantFormat::Q6K,
OutputWeight::Q8K(_) => oxibonsai_kernels::KQuantFormat::Q8K,
_ => unreachable!(),
};
match self.try_cuda_prefill_verify_k_quant(token_ids, pos_start, fmt) {
Ok(ids) => return Ok(ids),
Err(e) => {
tracing::warn!(error = %e,
"cuda K-quant batch prefill verify failed, falling back to sequential");
}
}
let mut token_ids_out = Vec::with_capacity(token_ids.len());
for (i, &token_id) in token_ids.iter().enumerate() {
let logits = self.forward(token_id, pos_start + i, kernel)?;
let mut best_idx = 0u32;
let mut best_val = f32::NEG_INFINITY;
for (j, &v) in logits.iter().enumerate() {
if v > best_val {
best_val = v;
best_idx = j as u32;
}
}
token_ids_out.push(best_idx);
}
return Ok(token_ids_out);
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel {
match self.try_cuda_prefill_verify(token_ids, pos_start) {
Ok(ids) => return Ok(ids),
Err(e) => {
tracing::warn!(
error = % e,
"cuda batch prefill verify failed, falling back to sequential"
);
}
}
}
let mut token_ids_out = Vec::with_capacity(token_ids.len());
for (i, &token_id) in token_ids.iter().enumerate() {
let logits = self.forward(token_id, pos_start + i, kernel)?;
let mut best_idx = 0u32;
let mut best_val = f32::NEG_INFINITY;
for (j, &v) in logits.iter().enumerate() {
if v > best_val {
best_val = v;
best_idx = j as u32;
}
}
token_ids_out.push(best_idx);
}
Ok(token_ids_out)
}
#[tracing::instrument(skip(self, kernel), fields(token_id, pos))]
pub fn forward(
&mut self,
token_id: u32,
pos: usize,
kernel: &dyn OneBitKernel,
) -> ModelResult<Vec<f32>> {
let h = self.config.hidden_size;
let vocab = self.config.vocab_size;
if pos >= self.kv_cache.max_seq_len() {
return Err(ModelError::SequenceTooLong {
seq_len: pos + 1,
max_ctx: self.kv_cache.max_seq_len(),
});
}
let embd_start = token_id as usize * h;
let embd_end = embd_start + h;
if embd_end > self.token_embd.len() {
return Err(ModelError::MissingTensor {
name: format!(
"token_id {} out of range (vocab={})",
token_id,
self.token_embd.len() / h
),
});
}
let mut hidden = self.token_embd[embd_start..embd_end].to_vec();
let t_blocks_start = std::time::Instant::now();
let _gpu_kernel = kernel.is_gpu_accelerated();
#[cfg(all(feature = "metal", target_os = "macos"))]
if _gpu_kernel {
let mut fused_logits = vec![0.0f32; vocab];
if self
.try_metal_full_forward_with_lm_head(&mut hidden, pos, &mut fused_logits)
.is_ok()
{
let t_elapsed = t_blocks_start.elapsed();
tracing::debug!(
target : "fwd_profile",
"pos={pos} fused_gpu={:.1}ms (metal layers+norm+lm_head)", t_elapsed
.as_secs_f64() * 1000.0,
);
return Ok(fused_logits);
}
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel
&& matches!(
&self.output_weight,
OutputWeight::FP8E4M3(_) | OutputWeight::FP8E5M2(_)
)
&& oxibonsai_kernels::CudaGraph::global().is_ok()
{
for block in &self.blocks {
block.forward(&mut hidden, pos, &mut self.kv_cache, &self.rope, kernel)?;
}
let t_blocks_elapsed = t_blocks_start.elapsed();
tracing::debug!(
target: "fwd_profile",
"pos={pos} fp8_cuda_dispatch={:.1}ms (cuda gemv via block dispatch)",
t_blocks_elapsed.as_secs_f64() * 1000.0,
);
let t_norm_start = std::time::Instant::now();
let mut normed = vec![0.0f32; h];
self.output_norm.forward(&hidden, &mut normed)?;
let t_norm_elapsed = t_norm_start.elapsed();
let t_lm_start = std::time::Instant::now();
let mut logits = vec![0.0f32; vocab];
match &self.output_weight {
OutputWeight::FP8E4M3(lm_head) => lm_head.forward(&normed, &mut logits)?,
OutputWeight::FP8E5M2(lm_head) => lm_head.forward(&normed, &mut logits)?,
_ => unreachable!("checked above"),
}
let t_lm_elapsed = t_lm_start.elapsed();
tracing::debug!(
target: "fwd_profile",
"pos={pos} norm={:.2}ms lm_head={:.2}ms",
t_norm_elapsed.as_secs_f64() * 1000.0,
t_lm_elapsed.as_secs_f64() * 1000.0,
);
return Ok(logits);
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel
&& matches!(
&self.output_weight,
OutputWeight::Q4_0(_) | OutputWeight::Q8_0(_)
)
&& oxibonsai_kernels::CudaGraph::global().is_ok()
{
for block in &self.blocks {
block.forward(&mut hidden, pos, &mut self.kv_cache, &self.rope, kernel)?;
}
let t_blocks_elapsed = t_blocks_start.elapsed();
tracing::debug!(
target: "fwd_profile",
"pos={pos} q4q8_cuda_dispatch={:.1}ms (cuda gemv via block dispatch)",
t_blocks_elapsed.as_secs_f64() * 1000.0,
);
let t_norm_start = std::time::Instant::now();
let mut normed = vec![0.0f32; h];
self.output_norm.forward(&hidden, &mut normed)?;
let t_norm_elapsed = t_norm_start.elapsed();
let t_lm_start = std::time::Instant::now();
let mut logits = vec![0.0f32; vocab];
match &self.output_weight {
OutputWeight::Q4_0(lm_head) => lm_head.forward(&normed, &mut logits)?,
OutputWeight::Q8_0(lm_head) => lm_head.forward(&normed, &mut logits)?,
_ => unreachable!("checked above"),
}
let t_lm_elapsed = t_lm_start.elapsed();
tracing::debug!(
target: "fwd_profile",
"pos={pos} norm={:.2}ms lm_head={:.2}ms",
t_norm_elapsed.as_secs_f64() * 1000.0,
t_lm_elapsed.as_secs_f64() * 1000.0,
);
return Ok(logits);
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel
&& matches!(
&self.output_weight,
OutputWeight::Q2K(_)
| OutputWeight::Q3K(_)
| OutputWeight::Q4K(_)
| OutputWeight::Q5K(_)
| OutputWeight::Q6K(_)
| OutputWeight::Q8K(_)
)
&& oxibonsai_kernels::CudaGraph::global().is_ok()
{
for block in &self.blocks {
block.forward(&mut hidden, pos, &mut self.kv_cache, &self.rope, kernel)?;
}
let t_blocks_elapsed = t_blocks_start.elapsed();
tracing::debug!(
target: "fwd_profile",
"pos={pos} kquant_cuda_dispatch={:.1}ms (cuda gemv via block dispatch)",
t_blocks_elapsed.as_secs_f64() * 1000.0,
);
let t_norm_start = std::time::Instant::now();
let mut normed = vec![0.0f32; h];
self.output_norm.forward(&hidden, &mut normed)?;
let t_norm_elapsed = t_norm_start.elapsed();
let t_lm_start = std::time::Instant::now();
let mut logits = vec![0.0f32; vocab];
match &self.output_weight {
OutputWeight::Q2K(lm_head) => lm_head.forward(&normed, &mut logits)?,
OutputWeight::Q3K(lm_head) => lm_head.forward(&normed, &mut logits)?,
OutputWeight::Q4K(lm_head) => lm_head.forward(&normed, &mut logits)?,
OutputWeight::Q5K(lm_head) => lm_head.forward(&normed, &mut logits)?,
OutputWeight::Q6K(lm_head) => lm_head.forward(&normed, &mut logits)?,
OutputWeight::Q8K(lm_head) => lm_head.forward(&normed, &mut logits)?,
_ => unreachable!("checked above"),
}
let t_lm_elapsed = t_lm_start.elapsed();
tracing::debug!(
target: "fwd_profile",
"pos={pos} norm={:.2}ms lm_head={:.2}ms",
t_norm_elapsed.as_secs_f64() * 1000.0,
t_lm_elapsed.as_secs_f64() * 1000.0,
);
return Ok(logits);
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if _gpu_kernel {
if let Ok(fused_logits) = self.try_cuda_full_forward_with_lm_head(&hidden, pos) {
return Ok(fused_logits);
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
let did_full_forward = if _gpu_kernel {
let q1_ok = self.try_metal_full_forward_inner(&mut hidden, pos).is_ok();
if q1_ok {
true
} else {
self.try_metal_full_forward_ternary_inner(&mut hidden, pos)
.is_ok()
}
} else {
false
};
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
let did_full_forward = if _gpu_kernel {
match self.try_cuda_full_forward_inner(&hidden, pos) {
Ok(new_hidden) => {
hidden = new_hidden;
true
}
Err(_) => false,
}
} else {
false
};
#[cfg(not(any(
all(feature = "metal", target_os = "macos"),
all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
)
)))]
let did_full_forward = false;
if !did_full_forward {
for block in &self.blocks {
block.forward(&mut hidden, pos, &mut self.kv_cache, &self.rope, kernel)?;
}
}
let t_blocks_elapsed = t_blocks_start.elapsed();
let t_norm_start = std::time::Instant::now();
let mut normed = vec![0.0f32; h];
self.output_norm.forward(&hidden, &mut normed)?;
let t_norm_elapsed = t_norm_start.elapsed();
let t_lm_start = std::time::Instant::now();
let mut logits = vec![0.0f32; vocab];
match &self.output_weight {
OutputWeight::OneBit(linear) => {
linear.forward_vec(&normed, &mut logits)?;
}
OutputWeight::Ternary(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::FP8E4M3(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::FP8E5M2(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::Q4_0(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::Q8_0(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::Q5K(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::Q6K(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::Q2K(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::Q3K(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::Q4K(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::Q8K(linear) => {
linear.forward(&normed, &mut logits)?;
}
OutputWeight::Fp32 {
weights,
out_features,
in_features,
} => {
for (i, logit) in logits.iter_mut().enumerate().take(*out_features) {
let row_start = i * in_features;
let mut sum = 0.0f32;
for j in 0..*in_features {
sum += weights[row_start + j] * normed[j];
}
*logit = sum;
}
}
}
let t_lm_elapsed = t_lm_start.elapsed();
tracing::debug!(
target : "fwd_profile",
"pos={pos} blocks={:.1}ms norm={:.1}ms lm_head={:.1}ms gpu={}",
t_blocks_elapsed.as_secs_f64() * 1000.0, t_norm_elapsed.as_secs_f64() *
1000.0, t_lm_elapsed.as_secs_f64() * 1000.0, did_full_forward,
);
Ok(logits)
}
}
pub(super) enum OutputWeight<'a> {
OneBit(Linear1Bit<'a>),
Ternary(LinearTernary<'a>),
FP8E4M3(LinearFP8E4M3<'a>),
FP8E5M2(LinearFP8E5M2<'a>),
Q4_0(LinearQ4_0<'a>),
Q8_0(LinearQ8_0<'a>),
Q5K(LinearQ5K<'a>),
Q6K(LinearQ6K<'a>),
Q2K(LinearQ2K<'a>),
Q3K(LinearQ3K<'a>),
Q4K(LinearQ4K<'a>),
Q8K(LinearQ8K<'a>),
Fp32 {
weights: Vec<f32>,
out_features: usize,
in_features: usize,
},
}
impl BonsaiModel<'static> {
pub fn new_for_testing_with_blocks(config: Qwen3Config) -> Self {
use crate::block::TransformerBlock;
use crate::layers::linear::{Linear1Bit, LinearLayer};
use half::f16;
use oxibonsai_core::tensor::BlockQ1_0G128;
use oxibonsai_kernels::{KernelDispatcher, KernelTier};
use std::sync::Arc;
let h = config.hidden_size;
let hd = config.head_dim;
let nq = config.num_attention_heads;
let nkv = config.num_kv_heads;
let inter = config.intermediate_size;
assert!(
h % 128 == 0,
"test fixture requires hidden_size to be a multiple of 128"
);
assert!(
inter % 128 == 0,
"test fixture requires intermediate_size to be a multiple of 128"
);
let h_bpr = h / 128;
let inter_bpr = inter / 128;
let kernel_arc = Arc::new(KernelDispatcher::with_tier(KernelTier::Reference));
let kv_cache = KvCache::new(
config.num_layers,
config.num_kv_heads,
config.head_dim,
4096,
);
let rope = RopeTable::new(config.head_dim, 4096, config.rope_freq_base);
fn make_blocks_static(n: usize, scale: f32, pattern: u8) -> &'static [BlockQ1_0G128] {
let v: Vec<BlockQ1_0G128> = (0..n)
.map(|i| BlockQ1_0G128 {
d: f16::from_f32(scale),
qs: [pattern.wrapping_add((i & 0xff) as u8); 16],
})
.collect();
Box::leak(v.into_boxed_slice())
}
let mut blocks = Vec::with_capacity(config.num_layers);
for layer_idx in 0..config.num_layers {
let q_blk = make_blocks_static(nq * hd * h_bpr, 0.01, 0xA5);
let k_blk = make_blocks_static(nkv * hd * h_bpr, 0.01, 0x5A);
let v_blk = make_blocks_static(nkv * hd * h_bpr, 0.01, 0x33);
let o_blk = make_blocks_static(h * (nq * hd / 128).max(1), 0.01, 0xCC);
let g_blk = make_blocks_static(inter * h_bpr, 0.01, 0x77);
let u_blk = make_blocks_static(inter * h_bpr, 0.01, 0x88);
let d_blk = make_blocks_static(h * inter_bpr, 0.01, 0x99);
let attn_q: LinearLayer<'static> =
Linear1Bit::new(q_blk, nq * hd, h, kernel_arc.clone())
.expect("q proj")
.into();
let attn_k: LinearLayer<'static> =
Linear1Bit::new(k_blk, nkv * hd, h, kernel_arc.clone())
.expect("k proj")
.into();
let attn_v: LinearLayer<'static> =
Linear1Bit::new(v_blk, nkv * hd, h, kernel_arc.clone())
.expect("v proj")
.into();
let attn_out: LinearLayer<'static> =
Linear1Bit::new(o_blk, h, nq * hd, kernel_arc.clone())
.expect("o proj")
.into();
let ffn_gate: LinearLayer<'static> =
Linear1Bit::new(g_blk, inter, h, kernel_arc.clone())
.expect("gate proj")
.into();
let ffn_up: LinearLayer<'static> = Linear1Bit::new(u_blk, inter, h, kernel_arc.clone())
.expect("up proj")
.into();
let ffn_down: LinearLayer<'static> =
Linear1Bit::new(d_blk, h, inter, kernel_arc.clone())
.expect("down proj")
.into();
let block = TransformerBlock::new(
layer_idx,
RmsNorm::new(vec![1.0; h], config.rms_norm_eps),
attn_q,
attn_k,
attn_v,
attn_out,
RmsNorm::new(vec![1.0; hd], config.rms_norm_eps),
RmsNorm::new(vec![1.0; hd], config.rms_norm_eps),
RmsNorm::new(vec![1.0; h], config.rms_norm_eps),
ffn_gate,
ffn_up,
ffn_down,
nq,
nkv,
hd,
h,
);
blocks.push(block);
}
Self {
token_embd: vec![0.01; config.vocab_size * h],
blocks,
output_norm: RmsNorm::new(vec![1.0; h], config.rms_norm_eps),
output_weight: OutputWeight::Fp32 {
weights: vec![0.0; config.vocab_size * h],
out_features: config.vocab_size,
in_features: h,
},
rope,
kv_cache,
dominant_quant_type: oxibonsai_core::GgufTensorType::Q1_0_g128,
#[cfg(all(feature = "metal", target_os = "macos"))]
gpu_weight_cache: std::sync::Mutex::new(None),
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
cuda_qkv_cache: std::sync::Mutex::new(None),
config,
}
}
}