metadata:
version: "1.0.0"
created: "2026-03-02"
author: "PAIML Engineering"
description: "Correctness and performance specification for softmax kernels"
references:
- "GH-384: Softmax 1.6x slower than ndarray"
- "Bridle (1990). Training Stochastic Model Recognition Algorithms as Networks can Lead to Maximum Mutual Information Estimation of Parameters"
- "Milakov & Gimelshein (2018). Online normalizer calculation for softmax. arXiv:1805.02867"
issues:
- "https://github.com/paiml/aprender/issues/384"
equations:
softmax:
formula: "softmax(x)_i = exp(x_i - max(x)) / Σ_j exp(x_j - max(x))"
domain: "x ∈ R^n, n ≥ 1"
properties:
- probability_simplex: "Σ_i softmax(x)_i = 1.0"
- non_negative: "softmax(x)_i >= 0 ∀i"
- monotone: "x_i > x_j → softmax(x)_i > softmax(x)_j"
- shift_invariant: "softmax(x + c) == softmax(x) ∀c ∈ R"
- max_dominance: "as x_i → ∞, softmax(x)_i → 1.0"
log_softmax:
formula: "log_softmax(x)_i = x_i - max(x) - log(Σ_j exp(x_j - max(x)))"
domain: "x ∈ R^n, n ≥ 1"
properties:
- negative: "log_softmax(x)_i ≤ 0 ∀i"
- log_sum_one: "Σ_i exp(log_softmax(x)_i) = 1.0"
- consistency: "log_softmax(x) == log(softmax(x)) (to float precision)"
implementation:
three_pass_inline:
description: "Tensor softmax MUST use 3-pass inline pattern (no per-row allocation)"
pattern: |
// Pass 1: max — auto-vectorizable
let mut max_val = f32::NEG_INFINITY;
for &v in row { max_val = max_val.max(v); }
// Pass 2: exp + sum — auto-vectorizable
let mut sum = 0.0f32;
for i in 0..n { let e = (row[i] - max_val).exp(); out[i] = e; sum += e; }
// Pass 3: normalize — auto-vectorizable
let inv_sum = 1.0 / sum;
for i in 0..n { out[i] *= inv_sum; }
rationale: "Avoids 2 intermediate Vec allocations per row that softmax_1d creates"
numerical_stability:
description: "MUST subtract max before exp to prevent overflow"
assertion: "No intermediate value exceeds exp(88.7) ≈ f32::MAX"
one_path_rule:
description: "All Tensor softmax delegates to nn::functional::softmax (UCBD §4)"
canonical_source: "src/nn/functional.rs::softmax"
performance:
benchmark_crate: "aprender-bench-compute"
benchmark_file: "benches/softmax.rs"
reference: "ndarray 3-pass scalar implementation"
bounds:
softmax_32000:
min_ratio_vs_ndarray: 0.55
target_ratio: 0.75
measured_ratio: 0.64
measured_date: "2026-03-02"
softmax_128256:
min_ratio_vs_ndarray: 0.70
target_ratio: 0.90
measured_ratio: 0.87
measured_date: "2026-03-02"
softmax_1d_4096:
min_ratio_vs_ndarray: 0.70
target_ratio: 0.90
measured_ratio: 0.87
measured_date: "2026-03-02"
falsification:
tests_file: "tests/contracts/softmax_contract.rs"
FALSIFY-SM-001:
name: "Probability simplex"
assertion: "sum(softmax(x)) ≈ 1.0 (within 1e-5)"
status: "PASS"
FALSIFY-SM-002:
name: "Non-negativity"
assertion: "∀i: softmax(x)[i] >= 0"
status: "PASS"
FALSIFY-SM-003:
name: "Shift invariance"
assertion: "softmax(x + c) ≈ softmax(x) (within 1e-5)"
status: "PASS"
FALSIFY-SM-004:
name: "Numerical stability"
assertion: "softmax([1000, 1001, 1002]) produces valid probabilities (no NaN/Inf)"
status: "PASS"
FALSIFY-SM-005:
name: "Log-softmax consistency"
assertion: "exp(log_softmax(x)) ≈ softmax(x) (within 1e-5)"
status: "PASS"
FALSIFY-SM-006:
name: "Monotonicity"
assertion: "x[i] > x[j] → softmax(x)[i] > softmax(x)[j]"
status: "PASS"
qa_gate:
id: "F-SOFTMAX-001"
name: "Softmax Kernel Contract"
description: "Validates softmax correctness and performance bounds"
checks:
- "Tensor softmax uses 3-pass inline pattern (no softmax_1d delegation)"
- "Max subtraction before exp in all paths"
- "Benchmark ratio >= min_ratio for all measured bounds"
- "All FALSIFY tests pass"
pass_criteria: "All checks pass"