1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
//! Trait definition for 1-bit compute kernels.
//!
//! [`OneBitKernel`] is the common interface implemented by every kernel tier
//! (reference, AVX2, AVX-512, NEON). The [`KernelDispatcher`](crate::KernelDispatcher)
//! implements this trait and delegates to the best available tier at runtime.
use crate::error::KernelResult;
use crate::weight_cache::GpuWeightHandle;
use oxibonsai_core::tensor::BlockQ1_0G128;
/// Trait for Q1\_0\_g128 compute kernel implementations.
///
/// Each tier (reference, portable SIMD, platform SIMD) implements this trait.
pub trait OneBitKernel: Send + Sync {
/// Dequantize blocks to FP32 values.
///
/// For each block: `output[i] = bit[i] ? +d : -d`
fn dequant(&self, blocks: &[BlockQ1_0G128], output: &mut [f32]) -> KernelResult<()>;
/// Fused 1-bit matrix × FP32 vector product (GEMV).
///
/// Computes `output[row] = sum_col(weight[row, col] * input[col])`
/// where weights are Q1\_0\_g128 packed.
///
/// - `blocks`: Row-major packed weight blocks, `n_rows * (k / 128)` blocks total
/// - `input`: FP32 input vector of length `k`
/// - `output`: FP32 output vector of length `n_rows`
/// - `n_rows`: Number of output rows (N dimension)
/// - `k`: Inner dimension (must be multiple of 128)
fn gemv(
&self,
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()>;
/// Fused 1-bit matrix × FP32 matrix product (GEMM).
///
/// Computes `output[m, n] = sum_k(weight[n, k] * input[m, k])`
///
/// - `blocks`: Weight blocks in row-major order, `n_rows * (k / 128)` blocks
/// - `input`: Row-major FP32 input [m × k]
/// - `output`: Row-major FP32 output [m × n_rows]
/// - `m`: Batch/sequence dimension
/// - `n_rows`: Number of weight matrix rows (output columns)
/// - `k`: Inner dimension (must be multiple of 128)
fn gemm(
&self,
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()>;
/// Display name for this kernel implementation.
fn name(&self) -> &'static str;
/// Upload weight blocks to GPU memory for future cached GEMV/GEMM calls.
///
/// Returns `Some(handle)` if the kernel supports GPU caching (i.e. the
/// GPU tier), or `None` for CPU-only tiers.
fn upload_weights(&self, _blocks: &[BlockQ1_0G128]) -> Option<GpuWeightHandle> {
None
}
/// GEMV using a pre-uploaded weight buffer (no host→device copy for weights).
///
/// Falls back to `Err(UnsupportedOperation)` by default; only the GPU tier
/// overrides this.
fn gemv_cached(
&self,
_handle: GpuWeightHandle,
_input: &[f32],
_output: &mut [f32],
_n_rows: usize,
_k: usize,
) -> KernelResult<()> {
Err(crate::error::KernelError::UnsupportedOperation(
"gemv_cached not supported by this kernel tier".into(),
))
}
/// Batch-accelerated attention input phase (RMSNorm + QKV in one command buffer).
///
/// Returns `Ok(Some((q, k, v)))` if batching succeeded, or `Ok(None)` if
/// not supported by this kernel tier.
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
fn batch_attn_phase(
&self,
_hidden: &[f32],
_norm_weight: &[f32],
_norm_eps: f32,
_qkv_handle: GpuWeightHandle,
_q_rows: usize,
_k_rows: usize,
_h: usize,
) -> KernelResult<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>> {
Ok(None)
}
/// Batch-accelerated FFN phase (attn_proj + residual + norm + gate_up + swiglu + down + residual).
///
/// Returns `Ok(true)` if batching succeeded and `hidden` was modified
/// in-place, or `Ok(false)` if not supported.
#[allow(clippy::too_many_arguments)]
fn batch_ffn_phase(
&self,
_hidden: &mut [f32],
_attn_out: &[f32],
_norm_weight: &[f32],
_norm_eps: f32,
_attn_proj_handle: GpuWeightHandle,
_gate_up_handle: GpuWeightHandle,
_down_handle: GpuWeightHandle,
_h: usize,
_intermediate: usize,
_attn_proj_k: usize,
) -> KernelResult<bool> {
Ok(false)
}
}