use crate::dequant;
use crate::error::KernelResult;
use crate::gemm;
use crate::gemv;
use crate::traits::{OneBitKernel, TernaryKernel};
use crate::weight_cache::GpuWeightHandle;
use oxibonsai_core::tensor::BlockQ1_0G128;
#[cfg(feature = "gpu")]
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KernelTier {
Reference,
#[cfg(target_arch = "x86_64")]
Avx2,
#[cfg(target_arch = "x86_64")]
Avx512,
#[cfg(target_arch = "aarch64")]
Neon,
#[cfg(feature = "gpu")]
Gpu,
}
impl std::fmt::Display for KernelTier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Reference => write!(f, "reference"),
#[cfg(target_arch = "x86_64")]
Self::Avx2 => write!(f, "avx2+fma"),
#[cfg(target_arch = "x86_64")]
Self::Avx512 => write!(f, "avx512f+bw+vl"),
#[cfg(target_arch = "aarch64")]
Self::Neon => write!(f, "neon"),
#[cfg(feature = "gpu")]
Self::Gpu => write!(f, "gpu"),
}
}
}
pub struct KernelDispatcher {
tier: KernelTier,
#[cfg(feature = "gpu")]
gpu_backend: Option<Arc<dyn crate::gpu_backend::GpuBackendTrait>>,
}
impl std::fmt::Debug for KernelDispatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KernelDispatcher")
.field("tier", &self.tier)
.finish_non_exhaustive()
}
}
impl KernelDispatcher {
pub fn auto_detect() -> Self {
#[cfg(feature = "gpu")]
{
let backend = crate::gpu_backend::select_backend();
if backend.is_accelerated() {
tracing::info!(backend = backend.name(), "GPU backend available");
return Self {
tier: KernelTier::Gpu,
gpu_backend: Some(Arc::from(backend)),
};
}
}
let caps = scirs2_core::simd::detect::get_cpu_features();
let tier = Self::select_tier(caps);
tracing::info!(tier = %tier, "selected kernel tier");
Self {
tier,
#[cfg(feature = "gpu")]
gpu_backend: None,
}
}
pub fn with_tier(tier: KernelTier) -> Self {
Self {
tier,
#[cfg(feature = "gpu")]
gpu_backend: None,
}
}
#[cfg(feature = "gpu")]
pub fn with_gpu(backend: Arc<dyn crate::gpu_backend::GpuBackendTrait>) -> Self {
Self {
tier: KernelTier::Gpu,
gpu_backend: Some(backend),
}
}
pub fn tier(&self) -> KernelTier {
self.tier
}
fn select_tier(caps: &scirs2_core::simd::detect::CpuFeatures) -> KernelTier {
#[cfg(target_arch = "x86_64")]
{
if caps.has_avx512f {
return KernelTier::Avx512;
}
if caps.has_avx2 && caps.has_fma {
return KernelTier::Avx2;
}
}
#[cfg(target_arch = "aarch64")]
{
if caps.has_neon {
return KernelTier::Neon;
}
}
let _ = caps;
KernelTier::Reference
}
}
#[cfg(feature = "gpu")]
const GPU_MIN_ROWS: usize = 1024;
impl KernelDispatcher {
#[cfg(feature = "gpu")]
fn cpu_tier() -> KernelTier {
let caps = scirs2_core::simd::detect::get_cpu_features();
Self::select_tier(caps)
}
#[cfg(feature = "gpu")]
fn cpu_dequant(blocks: &[BlockQ1_0G128], output: &mut [f32]) -> KernelResult<()> {
match Self::cpu_tier() {
KernelTier::Reference => dequant::dequant_1bit_g128(blocks, output),
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe { crate::simd_avx2::dequant_1bit_g128_avx2(blocks, output) },
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::dequant_1bit_g128_avx512(blocks, output)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe { crate::simd_neon::dequant_1bit_g128_neon(blocks, output) },
#[cfg(feature = "gpu")]
KernelTier::Gpu => dequant::dequant_1bit_g128(blocks, output),
}
}
#[cfg(feature = "gpu")]
fn cpu_gemv(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match Self::cpu_tier() {
KernelTier::Reference => gemv::gemv_1bit_g128(blocks, input, output, n_rows, k),
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::gemv_1bit_g128_avx2_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::gemv_1bit_g128_avx512_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::gemv_1bit_g128_neon_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => gemv::gemv_1bit_g128(blocks, input, output, n_rows, k),
}
}
#[cfg(feature = "gpu")]
fn cpu_gemm(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match Self::cpu_tier() {
KernelTier::Reference => gemm::gemm_1bit_g128(blocks, input, output, m, n_rows, k),
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::gemm_1bit_g128_avx2_prefetch(blocks, input, output, m, n_rows, k)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::gemm_1bit_g128_avx512(blocks, input, output, m, n_rows, k)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::gemm_1bit_g128_neon_prefetch(blocks, input, output, m, n_rows, k)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => gemm::gemm_1bit_g128(blocks, input, output, m, n_rows, k),
}
}
#[cfg(feature = "gpu")]
fn cpu_dequant_ternary(
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
output: &mut [f32],
) -> KernelResult<()> {
match Self::cpu_tier() {
KernelTier::Reference => crate::dequant_ternary::dequant_tq2_0_g128(blocks, output),
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::dequant_tq2_0_g128_avx2(blocks, output)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::dequant_tq2_0_g128_avx512(blocks, output)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::dequant_tq2_0_g128_neon(blocks, output)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => crate::dequant_ternary::dequant_tq2_0_g128(blocks, output),
}
}
#[cfg(feature = "gpu")]
fn cpu_gemv_ternary(
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match Self::cpu_tier() {
KernelTier::Reference => {
crate::gemv_ternary::gemv_tq2_0_g128(blocks, input, output, n_rows, k)
}
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::gemv_tq2_0_g128_avx2_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::gemv_tq2_0_g128_avx512_prefetch(
blocks, input, output, n_rows, k,
)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::gemv_tq2_0_g128_neon_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => {
crate::gemv_ternary::gemv_tq2_0_g128(blocks, input, output, n_rows, k)
}
}
}
#[cfg(feature = "gpu")]
fn cpu_gemm_ternary(
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match Self::cpu_tier() {
KernelTier::Reference => {
crate::gemm_ternary::gemm_tq2_0_g128(blocks, input, output, m, n_rows, k)
}
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::gemm_tq2_0_g128_avx2(blocks, input, output, m, n_rows, k)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::gemm_tq2_0_g128_avx512(blocks, input, output, m, n_rows, k)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::gemm_tq2_0_g128_neon(blocks, input, output, m, n_rows, k)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => {
crate::gemm_ternary::gemm_tq2_0_g128(blocks, input, output, m, n_rows, k)
}
}
}
#[cfg(feature = "gpu")]
fn blocks_as_bytes(blocks: &[BlockQ1_0G128]) -> &[u8] {
let ptr = blocks.as_ptr() as *const u8;
let len = std::mem::size_of_val(blocks);
unsafe { std::slice::from_raw_parts(ptr, len) }
}
}
impl OneBitKernel for KernelDispatcher {
fn dequant(&self, blocks: &[BlockQ1_0G128], output: &mut [f32]) -> KernelResult<()> {
match self.tier {
KernelTier::Reference => dequant::dequant_1bit_g128(blocks, output),
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe { crate::simd_avx2::dequant_1bit_g128_avx2(blocks, output) },
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::dequant_1bit_g128_avx512(blocks, output)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe { crate::simd_neon::dequant_1bit_g128_neon(blocks, output) },
#[cfg(feature = "gpu")]
KernelTier::Gpu => Self::cpu_dequant(blocks, output),
}
}
fn gemv(
&self,
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match self.tier {
KernelTier::Reference => gemv::gemv_1bit_g128(blocks, input, output, n_rows, k),
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::gemv_1bit_g128_avx2_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::gemv_1bit_g128_avx512_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::gemv_1bit_g128_neon_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => {
if n_rows < GPU_MIN_ROWS {
return Self::cpu_gemv(blocks, input, output, n_rows, k);
}
if let Some(ref backend) = self.gpu_backend {
let bytes = Self::blocks_as_bytes(blocks);
match backend.gemv_q1_g128(bytes, input, n_rows, k) {
Ok(result) => {
let copy_len = output.len().min(result.len());
output[..copy_len].copy_from_slice(&result[..copy_len]);
return Ok(());
}
Err(e) => {
tracing::warn!(error = %e, "GPU gemv failed, falling back to CPU");
return Self::cpu_gemv(blocks, input, output, n_rows, k);
}
}
}
Self::cpu_gemv(blocks, input, output, n_rows, k)
}
}
}
fn gemm(
&self,
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match self.tier {
KernelTier::Reference => gemm::gemm_1bit_g128(blocks, input, output, m, n_rows, k),
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::gemm_1bit_g128_avx2_prefetch(blocks, input, output, m, n_rows, k)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::gemm_1bit_g128_avx512(blocks, input, output, m, n_rows, k)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::gemm_1bit_g128_neon_prefetch(blocks, input, output, m, n_rows, k)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => {
if n_rows < GPU_MIN_ROWS {
return Self::cpu_gemm(blocks, input, output, m, n_rows, k);
}
if let Some(ref backend) = self.gpu_backend {
let bytes = Self::blocks_as_bytes(blocks);
match backend.gemm_q1_g128(bytes, input, m, n_rows, k) {
Ok(result) => {
let copy_len = output.len().min(result.len());
output[..copy_len].copy_from_slice(&result[..copy_len]);
return Ok(());
}
Err(e) => {
tracing::warn!(error = %e, "GPU gemm failed, falling back to CPU");
return Self::cpu_gemm(blocks, input, output, m, n_rows, k);
}
}
}
Self::cpu_gemm(blocks, input, output, m, n_rows, k)
}
}
}
fn name(&self) -> &'static str {
match self.tier {
KernelTier::Reference => "Q1_0_g128 reference (scalar)",
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => "Q1_0_g128 AVX2+FMA (256-bit)",
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => "Q1_0_g128 AVX-512 (512-bit)",
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => "Q1_0_g128 NEON (128-bit)",
#[cfg(feature = "gpu")]
KernelTier::Gpu => "Q1_0_g128 GPU (accelerated)",
}
}
fn upload_weights(&self, blocks: &[BlockQ1_0G128]) -> Option<GpuWeightHandle> {
#[cfg(feature = "gpu")]
{
if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
let bytes = Self::blocks_as_bytes(blocks);
match backend.upload_weights_raw(bytes) {
Ok(handle) => return Some(handle),
Err(e) => {
tracing::warn!(error = %e, "failed to upload weights to GPU");
}
}
}
}
let _ = blocks;
None
}
fn gemv_cached(
&self,
handle: GpuWeightHandle,
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
#[cfg(feature = "gpu")]
{
if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
match backend.gemv_q1_g128_cached(handle, input, n_rows, k) {
Ok(result) => {
let len = output.len().min(result.len());
output[..len].copy_from_slice(&result[..len]);
return Ok(());
}
Err(e) => {
tracing::warn!(error = %e, "cached GPU gemv failed, cannot fallback without blocks");
return Err(crate::error::KernelError::GpuError(e.to_string()));
}
}
}
}
let _ = (handle, input, output, n_rows, k);
Err(crate::error::KernelError::UnsupportedOperation(
"gemv_cached requires GPU tier".into(),
))
}
fn batch_attn_phase(
&self,
hidden: &[f32],
norm_weight: &[f32],
norm_eps: f32,
qkv_handle: GpuWeightHandle,
q_rows: usize,
k_rows: usize,
h: usize,
) -> KernelResult<Option<(Vec<f32>, Vec<f32>, Vec<f32>)>> {
let _ = (hidden, norm_weight, norm_eps, qkv_handle, q_rows, k_rows, h);
Ok(None)
}
fn batch_ffn_phase(
&self,
hidden: &mut [f32],
attn_out: &[f32],
norm_weight: &[f32],
norm_eps: f32,
attn_proj_handle: GpuWeightHandle,
gate_up_handle: GpuWeightHandle,
down_handle: GpuWeightHandle,
h: usize,
intermediate: usize,
attn_proj_k: usize,
) -> KernelResult<bool> {
#[cfg(feature = "gpu")]
{
if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
match backend.batch_ffn_phase(
hidden,
attn_out,
norm_weight,
norm_eps,
attn_proj_handle,
gate_up_handle,
down_handle,
h,
intermediate,
attn_proj_k,
) {
Ok(true) => return Ok(true),
Ok(false) => return Ok(false),
Err(e) => {
tracing::warn!(error = %e, "batch FFN phase failed, falling back");
return Ok(false);
}
}
}
}
let _ = (
hidden,
attn_out,
norm_weight,
norm_eps,
attn_proj_handle,
gate_up_handle,
down_handle,
h,
intermediate,
attn_proj_k,
);
Ok(false)
}
}
impl TernaryKernel for KernelDispatcher {
fn dequant_ternary_g128(
&self,
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
output: &mut [f32],
) -> KernelResult<()> {
match self.tier {
KernelTier::Reference => crate::dequant_ternary::dequant_tq2_0_g128(blocks, output),
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::dequant_tq2_0_g128_avx2(blocks, output)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::dequant_tq2_0_g128_avx512(blocks, output)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::dequant_tq2_0_g128_neon(blocks, output)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => Self::cpu_dequant_ternary(blocks, output),
}
}
fn gemv_ternary_g128(
&self,
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match self.tier {
KernelTier::Reference => {
crate::gemv_ternary::gemv_tq2_0_g128(blocks, input, output, n_rows, k)
}
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::gemv_tq2_0_g128_avx2_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::gemv_tq2_0_g128_avx512_prefetch(
blocks, input, output, n_rows, k,
)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::gemv_tq2_0_g128_neon_prefetch(blocks, input, output, n_rows, k)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => Self::cpu_gemv_ternary(blocks, input, output, n_rows, k),
}
}
fn gemm_ternary_g128(
&self,
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match self.tier {
KernelTier::Reference => {
crate::gemm_ternary::gemm_tq2_0_g128(blocks, input, output, m, n_rows, k)
}
#[cfg(target_arch = "x86_64")]
KernelTier::Avx2 => unsafe {
crate::simd_avx2::gemm_tq2_0_g128_avx2(blocks, input, output, m, n_rows, k)
},
#[cfg(target_arch = "x86_64")]
KernelTier::Avx512 => unsafe {
crate::simd_avx512::gemm_tq2_0_g128_avx512(blocks, input, output, m, n_rows, k)
},
#[cfg(target_arch = "aarch64")]
KernelTier::Neon => unsafe {
crate::simd_neon::gemm_tq2_0_g128_neon(blocks, input, output, m, n_rows, k)
},
#[cfg(feature = "gpu")]
KernelTier::Gpu => Self::cpu_gemm_ternary(blocks, input, output, m, n_rows, k),
}
}
fn upload_weights_ternary(
&self,
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
) -> Option<GpuWeightHandle> {
#[cfg(feature = "gpu")]
{
if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
match backend.upload_weights_ternary(blocks) {
Ok(handle) => return Some(handle),
Err(e) => {
use std::sync::atomic::{AtomicBool, Ordering};
static WARNED: AtomicBool = AtomicBool::new(false);
if !WARNED.swap(true, Ordering::Relaxed) {
tracing::warn!(
error = %e,
backend = backend.name(),
"ternary weight GPU upload not supported by backend; \
falling back to CPU SIMD for ternary GEMV (this message \
is shown once per process)"
);
}
}
}
}
}
let _ = blocks;
None
}
fn gemv_ternary_g128_cached(
&self,
handle: GpuWeightHandle,
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
#[cfg(feature = "gpu")]
{
if let (KernelTier::Gpu, Some(ref backend)) = (self.tier, &self.gpu_backend) {
match backend.gemv_tq2_g128_cached(handle, input, n_rows, k) {
Ok(result) => {
let len = output.len().min(result.len());
output[..len].copy_from_slice(&result[..len]);
return Ok(());
}
Err(e) => {
tracing::warn!(error = %e, "cached GPU ternary gemv failed, cannot fallback without blocks");
return Err(crate::error::KernelError::GpuError(e.to_string()));
}
}
}
}
let _ = (handle, input, output, n_rows, k);
Err(crate::error::KernelError::UnsupportedOperation(
"gemv_ternary_g128_cached requires GPU tier".into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auto_detect_creates_dispatcher() {
let dispatcher = KernelDispatcher::auto_detect();
let _tier = dispatcher.tier();
let _name = dispatcher.name();
}
#[test]
fn reference_tier_works() {
let dispatcher = KernelDispatcher::with_tier(KernelTier::Reference);
assert_eq!(dispatcher.tier(), KernelTier::Reference);
assert_eq!(dispatcher.name(), "Q1_0_g128 reference (scalar)");
}
#[cfg(target_arch = "x86_64")]
#[test]
fn avx2_tier_name() {
if !(is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")) {
return;
}
let dispatcher = KernelDispatcher::with_tier(KernelTier::Avx2);
assert_eq!(dispatcher.tier(), KernelTier::Avx2);
assert_eq!(dispatcher.name(), "Q1_0_g128 AVX2+FMA (256-bit)");
}
#[cfg(target_arch = "aarch64")]
#[test]
fn neon_tier_name() {
let dispatcher = KernelDispatcher::with_tier(KernelTier::Neon);
assert_eq!(dispatcher.tier(), KernelTier::Neon);
assert_eq!(dispatcher.name(), "Q1_0_g128 NEON (128-bit)");
}
#[test]
fn dispatcher_exposes_ternary_gemv() {
use crate::TernaryKernel;
use half::f16;
use oxibonsai_core::BlockTQ2_0_g128;
let dispatcher = KernelDispatcher::auto_detect();
let block_pos = BlockTQ2_0_g128 {
qs: [0xAA; 32],
d: f16::from_f32(1.0),
};
let block_neg = BlockTQ2_0_g128 {
qs: [0x00; 32],
d: f16::from_f32(1.0),
};
let blocks = vec![block_pos, block_neg];
let input = vec![1.0f32; 128];
let mut output = vec![0.0f32; 2];
dispatcher
.gemv_ternary_g128(&blocks, &input, &mut output, 2, 128)
.expect("gemv_ternary_g128 should succeed");
assert!(
(output[0] - 128.0).abs() < 1.0,
"row0 expected ~128.0, got {}",
output[0]
);
assert!(
(output[1] + 128.0).abs() < 1.0,
"row1 expected ~-128.0, got {}",
output[1]
);
}
#[test]
fn dispatcher_ternary_reference_tier() {
use crate::TernaryKernel;
use half::f16;
use oxibonsai_core::BlockTQ2_0_g128;
let dispatcher = KernelDispatcher::with_tier(KernelTier::Reference);
let blocks = vec![BlockTQ2_0_g128 {
qs: [0xAA; 32],
d: f16::from_f32(1.0),
}];
let input = vec![1.0f32; 128];
let mut output = vec![0.0f32; 1];
dispatcher
.gemv_ternary_g128(&blocks, &input, &mut output, 1, 128)
.expect("gemv_ternary_g128 should succeed");
assert!((output[0] - 128.0).abs() < 1.0);
}
#[test]
fn ternary_upload_non_gpu_returns_none() {
use crate::TernaryKernel;
use half::f16;
use oxibonsai_core::BlockTQ2_0_g128;
let dispatcher = KernelDispatcher::with_tier(KernelTier::Reference);
let block = BlockTQ2_0_g128 {
qs: [0xAAu8; 32],
d: f16::from_f32(1.0),
};
let handle = dispatcher.upload_weights_ternary(&[block]);
assert!(
handle.is_none(),
"expected None for non-GPU tier, got {:?}",
handle
);
}
}