#![cfg_attr(target_arch = "aarch64", feature(stdarch_aarch64_prefetch))]
#[cfg(all(feature = "metal", target_os = "macos"))]
#[macro_use]
extern crate objc;
pub mod gpu_backend;
#[cfg(feature = "gpu")]
pub use gpu_backend::Scirs2Backend;
pub use gpu_backend::{
gpu_gemv_1bit, gpu_matmul, select_backend, CpuBackend, DeviceBuffer, GpuBackend,
GpuBackendTrait, GpuError, LaunchConfig,
};
#[cfg(all(feature = "metal", target_os = "macos"))]
pub use gpu_backend::{
build_cached_weights, build_cached_weights_ternary_only, metal_fused_gate_up_swiglu_fp8_e4m3,
metal_fused_gate_up_swiglu_fp8_e5m2, metal_gemm_fp8_e4m3, metal_gemm_fp8_e4m3_residual,
metal_gemm_fp8_e5m2, metal_gemm_fp8_e5m2_residual, metal_gemv_fp8_e4m3, metal_gemv_fp8_e5m2,
print_gpu_profile_summary, try_metal_ffn, try_metal_forward_greedy_ternary,
try_metal_full_forward, try_metal_full_forward_cached, try_metal_full_forward_prefill,
try_metal_full_forward_prefill_ternary, try_metal_full_forward_prefill_verify,
try_metal_full_forward_prefill_verify_ternary, try_metal_full_forward_ternary,
try_metal_full_layer, try_metal_prefill_ternary, try_metal_prefill_verify_ternary,
try_metal_qkv, CachedLayerWeights, CachedModelWeights, FullForwardLayerParams,
FullForwardLayerParamsTernary, MetalGraph, MetalGraphError, MetalWeightHandle,
};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub use gpu_backend::{
cuda_gemv_q2k, cuda_gemv_q3k, cuda_gemv_q4k, cuda_gemv_q5k, cuda_gemv_q6k, cuda_gemv_q8k,
};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub use gpu_backend::{
cuda_gemv_fp8_e4m3, cuda_gemv_fp8_e5m2, cuda_gemv_q4_0, cuda_gemv_q8_0, try_cuda_ffn,
try_cuda_full_forward, try_cuda_full_forward_ternary,
try_cuda_full_forward_ternary_with_gpu_lm_head, try_cuda_full_forward_with_gpu_lm_head,
try_cuda_full_layer, try_cuda_prefill, try_cuda_prefill_q_std, try_cuda_prefill_ternary,
try_cuda_qkv, CudaCachedLayerWeights, CudaFullForwardLayerParams,
CudaFullForwardLayerParamsTernary, CudaGraph, CudaGraphError, CudaQStdPrefillLayerParams,
NativeCudaBackend,
};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub use gpu_backend::{try_cuda_prefill_k_quant, CudaKQuantPrefillLayerParams, KQuantFormat};
#[cfg(all(
feature = "native-cuda",
any(target_os = "linux", target_os = "windows")
))]
pub use gpu_backend::{try_cuda_prefill_fp8, CudaFP8PrefillLayerParams};
pub mod dequant;
pub mod dequant_fp8;
pub mod dequant_ternary;
pub mod dispatch;
pub mod error;
pub mod fp8_lut;
pub mod gemm;
pub mod gemm_fp8;
pub mod gemm_ternary;
pub mod gemv;
pub mod gemv_fp8;
pub mod gemv_q2k;
pub mod gemv_q3k;
pub mod gemv_q4_0;
pub mod gemv_q4k;
pub mod gemv_q5k;
pub mod gemv_q6k;
pub mod gemv_q8_0;
pub mod gemv_q8k;
pub mod gemv_ternary;
pub mod packing;
pub mod parallel;
pub mod parallel_tiled;
#[cfg(target_arch = "x86_64")]
pub mod simd_avx2;
#[cfg(target_arch = "x86_64")]
pub mod simd_avx512;
#[cfg(target_arch = "x86_64")]
pub mod simd_fp8_avx2;
#[cfg(target_arch = "x86_64")]
pub mod simd_fp8_avx512;
#[cfg(target_arch = "aarch64")]
pub mod simd_fp8_neon;
#[cfg(target_arch = "aarch64")]
pub mod simd_neon;
pub mod tiled;
pub mod traits;
pub mod weight_cache;
pub mod aligned;
pub mod prefetch;
pub mod simd_float_ops;
pub mod tuning;
pub use aligned::{AlignedBlocks, AlignedBuffer};
pub use dispatch::{KernelDispatcher, KernelTier};
pub use error::{KernelError, KernelResult};
pub use gemv_q2k::gemv_q2k;
pub use gemv_q3k::gemv_q3k;
pub use gemv_q4_0::gemv_q4_0;
pub use gemv_q4k::gemv_q4k;
pub use gemv_q5k::gemv_q5k;
pub use gemv_q6k::gemv_q6k;
pub use gemv_q8_0::gemv_q8_0;
pub use gemv_q8k::gemv_q8k;
pub use parallel::{
gemm_fp8_e4m3_par, gemm_fp8_e5m2_par, gemm_ternary_g128_par, gemv_fp8_e4m3_par,
gemv_fp8_e5m2_par, gemv_ternary_g128_par,
};
pub use parallel_tiled::{gemm_adaptive_ternary, gemv_adaptive, gemv_adaptive_ternary};
pub use prefetch::{PrefetchConfig, PrefetchLocality, PrefetchStrategy};
pub use simd_float_ops::{rms_norm_simd, rope_apply_simd, silu_simd, softmax_simd, swiglu_simd};
pub use traits::{Fp8Kernel, OneBitKernel, TernaryKernel};
pub use tuning::{PlatformProfile, TunedThresholds, TuningSummary};
pub use weight_cache::GpuWeightHandle;