ferrum-kernels 0.7.1

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
//! Ferrum unified compute kernels for high-performance inference.
//!
//! Provides the `Backend` trait and implementations for CUDA, Metal, and CPU.
//! On CUDA builds, kernels are compiled to PTX during `cargo build` and loaded
//! on demand at runtime.

pub mod backend;

pub mod linear;
pub use linear::Linear;

pub mod moe_host;

#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod moe_post_ops;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod moe_post_ops_batched;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod moe_router;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q4_k;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q4_k_gemm;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q4_k_gemv;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q4_k_gemv_v2;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q4_k_moe_id_gate_up_silu;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q4_k_moe_id_gate_up_silu_batched;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q4_k_moe_id_gemm;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q4_k_moe_id_gemv;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q4_k_moe_id_gemv_batched;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q6_k_gemm;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q6_k_gemv;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q6_k_moe_id_gemm;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q6_k_moe_id_gemv;
#[cfg(all(target_os = "macos", feature = "metal"))]
pub mod q6_k_moe_id_gemv_batched;

#[cfg(feature = "cuda")]
pub(crate) mod ptx {
    // Generated by build.rs from all .cu sources. Some kernels (e.g.
    // SOFTMAX, BATCHED_FLASH_DECODE_ATTENTION) are emitted unconditionally
    // but only loaded behind specific code paths, so dead_code fires in
    // configs that don't hit them.
    #![allow(dead_code)]
    include!(concat!(env!("OUT_DIR"), "/ptx.rs"));
}

#[cfg(feature = "cuda")]
mod fused_add_rms_norm;
#[cfg(feature = "cuda")]
pub use fused_add_rms_norm::fused_add_rms_norm;

#[cfg(feature = "cuda")]
mod fused_silu_mul;
#[cfg(feature = "cuda")]
pub use fused_silu_mul::fused_silu_mul;

#[cfg(feature = "cuda")]
mod rms_norm;
#[cfg(feature = "cuda")]
pub use rms_norm::rms_norm;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_meta;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_ptx;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_rms_norm;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub use triton_rms_norm::rms_norm_triton;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_residual_add;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub use triton_residual_add::residual_add_triton;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_residual_add_inplace;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub use triton_residual_add_inplace::residual_add_inplace_triton;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_fused_silu_mul;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub use triton_fused_silu_mul::fused_silu_mul_triton;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_fused_add_rms_norm;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub use triton_fused_add_rms_norm::fused_add_rms_norm_triton;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_layer_norm;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub use triton_layer_norm::layer_norm_triton;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_softmax;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub use triton_softmax::softmax_triton;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_gelu;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub use triton_gelu::gelu_triton;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
mod triton_add_bias;
#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub use triton_add_bias::add_bias_triton;

#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
pub mod triton_w4a16;

#[cfg(feature = "cuda")]
mod rope;
#[cfg(feature = "cuda")]
pub use rope::rope;

#[cfg(feature = "cuda")]
mod decode_attention;
#[cfg(feature = "cuda")]
pub use decode_attention::decode_attention;

#[cfg(feature = "cuda")]
mod residual_add;
#[cfg(feature = "cuda")]
pub use residual_add::residual_add;

#[cfg(feature = "cuda")]
pub mod cublas;

#[cfg(feature = "cuda")]
pub mod decode_buffers;

#[cfg(feature = "cuda")]
pub mod weight_store;

#[cfg(feature = "cuda")]
pub mod cuda_graph;

#[cfg(feature = "cuda")]
pub mod quant;

#[cfg(feature = "cuda")]
pub mod marlin;

#[cfg(feature = "cuda")]
pub mod gpu_paged_kv;

#[cfg(feature = "cuda")]
pub mod cuda_decode;

#[cfg(feature = "cuda")]
pub mod nccl_comm;

#[cfg(feature = "cuda")]
pub mod tp_decode;