ferrum-kernels 0.7.2

Unified compute kernels (CUDA/Metal/CPU) and model runner for Ferrum inference
Documentation
//! Pre-compiled Triton-rs PTX + metadata embedded as static strings.
//!
//! These PTX files are generated by `tools/regen-triton-ptx.sh` on a
//! Linux+CUDA box that has triton-rs checked out as a sibling repo. The
//! bytes are committed to git so that downstream cargo builds don't need
//! the Triton C++ toolchain.
//!
//! Each kernel exposes a const PTX (text) and META (JSON with kernel name,
//! shared_mem, num_warps, global_scratch_size, profile_scratch_size).
//!
//! Only present when the `triton-kernels` feature is enabled.

#![cfg(feature = "triton-kernels")]
#![allow(dead_code)]

pub mod rms_norm_f32 {
    pub const PTX: &str = include_str!("../triton_ptx/rms_norm_f32.ptx");
    pub const META: &str = include_str!("../triton_ptx/rms_norm_f32.json");
}

pub mod residual_add_f32 {
    pub const PTX: &str = include_str!("../triton_ptx/residual_add_f32.ptx");
    pub const META: &str = include_str!("../triton_ptx/residual_add_f32.json");
}

pub mod residual_add_inplace_f32 {
    pub const PTX: &str = include_str!("../triton_ptx/residual_add_inplace_f32.ptx");
    pub const META: &str = include_str!("../triton_ptx/residual_add_inplace_f32.json");
}

pub mod fused_silu_mul_f32 {
    pub const PTX: &str = include_str!("../triton_ptx/fused_silu_mul_f32.ptx");
    pub const META: &str = include_str!("../triton_ptx/fused_silu_mul_f32.json");
}

pub mod fused_add_rms_norm_f32 {
    pub const PTX: &str = include_str!("../triton_ptx/fused_add_rms_norm_f32.ptx");
    pub const META: &str = include_str!("../triton_ptx/fused_add_rms_norm_f32.json");
}

pub mod layer_norm_f32 {
    pub const PTX: &str = include_str!("../triton_ptx/layer_norm_f32.ptx");
    pub const META: &str = include_str!("../triton_ptx/layer_norm_f32.json");
}

pub mod softmax_f32 {
    pub const PTX: &str = include_str!("../triton_ptx/softmax_f32.ptx");
    pub const META: &str = include_str!("../triton_ptx/softmax_f32.json");
}

pub mod gelu_f32 {
    pub const PTX: &str = include_str!("../triton_ptx/gelu_f32.ptx");
    pub const META: &str = include_str!("../triton_ptx/gelu_f32.json");
}

pub mod add_bias_f32 {
    pub const PTX: &str = include_str!("../triton_ptx/add_bias_f32.ptx");
    pub const META: &str = include_str!("../triton_ptx/add_bias_f32.json");
}

// ── f16 variants (used by the LLM decode path; Qwen / Llama / TinyLlama
//   all run with f16 weights). All 9 kernels share the same Rust fn name
//   `<X>_typed` post-dtype-generic — the JSON `name` field reflects this,
//   so loaders should use `meta.name` rather than hard-coding.

pub mod rms_norm_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/rms_norm_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/rms_norm_f16.json");
}

pub mod residual_add_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/residual_add_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/residual_add_f16.json");
}

pub mod residual_add_inplace_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/residual_add_inplace_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/residual_add_inplace_f16.json");
}

pub mod fused_silu_mul_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/fused_silu_mul_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/fused_silu_mul_f16.json");
}

pub mod fused_add_rms_norm_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/fused_add_rms_norm_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/fused_add_rms_norm_f16.json");
}

pub mod layer_norm_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/layer_norm_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/layer_norm_f16.json");
}

pub mod softmax_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/softmax_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/softmax_f16.json");
}

pub mod gelu_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/gelu_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/gelu_f16.json");
}

pub mod add_bias_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/add_bias_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/add_bias_f16.json");
}

// decode_attention seq-major, HEAD_DIM=64 (TinyLlama / 1.1B-class Llamas)
// and HEAD_DIM=128 (Qwen3, Llama-3, …). Both PTX kernels handle arbitrary
// valid_kv_len (partial last KV tile gets per-position score masking).
pub mod decode_attention_f16_h64 {
    pub const PTX: &str = include_str!("../triton_ptx/decode_attention_f16_h64.ptx");
    pub const META: &str = include_str!("../triton_ptx/decode_attention_f16_h64.json");
}

pub mod decode_attention_f16_h128 {
    pub const PTX: &str = include_str!("../triton_ptx/decode_attention_f16_h128.ptx");
    pub const META: &str = include_str!("../triton_ptx/decode_attention_f16_h128.json");
}

// w4a16 GPTQ INT4-weight × FP16-act fused GEMM. Tile <BM=64, BN=64, BK=32>.
// Drop-in alternative to Marlin's `gemm_gptq` path; gated at load time via
// FERRUM_TRITON_INT4=1 so the existing Marlin path stays default.
pub mod w4a16_gptq_f16 {
    pub const PTX: &str = include_str!("../triton_ptx/w4a16_gptq_f16.ptx");
    pub const META: &str = include_str!("../triton_ptx/w4a16_gptq_f16.json");
}