Skip to main content

ferrum_kernels/
lib.rs

1//! Ferrum unified compute kernels for high-performance inference.
2//!
3//! Provides the `Backend` trait and implementations for CUDA, Metal, and CPU.
4//! On CUDA builds, kernels are compiled to PTX during `cargo build` and loaded
5//! on demand at runtime.
6
7pub mod backend;
8
9pub mod linear;
10pub use linear::Linear;
11
12pub mod moe_host;
13
14#[cfg(all(target_os = "macos", feature = "metal"))]
15pub mod moe_post_ops;
16#[cfg(all(target_os = "macos", feature = "metal"))]
17pub mod moe_post_ops_batched;
18#[cfg(all(target_os = "macos", feature = "metal"))]
19pub mod moe_router;
20#[cfg(all(target_os = "macos", feature = "metal"))]
21pub mod q4_k;
22#[cfg(all(target_os = "macos", feature = "metal"))]
23pub mod q4_k_gemm;
24#[cfg(all(target_os = "macos", feature = "metal"))]
25pub mod q4_k_gemv;
26#[cfg(all(target_os = "macos", feature = "metal"))]
27pub mod q4_k_gemv_v2;
28#[cfg(all(target_os = "macos", feature = "metal"))]
29pub mod q4_k_moe_id_gate_up_silu;
30#[cfg(all(target_os = "macos", feature = "metal"))]
31pub mod q4_k_moe_id_gate_up_silu_batched;
32#[cfg(all(target_os = "macos", feature = "metal"))]
33pub mod q4_k_moe_id_gemm;
34#[cfg(all(target_os = "macos", feature = "metal"))]
35pub mod q4_k_moe_id_gemv;
36#[cfg(all(target_os = "macos", feature = "metal"))]
37pub mod q4_k_moe_id_gemv_batched;
38#[cfg(all(target_os = "macos", feature = "metal"))]
39pub mod q6_k_gemm;
40#[cfg(all(target_os = "macos", feature = "metal"))]
41pub mod q6_k_gemv;
42#[cfg(all(target_os = "macos", feature = "metal"))]
43pub mod q6_k_moe_id_gemm;
44#[cfg(all(target_os = "macos", feature = "metal"))]
45pub mod q6_k_moe_id_gemv;
46#[cfg(all(target_os = "macos", feature = "metal"))]
47pub mod q6_k_moe_id_gemv_batched;
48
49#[cfg(feature = "cuda")]
50pub(crate) mod ptx {
51    // Generated by build.rs from all .cu sources. Some kernels (e.g.
52    // SOFTMAX, BATCHED_FLASH_DECODE_ATTENTION) are emitted unconditionally
53    // but only loaded behind specific code paths, so dead_code fires in
54    // configs that don't hit them.
55    #![allow(dead_code)]
56    include!(concat!(env!("OUT_DIR"), "/ptx.rs"));
57}
58
59#[cfg(feature = "cuda")]
60mod fused_add_rms_norm;
61#[cfg(feature = "cuda")]
62pub use fused_add_rms_norm::fused_add_rms_norm;
63
64#[cfg(feature = "cuda")]
65mod fused_silu_mul;
66#[cfg(feature = "cuda")]
67pub use fused_silu_mul::fused_silu_mul;
68
69#[cfg(feature = "cuda")]
70mod rms_norm;
71#[cfg(feature = "cuda")]
72pub use rms_norm::rms_norm;
73
74#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
75mod triton_meta;
76#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
77mod triton_ptx;
78#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
79mod triton_rms_norm;
80#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
81pub use triton_rms_norm::rms_norm_triton;
82
83#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
84mod triton_residual_add;
85#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
86pub use triton_residual_add::residual_add_triton;
87
88#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
89mod triton_residual_add_inplace;
90#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
91pub use triton_residual_add_inplace::residual_add_inplace_triton;
92
93#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
94mod triton_fused_silu_mul;
95#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
96pub use triton_fused_silu_mul::fused_silu_mul_triton;
97
98#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
99mod triton_fused_add_rms_norm;
100#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
101pub use triton_fused_add_rms_norm::fused_add_rms_norm_triton;
102
103#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
104mod triton_layer_norm;
105#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
106pub use triton_layer_norm::layer_norm_triton;
107
108#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
109mod triton_softmax;
110#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
111pub use triton_softmax::softmax_triton;
112
113#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
114mod triton_gelu;
115#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
116pub use triton_gelu::gelu_triton;
117
118#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
119mod triton_add_bias;
120#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
121pub use triton_add_bias::add_bias_triton;
122
123#[cfg(all(feature = "cuda", feature = "triton-kernels"))]
124pub mod triton_w4a16;
125
126#[cfg(feature = "cuda")]
127mod rope;
128#[cfg(feature = "cuda")]
129pub use rope::rope;
130
131#[cfg(feature = "cuda")]
132mod decode_attention;
133#[cfg(feature = "cuda")]
134pub use decode_attention::decode_attention;
135
136#[cfg(feature = "cuda")]
137mod residual_add;
138#[cfg(feature = "cuda")]
139pub use residual_add::residual_add;
140
141#[cfg(feature = "cuda")]
142pub mod cublas;
143
144#[cfg(feature = "cuda")]
145pub mod decode_buffers;
146
147#[cfg(feature = "cuda")]
148pub mod weight_store;
149
150#[cfg(feature = "cuda")]
151pub mod cuda_graph;
152
153#[cfg(feature = "cuda")]
154pub mod quant;
155
156#[cfg(feature = "cuda")]
157pub mod marlin;
158
159#[cfg(feature = "cuda")]
160pub mod gpu_paged_kv;
161
162#[cfg(feature = "cuda")]
163pub mod cuda_decode;
164
165#[cfg(feature = "cuda")]
166pub mod nccl_comm;
167
168#[cfg(feature = "cuda")]
169pub mod tp_decode;