1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
//! Kernel implementations: scalar reference, AVX2 SIMD, and CUDA PTX.
//!
//! Each submodule provides three variants of its kernel:
//! - `fn {name}_scalar(...)` — Pure Rust scalar reference (ground truth)
//! - `unsafe fn {name}_avx2(...)` — AVX2 SIMD implementation
//! - `fn {name}_ptx() -> &'static str` — PTX assembly source string
// Kernel code naturally uses single-character math variable names (m, n, k, q, v, etc.),
// raw string hashes for PTX assembly, and unsafe intrinsics inside unsafe fns.
#![allow(
clippy::many_single_char_names,
clippy::similar_names,
clippy::needless_raw_string_hashes,
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::explicit_iter_loop,
clippy::needless_range_loop,
clippy::float_cmp,
clippy::wildcard_imports,
clippy::doc_markdown,
unsafe_op_in_unsafe_fn
)]
pub mod ops;
pub mod ulp;
// Group A — Elementwise
pub mod activation;
pub mod silu_standalone;
// Group B — Normalization
pub mod batchnorm;
pub mod layernorm;
pub mod rmsnorm;
pub mod softmax;
// Group C — Gated + Positional + Loss
pub mod absolute_position;
pub mod bias_add;
pub mod cross_entropy;
pub mod dropout;
pub mod gelu;
pub mod rope;
pub mod swiglu;
// Group D — Matrix + Projection
pub mod attention;
pub mod flash_attention;
pub mod gqa;
pub mod linear;
pub mod matmul;
pub mod tied_embeddings;
pub mod transpose;
// Group E — Optimizer + Sequence + Classical ML + IO
pub mod adamw;
pub mod alibi;
pub mod cma_es;
pub mod conv1d;
pub mod embedding;
pub mod f16_convert;
pub mod gated_delta_net;
pub mod kmeans;
pub mod lbfgs;
pub mod pagerank;
pub mod sampling;
pub mod ssm;
#[cfg(kani)]
mod kani_proofs;
/// Backend selector for kernel dispatch.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Backend {
/// Pure Rust scalar reference implementation.
Scalar,
/// x86-64 AVX2 SIMD implementation.
Avx2,
/// CUDA PTX kernel (returned as assembly source string).
Ptx,
}