use std::sync::Mutex;
use std::time::Instant;
use oxibonsai_kernels::traits::OneBitKernel;
use oxibonsai_kernels::GpuWeightHandle;
#[cfg(any(
feature = "metal",
all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
)
))]
pub(crate) fn blocks_as_bytes(blocks: &[oxibonsai_core::BlockQ1_0G128]) -> &[u8] {
let ptr = blocks.as_ptr() as *const u8;
let len = std::mem::size_of_val(blocks);
unsafe { std::slice::from_raw_parts(ptr, len) }
}
use crate::error::ModelResult;
use crate::kv_cache::KvCache;
use crate::layers::attention::attention_head;
use crate::layers::linear::Linear1Bit;
use crate::layers::rms_norm::RmsNorm;
use crate::layers::rope::RopeTable;
use crate::layers::sliding_window::SlidingWindowConfig;
use crate::layers::swiglu::swiglu as swiglu_fn;
#[derive(Debug, Clone)]
pub struct LayerStats {
pub layer_idx: usize,
pub projection_us: u64,
pub rope_us: u64,
pub attention_us: u64,
pub ffn_us: u64,
pub total_us: u64,
}
impl LayerStats {
fn new(layer_idx: usize) -> Self {
Self {
layer_idx,
projection_us: 0,
rope_us: 0,
attention_us: 0,
ffn_us: 0,
total_us: 0,
}
}
pub fn attention_fraction(&self) -> f64 {
if self.total_us == 0 {
return 0.0;
}
self.attention_us as f64 / self.total_us as f64
}
pub fn ffn_fraction(&self) -> f64 {
if self.total_us == 0 {
return 0.0;
}
self.ffn_us as f64 / self.total_us as f64
}
}
struct ScratchBuffers {
normed: Vec<f32>, q_all: Vec<f32>, k_all: Vec<f32>, v_all: Vec<f32>, q_normed: Vec<f32>, k_normed: Vec<f32>, q_rope: Vec<f32>, k_rope: Vec<f32>, attn_out: Vec<f32>, head_output: Vec<f32>, attn_proj: Vec<f32>, gate_out: Vec<f32>, up_out: Vec<f32>, swiglu_out: Vec<f32>, down_out: Vec<f32>, fused_qkv: Vec<f32>, fused_gate_up: Vec<f32>, }
impl ScratchBuffers {
fn new(h: usize, nq: usize, nkv: usize, hd: usize, inter: usize) -> Self {
Self {
normed: vec![0.0; h],
q_all: vec![0.0; nq * hd],
k_all: vec![0.0; nkv * hd],
v_all: vec![0.0; nkv * hd],
q_normed: vec![0.0; nq * hd],
k_normed: vec![0.0; nkv * hd],
q_rope: vec![0.0; nq * hd],
k_rope: vec![0.0; nkv * hd],
attn_out: vec![0.0; nq * hd],
head_output: vec![0.0; hd],
attn_proj: vec![0.0; h],
gate_out: vec![0.0; inter],
up_out: vec![0.0; inter],
swiglu_out: vec![0.0; inter],
down_out: vec![0.0; h],
fused_qkv: vec![0.0; nq * hd + nkv * hd + nkv * hd],
fused_gate_up: vec![0.0; inter * 2],
}
}
fn clear(&mut self) {
self.normed.fill(0.0);
self.q_all.fill(0.0);
self.k_all.fill(0.0);
self.v_all.fill(0.0);
self.q_normed.fill(0.0);
self.k_normed.fill(0.0);
self.q_rope.fill(0.0);
self.k_rope.fill(0.0);
self.attn_out.fill(0.0);
self.head_output.fill(0.0);
self.attn_proj.fill(0.0);
self.gate_out.fill(0.0);
self.up_out.fill(0.0);
self.swiglu_out.fill(0.0);
self.down_out.fill(0.0);
self.fused_qkv.fill(0.0);
self.fused_gate_up.fill(0.0);
}
}
pub struct TransformerBlock<'a> {
layer_idx: usize,
attn_norm: RmsNorm,
attn_q: Linear1Bit<'a>,
attn_k: Linear1Bit<'a>,
attn_v: Linear1Bit<'a>,
attn_output: Linear1Bit<'a>,
attn_q_norm: RmsNorm,
attn_k_norm: RmsNorm,
ffn_norm: RmsNorm,
ffn_gate: Linear1Bit<'a>,
ffn_up: Linear1Bit<'a>,
ffn_down: Linear1Bit<'a>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
hidden_size: usize,
fused_qkv_handle: Option<GpuWeightHandle>,
fused_gate_up_handle: Option<GpuWeightHandle>,
scratch: Mutex<ScratchBuffers>,
}
impl<'a> TransformerBlock<'a> {
#[allow(clippy::too_many_arguments)]
pub fn new(
layer_idx: usize,
attn_norm: RmsNorm,
attn_q: Linear1Bit<'a>,
attn_k: Linear1Bit<'a>,
attn_v: Linear1Bit<'a>,
attn_output: Linear1Bit<'a>,
attn_q_norm: RmsNorm,
attn_k_norm: RmsNorm,
ffn_norm: RmsNorm,
ffn_gate: Linear1Bit<'a>,
ffn_up: Linear1Bit<'a>,
ffn_down: Linear1Bit<'a>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
hidden_size: usize,
) -> Self {
let inter = ffn_gate.out_features();
let scratch = Mutex::new(ScratchBuffers::new(
hidden_size,
num_heads,
num_kv_heads,
head_dim,
inter,
));
Self {
layer_idx,
attn_norm,
attn_q,
attn_k,
attn_v,
attn_output,
attn_q_norm,
attn_k_norm,
ffn_norm,
ffn_gate,
ffn_up,
ffn_down,
num_heads,
num_kv_heads,
head_dim,
hidden_size,
fused_qkv_handle: None,
fused_gate_up_handle: None,
scratch,
}
}
pub fn upload_to_gpu(&mut self, kernel: &dyn OneBitKernel) {
self.attn_q.upload_to_gpu(kernel);
self.attn_k.upload_to_gpu(kernel);
self.attn_v.upload_to_gpu(kernel);
self.attn_output.upload_to_gpu(kernel);
self.ffn_gate.upload_to_gpu(kernel);
self.ffn_up.upload_to_gpu(kernel);
self.ffn_down.upload_to_gpu(kernel);
let mut qkv_blocks = Vec::with_capacity(
self.attn_q.blocks().len() + self.attn_k.blocks().len() + self.attn_v.blocks().len(),
);
qkv_blocks.extend_from_slice(self.attn_q.blocks());
qkv_blocks.extend_from_slice(self.attn_k.blocks());
qkv_blocks.extend_from_slice(self.attn_v.blocks());
self.fused_qkv_handle = kernel.upload_weights(&qkv_blocks);
let mut gate_up_blocks =
Vec::with_capacity(self.ffn_gate.blocks().len() + self.ffn_up.blocks().len());
gate_up_blocks.extend_from_slice(self.ffn_gate.blocks());
gate_up_blocks.extend_from_slice(self.ffn_up.blocks());
self.fused_gate_up_handle = kernel.upload_weights(&gate_up_blocks);
}
#[allow(clippy::needless_late_init)]
#[tracing::instrument(skip_all, fields(layer = self.layer_idx))]
pub fn forward(
&self,
hidden: &mut [f32],
pos: usize,
kv_cache: &mut KvCache,
rope: &RopeTable,
kernel: &dyn OneBitKernel,
) -> ModelResult<()> {
#[cfg(all(feature = "metal", target_os = "macos"))]
{
if let Some(Ok(())) = self.try_full_layer_gpu(hidden, pos, rope, kv_cache) {
return Ok(());
}
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
{
if let Some(Ok(())) = self.try_full_layer_cuda(hidden, pos, rope, kv_cache) {
return Ok(());
}
}
let h = self.hidden_size;
let hd = self.head_dim;
let nq = self.num_heads;
let nkv = self.num_kv_heads;
let heads_per_group = nq / nkv;
let total_start = Instant::now();
let mut scratch = self.scratch.lock().map_err(|e| {
crate::error::ModelError::Internal(format!("scratch lock poisoned: {e}"))
})?;
scratch.clear();
let ScratchBuffers {
normed,
q_all,
k_all,
v_all,
q_normed,
k_normed,
q_rope,
k_rope,
attn_out,
head_output,
attn_proj,
gate_out,
up_out,
swiglu_out,
down_out,
fused_qkv,
fused_gate_up,
} = &mut *scratch;
let norm_us: u128;
let qkv_us: u128;
let qknorm_us: u128;
let rope_us: u128;
let cache_us: u128;
let attn_us: u128;
let ffn_us: u128;
{
let norm_start = Instant::now();
self.attn_norm.forward(hidden, normed)?;
norm_us = norm_start.elapsed().as_micros();
let qkv_start = Instant::now();
if let Some(fused_handle) = self.fused_qkv_handle {
let q_rows = nq * hd;
let k_rows = nkv * hd;
let total_rows = q_rows + k_rows + k_rows;
#[cfg(all(feature = "metal", target_os = "macos"))]
let metal_ok = {
let q_bytes = blocks_as_bytes(self.attn_q.blocks());
let k_bytes = blocks_as_bytes(self.attn_k.blocks());
let v_bytes = blocks_as_bytes(self.attn_v.blocks());
oxibonsai_kernels::try_metal_qkv(
normed,
fused_qkv,
fused_handle.id(),
q_bytes,
k_bytes,
v_bytes,
total_rows,
h,
)
.is_ok()
};
#[cfg(not(all(feature = "metal", target_os = "macos")))]
let metal_ok = false;
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
let cuda_ok = if !metal_ok {
let q_bytes = blocks_as_bytes(self.attn_q.blocks());
let k_bytes = blocks_as_bytes(self.attn_k.blocks());
let v_bytes = blocks_as_bytes(self.attn_v.blocks());
oxibonsai_kernels::try_cuda_qkv(
normed,
fused_qkv,
fused_handle.id(),
q_bytes,
k_bytes,
v_bytes,
total_rows,
h,
)
.is_ok()
} else {
false
};
#[cfg(not(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
)))]
let cuda_ok = false;
if !metal_ok && !cuda_ok {
kernel.gemv_cached(fused_handle, normed, fused_qkv, total_rows, h)?;
}
q_all[..q_rows].copy_from_slice(&fused_qkv[..q_rows]);
k_all[..k_rows].copy_from_slice(&fused_qkv[q_rows..q_rows + k_rows]);
v_all[..k_rows].copy_from_slice(&fused_qkv[q_rows + k_rows..total_rows]);
} else {
self.attn_q.forward_vec(normed, q_all, kernel)?;
self.attn_k.forward_vec(normed, k_all, kernel)?;
self.attn_v.forward_vec(normed, v_all, kernel)?;
}
qkv_us = qkv_start.elapsed().as_micros();
}
let qknorm_start = Instant::now();
for head in 0..nq {
let start = head * hd;
self.attn_q_norm
.forward(&q_all[start..start + hd], &mut q_normed[start..start + hd])?;
}
for head in 0..nkv {
let start = head * hd;
self.attn_k_norm
.forward(&k_all[start..start + hd], &mut k_normed[start..start + hd])?;
}
qknorm_us = qknorm_start.elapsed().as_micros();
let rope_start = Instant::now();
for head in 0..nq {
let start = head * hd;
rope.apply(
&q_normed[start..start + hd],
&mut q_rope[start..start + hd],
pos,
)?;
}
for head in 0..nkv {
let start = head * hd;
rope.apply(
&k_normed[start..start + hd],
&mut k_rope[start..start + hd],
pos,
)?;
}
rope_us = rope_start.elapsed().as_micros();
let cache_start = Instant::now();
for head in 0..nkv {
let start = head * hd;
kv_cache.store_key(self.layer_idx, head, pos, &k_rope[start..start + hd]);
kv_cache.store_value(self.layer_idx, head, pos, &v_all[start..start + hd]);
}
cache_us = cache_start.elapsed().as_micros();
let seq_len = pos + 1;
let attn_start = Instant::now();
for q_head in 0..nq {
let kv_head = q_head / heads_per_group;
let q_start = q_head * hd;
let keys = kv_cache.keys_for(self.layer_idx, kv_head, seq_len);
let values = kv_cache.values_for(self.layer_idx, kv_head, seq_len);
attention_head(
&q_rope[q_start..q_start + hd],
keys,
values,
head_output,
seq_len,
hd,
)?;
attn_out[q_start..q_start + hd].copy_from_slice(head_output);
}
attn_us = attn_start.elapsed().as_micros();
let ffn_start = Instant::now();
let did_batch_ffn = if let (
Some(attn_proj_handle),
Some(gate_up_handle),
Some(down_handle),
) = (
self.attn_output.gpu_handle(),
self.fused_gate_up_handle,
self.ffn_down.gpu_handle(),
) {
let inter = self.ffn_gate.out_features();
#[cfg(all(feature = "metal", target_os = "macos"))]
{
let attn_proj_blocks = self.attn_output.blocks();
let gate_blocks = self.ffn_gate.blocks();
let up_blocks = self.ffn_up.blocks();
let down_blocks = self.ffn_down.blocks();
let attn_proj_bytes = blocks_as_bytes(attn_proj_blocks);
let gate_bytes = blocks_as_bytes(gate_blocks);
let up_bytes = blocks_as_bytes(up_blocks);
let down_bytes = blocks_as_bytes(down_blocks);
let metal_result = oxibonsai_kernels::try_metal_ffn(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle.id(),
attn_proj_bytes,
gate_up_handle.id(),
gate_bytes,
up_bytes,
down_handle.id(),
down_bytes,
h,
inter,
);
if metal_result.is_ok() {
true
} else {
tracing::warn!(error = ?metal_result.err(), "MetalGraph FFN failed, falling back");
kernel.batch_ffn_phase(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle,
gate_up_handle,
down_handle,
h,
inter,
nq * hd,
)?
}
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
{
let attn_proj_blocks = self.attn_output.blocks();
let gate_blocks = self.ffn_gate.blocks();
let up_blocks = self.ffn_up.blocks();
let down_blocks = self.ffn_down.blocks();
let attn_proj_bytes = blocks_as_bytes(attn_proj_blocks);
let gate_bytes = blocks_as_bytes(gate_blocks);
let up_bytes = blocks_as_bytes(up_blocks);
let down_bytes = blocks_as_bytes(down_blocks);
let cuda_result = oxibonsai_kernels::try_cuda_ffn(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle.id(),
attn_proj_bytes,
gate_up_handle.id(),
gate_bytes,
up_bytes,
down_handle.id(),
down_bytes,
h,
inter,
);
if cuda_result.is_ok() {
true
} else {
tracing::warn!(error = ?cuda_result.err(), "CudaGraph FFN failed, falling back");
kernel.batch_ffn_phase(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle,
gate_up_handle,
down_handle,
h,
inter,
nq * hd,
)?
}
}
#[cfg(not(any(
all(feature = "metal", target_os = "macos"),
all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
)
)))]
{
kernel.batch_ffn_phase(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle,
gate_up_handle,
down_handle,
h,
inter,
nq * hd,
)?
}
} else {
false
};
if !did_batch_ffn {
self.attn_output.forward_vec(attn_out, attn_proj, kernel)?;
for i in 0..h {
hidden[i] += attn_proj[i];
}
self.ffn_norm.forward(hidden, normed)?;
if let Some(fused_handle) = self.fused_gate_up_handle {
let inter = gate_out.len();
let total_rows = inter * 2;
kernel.gemv_cached(fused_handle, normed, fused_gate_up, total_rows, h)?;
gate_out[..inter].copy_from_slice(&fused_gate_up[..inter]);
up_out[..inter].copy_from_slice(&fused_gate_up[inter..total_rows]);
} else {
self.ffn_gate.forward_vec(normed, gate_out, kernel)?;
self.ffn_up.forward_vec(normed, up_out, kernel)?;
}
swiglu_fn(gate_out, up_out, swiglu_out);
self.ffn_down.forward_vec(swiglu_out, down_out, kernel)?;
for i in 0..h {
hidden[i] += down_out[i];
}
}
ffn_us = ffn_start.elapsed().as_micros();
let total_us = total_start.elapsed().as_micros();
tracing::debug!(
target: "block_profile",
"L{layer}: norm={norm_us}µs qkv={qkv_us}µs qknorm={qknorm_us}µs rope={rope_us}µs cache={cache_us}µs attn={attn_us}µs ffn={ffn_us}µs total={total_us}µs",
layer = self.layer_idx,
);
Ok(())
}
#[tracing::instrument(skip_all, fields(layer = self.layer_idx))]
pub fn forward_with_stats(
&self,
hidden: &mut [f32],
pos: usize,
kv_cache: &mut KvCache,
rope: &RopeTable,
kernel: &dyn OneBitKernel,
) -> ModelResult<LayerStats> {
let total_start = Instant::now();
let mut stats = LayerStats::new(self.layer_idx);
let h = self.hidden_size;
let hd = self.head_dim;
let nq = self.num_heads;
let nkv = self.num_kv_heads;
let heads_per_group = nq / nkv;
let mut scratch = self.scratch.lock().map_err(|e| {
crate::error::ModelError::Internal(format!("scratch lock poisoned: {e}"))
})?;
scratch.clear();
let ScratchBuffers {
normed,
q_all,
k_all,
v_all,
q_normed,
k_normed,
q_rope,
k_rope,
attn_out,
head_output,
attn_proj,
gate_out,
up_out,
swiglu_out,
down_out,
fused_qkv,
fused_gate_up,
} = &mut *scratch;
let proj_start = Instant::now();
let batch_qkv = if let Some(fused_handle) = self.fused_qkv_handle {
kernel.batch_attn_phase(
hidden,
self.attn_norm.weight(),
self.attn_norm.eps(),
fused_handle,
nq * hd,
nkv * hd,
h,
)?
} else {
None
};
if let Some((q_data, k_data, v_data)) = batch_qkv {
q_all[..nq * hd].copy_from_slice(&q_data);
k_all[..nkv * hd].copy_from_slice(&k_data);
v_all[..nkv * hd].copy_from_slice(&v_data);
} else {
self.attn_norm.forward(hidden, normed)?;
if let Some(fused_handle) = self.fused_qkv_handle {
let q_rows = nq * hd;
let k_rows = nkv * hd;
let total_rows = q_rows + k_rows + k_rows;
#[cfg(all(feature = "metal", target_os = "macos"))]
let metal_ok = {
let q_bytes = blocks_as_bytes(self.attn_q.blocks());
let k_bytes = blocks_as_bytes(self.attn_k.blocks());
let v_bytes = blocks_as_bytes(self.attn_v.blocks());
oxibonsai_kernels::try_metal_qkv(
normed,
fused_qkv,
fused_handle.id(),
q_bytes,
k_bytes,
v_bytes,
total_rows,
h,
)
.is_ok()
};
#[cfg(not(all(feature = "metal", target_os = "macos")))]
let metal_ok = false;
if !metal_ok {
kernel.gemv_cached(fused_handle, normed, fused_qkv, total_rows, h)?;
}
q_all[..q_rows].copy_from_slice(&fused_qkv[..q_rows]);
k_all[..k_rows].copy_from_slice(&fused_qkv[q_rows..q_rows + k_rows]);
v_all[..k_rows].copy_from_slice(&fused_qkv[q_rows + k_rows..total_rows]);
} else {
self.attn_q.forward_vec(normed, q_all, kernel)?;
self.attn_k.forward_vec(normed, k_all, kernel)?;
self.attn_v.forward_vec(normed, v_all, kernel)?;
}
}
for head in 0..nq {
let start = head * hd;
self.attn_q_norm
.forward(&q_all[start..start + hd], &mut q_normed[start..start + hd])?;
}
for head in 0..nkv {
let start = head * hd;
self.attn_k_norm
.forward(&k_all[start..start + hd], &mut k_normed[start..start + hd])?;
}
stats.projection_us = proj_start.elapsed().as_micros() as u64;
let rope_start = Instant::now();
for head in 0..nq {
let start = head * hd;
rope.apply(
&q_normed[start..start + hd],
&mut q_rope[start..start + hd],
pos,
)?;
}
for head in 0..nkv {
let start = head * hd;
rope.apply(
&k_normed[start..start + hd],
&mut k_rope[start..start + hd],
pos,
)?;
}
stats.rope_us = rope_start.elapsed().as_micros() as u64;
let attn_start = Instant::now();
for head in 0..nkv {
let start = head * hd;
kv_cache.store_key(self.layer_idx, head, pos, &k_rope[start..start + hd]);
kv_cache.store_value(self.layer_idx, head, pos, &v_all[start..start + hd]);
}
let seq_len = pos + 1;
for q_head in 0..nq {
let kv_head = q_head / heads_per_group;
let q_start = q_head * hd;
let keys = kv_cache.keys_for(self.layer_idx, kv_head, seq_len);
let values = kv_cache.values_for(self.layer_idx, kv_head, seq_len);
attention_head(
&q_rope[q_start..q_start + hd],
keys,
values,
head_output,
seq_len,
hd,
)?;
attn_out[q_start..q_start + hd].copy_from_slice(head_output);
}
let did_batch_ffn =
if let (Some(attn_proj_handle), Some(gate_up_handle), Some(down_handle)) = (
self.attn_output.gpu_handle(),
self.fused_gate_up_handle,
self.ffn_down.gpu_handle(),
) {
let inter = self.ffn_gate.out_features();
#[cfg(all(feature = "metal", target_os = "macos"))]
{
let attn_proj_blocks = self.attn_output.blocks();
let gate_blocks = self.ffn_gate.blocks();
let up_blocks = self.ffn_up.blocks();
let down_blocks = self.ffn_down.blocks();
let attn_proj_bytes = blocks_as_bytes(attn_proj_blocks);
let gate_bytes = blocks_as_bytes(gate_blocks);
let up_bytes = blocks_as_bytes(up_blocks);
let down_bytes = blocks_as_bytes(down_blocks);
let metal_result = oxibonsai_kernels::try_metal_ffn(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle.id(),
attn_proj_bytes,
gate_up_handle.id(),
gate_bytes,
up_bytes,
down_handle.id(),
down_bytes,
h,
inter,
);
if metal_result.is_ok() {
true
} else {
kernel.batch_ffn_phase(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle,
gate_up_handle,
down_handle,
h,
inter,
nq * hd,
)?
}
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
{
kernel.batch_ffn_phase(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle,
gate_up_handle,
down_handle,
h,
inter,
nq * hd,
)?
}
} else {
false
};
if !did_batch_ffn {
self.attn_output.forward_vec(attn_out, attn_proj, kernel)?;
for i in 0..h {
hidden[i] += attn_proj[i];
}
}
stats.attention_us = attn_start.elapsed().as_micros() as u64;
let ffn_start = Instant::now();
if !did_batch_ffn {
self.ffn_norm.forward(hidden, normed)?;
if let Some(fused_handle) = self.fused_gate_up_handle {
let inter = gate_out.len();
let total_rows = inter * 2;
kernel.gemv_cached(fused_handle, normed, fused_gate_up, total_rows, h)?;
gate_out[..inter].copy_from_slice(&fused_gate_up[..inter]);
up_out[..inter].copy_from_slice(&fused_gate_up[inter..total_rows]);
} else {
self.ffn_gate.forward_vec(normed, gate_out, kernel)?;
self.ffn_up.forward_vec(normed, up_out, kernel)?;
}
swiglu_fn(gate_out, up_out, swiglu_out);
self.ffn_down.forward_vec(swiglu_out, down_out, kernel)?;
for i in 0..h {
hidden[i] += down_out[i];
}
}
stats.ffn_us = ffn_start.elapsed().as_micros() as u64;
stats.total_us = total_start.elapsed().as_micros() as u64;
Ok(stats)
}
#[tracing::instrument(skip_all, fields(layer = self.layer_idx))]
pub fn forward_with_sliding_window(
&self,
hidden: &mut [f32],
pos: usize,
kv_cache: &mut KvCache,
rope: &RopeTable,
kernel: &dyn OneBitKernel,
sliding_window: Option<&SlidingWindowConfig>,
) -> ModelResult<()> {
let h = self.hidden_size;
let hd = self.head_dim;
let nq = self.num_heads;
let nkv = self.num_kv_heads;
let heads_per_group = nq / nkv;
let mut scratch = self.scratch.lock().map_err(|e| {
crate::error::ModelError::Internal(format!("scratch lock poisoned: {e}"))
})?;
scratch.clear();
let ScratchBuffers {
normed,
q_all,
k_all,
v_all,
q_normed,
k_normed,
q_rope,
k_rope,
attn_out,
head_output,
attn_proj,
gate_out,
up_out,
swiglu_out,
down_out,
fused_qkv,
fused_gate_up,
} = &mut *scratch;
let batch_qkv = if let Some(fused_handle) = self.fused_qkv_handle {
kernel.batch_attn_phase(
hidden,
self.attn_norm.weight(),
self.attn_norm.eps(),
fused_handle,
nq * hd,
nkv * hd,
h,
)?
} else {
None
};
if let Some((q_data, k_data, v_data)) = batch_qkv {
q_all[..nq * hd].copy_from_slice(&q_data);
k_all[..nkv * hd].copy_from_slice(&k_data);
v_all[..nkv * hd].copy_from_slice(&v_data);
} else {
self.attn_norm.forward(hidden, normed)?;
if let Some(fused_handle) = self.fused_qkv_handle {
let q_rows = nq * hd;
let k_rows = nkv * hd;
let total_rows = q_rows + k_rows + k_rows;
#[cfg(all(feature = "metal", target_os = "macos"))]
let metal_ok = {
let q_bytes = blocks_as_bytes(self.attn_q.blocks());
let k_bytes = blocks_as_bytes(self.attn_k.blocks());
let v_bytes = blocks_as_bytes(self.attn_v.blocks());
oxibonsai_kernels::try_metal_qkv(
normed,
fused_qkv,
fused_handle.id(),
q_bytes,
k_bytes,
v_bytes,
total_rows,
h,
)
.is_ok()
};
#[cfg(not(all(feature = "metal", target_os = "macos")))]
let metal_ok = false;
if !metal_ok {
kernel.gemv_cached(fused_handle, normed, fused_qkv, total_rows, h)?;
}
q_all[..q_rows].copy_from_slice(&fused_qkv[..q_rows]);
k_all[..k_rows].copy_from_slice(&fused_qkv[q_rows..q_rows + k_rows]);
v_all[..k_rows].copy_from_slice(&fused_qkv[q_rows + k_rows..total_rows]);
} else {
self.attn_q.forward_vec(normed, q_all, kernel)?;
self.attn_k.forward_vec(normed, k_all, kernel)?;
self.attn_v.forward_vec(normed, v_all, kernel)?;
}
}
for head in 0..nq {
let start = head * hd;
self.attn_q_norm
.forward(&q_all[start..start + hd], &mut q_normed[start..start + hd])?;
}
for head in 0..nkv {
let start = head * hd;
self.attn_k_norm
.forward(&k_all[start..start + hd], &mut k_normed[start..start + hd])?;
}
for head in 0..nq {
let start = head * hd;
rope.apply(
&q_normed[start..start + hd],
&mut q_rope[start..start + hd],
pos,
)?;
}
for head in 0..nkv {
let start = head * hd;
rope.apply(
&k_normed[start..start + hd],
&mut k_rope[start..start + hd],
pos,
)?;
}
for head in 0..nkv {
let start = head * hd;
kv_cache.store_key(self.layer_idx, head, pos, &k_rope[start..start + hd]);
kv_cache.store_value(self.layer_idx, head, pos, &v_all[start..start + hd]);
}
let full_seq_len = pos + 1;
if let Some(sw_config) = sliding_window {
let (positions, _count) =
crate::layers::sliding_window::attention_range(pos, full_seq_len, sw_config);
for q_head in 0..nq {
let kv_head = q_head / heads_per_group;
let q_start = q_head * hd;
let all_keys = kv_cache.keys_for(self.layer_idx, kv_head, full_seq_len);
let all_values = kv_cache.values_for(self.layer_idx, kv_head, full_seq_len);
let windowed_keys: Vec<f32> = positions
.iter()
.flat_map(|&p| all_keys[p * hd..(p + 1) * hd].iter().copied())
.collect();
let windowed_values: Vec<f32> = positions
.iter()
.flat_map(|&p| all_values[p * hd..(p + 1) * hd].iter().copied())
.collect();
attention_head(
&q_rope[q_start..q_start + hd],
&windowed_keys,
&windowed_values,
head_output,
positions.len(),
hd,
)?;
attn_out[q_start..q_start + hd].copy_from_slice(head_output);
}
} else {
for q_head in 0..nq {
let kv_head = q_head / heads_per_group;
let q_start = q_head * hd;
let keys = kv_cache.keys_for(self.layer_idx, kv_head, full_seq_len);
let values = kv_cache.values_for(self.layer_idx, kv_head, full_seq_len);
attention_head(
&q_rope[q_start..q_start + hd],
keys,
values,
head_output,
full_seq_len,
hd,
)?;
attn_out[q_start..q_start + hd].copy_from_slice(head_output);
}
}
let did_batch_ffn =
if let (Some(attn_proj_handle), Some(gate_up_handle), Some(down_handle)) = (
self.attn_output.gpu_handle(),
self.fused_gate_up_handle,
self.ffn_down.gpu_handle(),
) {
let inter = self.ffn_gate.out_features();
#[cfg(all(feature = "metal", target_os = "macos"))]
{
let attn_proj_blocks = self.attn_output.blocks();
let gate_blocks = self.ffn_gate.blocks();
let up_blocks = self.ffn_up.blocks();
let down_blocks = self.ffn_down.blocks();
let attn_proj_bytes = blocks_as_bytes(attn_proj_blocks);
let gate_bytes = blocks_as_bytes(gate_blocks);
let up_bytes = blocks_as_bytes(up_blocks);
let down_bytes = blocks_as_bytes(down_blocks);
let metal_result = oxibonsai_kernels::try_metal_ffn(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle.id(),
attn_proj_bytes,
gate_up_handle.id(),
gate_bytes,
up_bytes,
down_handle.id(),
down_bytes,
h,
inter,
);
if metal_result.is_ok() {
true
} else {
kernel.batch_ffn_phase(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle,
gate_up_handle,
down_handle,
h,
inter,
nq * hd,
)?
}
}
#[cfg(not(all(feature = "metal", target_os = "macos")))]
{
kernel.batch_ffn_phase(
hidden,
attn_out,
self.ffn_norm.weight(),
self.ffn_norm.eps(),
attn_proj_handle,
gate_up_handle,
down_handle,
h,
inter,
nq * hd,
)?
}
} else {
false
};
if !did_batch_ffn {
self.attn_output.forward_vec(attn_out, attn_proj, kernel)?;
for i in 0..h {
hidden[i] += attn_proj[i];
}
self.ffn_norm.forward(hidden, normed)?;
if let Some(fused_handle) = self.fused_gate_up_handle {
let inter = gate_out.len();
let total_rows = inter * 2;
kernel.gemv_cached(fused_handle, normed, fused_gate_up, total_rows, h)?;
gate_out[..inter].copy_from_slice(&fused_gate_up[..inter]);
up_out[..inter].copy_from_slice(&fused_gate_up[inter..total_rows]);
} else {
self.ffn_gate.forward_vec(normed, gate_out, kernel)?;
self.ffn_up.forward_vec(normed, up_out, kernel)?;
}
swiglu_fn(gate_out, up_out, swiglu_out);
self.ffn_down.forward_vec(swiglu_out, down_out, kernel)?;
for i in 0..h {
hidden[i] += down_out[i];
}
}
Ok(())
}
pub fn layer_idx(&self) -> usize {
self.layer_idx
}
#[cfg(all(feature = "metal", target_os = "macos"))]
fn try_full_layer_gpu(
&self,
hidden: &mut [f32],
pos: usize,
rope: &RopeTable,
kv_cache: &KvCache,
) -> Option<ModelResult<()>> {
let fused_qkv_handle = self.fused_qkv_handle?;
let attn_proj_handle = self.attn_output.gpu_handle()?;
let fused_gate_up_handle = self.fused_gate_up_handle?;
let down_handle = self.ffn_down.gpu_handle()?;
let h = self.hidden_size;
let hd = self.head_dim;
let nq = self.num_heads;
let nkv = self.num_kv_heads;
let inter = self.ffn_gate.out_features();
let eps = self.attn_norm.eps();
let n_layers = kv_cache.num_layers();
let max_seq_len = kv_cache.max_seq_len();
let norm_handle_base = 1_000_000u64 + (self.layer_idx as u64) * 10;
let attn_norm_handle_id = norm_handle_base;
let q_norm_handle_id = norm_handle_base + 1;
let k_norm_handle_id = norm_handle_base + 2;
let ffn_norm_handle_id = norm_handle_base + 3;
let fused_qkv_bytes = blocks_as_bytes(self.attn_q.blocks());
let fused_qkv_k_bytes = blocks_as_bytes(self.attn_k.blocks());
let fused_qkv_v_bytes = blocks_as_bytes(self.attn_v.blocks());
let mut qkv_concat = Vec::with_capacity(
fused_qkv_bytes.len() + fused_qkv_k_bytes.len() + fused_qkv_v_bytes.len(),
);
qkv_concat.extend_from_slice(fused_qkv_bytes);
qkv_concat.extend_from_slice(fused_qkv_k_bytes);
qkv_concat.extend_from_slice(fused_qkv_v_bytes);
let attn_proj_bytes = blocks_as_bytes(self.attn_output.blocks());
let gate_bytes = blocks_as_bytes(self.ffn_gate.blocks());
let up_bytes = blocks_as_bytes(self.ffn_up.blocks());
let down_bytes = blocks_as_bytes(self.ffn_down.blocks());
let rope_cos = rope.cos_at(pos);
let rope_sin = rope.sin_at(pos);
let result = oxibonsai_kernels::try_metal_full_layer(
hidden,
pos,
self.layer_idx,
attn_norm_handle_id,
self.attn_norm.weight(),
fused_qkv_handle.id(),
&qkv_concat,
q_norm_handle_id,
self.attn_q_norm.weight(),
k_norm_handle_id,
self.attn_k_norm.weight(),
attn_proj_handle.id(),
attn_proj_bytes,
ffn_norm_handle_id,
self.ffn_norm.weight(),
fused_gate_up_handle.id(),
gate_bytes,
up_bytes,
down_handle.id(),
down_bytes,
rope_cos,
rope_sin,
h,
inter,
nq,
nkv,
hd,
eps,
max_seq_len,
n_layers,
);
match result {
Ok(()) => {
tracing::debug!(
target: "block_profile",
"L{layer}: full-layer GPU dispatch OK",
layer = self.layer_idx,
);
Some(Ok(()))
}
Err(e) => {
tracing::warn!(
layer = self.layer_idx,
error = %e,
"full-layer GPU dispatch failed, falling back to CPU path"
);
Some(Err(crate::error::ModelError::Internal(format!(
"Metal full-layer dispatch failed: {e}"
))))
}
}
}
#[cfg(all(
feature = "native-cuda",
not(all(feature = "metal", target_os = "macos")),
any(target_os = "linux", target_os = "windows")
))]
fn try_full_layer_cuda(
&self,
hidden: &mut [f32],
pos: usize,
rope: &RopeTable,
kv_cache: &KvCache,
) -> Option<ModelResult<()>> {
let fused_qkv_handle = self.fused_qkv_handle?;
let attn_proj_handle = self.attn_output.gpu_handle()?;
let fused_gate_up_handle = self.fused_gate_up_handle?;
let down_handle = self.ffn_down.gpu_handle()?;
let h = self.hidden_size;
let hd = self.head_dim;
let nq = self.num_heads;
let nkv = self.num_kv_heads;
let heads_per_group = nq / nkv;
let inter = self.ffn_gate.out_features();
let eps = self.attn_norm.eps();
let n_layers = kv_cache.num_layers();
let max_seq_len = kv_cache.max_seq_len();
let norm_handle_base = 2_000_000u64 + (self.layer_idx as u64) * 10;
let attn_norm_handle_id = norm_handle_base;
let q_norm_handle_id = norm_handle_base + 1;
let k_norm_handle_id = norm_handle_base + 2;
let ffn_norm_handle_id = norm_handle_base + 3;
let fused_qkv_bytes = blocks_as_bytes(self.attn_q.blocks());
let fused_qkv_k_bytes = blocks_as_bytes(self.attn_k.blocks());
let fused_qkv_v_bytes = blocks_as_bytes(self.attn_v.blocks());
let mut qkv_concat = Vec::with_capacity(
fused_qkv_bytes.len() + fused_qkv_k_bytes.len() + fused_qkv_v_bytes.len(),
);
qkv_concat.extend_from_slice(fused_qkv_bytes);
qkv_concat.extend_from_slice(fused_qkv_k_bytes);
qkv_concat.extend_from_slice(fused_qkv_v_bytes);
let attn_proj_bytes = blocks_as_bytes(self.attn_output.blocks());
let gate_bytes = blocks_as_bytes(self.ffn_gate.blocks());
let up_bytes = blocks_as_bytes(self.ffn_up.blocks());
let down_bytes = blocks_as_bytes(self.ffn_down.blocks());
let rope_cos = rope.cos_at(pos);
let rope_sin = rope.sin_at(pos);
let result = oxibonsai_kernels::try_cuda_full_layer(
hidden,
pos,
self.layer_idx,
attn_norm_handle_id,
self.attn_norm.weight(),
fused_qkv_handle.id(),
&qkv_concat,
attn_proj_handle.id(),
attn_proj_bytes,
q_norm_handle_id,
self.attn_q_norm.weight(),
k_norm_handle_id,
self.attn_k_norm.weight(),
ffn_norm_handle_id,
self.ffn_norm.weight(),
fused_gate_up_handle.id(),
gate_bytes,
up_bytes,
down_handle.id(),
down_bytes,
rope_cos,
rope_sin,
h,
inter,
nq,
nkv,
hd,
heads_per_group,
eps,
max_seq_len,
n_layers,
);
match result {
Ok(()) => {
tracing::debug!(
target: "block_profile",
"L{layer}: CUDA full-layer dispatch OK",
layer = self.layer_idx,
);
Some(Ok(()))
}
Err(e) => {
tracing::warn!(
layer = self.layer_idx,
error = %e,
"CUDA full-layer dispatch failed, falling back to CPU path"
);
Some(Err(crate::error::ModelError::Internal(format!(
"CUDA full-layer dispatch failed: {e}"
))))
}
}
}
pub fn attn_norm_weight(&self) -> &[f32] {
self.attn_norm.weight()
}
pub fn attn_norm_eps(&self) -> f32 {
self.attn_norm.eps()
}
pub fn q_norm_weight(&self) -> &[f32] {
self.attn_q_norm.weight()
}
pub fn k_norm_weight(&self) -> &[f32] {
self.attn_k_norm.weight()
}
pub fn ffn_norm_weight(&self) -> &[f32] {
self.ffn_norm.weight()
}
pub fn layer_index(&self) -> usize {
self.layer_idx
}
pub fn fused_qkv_gpu_handle(&self) -> Option<GpuWeightHandle> {
self.fused_qkv_handle
}
pub fn attn_output_gpu_handle(&self) -> Option<GpuWeightHandle> {
self.attn_output.gpu_handle()
}
pub fn fused_gate_up_gpu_handle(&self) -> Option<GpuWeightHandle> {
self.fused_gate_up_handle
}
pub fn ffn_down_gpu_handle(&self) -> Option<GpuWeightHandle> {
self.ffn_down.gpu_handle()
}
pub fn attn_q_blocks(&self) -> &[oxibonsai_core::BlockQ1_0G128] {
self.attn_q.blocks()
}
pub fn attn_k_blocks(&self) -> &[oxibonsai_core::BlockQ1_0G128] {
self.attn_k.blocks()
}
pub fn attn_v_blocks(&self) -> &[oxibonsai_core::BlockQ1_0G128] {
self.attn_v.blocks()
}
pub fn attn_output_blocks(&self) -> &[oxibonsai_core::BlockQ1_0G128] {
self.attn_output.blocks()
}
pub fn ffn_gate_blocks(&self) -> &[oxibonsai_core::BlockQ1_0G128] {
self.ffn_gate.blocks()
}
pub fn ffn_up_blocks(&self) -> &[oxibonsai_core::BlockQ1_0G128] {
self.ffn_up.blocks()
}
pub fn ffn_down_blocks(&self) -> &[oxibonsai_core::BlockQ1_0G128] {
self.ffn_down.blocks()
}
pub fn ffn_gate_out_features(&self) -> usize {
self.ffn_gate.out_features()
}
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
use oxibonsai_core::tensor::BlockQ1_0G128;
fn make_blocks(n: usize, scale: f32, pattern: u8) -> Vec<BlockQ1_0G128> {
(0..n)
.map(|_| BlockQ1_0G128 {
d: f16::from_f32(scale),
qs: [pattern; 16],
})
.collect()
}
#[test]
fn transformer_block_smoke_test() {
let h = 128;
let hd = 64;
let nq = 2;
let nkv = 1;
let inter = 256;
let blocks_per_row = h / 128;
let q_blocks = make_blocks(nq * hd * blocks_per_row, 0.01, 0xFF);
let k_blocks = make_blocks(nkv * hd * blocks_per_row, 0.01, 0xFF);
let v_blocks = make_blocks(nkv * hd * blocks_per_row, 0.01, 0xFF);
let o_blocks = make_blocks(h * blocks_per_row, 0.01, 0xFF);
let gate_blocks = make_blocks(inter * blocks_per_row, 0.01, 0xFF);
let up_blocks = make_blocks(inter * blocks_per_row, 0.01, 0xFF);
let down_blocks = make_blocks(h * (inter / 128), 0.01, 0xFF);
let block = TransformerBlock::new(
0,
RmsNorm::new(vec![1.0; h], 1e-6),
Linear1Bit::new(&q_blocks, nq * hd, h),
Linear1Bit::new(&k_blocks, nkv * hd, h),
Linear1Bit::new(&v_blocks, nkv * hd, h),
Linear1Bit::new(&o_blocks, h, nq * hd),
RmsNorm::new(vec![1.0; hd], 1e-6),
RmsNorm::new(vec![1.0; hd], 1e-6),
RmsNorm::new(vec![1.0; h], 1e-6),
Linear1Bit::new(&gate_blocks, inter, h),
Linear1Bit::new(&up_blocks, inter, h),
Linear1Bit::new(&down_blocks, h, inter),
nq,
nkv,
hd,
h,
);
let rope = RopeTable::new(hd, 16, 10000.0);
let kernel = oxibonsai_kernels::KernelDispatcher::auto_detect();
let mut kv_cache = KvCache::new(1, nkv, hd, 16);
let mut hidden: Vec<f32> = (0..h).map(|i| (i as f32 + 1.0) * 0.01).collect();
let original = hidden.clone();
block
.forward(&mut hidden, 0, &mut kv_cache, &rope, &kernel)
.expect("block forward should succeed");
let max_diff = hidden
.iter()
.zip(original.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_diff > 1e-6,
"forward should modify hidden state, max_diff={max_diff}"
);
}
#[test]
fn forward_with_stats_returns_timing() {
let h = 128;
let hd = 64;
let nq = 2;
let nkv = 1;
let inter = 256;
let blocks_per_row = h / 128;
let q_blocks = make_blocks(nq * hd * blocks_per_row, 0.01, 0xFF);
let k_blocks = make_blocks(nkv * hd * blocks_per_row, 0.01, 0xFF);
let v_blocks = make_blocks(nkv * hd * blocks_per_row, 0.01, 0xFF);
let o_blocks = make_blocks(h * blocks_per_row, 0.01, 0xFF);
let gate_blocks = make_blocks(inter * blocks_per_row, 0.01, 0xFF);
let up_blocks = make_blocks(inter * blocks_per_row, 0.01, 0xFF);
let down_blocks = make_blocks(h * (inter / 128), 0.01, 0xFF);
let block = TransformerBlock::new(
0,
RmsNorm::new(vec![1.0; h], 1e-6),
Linear1Bit::new(&q_blocks, nq * hd, h),
Linear1Bit::new(&k_blocks, nkv * hd, h),
Linear1Bit::new(&v_blocks, nkv * hd, h),
Linear1Bit::new(&o_blocks, h, nq * hd),
RmsNorm::new(vec![1.0; hd], 1e-6),
RmsNorm::new(vec![1.0; hd], 1e-6),
RmsNorm::new(vec![1.0; h], 1e-6),
Linear1Bit::new(&gate_blocks, inter, h),
Linear1Bit::new(&up_blocks, inter, h),
Linear1Bit::new(&down_blocks, h, inter),
nq,
nkv,
hd,
h,
);
let rope = RopeTable::new(hd, 16, 10000.0);
let kernel = oxibonsai_kernels::KernelDispatcher::auto_detect();
let mut kv_cache = KvCache::new(1, nkv, hd, 16);
let mut hidden: Vec<f32> = (0..h).map(|i| (i as f32 + 1.0) * 0.01).collect();
let stats = block
.forward_with_stats(&mut hidden, 0, &mut kv_cache, &rope, &kernel)
.expect("forward_with_stats should succeed");
assert_eq!(stats.layer_idx, 0);
assert!(stats.total_us >= stats.projection_us.min(stats.attention_us));
}
#[test]
fn forward_with_sliding_window_smoke() {
let h = 128;
let hd = 64;
let nq = 2;
let nkv = 1;
let inter = 256;
let blocks_per_row = h / 128;
let q_blocks = make_blocks(nq * hd * blocks_per_row, 0.01, 0xFF);
let k_blocks = make_blocks(nkv * hd * blocks_per_row, 0.01, 0xFF);
let v_blocks = make_blocks(nkv * hd * blocks_per_row, 0.01, 0xFF);
let o_blocks = make_blocks(h * blocks_per_row, 0.01, 0xFF);
let gate_blocks = make_blocks(inter * blocks_per_row, 0.01, 0xFF);
let up_blocks = make_blocks(inter * blocks_per_row, 0.01, 0xFF);
let down_blocks = make_blocks(h * (inter / 128), 0.01, 0xFF);
let block = TransformerBlock::new(
0,
RmsNorm::new(vec![1.0; h], 1e-6),
Linear1Bit::new(&q_blocks, nq * hd, h),
Linear1Bit::new(&k_blocks, nkv * hd, h),
Linear1Bit::new(&v_blocks, nkv * hd, h),
Linear1Bit::new(&o_blocks, h, nq * hd),
RmsNorm::new(vec![1.0; hd], 1e-6),
RmsNorm::new(vec![1.0; hd], 1e-6),
RmsNorm::new(vec![1.0; h], 1e-6),
Linear1Bit::new(&gate_blocks, inter, h),
Linear1Bit::new(&up_blocks, inter, h),
Linear1Bit::new(&down_blocks, h, inter),
nq,
nkv,
hd,
h,
);
let rope = RopeTable::new(hd, 16, 10000.0);
let kernel = oxibonsai_kernels::KernelDispatcher::auto_detect();
let mut kv_cache = KvCache::new(1, nkv, hd, 16);
let sw_config = SlidingWindowConfig::new(8, 2);
let mut hidden: Vec<f32> = (0..h).map(|i| (i as f32 + 1.0) * 0.01).collect();
let original = hidden.clone();
block
.forward_with_sliding_window(
&mut hidden,
0,
&mut kv_cache,
&rope,
&kernel,
Some(&sw_config),
)
.expect("forward_with_sliding_window should succeed");
let max_diff = hidden
.iter()
.zip(original.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(max_diff > 1e-6);
}
#[test]
fn layer_stats_fractions() {
let mut stats = LayerStats::new(0);
stats.total_us = 100;
stats.attention_us = 60;
stats.ffn_us = 30;
assert!((stats.attention_fraction() - 0.6).abs() < 1e-10);
assert!((stats.ffn_fraction() - 0.3).abs() < 1e-10);
}
#[test]
fn layer_stats_zero_total() {
let stats = LayerStats::new(5);
assert!((stats.attention_fraction() - 0.0).abs() < 1e-10);
assert!((stats.ffn_fraction() - 0.0).abs() < 1e-10);
}
}