Skip to main content

provable_contracts/kernels/
mod.rs

1//! Kernel implementations: scalar reference, AVX2 SIMD, and CUDA PTX.
2//!
3//! Each submodule provides three variants of its kernel:
4//! - `fn {name}_scalar(...)` — Pure Rust scalar reference (ground truth)
5//! - `unsafe fn {name}_avx2(...)` — AVX2 SIMD implementation
6//! - `fn {name}_ptx() -> &'static str` — PTX assembly source string
7
8// Kernel code naturally uses single-character math variable names (m, n, k, q, v, etc.),
9// raw string hashes for PTX assembly, and unsafe intrinsics inside unsafe fns.
10#![allow(
11    clippy::many_single_char_names,
12    clippy::similar_names,
13    clippy::needless_raw_string_hashes,
14    clippy::too_many_arguments,
15    clippy::too_many_lines,
16    clippy::cast_precision_loss,
17    clippy::cast_possible_truncation,
18    clippy::cast_possible_wrap,
19    clippy::cast_sign_loss,
20    clippy::explicit_iter_loop,
21    clippy::needless_range_loop,
22    clippy::float_cmp,
23    clippy::wildcard_imports,
24    clippy::doc_markdown,
25    unsafe_op_in_unsafe_fn
26)]
27
28pub mod ops;
29pub mod ulp;
30
31// Group A — Elementwise
32pub mod activation;
33pub mod silu_standalone;
34
35// Group B — Normalization
36pub mod batchnorm;
37pub mod layernorm;
38pub mod rmsnorm;
39pub mod softmax;
40
41// Group C — Gated + Positional + Loss
42pub mod absolute_position;
43pub mod bias_add;
44pub mod cross_entropy;
45pub mod dropout;
46pub mod gelu;
47pub mod rope;
48pub mod swiglu;
49
50// Group D — Matrix + Projection
51pub mod attention;
52pub mod flash_attention;
53pub mod gqa;
54pub mod linear;
55pub mod matmul;
56pub mod tied_embeddings;
57pub mod transpose;
58
59// Group E — Optimizer + Sequence + Classical ML + IO
60pub mod adamw;
61pub mod alibi;
62pub mod cma_es;
63pub mod conv1d;
64pub mod embedding;
65pub mod f16_convert;
66pub mod gated_delta_net;
67pub mod kmeans;
68pub mod lbfgs;
69pub mod pagerank;
70pub mod sampling;
71pub mod ssm;
72
73#[cfg(kani)]
74mod kani_proofs;
75
76/// Backend selector for kernel dispatch.
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum Backend {
79    /// Pure Rust scalar reference implementation.
80    Scalar,
81    /// x86-64 AVX2 SIMD implementation.
82    Avx2,
83    /// CUDA PTX kernel (returned as assembly source string).
84    Ptx,
85}