Skip to main content

oxibonsai_kernels/
lib.rs

1#![cfg_attr(target_arch = "aarch64", feature(stdarch_aarch64_prefetch))]
2
3//! # oxibonsai-kernels
4//!
5//! 1-bit Q1\_0\_g128 compute kernels for OxiBonsai.
6//!
7//! Provides dequantization and fused matrix-multiply operations optimized
8//! for the PrismML 1-bit weight format. The kernels are organised in a
9//! tiered dispatch architecture that auto-selects the fastest implementation
10//! available on the current CPU:
11//!
12//! | Tier | Feature gate | Instruction set |
13//! |------|-------------|-----------------|
14//! | **Reference** | always | Pure scalar Rust (correctness baseline) |
15//! | **AVX2+FMA** | `simd-avx2` | 256-bit SIMD (x86-64) |
16//! | **AVX-512** | `simd-avx512` | 512-bit SIMD (x86-64) |
17//! | **NEON** | `simd-neon` | 128-bit SIMD (AArch64) |
18//!
19//! Runtime dispatch is handled by [`KernelDispatcher`] which queries
20//! SciRS2-Core's SIMD capability cache on construction.
21//!
22//! ## Key Kernels
23//!
24//! | Kernel | Description |
25//! |--------|-------------|
26//! | [`dequant::dequant_1bit_g128`] | Unpack 128 sign bits + FP16 scale → FP32 |
27//! | [`gemv::gemv_1bit_g128`] | 1-bit weight matrix × FP32 vector (single-token decode) |
28//! | [`gemm::gemm_1bit_g128`] | 1-bit weight matrix × FP32 matrix (multi-token prefill) |
29//!
30//! ## Trait
31//!
32//! All tiers implement [`OneBitKernel`] so callers are agnostic to the
33//! underlying SIMD level.
34
35#[cfg(all(feature = "metal", target_os = "macos"))]
36#[macro_use]
37extern crate objc;
38
39pub mod gpu_backend;
40#[cfg(feature = "gpu")]
41pub use gpu_backend::Scirs2Backend;
42pub use gpu_backend::{
43    gpu_gemv_1bit, gpu_matmul, select_backend, CpuBackend, DeviceBuffer, GpuBackend,
44    GpuBackendTrait, GpuError, LaunchConfig,
45};
46
47#[cfg(all(feature = "metal", target_os = "macos"))]
48pub use gpu_backend::{
49    build_cached_weights, build_cached_weights_ternary_only, metal_fused_gate_up_swiglu_fp8_e4m3,
50    metal_fused_gate_up_swiglu_fp8_e5m2, metal_gemm_fp8_e4m3, metal_gemm_fp8_e4m3_residual,
51    metal_gemm_fp8_e5m2, metal_gemm_fp8_e5m2_residual, metal_gemv_fp8_e4m3, metal_gemv_fp8_e5m2,
52    print_gpu_profile_summary, try_metal_ffn, try_metal_forward_greedy_ternary,
53    try_metal_full_forward, try_metal_full_forward_cached, try_metal_full_forward_prefill,
54    try_metal_full_forward_prefill_ternary, try_metal_full_forward_prefill_verify,
55    try_metal_full_forward_prefill_verify_ternary, try_metal_full_forward_ternary,
56    try_metal_full_layer, try_metal_prefill_ternary, try_metal_prefill_verify_ternary,
57    try_metal_qkv, CachedLayerWeights, CachedModelWeights, FullForwardLayerParams,
58    FullForwardLayerParamsTernary, MetalGraph, MetalGraphError, MetalWeightHandle,
59};
60
61#[cfg(all(
62    feature = "native-cuda",
63    any(target_os = "linux", target_os = "windows")
64))]
65pub use gpu_backend::{
66    cuda_gemv_q2k, cuda_gemv_q3k, cuda_gemv_q4k, cuda_gemv_q5k, cuda_gemv_q6k, cuda_gemv_q8k,
67};
68
69#[cfg(all(
70    feature = "native-cuda",
71    any(target_os = "linux", target_os = "windows")
72))]
73pub use gpu_backend::{
74    cuda_gemv_fp8_e4m3, cuda_gemv_fp8_e5m2, cuda_gemv_q4_0, cuda_gemv_q8_0, try_cuda_ffn,
75    try_cuda_full_forward, try_cuda_full_forward_ternary,
76    try_cuda_full_forward_ternary_with_gpu_lm_head, try_cuda_full_forward_with_gpu_lm_head,
77    try_cuda_full_layer, try_cuda_prefill, try_cuda_prefill_q_std, try_cuda_prefill_ternary,
78    try_cuda_qkv, CudaCachedLayerWeights, CudaFullForwardLayerParams,
79    CudaFullForwardLayerParamsTernary, CudaGraph, CudaGraphError, CudaQStdPrefillLayerParams,
80    NativeCudaBackend,
81};
82
83#[cfg(all(
84    feature = "native-cuda",
85    any(target_os = "linux", target_os = "windows")
86))]
87pub use gpu_backend::{try_cuda_prefill_k_quant, CudaKQuantPrefillLayerParams, KQuantFormat};
88
89#[cfg(all(
90    feature = "native-cuda",
91    any(target_os = "linux", target_os = "windows")
92))]
93pub use gpu_backend::{try_cuda_prefill_fp8, CudaFP8PrefillLayerParams};
94
95pub mod dequant;
96pub mod dequant_fp8;
97pub mod dequant_ternary;
98pub mod dispatch;
99pub mod error;
100pub mod fp8_lut;
101pub mod gemm;
102pub mod gemm_fp8;
103pub mod gemm_ternary;
104pub mod gemv;
105pub mod gemv_fp8;
106pub mod gemv_q2k;
107pub mod gemv_q3k;
108pub mod gemv_q4_0;
109pub mod gemv_q4k;
110pub mod gemv_q5k;
111pub mod gemv_q6k;
112pub mod gemv_q8_0;
113pub mod gemv_q8k;
114pub mod gemv_ternary;
115pub mod packing;
116pub mod parallel;
117pub mod parallel_tiled;
118#[cfg(target_arch = "x86_64")]
119pub mod simd_avx2;
120#[cfg(target_arch = "x86_64")]
121pub mod simd_avx512;
122#[cfg(target_arch = "x86_64")]
123pub mod simd_fp8_avx2;
124#[cfg(target_arch = "x86_64")]
125pub mod simd_fp8_avx512;
126#[cfg(target_arch = "aarch64")]
127pub mod simd_fp8_neon;
128#[cfg(target_arch = "aarch64")]
129pub mod simd_neon;
130pub mod tiled;
131pub mod traits;
132pub mod weight_cache;
133
134pub mod aligned;
135pub mod prefetch;
136pub mod simd_float_ops;
137pub mod tuning;
138
139pub use aligned::{AlignedBlocks, AlignedBuffer};
140pub use dispatch::{KernelDispatcher, KernelTier};
141pub use error::{KernelError, KernelResult};
142pub use gemv_q2k::gemv_q2k;
143pub use gemv_q3k::gemv_q3k;
144pub use gemv_q4_0::gemv_q4_0;
145pub use gemv_q4k::gemv_q4k;
146pub use gemv_q5k::gemv_q5k;
147pub use gemv_q6k::gemv_q6k;
148pub use gemv_q8_0::gemv_q8_0;
149pub use gemv_q8k::gemv_q8k;
150pub use parallel::{
151    gemm_fp8_e4m3_par, gemm_fp8_e5m2_par, gemm_ternary_g128_par, gemv_fp8_e4m3_par,
152    gemv_fp8_e5m2_par, gemv_ternary_g128_par,
153};
154pub use parallel_tiled::{gemm_adaptive_ternary, gemv_adaptive, gemv_adaptive_ternary};
155pub use prefetch::{PrefetchConfig, PrefetchLocality, PrefetchStrategy};
156pub use simd_float_ops::{rms_norm_simd, rope_apply_simd, silu_simd, softmax_simd, swiglu_simd};
157pub use traits::{Fp8Kernel, OneBitKernel, TernaryKernel};
158pub use tuning::{PlatformProfile, TunedThresholds, TuningSummary};
159pub use weight_cache::GpuWeightHandle;