vyre-primitives 0.4.1

Compositional primitives for vyre — marker types (always on) + Tier 2.5 LEGO substrate (feature-gated per domain).
Documentation
//! Differentiable algorithm primitives — smoothed argmax + softmax.
//!
//! The breakthrough wasn't auto-diff — it was realizing that **classical
//! algorithms become NN building blocks once you smooth their argmaxes**
//! (Mensch-Blondel 2018, Berthet 2020, Petersen 2022). Every smoothed
//! classical algorithm is a primitive that participates in end-to-end
//! gradient flow.
//!
//! This file ships:
//! - [`crate::math::differentiable::softmax_step`] — `out[i] = exp(x[i] - max) / Σ exp(x[j] - max)`
//!   numerically stable, fixed-point.
//! - `differentiable_argmax` — uses softmax with high temperature
//!   for sharp argmax-like behavior. Forward and gradient agree on
//!   the soft assignment.
//!
//! # Why this primitive is dual-use
//!
//! | Consumer | Use |
//! |---|---|
//! | `vyre-libs::ml::transformers` consumers | attention softmax |
//! | `vyre-libs::ml::structured` consumers | differentiable beam search, top-k, sorting |
//! | `vyre-libs::parsing::c::sema` consumers | identifier→typedef gradient flow |
//! | `vyre-driver` autotuner consumers | smoothed argmax over the workgroup-size / tile-shape / fusion-threshold tuning grid; gradient descent picks the configuration |
//!
//! # Fixed-point convention
//!
//! u32 16.16 fixed-point with `EXP_TABLE_BITS` mantissa for the inner
//! exp lookup. The exp evaluator is a small Taylor expansion at u32
//! precision; for higher precision, callers compose this primitive
//! with their own f32-based exp evaluator and supply the pre-exp'd
//! values.
//!
//! Separate registered ops own differentiable sorting and top-k
//! selection; this module's contract is softmax plus differentiable
//! argmax.

use std::sync::Arc;

use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};

/// Op id.
pub const OP_ID: &str = "vyre-primitives::math::softmax_step";

/// Emit a numerically-stable softmax Program over a length-`n` vector.
///
/// Inputs:
/// - `x`: length-n u32 buffer (16.16 fixed-point logits, post-temperature scaling).
/// - `pre_exp`: caller-supplied length-n u32 buffer of `exp(x[i] - max)`
///   in 16.16 fixed-point. This file does not embed an exp evaluator
///   on GPU — the caller composes a separate exp Program (or uses a
///   precomputed lookup) before calling this.
///
/// Output:
/// - `out`: length-n u32 buffer of softmax probabilities (16.16).
///
/// Single-lane normalization (lane 0 walks the array, sums, then
/// writes normalized values back). For large n compose with
/// `reduce::sum` first to get the partition function, then a separate
/// elementwise divide (which may compose multiple of this primitive's
/// inner divides into one parallel pass).
#[must_use]
pub fn softmax_step(pre_exp: &str, out: &str, n: u32) -> Program {
    if n == 0 {
        return crate::invalid_output_program(
            OP_ID,
            out,
            DataType::U32,
            format!("Fix: softmax_step requires n > 0, got {n}."),
        );
    }

    let t = Expr::InvocationId { axis: 0 };
    let body = vec![Node::if_then(
        Expr::eq(t.clone(), Expr::u32(0)),
        vec![
            // sum = Σ pre_exp[i]
            Node::let_bind("sum", Expr::u32(0)),
            Node::loop_for(
                "i",
                Expr::u32(0),
                Expr::u32(n),
                vec![Node::assign(
                    "sum",
                    Expr::add(Expr::var("sum"), Expr::load(pre_exp, Expr::var("i"))),
                )],
            ),
            // sum_safe = max(sum, 1)
            Node::let_bind(
                "sum_safe",
                Expr::select(
                    Expr::eq(Expr::var("sum"), Expr::u32(0)),
                    Expr::u32(1),
                    Expr::var("sum"),
                ),
            ),
            // out[i] = (pre_exp[i] << 16) / sum_safe  (preserve 16.16
            // by shifting the numerator before dividing).
            Node::loop_for(
                "j",
                Expr::u32(0),
                Expr::u32(n),
                vec![Node::store(
                    out,
                    Expr::var("j"),
                    Expr::div(
                        Expr::shl(Expr::load(pre_exp, Expr::var("j")), Expr::u32(16)),
                        Expr::var("sum_safe"),
                    ),
                )],
            ),
        ],
    )];

    Program::wrapped(
        vec![
            BufferDecl::storage(pre_exp, 0, BufferAccess::ReadOnly, DataType::U32).with_count(n),
            BufferDecl::storage(out, 1, BufferAccess::ReadWrite, DataType::U32).with_count(n),
        ],
        [256, 1, 1],
        vec![Node::Region {
            generator: Ident::from(OP_ID),
            source_region: None,
            body: Arc::new(body),
        }],
    )
}

/// CPU reference: softmax in f64 for clarity.
#[must_use]
pub fn softmax_cpu(x: &[f64]) -> Vec<f64> {
    let mut out = Vec::new();
    softmax_cpu_into(x, &mut out);
    out
}

/// CPU reference: softmax in f64 using caller-owned output.
pub fn softmax_cpu_into(x: &[f64], out: &mut Vec<f64>) {
    out.clear();
    if x.is_empty() {
        return;
    }
    let max = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
    out.reserve(x.len());
    let mut sum = 0.0;
    for &value in x {
        let exp = (value - max).exp();
        sum += exp;
        out.push(exp);
    }
    for value in out.iter_mut() {
        *value /= sum;
    }
}

/// CPU reference: differentiable argmax via temperature-scaled softmax.
/// Higher `temperature` → softer assignment; `temperature → 0+`
/// recovers hard argmax. Returns the soft-assignment vector that sums
/// to 1 (probability over indices).
#[must_use]
pub fn differentiable_argmax_cpu(x: &[f64], temperature: f64) -> Vec<f64> {
    let mut scaled = Vec::new();
    let mut out = Vec::new();
    differentiable_argmax_cpu_into(x, temperature, &mut scaled, &mut out);
    out
}

/// CPU reference: differentiable argmax using caller-owned scratch and output.
pub fn differentiable_argmax_cpu_into(
    x: &[f64],
    temperature: f64,
    scaled: &mut Vec<f64>,
    out: &mut Vec<f64>,
) {
    scaled.clear();
    out.clear();
    if temperature <= 0.0 || !temperature.is_finite() {
        return;
    }
    scaled.reserve(x.len());
    scaled.extend(x.iter().map(|&v| v / temperature));
    softmax_cpu_into(scaled, out);
}

#[cfg(test)]
mod tests {
    use super::*;

    fn approx_eq(a: f64, b: f64) -> bool {
        (a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
    }

    #[test]
    fn cpu_softmax_uniform_input_is_uniform_output() {
        let x = vec![1.0, 1.0, 1.0, 1.0];
        let out = softmax_cpu(&x);
        for v in out {
            assert!(approx_eq(v, 0.25));
        }
    }

    #[test]
    fn cpu_softmax_sums_to_one() {
        let x = vec![0.5, 1.0, 1.5, 2.0, 2.5];
        let out = softmax_cpu(&x);
        let s: f64 = out.iter().sum();
        assert!(approx_eq(s, 1.0));
    }

    #[test]
    fn cpu_softmax_monotone_in_input() {
        let x = vec![0.0, 1.0, 2.0];
        let out = softmax_cpu(&x);
        assert!(out[0] < out[1]);
        assert!(out[1] < out[2]);
    }

    #[test]
    fn cpu_softmax_handles_large_inputs_no_overflow() {
        // Without max-subtraction, exp(1000) would overflow.
        let x = vec![1000.0, 1000.0, 1000.0];
        let out = softmax_cpu(&x);
        for v in out {
            assert!(v.is_finite());
            assert!(approx_eq(v, 1.0 / 3.0));
        }
    }

    #[test]
    fn cpu_diff_argmax_low_temp_concentrates() {
        // Low temperature → soft argmax should concentrate on the max.
        let x = vec![1.0, 5.0, 2.0];
        let probs = differentiable_argmax_cpu(&x, 0.001);
        // probs[1] should be very close to 1.
        assert!(probs[1] > 0.99);
        assert!(probs[0] < 0.01);
        assert!(probs[2] < 0.01);
    }

    #[test]
    fn cpu_diff_argmax_high_temp_uniform() {
        // High temperature → soft argmax should be near-uniform.
        let x = vec![1.0, 5.0, 2.0];
        let probs = differentiable_argmax_cpu(&x, 1000.0);
        for v in probs {
            assert!((v - 1.0 / 3.0).abs() < 0.01);
        }
    }

    #[test]
    fn cpu_diff_argmax_sums_to_one() {
        let x = vec![0.5, 1.0, 1.5, 2.0];
        let probs = differentiable_argmax_cpu(&x, 1.0);
        let s: f64 = probs.iter().sum();
        assert!(approx_eq(s, 1.0));
    }

    #[test]
    fn cpu_diff_argmax_into_reuses_buffers() {
        let x = vec![1.0, 5.0, 2.0];
        let mut scaled = Vec::with_capacity(8);
        let mut out = Vec::with_capacity(8);
        let scaled_ptr = scaled.as_ptr();
        let out_ptr = out.as_ptr();
        differentiable_argmax_cpu_into(&x, 1000.0, &mut scaled, &mut out);
        assert_eq!(scaled.as_ptr(), scaled_ptr);
        assert_eq!(out.as_ptr(), out_ptr);
        let s: f64 = out.iter().sum();
        assert!(approx_eq(s, 1.0));
    }

    #[test]
    fn ir_program_buffer_layout() {
        let p = softmax_step("e", "out", 32);
        assert_eq!(p.workgroup_size, [256, 1, 1]);
        let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
        assert_eq!(names, vec!["e", "out"]);
        for buf in p.buffers.iter() {
            assert_eq!(buf.count(), 32);
        }
    }

    #[test]
    fn zero_n_traps() {
        let p = softmax_step("e", "out", 0);
        assert!(p.stats().trap());
    }
}