mod weight_loaders;
use weight_loaders::{load_f32_tensor, load_output_weight, load_transformer_block};
use oxibonsai_core::config::Qwen3Config;
use oxibonsai_core::gguf::reader::GgufFile;
use oxibonsai_core::gguf::tensor_info::tensor_names;
use oxibonsai_kernels::traits::OneBitKernel;
#[cfg(any(
all(feature = "metal", target_os = "macos"),
all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
)
))]
use crate::block::blocks_as_bytes;
use crate::block::TransformerBlock;
use crate::error::{ModelError, ModelResult};
use crate::kv_cache::KvCache;
use crate::layers::linear::Linear1Bit;
use crate::layers::rms_norm::RmsNorm;
use crate::layers::rope::RopeTable;
use crate::model_registry::ModelVariant;
pub struct BonsaiModel<'a> {
config: Qwen3Config,
token_embd: Vec<f32>,
blocks: Vec<TransformerBlock<'a>>,
output_norm: RmsNorm,
output_weight: OutputWeight<'a>,
rope: RopeTable,
kv_cache: KvCache,
#[cfg(all(feature = "metal", target_os = "macos"))]
gpu_weight_cache: std::sync::Mutex<Option<oxibonsai_kernels::CachedModelWeights>>,
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
cuda_qkv_cache: std::sync::Mutex<Option<std::sync::Arc<Vec<Vec<u8>>>>>,
}
enum OutputWeight<'a> {
OneBit(Linear1Bit<'a>),
Fp32 {
weights: Vec<f32>,
out_features: usize,
in_features: usize,
},
}
impl<'a> BonsaiModel<'a> {
pub fn from_gguf(gguf: &'a GgufFile<'a>, max_seq_len: usize) -> ModelResult<Self> {
let mut config = Qwen3Config::from_metadata(&gguf.metadata)?;
if let Some(embd_info) = gguf.tensors.get(tensor_names::TOKEN_EMBD) {
if embd_info.shape.len() >= 2 {
let tensor_vocab = embd_info.shape[1] as usize;
if tensor_vocab != config.vocab_size {
tracing::warn!(
metadata_vocab = config.vocab_size,
tensor_vocab,
"vocab_size mismatch: GGUF metadata says {} but token_embd tensor has {} rows; using tensor dimension",
config.vocab_size,
tensor_vocab,
);
config.vocab_size = tensor_vocab;
}
}
}
tracing::info!(
layers = config.num_layers,
hidden = config.hidden_size,
heads = config.num_attention_heads,
kv_heads = config.num_kv_heads,
vocab = config.vocab_size,
"loading BonsaiModel from GGUF"
);
let token_embd = load_f32_tensor(gguf, tensor_names::TOKEN_EMBD)?;
let output_norm_w = load_f32_tensor(gguf, tensor_names::OUTPUT_NORM)?;
let output_norm = RmsNorm::new(output_norm_w, config.rms_norm_eps);
let output_weight = load_output_weight(gguf, &config)?;
let mut blocks = Vec::with_capacity(config.num_layers);
for layer_idx in 0..config.num_layers {
let block = load_transformer_block(gguf, &config, layer_idx)?;
blocks.push(block);
}
let rope = RopeTable::new(config.head_dim, max_seq_len, config.rope_freq_base);
let kv_cache = KvCache::new(
config.num_layers,
config.num_kv_heads,
config.head_dim,
max_seq_len,
);
tracing::info!(
blocks = blocks.len(),
embd_size = token_embd.len(),
max_seq_len,
"model loaded successfully"
);
Ok(Self {
config,
token_embd,
blocks,
output_norm,
output_weight,
rope,
kv_cache,
#[cfg(all(feature = "metal", target_os = "macos"))]
gpu_weight_cache: std::sync::Mutex::new(None),
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
cuda_qkv_cache: std::sync::Mutex::new(None),
})
}
pub fn new(config: Qwen3Config) -> Self {
let h = config.hidden_size;
let kv_cache = KvCache::new(
config.num_layers,
config.num_kv_heads,
config.head_dim,
4096,
);
let rope = RopeTable::new(config.head_dim, 4096, config.rope_freq_base);
Self {
token_embd: vec![0.0; config.vocab_size * h],
blocks: Vec::new(),
output_norm: RmsNorm::new(vec![1.0; h], config.rms_norm_eps),
output_weight: OutputWeight::Fp32 {
weights: vec![0.0; config.vocab_size * h],
out_features: config.vocab_size,
in_features: h,
},
rope,
kv_cache,
#[cfg(all(feature = "metal", target_os = "macos"))]
gpu_weight_cache: std::sync::Mutex::new(None),
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
cuda_qkv_cache: std::sync::Mutex::new(None),
config,
}
}
pub fn config(&self) -> &Qwen3Config {
&self.config
}
pub fn kv_cache_mut(&mut self) -> &mut KvCache {
&mut self.kv_cache
}
pub fn reset(&mut self) {
self.kv_cache.clear();
}
pub fn reset_cache(&mut self) {
self.kv_cache.clear();
}
pub fn upload_weights_to_gpu(&mut self, kernel: &dyn OneBitKernel) {
let n_blocks = self.blocks.len();
if n_blocks == 0 {
return;
}
tracing::info!(blocks = n_blocks, "uploading model weights to GPU");
for block in &mut self.blocks {
block.upload_to_gpu(kernel);
}
if let OutputWeight::OneBit(ref mut linear) = self.output_weight {
linear.upload_to_gpu(kernel);
}
tracing::info!("GPU weight upload complete");
}
pub fn variant(&self) -> ModelVariant {
ModelVariant::from_config(&self.config)
}
pub fn num_parameters(&self) -> u64 {
self.variant().param_count()
}
pub fn model_size_bytes(&self) -> u64 {
self.variant().expected_model_size_bytes()
}
pub fn context_length(&self) -> usize {
self.config.max_context_length
}
pub fn num_layers(&self) -> usize {
self.config.num_layers
}
pub fn hidden_size(&self) -> usize {
self.config.hidden_size
}
pub fn kv_cache_memory_bytes(&self) -> usize {
self.kv_cache.memory_bytes()
}
pub fn from_gguf_auto(gguf: &'a GgufFile<'a>, max_seq_len: usize) -> ModelResult<Self> {
let model = Self::from_gguf(gguf, max_seq_len)?;
let variant = model.variant();
tracing::info!(
variant = variant.name(),
params = variant.param_count(),
"auto-detected model variant"
);
Ok(model)
}
#[cfg(all(feature = "metal", target_os = "macos"))]
fn try_metal_full_forward_inner(
&self,
hidden: &mut [f32],
pos: usize,
) -> Result<(), Box<dyn std::error::Error>> {
use oxibonsai_kernels::FullForwardLayerParams;
let n_layers = self.blocks.len();
if n_layers == 0 {
return Err("no blocks".into());
}
let eps = self.blocks[0].attn_norm_eps();
let h = self.config.hidden_size;
let inter = self.config.intermediate_size;
let nq = self.config.num_attention_heads;
let nkv = self.config.num_kv_heads;
let hd = self.config.head_dim;
let max_seq_len = self.kv_cache.max_seq_len();
for block in &self.blocks {
if block.fused_qkv_gpu_handle().is_none()
|| block.attn_output_gpu_handle().is_none()
|| block.fused_gate_up_gpu_handle().is_none()
|| block.ffn_down_gpu_handle().is_none()
{
return Err("missing GPU handle".into());
}
}
let mut qkv_concats: Vec<Vec<u8>> = Vec::with_capacity(n_layers);
for block in &self.blocks {
let q_bytes = blocks_as_bytes(block.attn_q_blocks());
let k_bytes = blocks_as_bytes(block.attn_k_blocks());
let v_bytes = blocks_as_bytes(block.attn_v_blocks());
let mut concat = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
concat.extend_from_slice(q_bytes);
concat.extend_from_slice(k_bytes);
concat.extend_from_slice(v_bytes);
qkv_concats.push(concat);
}
let mut layer_params: Vec<FullForwardLayerParams<'_>> = Vec::with_capacity(n_layers);
for (i, block) in self.blocks.iter().enumerate() {
let norm_handle_base = 1_000_000u64 + (block.layer_index() as u64) * 10;
layer_params.push(FullForwardLayerParams {
attn_norm_handle: norm_handle_base,
attn_norm_bytes: block.attn_norm_weight(),
fused_qkv_handle: block.fused_qkv_gpu_handle().map(|h| h.id()).unwrap_or(0),
fused_qkv_bytes: &qkv_concats[i],
q_norm_handle: norm_handle_base + 1,
q_norm_bytes: block.q_norm_weight(),
k_norm_handle: norm_handle_base + 2,
k_norm_bytes: block.k_norm_weight(),
attn_proj_handle: block.attn_output_gpu_handle().map(|h| h.id()).unwrap_or(0),
attn_proj_bytes: blocks_as_bytes(block.attn_output_blocks()),
ffn_norm_handle: norm_handle_base + 3,
ffn_norm_bytes: block.ffn_norm_weight(),
gate_up_handle: block
.fused_gate_up_gpu_handle()
.map(|h| h.id())
.unwrap_or(0),
gate_bytes: blocks_as_bytes(block.ffn_gate_blocks()),
up_bytes: blocks_as_bytes(block.ffn_up_blocks()),
down_handle: block.ffn_down_gpu_handle().map(|h| h.id()).unwrap_or(0),
down_bytes: blocks_as_bytes(block.ffn_down_blocks()),
});
}
let rope_cos = self.rope.cos_at(pos);
let rope_sin = self.rope.sin_at(pos);
oxibonsai_kernels::try_metal_full_forward(
hidden,
pos,
n_layers,
&layer_params,
rope_cos,
rope_sin,
h,
inter,
nq,
nkv,
hd,
eps,
max_seq_len,
None,
None,
eps,
None,
None,
0,
None,
None,
)
.map_err(|e| {
tracing::warn!(error = %e, "full-forward GPU dispatch failed, falling back");
Box::new(e) as Box<dyn std::error::Error>
})
}
#[cfg(all(feature = "metal", target_os = "macos"))]
fn try_metal_full_forward_with_lm_head(
&self,
hidden: &mut [f32],
pos: usize,
logits: &mut Vec<f32>,
) -> Result<(), Box<dyn std::error::Error>> {
use oxibonsai_kernels::FullForwardLayerParams;
let n_layers = self.blocks.len();
if n_layers == 0 {
return Err("no blocks".into());
}
let lm_head_linear = match &self.output_weight {
OutputWeight::OneBit(linear) => linear,
OutputWeight::Fp32 { .. } => {
return Err("FP32 LM head not supported on fused GPU path".into());
}
};
let eps = self.blocks[0].attn_norm_eps();
let h = self.config.hidden_size;
let inter = self.config.intermediate_size;
let nq = self.config.num_attention_heads;
let nkv = self.config.num_kv_heads;
let hd = self.config.head_dim;
let max_seq_len = self.kv_cache.max_seq_len();
for block in &self.blocks {
if block.fused_qkv_gpu_handle().is_none()
|| block.attn_output_gpu_handle().is_none()
|| block.fused_gate_up_gpu_handle().is_none()
|| block.ffn_down_gpu_handle().is_none()
{
return Err("missing GPU handle".into());
}
}
let mut qkv_concats: Vec<Vec<u8>> = Vec::with_capacity(n_layers);
for block in &self.blocks {
let q_bytes = blocks_as_bytes(block.attn_q_blocks());
let k_bytes = blocks_as_bytes(block.attn_k_blocks());
let v_bytes = blocks_as_bytes(block.attn_v_blocks());
let mut concat = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
concat.extend_from_slice(q_bytes);
concat.extend_from_slice(k_bytes);
concat.extend_from_slice(v_bytes);
qkv_concats.push(concat);
}
let mut layer_params: Vec<FullForwardLayerParams<'_>> = Vec::with_capacity(n_layers);
for (i, block) in self.blocks.iter().enumerate() {
let norm_handle_base = 1_000_000u64 + (block.layer_index() as u64) * 10;
layer_params.push(FullForwardLayerParams {
attn_norm_handle: norm_handle_base,
attn_norm_bytes: block.attn_norm_weight(),
fused_qkv_handle: block.fused_qkv_gpu_handle().map(|h| h.id()).unwrap_or(0),
fused_qkv_bytes: &qkv_concats[i],
q_norm_handle: norm_handle_base + 1,
q_norm_bytes: block.q_norm_weight(),
k_norm_handle: norm_handle_base + 2,
k_norm_bytes: block.k_norm_weight(),
attn_proj_handle: block.attn_output_gpu_handle().map(|h| h.id()).unwrap_or(0),
attn_proj_bytes: blocks_as_bytes(block.attn_output_blocks()),
ffn_norm_handle: norm_handle_base + 3,
ffn_norm_bytes: block.ffn_norm_weight(),
gate_up_handle: block
.fused_gate_up_gpu_handle()
.map(|h| h.id())
.unwrap_or(0),
gate_bytes: blocks_as_bytes(block.ffn_gate_blocks()),
up_bytes: blocks_as_bytes(block.ffn_up_blocks()),
down_handle: block.ffn_down_gpu_handle().map(|h| h.id()).unwrap_or(0),
down_bytes: blocks_as_bytes(block.ffn_down_blocks()),
});
}
let rope_cos = self.rope.cos_at(pos);
let rope_sin = self.rope.sin_at(pos);
let final_norm_handle = 2_000_000u64;
let final_norm_bytes = self.output_norm.weight();
let final_norm_eps = self.output_norm.eps();
let lm_head_handle = 3_000_000u64;
let lm_head_bytes = blocks_as_bytes(lm_head_linear.blocks());
let lm_head_out_features = lm_head_linear.out_features();
oxibonsai_kernels::try_metal_full_forward(
hidden,
pos,
n_layers,
&layer_params,
rope_cos,
rope_sin,
h,
inter,
nq,
nkv,
hd,
eps,
max_seq_len,
Some(final_norm_handle),
Some(final_norm_bytes),
final_norm_eps,
Some(lm_head_handle),
Some(lm_head_bytes),
lm_head_out_features,
Some(logits),
None, )
.map_err(|e| {
tracing::warn!(error = %e, "full-forward+lm_head GPU dispatch failed, falling back");
Box::new(e) as Box<dyn std::error::Error>
})
}
pub fn forward_prefill(
&mut self,
token_ids: &[u32],
pos_start: usize,
kernel: &dyn OneBitKernel,
) -> ModelResult<Vec<f32>> {
if token_ids.is_empty() {
return Err(ModelError::MissingTensor {
name: "forward_prefill: empty token_ids".into(),
});
}
if token_ids.len() == 1 {
return self.forward(token_ids[0], pos_start, kernel);
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
if token_ids.len() <= 16 {
let mut last_logits = Vec::new();
for (i, &token_id) in token_ids.iter().enumerate() {
last_logits = self.forward(token_id, pos_start + i, kernel)?;
}
return Ok(last_logits);
}
#[cfg(all(feature = "metal", target_os = "macos"))]
{
match self.try_metal_prefill_with_lm_head(token_ids, pos_start) {
Ok(logits) => return Ok(logits),
Err(e) => {
tracing::warn!(
error = %e,
"metal batch prefill failed, falling back to sequential"
);
}
}
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
{
match self.try_cuda_prefill_with_lm_head(token_ids, pos_start) {
Ok(logits) => return Ok(logits),
Err(e) => {
tracing::warn!(
error = %e,
"cuda batch prefill failed, falling back to sequential"
);
}
}
}
let mut last_logits = Vec::new();
for (i, &token_id) in token_ids.iter().enumerate() {
last_logits = self.forward(token_id, pos_start + i, kernel)?;
}
Ok(last_logits)
}
pub fn forward_prefill_verify(
&mut self,
token_ids: &[u32],
pos_start: usize,
kernel: &dyn OneBitKernel,
) -> ModelResult<Vec<u32>> {
if token_ids.is_empty() {
return Ok(vec![]);
}
#[cfg(all(feature = "metal", target_os = "macos"))]
{
match self.try_metal_prefill_verify(token_ids, pos_start) {
Ok(ids) => return Ok(ids),
Err(e) => {
tracing::warn!(
error = %e,
"metal batch prefill verify failed, falling back to sequential"
);
}
}
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
{
match self.try_cuda_prefill_verify(token_ids, pos_start) {
Ok(ids) => return Ok(ids),
Err(e) => {
tracing::warn!(
error = %e,
"cuda batch prefill verify failed, falling back to sequential"
);
}
}
}
let mut token_ids_out = Vec::with_capacity(token_ids.len());
for (i, &token_id) in token_ids.iter().enumerate() {
let logits = self.forward(token_id, pos_start + i, kernel)?;
let mut best_idx = 0u32;
let mut best_val = f32::NEG_INFINITY;
for (j, &v) in logits.iter().enumerate() {
if v > best_val {
best_val = v;
best_idx = j as u32;
}
}
token_ids_out.push(best_idx);
}
Ok(token_ids_out)
}
#[cfg(all(feature = "metal", target_os = "macos"))]
fn try_metal_prefill_with_lm_head(
&self,
token_ids: &[u32],
pos_start: usize,
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
use oxibonsai_kernels::FullForwardLayerParams;
let batch_size = token_ids.len();
let n_layers = self.blocks.len();
if n_layers == 0 {
return Err("no blocks".into());
}
let lm_head_linear = match &self.output_weight {
OutputWeight::OneBit(linear) => linear,
OutputWeight::Fp32 { .. } => {
return Err("FP32 LM head not supported on GPU prefill path".into());
}
};
let eps = self.blocks[0].attn_norm_eps();
let h = self.config.hidden_size;
let inter = self.config.intermediate_size;
let nq = self.config.num_attention_heads;
let nkv = self.config.num_kv_heads;
let hd = self.config.head_dim;
let half_dim = hd / 2;
let max_seq_len = self.kv_cache.max_seq_len();
for block in &self.blocks {
if block.fused_qkv_gpu_handle().is_none()
|| block.attn_output_gpu_handle().is_none()
|| block.fused_gate_up_gpu_handle().is_none()
|| block.ffn_down_gpu_handle().is_none()
{
return Err("missing GPU handle".into());
}
}
let mut hidden_batch = vec![0.0f32; batch_size * h];
for (t, &token_id) in token_ids.iter().enumerate() {
let embd_start = token_id as usize * h;
let embd_end = embd_start + h;
if embd_end > self.token_embd.len() {
return Err(format!(
"token_id {} out of range (vocab={})",
token_id,
self.token_embd.len() / h
)
.into());
}
hidden_batch[t * h..(t + 1) * h]
.copy_from_slice(&self.token_embd[embd_start..embd_end]);
}
let mut cos_table = vec![0.0f32; batch_size * half_dim];
let mut sin_table = vec![0.0f32; batch_size * half_dim];
for t in 0..batch_size {
let pos = pos_start + t;
let cos_vals = self.rope.cos_at(pos);
let sin_vals = self.rope.sin_at(pos);
cos_table[t * half_dim..(t + 1) * half_dim].copy_from_slice(cos_vals);
sin_table[t * half_dim..(t + 1) * half_dim].copy_from_slice(sin_vals);
}
let mut qkv_concats: Vec<Vec<u8>> = Vec::with_capacity(n_layers);
for block in &self.blocks {
let q_bytes = blocks_as_bytes(block.attn_q_blocks());
let k_bytes = blocks_as_bytes(block.attn_k_blocks());
let v_bytes = blocks_as_bytes(block.attn_v_blocks());
let mut concat = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
concat.extend_from_slice(q_bytes);
concat.extend_from_slice(k_bytes);
concat.extend_from_slice(v_bytes);
qkv_concats.push(concat);
}
let mut layer_params: Vec<FullForwardLayerParams<'_>> = Vec::with_capacity(n_layers);
for (i, block) in self.blocks.iter().enumerate() {
let norm_handle_base = 1_000_000u64 + (block.layer_index() as u64) * 10;
layer_params.push(FullForwardLayerParams {
attn_norm_handle: norm_handle_base,
attn_norm_bytes: block.attn_norm_weight(),
fused_qkv_handle: block
.fused_qkv_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
fused_qkv_bytes: &qkv_concats[i],
q_norm_handle: norm_handle_base + 1,
q_norm_bytes: block.q_norm_weight(),
k_norm_handle: norm_handle_base + 2,
k_norm_bytes: block.k_norm_weight(),
attn_proj_handle: block
.attn_output_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
attn_proj_bytes: blocks_as_bytes(block.attn_output_blocks()),
ffn_norm_handle: norm_handle_base + 3,
ffn_norm_bytes: block.ffn_norm_weight(),
gate_up_handle: block
.fused_gate_up_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
gate_bytes: blocks_as_bytes(block.ffn_gate_blocks()),
up_bytes: blocks_as_bytes(block.ffn_up_blocks()),
down_handle: block.ffn_down_gpu_handle().map(|hnd| hnd.id()).unwrap_or(0),
down_bytes: blocks_as_bytes(block.ffn_down_blocks()),
});
}
let final_norm_handle = 2_000_000u64;
let final_norm_bytes = self.output_norm.weight();
let final_norm_eps = self.output_norm.eps();
let lm_head_handle = 3_000_000u64;
let lm_head_bytes = blocks_as_bytes(lm_head_linear.blocks());
let lm_head_out_features = lm_head_linear.out_features();
let mut logits = vec![0.0f32; lm_head_out_features];
oxibonsai_kernels::try_metal_full_forward_prefill(
&hidden_batch,
batch_size,
pos_start,
n_layers,
&layer_params,
&cos_table,
&sin_table,
h,
inter,
nq,
nkv,
hd,
eps,
max_seq_len,
Some(final_norm_handle),
Some(final_norm_bytes),
final_norm_eps,
Some(lm_head_handle),
Some(lm_head_bytes),
lm_head_out_features,
Some(&mut logits),
None,
)
.map_err(|e| {
tracing::warn!(error = %e, "batch prefill GPU dispatch failed");
Box::new(e) as Box<dyn std::error::Error>
})?;
Ok(logits)
}
#[cfg(all(feature = "metal", target_os = "macos"))]
fn try_metal_prefill_verify(
&self,
token_ids: &[u32],
pos_start: usize,
) -> Result<Vec<u32>, Box<dyn std::error::Error>> {
use oxibonsai_kernels::FullForwardLayerParams;
let batch_size = token_ids.len();
let n_layers = self.blocks.len();
if n_layers == 0 {
return Err("no blocks".into());
}
let lm_head_linear = match &self.output_weight {
OutputWeight::OneBit(linear) => linear,
OutputWeight::Fp32 { .. } => {
return Err("FP32 LM head not supported on GPU prefill verify path".into());
}
};
let eps = self.blocks[0].attn_norm_eps();
let h = self.config.hidden_size;
let inter = self.config.intermediate_size;
let nq = self.config.num_attention_heads;
let nkv = self.config.num_kv_heads;
let hd = self.config.head_dim;
let half_dim = hd / 2;
let max_seq_len = self.kv_cache.max_seq_len();
for block in &self.blocks {
if block.fused_qkv_gpu_handle().is_none()
|| block.attn_output_gpu_handle().is_none()
|| block.fused_gate_up_gpu_handle().is_none()
|| block.ffn_down_gpu_handle().is_none()
{
return Err("missing GPU handle".into());
}
}
let mut hidden_batch = vec![0.0f32; batch_size * h];
for (t, &token_id) in token_ids.iter().enumerate() {
let embd_start = token_id as usize * h;
let embd_end = embd_start + h;
if embd_end > self.token_embd.len() {
return Err(format!(
"token_id {} out of range (vocab={})",
token_id,
self.token_embd.len() / h
)
.into());
}
hidden_batch[t * h..(t + 1) * h]
.copy_from_slice(&self.token_embd[embd_start..embd_end]);
}
let mut cos_table = vec![0.0f32; batch_size * half_dim];
let mut sin_table = vec![0.0f32; batch_size * half_dim];
for t in 0..batch_size {
let pos = pos_start + t;
let cos_vals = self.rope.cos_at(pos);
let sin_vals = self.rope.sin_at(pos);
cos_table[t * half_dim..(t + 1) * half_dim].copy_from_slice(cos_vals);
sin_table[t * half_dim..(t + 1) * half_dim].copy_from_slice(sin_vals);
}
let mut qkv_concats: Vec<Vec<u8>> = Vec::with_capacity(n_layers);
for block in &self.blocks {
let q_bytes = blocks_as_bytes(block.attn_q_blocks());
let k_bytes = blocks_as_bytes(block.attn_k_blocks());
let v_bytes = blocks_as_bytes(block.attn_v_blocks());
let mut concat = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
concat.extend_from_slice(q_bytes);
concat.extend_from_slice(k_bytes);
concat.extend_from_slice(v_bytes);
qkv_concats.push(concat);
}
let mut layer_params: Vec<FullForwardLayerParams<'_>> = Vec::with_capacity(n_layers);
for (i, block) in self.blocks.iter().enumerate() {
let norm_handle_base = 1_000_000u64 + (block.layer_index() as u64) * 10;
layer_params.push(FullForwardLayerParams {
attn_norm_handle: norm_handle_base,
attn_norm_bytes: block.attn_norm_weight(),
fused_qkv_handle: block
.fused_qkv_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
fused_qkv_bytes: &qkv_concats[i],
q_norm_handle: norm_handle_base + 1,
q_norm_bytes: block.q_norm_weight(),
k_norm_handle: norm_handle_base + 2,
k_norm_bytes: block.k_norm_weight(),
attn_proj_handle: block
.attn_output_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
attn_proj_bytes: blocks_as_bytes(block.attn_output_blocks()),
ffn_norm_handle: norm_handle_base + 3,
ffn_norm_bytes: block.ffn_norm_weight(),
gate_up_handle: block
.fused_gate_up_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
gate_bytes: blocks_as_bytes(block.ffn_gate_blocks()),
up_bytes: blocks_as_bytes(block.ffn_up_blocks()),
down_handle: block.ffn_down_gpu_handle().map(|hnd| hnd.id()).unwrap_or(0),
down_bytes: blocks_as_bytes(block.ffn_down_blocks()),
});
}
let final_norm_handle = 2_000_000u64;
let final_norm_bytes = self.output_norm.weight();
let final_norm_eps = self.output_norm.eps();
let lm_head_handle = 3_000_000u64;
let lm_head_bytes = blocks_as_bytes(lm_head_linear.blocks());
let lm_head_out_features = lm_head_linear.out_features();
let mut batch_token_ids: Vec<u32> = Vec::with_capacity(batch_size);
oxibonsai_kernels::try_metal_full_forward_prefill_verify(
&hidden_batch,
batch_size,
pos_start,
n_layers,
&layer_params,
&cos_table,
&sin_table,
h,
inter,
nq,
nkv,
hd,
eps,
max_seq_len,
Some(final_norm_handle),
Some(final_norm_bytes),
final_norm_eps,
Some(lm_head_handle),
Some(lm_head_bytes),
lm_head_out_features,
&mut batch_token_ids,
)
.map_err(|e| {
tracing::warn!(error = %e, "batch prefill verify GPU dispatch failed");
Box::new(e) as Box<dyn std::error::Error>
})?;
Ok(batch_token_ids)
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn forward_greedy_gpu(
&self,
token_id: u32,
pos: usize,
) -> Result<u32, Box<dyn std::error::Error>> {
let h = self.config.hidden_size;
let inter = self.config.intermediate_size;
let nq = self.config.num_attention_heads;
let nkv = self.config.num_kv_heads;
let hd = self.config.head_dim;
let max_seq_len = self.kv_cache.max_seq_len();
let eps = if self.blocks.is_empty() {
return Err("no blocks".into());
} else {
self.blocks[0].attn_norm_eps()
};
let final_norm_eps = self.output_norm.eps();
let lm_head_out_features = match &self.output_weight {
OutputWeight::OneBit(linear) => linear.out_features(),
OutputWeight::Fp32 { .. } => {
return Err("FP32 LM head not supported on greedy GPU path".into());
}
};
self.get_or_create_gpu_cache()?;
let embd_start = token_id as usize * h;
let embd_end = embd_start + h;
if embd_end > self.token_embd.len() {
return Err(format!(
"token_id {} out of range (vocab={})",
token_id,
self.token_embd.len() / h
)
.into());
}
let mut hidden = self.token_embd[embd_start..embd_end].to_vec();
let rope_cos = self.rope.cos_at(pos);
let rope_sin = self.rope.sin_at(pos);
let mut greedy_token_id: u32 = 0;
let guard = self
.gpu_weight_cache
.lock()
.map_err(|e| format!("gpu_weight_cache lock: {e}"))?;
let cached = guard.as_ref().ok_or("GPU weight cache not populated")?;
oxibonsai_kernels::try_metal_full_forward_cached(
&mut hidden,
pos,
cached,
rope_cos,
rope_sin,
h,
inter,
nq,
nkv,
hd,
eps,
max_seq_len,
final_norm_eps,
lm_head_out_features,
None, Some(&mut greedy_token_id),
)
.map_err(|e| {
tracing::warn!(error = %e, "cached greedy GPU forward failed");
Box::new(e) as Box<dyn std::error::Error>
})?;
Ok(greedy_token_id)
}
#[cfg(all(feature = "metal", target_os = "macos"))]
pub fn get_or_create_gpu_cache(&self) -> Result<(), Box<dyn std::error::Error>> {
use oxibonsai_kernels::FullForwardLayerParams;
{
let guard = self
.gpu_weight_cache
.lock()
.map_err(|e| format!("gpu_weight_cache lock: {e}"))?;
if guard.is_some() {
return Ok(());
}
}
let n_layers = self.blocks.len();
let lm_head_linear = match &self.output_weight {
OutputWeight::OneBit(ref linear) => linear,
OutputWeight::Fp32 { .. } => return Err("FP32 LM head not supported".into()),
};
let mut qkv_concats: Vec<Vec<u8>> = Vec::with_capacity(n_layers);
for block in &self.blocks {
let q_bytes = blocks_as_bytes(block.attn_q_blocks());
let k_bytes = blocks_as_bytes(block.attn_k_blocks());
let v_bytes = blocks_as_bytes(block.attn_v_blocks());
let mut concat = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
concat.extend_from_slice(q_bytes);
concat.extend_from_slice(k_bytes);
concat.extend_from_slice(v_bytes);
qkv_concats.push(concat);
}
let mut layer_params: Vec<FullForwardLayerParams<'_>> = Vec::with_capacity(n_layers);
for (i, block) in self.blocks.iter().enumerate() {
let norm_handle_base = 1_000_000u64 + (block.layer_index() as u64) * 10;
layer_params.push(FullForwardLayerParams {
attn_norm_handle: norm_handle_base,
attn_norm_bytes: block.attn_norm_weight(),
fused_qkv_handle: block
.fused_qkv_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
fused_qkv_bytes: &qkv_concats[i],
q_norm_handle: norm_handle_base + 1,
q_norm_bytes: block.q_norm_weight(),
k_norm_handle: norm_handle_base + 2,
k_norm_bytes: block.k_norm_weight(),
attn_proj_handle: block
.attn_output_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
attn_proj_bytes: blocks_as_bytes(block.attn_output_blocks()),
ffn_norm_handle: norm_handle_base + 3,
ffn_norm_bytes: block.ffn_norm_weight(),
gate_up_handle: block
.fused_gate_up_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
gate_bytes: blocks_as_bytes(block.ffn_gate_blocks()),
up_bytes: blocks_as_bytes(block.ffn_up_blocks()),
down_handle: block.ffn_down_gpu_handle().map(|hnd| hnd.id()).unwrap_or(0),
down_bytes: blocks_as_bytes(block.ffn_down_blocks()),
});
}
let final_norm_handle = 2_000_000u64;
let final_norm_bytes = self.output_norm.weight();
let lm_head_handle = 3_000_000u64;
let lm_head_bytes = blocks_as_bytes(lm_head_linear.blocks());
let cached = oxibonsai_kernels::build_cached_weights(
&layer_params,
final_norm_handle,
final_norm_bytes,
lm_head_handle,
lm_head_bytes,
)
.map_err(|e| format!("build_cached_weights: {e}"))?;
let mut guard = self
.gpu_weight_cache
.lock()
.map_err(|e| format!("gpu_weight_cache lock: {e}"))?;
*guard = Some(cached);
tracing::info!("GPU weight cache populated (all subsequent tokens use cached handles)");
Ok(())
}
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
fn get_or_build_cuda_qkv_cache(
&self,
) -> Result<std::sync::Arc<Vec<Vec<u8>>>, Box<dyn std::error::Error>> {
let guard = self
.cuda_qkv_cache
.lock()
.map_err(|e| format!("cuda_qkv_cache lock: {e}"))?;
if let Some(ref cache) = *guard {
return Ok(std::sync::Arc::clone(cache));
}
drop(guard);
let n_layers = self.blocks.len();
let mut qkv_concats: Vec<Vec<u8>> = Vec::with_capacity(n_layers);
for block in &self.blocks {
let q_bytes = blocks_as_bytes(block.attn_q_blocks());
let k_bytes = blocks_as_bytes(block.attn_k_blocks());
let v_bytes = blocks_as_bytes(block.attn_v_blocks());
let mut concat = Vec::with_capacity(q_bytes.len() + k_bytes.len() + v_bytes.len());
concat.extend_from_slice(q_bytes);
concat.extend_from_slice(k_bytes);
concat.extend_from_slice(v_bytes);
qkv_concats.push(concat);
}
let mut guard = self
.cuda_qkv_cache
.lock()
.map_err(|e| format!("cuda_qkv_cache lock: {e}"))?;
let arc = std::sync::Arc::new(qkv_concats);
*guard = Some(std::sync::Arc::clone(&arc));
Ok(arc)
}
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
fn build_cuda_layer_params<'b>(
&'b self,
qkv_concats: &'b [Vec<u8>],
) -> Result<Vec<oxibonsai_kernels::CudaFullForwardLayerParams<'b>>, Box<dyn std::error::Error>>
{
let n_layers = self.blocks.len();
if n_layers == 0 {
return Err("no blocks".into());
}
let mut layer_params: Vec<oxibonsai_kernels::CudaFullForwardLayerParams<'b>> =
Vec::with_capacity(n_layers);
for (i, block) in self.blocks.iter().enumerate() {
let norm_handle_base = 1_000_000u64 + (block.layer_index() as u64) * 10;
layer_params.push(oxibonsai_kernels::CudaFullForwardLayerParams {
attn_norm_handle: norm_handle_base,
attn_norm_bytes: block.attn_norm_weight(),
fused_qkv_handle: block
.fused_qkv_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
fused_qkv_bytes: &qkv_concats[i],
q_norm_handle: norm_handle_base + 1,
q_norm_bytes: block.q_norm_weight(),
k_norm_handle: norm_handle_base + 2,
k_norm_bytes: block.k_norm_weight(),
attn_proj_handle: block
.attn_output_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
attn_proj_bytes: blocks_as_bytes(block.attn_output_blocks()),
ffn_norm_handle: norm_handle_base + 3,
ffn_norm_bytes: block.ffn_norm_weight(),
gate_up_handle: block
.fused_gate_up_gpu_handle()
.map(|hnd| hnd.id())
.unwrap_or(0),
gate_bytes: blocks_as_bytes(block.ffn_gate_blocks()),
up_bytes: blocks_as_bytes(block.ffn_up_blocks()),
down_handle: block.ffn_down_gpu_handle().map(|hnd| hnd.id()).unwrap_or(0),
down_bytes: blocks_as_bytes(block.ffn_down_blocks()),
});
}
Ok(layer_params)
}
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
fn try_cuda_full_forward_inner(
&self,
hidden: &[f32],
pos: usize,
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let n_layers = self.blocks.len();
if n_layers == 0 {
return Err("no blocks".into());
}
let eps = self.blocks[0].attn_norm_eps();
let h = self.config.hidden_size;
let inter = self.config.intermediate_size;
let nq = self.config.num_attention_heads;
let nkv = self.config.num_kv_heads;
let hd = self.config.head_dim;
let heads_per_group = if nkv > 0 { nq / nkv } else { 1 };
let max_seq_len = self.kv_cache.max_seq_len();
let qkv_concats = self.get_or_build_cuda_qkv_cache()?;
let layer_params = self.build_cuda_layer_params(&*qkv_concats)?;
let rope_cos = self.rope.cos_at(pos);
let rope_sin = self.rope.sin_at(pos);
oxibonsai_kernels::try_cuda_full_forward(
hidden,
&layer_params,
rope_cos,
rope_sin,
pos,
nq,
nkv,
hd,
heads_per_group,
eps,
h,
inter,
max_seq_len,
None,
0,
)
.ok_or_else(|| {
tracing::warn!("CUDA full-forward (layers only) returned None, falling back");
Box::<dyn std::error::Error>::from("CUDA layers-only forward returned None")
})
}
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
fn try_cuda_full_forward_with_lm_head(
&self,
hidden: &[f32],
pos: usize,
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let n_layers = self.blocks.len();
if n_layers == 0 {
return Err("no blocks".into());
}
let lm_head_linear = match &self.output_weight {
OutputWeight::OneBit(linear) => linear,
OutputWeight::Fp32 { .. } => {
return Err("FP32 LM head not supported on CUDA fused GPU path".into());
}
};
let eps = self.blocks[0].attn_norm_eps();
let h = self.config.hidden_size;
let inter = self.config.intermediate_size;
let nq = self.config.num_attention_heads;
let nkv = self.config.num_kv_heads;
let hd = self.config.head_dim;
let heads_per_group = if nkv > 0 { nq / nkv } else { 1 };
let max_seq_len = self.kv_cache.max_seq_len();
let final_norm_handle = 2_000_000u64;
let final_norm_bytes = self.output_norm.weight();
let lm_head_handle = 4_000_000u64;
let lm_head_bytes = blocks_as_bytes(lm_head_linear.blocks());
let vocab_size = lm_head_linear.out_features();
let qkv_concats = self.get_or_build_cuda_qkv_cache()?;
let layer_params = self.build_cuda_layer_params(&*qkv_concats)?;
let rope_cos = self.rope.cos_at(pos);
let rope_sin = self.rope.sin_at(pos);
match oxibonsai_kernels::try_cuda_full_forward_with_gpu_lm_head(
hidden,
&layer_params,
rope_cos,
rope_sin,
pos,
nq,
nkv,
hd,
heads_per_group,
eps,
h,
inter,
max_seq_len,
Some(final_norm_bytes),
final_norm_handle,
lm_head_handle,
&lm_head_bytes,
vocab_size,
) {
Some(gpu_logits) => Ok(gpu_logits),
None => {
tracing::warn!("CUDA full-forward+gpu_lm_head returned None, falling back");
Err("CUDA full-forward+gpu_lm_head returned None".into())
}
}
}
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
fn try_cuda_prefill_with_lm_head(
&self,
token_ids: &[u32],
pos_start: usize,
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let batch_size = token_ids.len();
let n_layers = self.blocks.len();
if n_layers == 0 {
return Err("no blocks".into());
}
let lm_head_linear = match &self.output_weight {
OutputWeight::OneBit(linear) => linear,
OutputWeight::Fp32 { .. } => {
return Err("FP32 LM head not supported on CUDA prefill path".into());
}
};
let eps = self.blocks[0].attn_norm_eps();
let h = self.config.hidden_size;
let inter = self.config.intermediate_size;
let nq = self.config.num_attention_heads;
let nkv = self.config.num_kv_heads;
let hd = self.config.head_dim;
let half_dim = hd / 2;
let heads_per_group = if nkv > 0 { nq / nkv } else { 1 };
let max_seq_len = self.kv_cache.max_seq_len();
let mut hidden_batch = vec![0.0f32; batch_size * h];
for (t, &token_id) in token_ids.iter().enumerate() {
let embd_start = token_id as usize * h;
let embd_end = embd_start + h;
if embd_end > self.token_embd.len() {
return Err(format!(
"token_id {} out of range (vocab={})",
token_id,
self.token_embd.len() / h
)
.into());
}
hidden_batch[t * h..(t + 1) * h]
.copy_from_slice(&self.token_embd[embd_start..embd_end]);
}
let mut cos_table = vec![0.0f32; batch_size * half_dim];
let mut sin_table = vec![0.0f32; batch_size * half_dim];
for t in 0..batch_size {
let pos = pos_start + t;
let cos_vals = self.rope.cos_at(pos);
let sin_vals = self.rope.sin_at(pos);
cos_table[t * half_dim..(t + 1) * half_dim].copy_from_slice(cos_vals);
sin_table[t * half_dim..(t + 1) * half_dim].copy_from_slice(sin_vals);
}
let final_norm_handle = 2_000_000u64;
let final_norm_bytes = self.output_norm.weight();
let final_norm_eps = self.output_norm.eps();
let lm_head_handle = 3_000_000u64;
let lm_head_bytes = blocks_as_bytes(lm_head_linear.blocks());
let lm_head_out_features = lm_head_linear.out_features();
let qkv_concats = self.get_or_build_cuda_qkv_cache()?;
let layer_params = self.build_cuda_layer_params(&*qkv_concats)?;
let mut logits = vec![0.0f32; lm_head_out_features];
oxibonsai_kernels::try_cuda_prefill(
&hidden_batch,
batch_size,
pos_start,
n_layers,
&layer_params,
&cos_table,
&sin_table,
h,
inter,
nq,
nkv,
hd,
heads_per_group,
eps,
max_seq_len,
Some(final_norm_handle),
Some(final_norm_bytes),
final_norm_eps,
Some(lm_head_handle),
Some(lm_head_bytes),
lm_head_out_features,
Some(&mut logits),
None,
)
.map_err(|e| {
tracing::warn!(error = %e, "CUDA batch prefill dispatch failed");
Box::new(e) as Box<dyn std::error::Error>
})?;
Ok(logits)
}
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
fn try_cuda_prefill_verify(
&self,
token_ids: &[u32],
pos_start: usize,
) -> Result<Vec<u32>, Box<dyn std::error::Error>> {
let batch_size = token_ids.len();
let n_layers = self.blocks.len();
if n_layers == 0 {
return Err("no blocks".into());
}
let lm_head_linear = match &self.output_weight {
OutputWeight::OneBit(linear) => linear,
OutputWeight::Fp32 { .. } => {
return Err("FP32 LM head not supported on CUDA prefill verify path".into());
}
};
let eps = self.blocks[0].attn_norm_eps();
let h = self.config.hidden_size;
let inter = self.config.intermediate_size;
let nq = self.config.num_attention_heads;
let nkv = self.config.num_kv_heads;
let hd = self.config.head_dim;
let half_dim = hd / 2;
let heads_per_group = if nkv > 0 { nq / nkv } else { 1 };
let max_seq_len = self.kv_cache.max_seq_len();
let mut hidden_batch = vec![0.0f32; batch_size * h];
for (t, &token_id) in token_ids.iter().enumerate() {
let embd_start = token_id as usize * h;
let embd_end = embd_start + h;
if embd_end > self.token_embd.len() {
return Err(format!(
"token_id {} out of range (vocab={})",
token_id,
self.token_embd.len() / h
)
.into());
}
hidden_batch[t * h..(t + 1) * h]
.copy_from_slice(&self.token_embd[embd_start..embd_end]);
}
let mut cos_table = vec![0.0f32; batch_size * half_dim];
let mut sin_table = vec![0.0f32; batch_size * half_dim];
for t in 0..batch_size {
let pos = pos_start + t;
let cos_vals = self.rope.cos_at(pos);
let sin_vals = self.rope.sin_at(pos);
cos_table[t * half_dim..(t + 1) * half_dim].copy_from_slice(cos_vals);
sin_table[t * half_dim..(t + 1) * half_dim].copy_from_slice(sin_vals);
}
let final_norm_handle = 2_000_000u64;
let final_norm_bytes = self.output_norm.weight();
let final_norm_eps = self.output_norm.eps();
let lm_head_handle = 3_000_000u64;
let lm_head_bytes = blocks_as_bytes(lm_head_linear.blocks());
let lm_head_out_features = lm_head_linear.out_features();
let qkv_concats = self.get_or_build_cuda_qkv_cache()?;
let layer_params = self.build_cuda_layer_params(&*qkv_concats)?;
let mut token_ids_out: Vec<u32> = Vec::with_capacity(batch_size);
for t in 0..batch_size {
let single_embd_start = token_ids[t] as usize * h;
let single_hidden = self.token_embd[single_embd_start..single_embd_start + h].to_vec();
let pos = pos_start + t;
let t_half = half_dim;
let cos_single = &cos_table[t * t_half..(t + 1) * t_half];
let sin_single = &sin_table[t * t_half..(t + 1) * t_half];
let mut greedy_id: u32 = 0;
oxibonsai_kernels::try_cuda_prefill(
&single_hidden,
1,
pos,
n_layers,
&layer_params,
cos_single,
sin_single,
h,
inter,
nq,
nkv,
hd,
heads_per_group,
eps,
max_seq_len,
Some(final_norm_handle),
Some(final_norm_bytes),
final_norm_eps,
Some(lm_head_handle),
Some(lm_head_bytes),
lm_head_out_features,
None,
Some(&mut greedy_id),
)
.map_err(|e| {
tracing::warn!(error = %e, "CUDA prefill verify dispatch failed at pos {pos}");
Box::new(e) as Box<dyn std::error::Error>
})?;
token_ids_out.push(greedy_id);
}
Ok(token_ids_out)
}
#[tracing::instrument(skip(self, kernel), fields(token_id, pos))]
pub fn forward(
&mut self,
token_id: u32,
pos: usize,
kernel: &dyn OneBitKernel,
) -> ModelResult<Vec<f32>> {
let h = self.config.hidden_size;
let vocab = self.config.vocab_size;
if pos >= self.kv_cache.max_seq_len() {
return Err(ModelError::SequenceTooLong {
seq_len: pos + 1,
max_ctx: self.kv_cache.max_seq_len(),
});
}
let embd_start = token_id as usize * h;
let embd_end = embd_start + h;
if embd_end > self.token_embd.len() {
return Err(ModelError::MissingTensor {
name: format!(
"token_id {} out of range (vocab={})",
token_id,
self.token_embd.len() / h
),
});
}
let mut hidden = self.token_embd[embd_start..embd_end].to_vec();
let t_blocks_start = std::time::Instant::now();
#[cfg(all(feature = "metal", target_os = "macos"))]
{
let mut fused_logits = vec![0.0f32; vocab];
if self
.try_metal_full_forward_with_lm_head(&mut hidden, pos, &mut fused_logits)
.is_ok()
{
let t_elapsed = t_blocks_start.elapsed();
tracing::debug!(
target: "fwd_profile",
"pos={pos} fused_gpu={:.1}ms (metal layers+norm+lm_head)",
t_elapsed.as_secs_f64() * 1000.0,
);
return Ok(fused_logits);
}
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
{
if let Ok(fused_logits) = self.try_cuda_full_forward_with_lm_head(&hidden, pos) {
return Ok(fused_logits);
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
let did_full_forward = self.try_metal_full_forward_inner(&mut hidden, pos).is_ok();
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
let did_full_forward = match self.try_cuda_full_forward_inner(&hidden, pos) {
Ok(new_hidden) => {
hidden = new_hidden;
true
}
Err(_) => false,
};
#[cfg(not(any(
all(feature = "metal", target_os = "macos"),
all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
)
)))]
let did_full_forward = false;
if !did_full_forward {
for block in &self.blocks {
block.forward(&mut hidden, pos, &mut self.kv_cache, &self.rope, kernel)?;
}
}
let t_blocks_elapsed = t_blocks_start.elapsed();
let t_norm_start = std::time::Instant::now();
let mut normed = vec![0.0f32; h];
self.output_norm.forward(&hidden, &mut normed)?;
let t_norm_elapsed = t_norm_start.elapsed();
let t_lm_start = std::time::Instant::now();
let mut logits = vec![0.0f32; vocab];
match &self.output_weight {
OutputWeight::OneBit(linear) => {
linear.forward_vec(&normed, &mut logits, kernel)?;
}
OutputWeight::Fp32 {
weights,
out_features,
in_features,
} => {
for (i, logit) in logits.iter_mut().enumerate().take(*out_features) {
let row_start = i * in_features;
let mut sum = 0.0f32;
for j in 0..*in_features {
sum += weights[row_start + j] * normed[j];
}
*logit = sum;
}
}
}
let t_lm_elapsed = t_lm_start.elapsed();
tracing::debug!(
target: "fwd_profile",
"pos={pos} blocks={:.1}ms norm={:.1}ms lm_head={:.1}ms gpu={}",
t_blocks_elapsed.as_secs_f64() * 1000.0,
t_norm_elapsed.as_secs_f64() * 1000.0,
t_lm_elapsed.as_secs_f64() * 1000.0,
did_full_forward,
);
Ok(logits)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_creation() {
let config = Qwen3Config::bonsai_8b();
let model = BonsaiModel::new(config);
assert_eq!(model.config().num_layers, 36);
assert_eq!(model.config().hidden_size, 4096);
}
#[test]
fn model_new_has_empty_blocks() {
let config = Qwen3Config::bonsai_8b();
let model = BonsaiModel::new(config);
assert_eq!(model.blocks.len(), 0);
}
#[test]
fn model_variant_detection() {
let model_8b = BonsaiModel::new(Qwen3Config::bonsai_8b());
assert_eq!(model_8b.variant(), ModelVariant::Bonsai8B);
let model_4b = BonsaiModel::new(Qwen3Config::bonsai_4b());
assert_eq!(model_4b.variant(), ModelVariant::Bonsai4B);
let model_1_7b = BonsaiModel::new(Qwen3Config::bonsai_1_7b());
assert_eq!(model_1_7b.variant(), ModelVariant::Bonsai1_7B);
}
#[test]
fn model_info_methods() {
let model = BonsaiModel::new(Qwen3Config::bonsai_8b());
assert_eq!(model.num_layers(), 36);
assert_eq!(model.hidden_size(), 4096);
assert_eq!(model.context_length(), 65536);
assert!(model.num_parameters() > 0);
assert!(model.model_size_bytes() > 0);
}
#[test]
fn model_reset_cache() {
let mut model = BonsaiModel::new(Qwen3Config::bonsai_8b());
model.reset_cache();
assert_eq!(model.kv_cache_mut().seq_len(), 0);
}
#[test]
fn model_kv_cache_memory() {
let model = BonsaiModel::new(Qwen3Config::bonsai_8b());
assert!(model.kv_cache_memory_bytes() > 0);
}
}