mod archive;
mod attention;
mod decode;
mod decode_ternary;
mod prefill;
mod utility;
#[cfg(any(feature = "metal", feature = "cuda"))]
pub use archive::*;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub use attention::*;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub use decode::*;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub use decode_ternary::*;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub use prefill::*;
#[cfg(all(feature = "metal", target_os = "macos"))]
pub use utility::*;
#[cfg(test)]
mod tests {
#[cfg(any(all(feature = "metal", target_os = "macos"), feature = "cuda"))]
use super::*;
#[test]
#[cfg(all(feature = "metal", target_os = "macos"))]
fn metal_kernels_contain_entry_points() {
assert!(MSL_GEMV_Q1_G128.contains("kernel void gemv_q1_g128"));
assert!(MSL_GEMV_Q1_G128_V2.contains("kernel void gemv_q1_g128_v2"));
assert!(MSL_GEMV_Q1_G128_V3.contains("kernel void gemv_q1_g128_v3"));
assert!(MSL_GEMM_Q1_G128.contains("kernel void gemm_q1_g128"));
assert!(MSL_SOFTMAX.contains("kernel void softmax"));
assert!(MSL_RELU.contains("kernel void relu"));
assert!(MSL_RMSNORM.contains("kernel void rmsnorm"));
assert!(MSL_SILU.contains("kernel void silu"));
assert!(MSL_MATVEC_F32.contains("kernel void matvec_f32"));
assert!(MSL_SWIGLU.contains("kernel void swiglu"));
assert!(MSL_SWIGLU_FUSED.contains("kernel void swiglu_fused"));
assert!(MSL_RESIDUAL_ADD.contains("kernel void residual_add"));
assert!(MSL_RMSNORM_WEIGHTED.contains("kernel void rmsnorm_weighted"));
assert!(MSL_RMSNORM_WEIGHTED_V2.contains("kernel void rmsnorm_weighted_v2"));
assert!(MSL_BATCHED_RMSNORM_V2.contains("kernel void batched_rmsnorm_v2"));
assert!(MSL_BATCHED_ATTENTION_SCORES.contains("kernel void batched_attention_scores"));
assert!(MSL_BATCHED_ATTENTION_SCORES_V2.contains("kernel void batched_attention_scores_v2"));
assert!(MSL_BATCHED_SOFTMAX.contains("kernel void batched_softmax"));
assert!(MSL_BATCHED_ATTENTION_WEIGHTED_SUM
.contains("kernel void batched_attention_weighted_sum"));
assert!(MSL_FUSED_QK_NORM.contains("kernel void fused_qk_norm"));
assert!(MSL_FUSED_QK_ROPE.contains("kernel void fused_qk_rope"));
assert!(MSL_FUSED_KV_STORE.contains("kernel void fused_kv_store"));
assert!(MSL_GEMV_Q1_G128_RESIDUAL.contains("kernel void gemv_q1_g128_residual"));
assert!(MSL_GEMV_Q1_G128_V7.contains("kernel void gemv_q1_g128_v7"));
assert!(MSL_GEMV_Q1_G128_V7_RESIDUAL.contains("kernel void gemv_q1_g128_v7_residual"));
assert!(MSL_GEMV_Q1_G128_V8.contains("kernel void gemv_q1_g128_v8"));
assert!(MSL_GEMV_Q1_G128_V8_RESIDUAL.contains("kernel void gemv_q1_g128_v8_residual"));
assert!(MSL_GEMV_Q1_G128_V9.contains("kernel void gemv_q1_g128_v9"));
assert!(MSL_GEMV_Q1_G128_V9_RESIDUAL.contains("kernel void gemv_q1_g128_v9_residual"));
assert!(MSL_GEMV_Q1_G128_V10.contains("kernel void gemv_q1_g128_v10"));
assert!(MSL_GEMV_Q1_G128_V10_RESIDUAL.contains("kernel void gemv_q1_g128_v10_residual"));
assert!(MSL_ARGMAX.contains("kernel void argmax"));
assert!(MSL_FUSED_GATE_UP_SWIGLU_Q1.contains("kernel void fused_gate_up_swiglu_q1"));
assert!(MSL_GEMV_TQ2_G128_V1.contains("kernel void gemv_tq2_g128_v1"));
assert!(MSL_GEMM_Q1_G128_V7.contains("kernel void gemm_q1_g128_v7"));
assert!(MSL_GEMM_Q1_G128_V7_RESIDUAL.contains("kernel void gemm_q1_g128_v7_residual"));
assert!(
MSL_FUSED_GATE_UP_SWIGLU_GEMM_Q1.contains("kernel void fused_gate_up_swiglu_gemm_q1")
);
}
#[test]
#[cfg(feature = "cuda")]
fn cuda_kernels_contain_entry_points() {
assert!(CUDA_GEMV_Q1_G128.contains("gemv_q1_g128"));
assert!(CUDA_GEMM_Q1_G128.contains("gemm_q1_g128"));
assert!(CUDA_SOFTMAX.contains("softmax"));
assert!(CUDA_RELU.contains("relu"));
assert!(CUDA_RMSNORM.contains("rmsnorm"));
assert!(CUDA_SILU.contains("silu"));
assert!(CUDA_MATVEC_F32.contains("matvec_f32"));
assert!(CUDA_SWIGLU.contains("swiglu"));
assert!(CUDA_SWIGLU_FUSED.contains("swiglu_fused"));
assert!(CUDA_RESIDUAL_ADD.contains("residual_add"));
assert!(CUDA_RMSNORM_WEIGHTED.contains("rmsnorm_weighted"));
}
}