1#![cfg_attr(target_arch = "aarch64", feature(stdarch_aarch64_prefetch))]
2
3#[cfg(all(feature = "metal", target_os = "macos"))]
36#[macro_use]
37extern crate objc;
38
39pub mod gpu_backend;
40#[cfg(feature = "gpu")]
41pub use gpu_backend::Scirs2Backend;
42pub use gpu_backend::{
43 gpu_gemv_1bit, gpu_matmul, select_backend, CpuBackend, DeviceBuffer, GpuBackend,
44 GpuBackendTrait, GpuError, LaunchConfig,
45};
46
47#[cfg(all(feature = "metal", target_os = "macos"))]
48pub use gpu_backend::{
49 build_cached_weights, build_cached_weights_ternary_only, metal_fused_gate_up_swiglu_fp8_e4m3,
50 metal_fused_gate_up_swiglu_fp8_e5m2, metal_gemm_fp8_e4m3, metal_gemm_fp8_e4m3_residual,
51 metal_gemm_fp8_e5m2, metal_gemm_fp8_e5m2_residual, metal_gemv_fp8_e4m3, metal_gemv_fp8_e5m2,
52 print_gpu_profile_summary, try_metal_ffn, try_metal_forward_greedy_ternary,
53 try_metal_full_forward, try_metal_full_forward_cached, try_metal_full_forward_prefill,
54 try_metal_full_forward_prefill_ternary, try_metal_full_forward_prefill_verify,
55 try_metal_full_forward_prefill_verify_ternary, try_metal_full_forward_ternary,
56 try_metal_full_layer, try_metal_prefill_ternary, try_metal_prefill_verify_ternary,
57 try_metal_qkv, CachedLayerWeights, CachedModelWeights, FullForwardLayerParams,
58 FullForwardLayerParamsTernary, MetalGraph, MetalGraphError, MetalWeightHandle,
59};
60
61#[cfg(all(
62 feature = "native-cuda",
63 any(target_os = "linux", target_os = "windows")
64))]
65pub use gpu_backend::{
66 cuda_gemv_q2k, cuda_gemv_q3k, cuda_gemv_q4k, cuda_gemv_q5k, cuda_gemv_q6k, cuda_gemv_q8k,
67};
68
69#[cfg(all(
70 feature = "native-cuda",
71 any(target_os = "linux", target_os = "windows")
72))]
73pub use gpu_backend::{
74 cuda_gemv_fp8_e4m3, cuda_gemv_fp8_e5m2, cuda_gemv_q4_0, cuda_gemv_q8_0, try_cuda_ffn,
75 try_cuda_full_forward, try_cuda_full_forward_ternary,
76 try_cuda_full_forward_ternary_with_gpu_lm_head, try_cuda_full_forward_with_gpu_lm_head,
77 try_cuda_full_layer, try_cuda_prefill, try_cuda_prefill_q_std, try_cuda_prefill_ternary,
78 try_cuda_qkv, CudaCachedLayerWeights, CudaFullForwardLayerParams,
79 CudaFullForwardLayerParamsTernary, CudaGraph, CudaGraphError, CudaQStdPrefillLayerParams,
80 NativeCudaBackend,
81};
82
83#[cfg(all(
84 feature = "native-cuda",
85 any(target_os = "linux", target_os = "windows")
86))]
87pub use gpu_backend::{try_cuda_prefill_k_quant, CudaKQuantPrefillLayerParams, KQuantFormat};
88
89#[cfg(all(
90 feature = "native-cuda",
91 any(target_os = "linux", target_os = "windows")
92))]
93pub use gpu_backend::{try_cuda_prefill_fp8, CudaFP8PrefillLayerParams};
94
95pub mod dequant;
96pub mod dequant_fp8;
97pub mod dequant_ternary;
98pub mod dispatch;
99pub mod error;
100pub mod fp8_lut;
101pub mod gemm;
102pub mod gemm_fp8;
103pub mod gemm_ternary;
104pub mod gemv;
105pub mod gemv_fp8;
106pub mod gemv_q2k;
107pub mod gemv_q3k;
108pub mod gemv_q4_0;
109pub mod gemv_q4k;
110pub mod gemv_q5k;
111pub mod gemv_q6k;
112pub mod gemv_q8_0;
113pub mod gemv_q8k;
114pub mod gemv_ternary;
115pub mod packing;
116pub mod parallel;
117pub mod parallel_tiled;
118#[cfg(target_arch = "x86_64")]
119pub mod simd_avx2;
120#[cfg(target_arch = "x86_64")]
121pub mod simd_avx512;
122#[cfg(target_arch = "x86_64")]
123pub mod simd_fp8_avx2;
124#[cfg(target_arch = "x86_64")]
125pub mod simd_fp8_avx512;
126#[cfg(target_arch = "aarch64")]
127pub mod simd_fp8_neon;
128#[cfg(target_arch = "aarch64")]
129pub mod simd_neon;
130pub mod tiled;
131pub mod traits;
132pub mod weight_cache;
133
134pub mod aligned;
135pub mod prefetch;
136pub mod simd_float_ops;
137pub mod tuning;
138
139pub use aligned::{AlignedBlocks, AlignedBuffer};
140pub use dispatch::{KernelDispatcher, KernelTier};
141pub use error::{KernelError, KernelResult};
142pub use gemv_q2k::gemv_q2k;
143pub use gemv_q3k::gemv_q3k;
144pub use gemv_q4_0::gemv_q4_0;
145pub use gemv_q4k::gemv_q4k;
146pub use gemv_q5k::gemv_q5k;
147pub use gemv_q6k::gemv_q6k;
148pub use gemv_q8_0::gemv_q8_0;
149pub use gemv_q8k::gemv_q8k;
150pub use parallel::{
151 gemm_fp8_e4m3_par, gemm_fp8_e5m2_par, gemm_ternary_g128_par, gemv_fp8_e4m3_par,
152 gemv_fp8_e5m2_par, gemv_ternary_g128_par,
153};
154pub use parallel_tiled::{gemm_adaptive_ternary, gemv_adaptive, gemv_adaptive_ternary};
155pub use prefetch::{PrefetchConfig, PrefetchLocality, PrefetchStrategy};
156pub use simd_float_ops::{rms_norm_simd, rope_apply_simd, silu_simd, softmax_simd, swiglu_simd};
157pub use traits::{Fp8Kernel, OneBitKernel, TernaryKernel};
158pub use tuning::{PlatformProfile, TunedThresholds, TuningSummary};
159pub use weight_cache::GpuWeightHandle;