use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::softmax_step";
#[must_use]
pub fn softmax_step(pre_exp: &str, out: &str, n: u32) -> Program {
if n == 0 {
return crate::invalid_output_program(
OP_ID,
out,
DataType::U32,
format!("Fix: softmax_step requires n > 0, got {n}."),
);
}
let t = Expr::InvocationId { axis: 0 };
let body = vec![Node::if_then(
Expr::eq(t.clone(), Expr::u32(0)),
vec![
Node::let_bind("sum", Expr::u32(0)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(n),
vec![Node::assign(
"sum",
Expr::add(Expr::var("sum"), Expr::load(pre_exp, Expr::var("i"))),
)],
),
Node::let_bind(
"sum_safe",
Expr::select(
Expr::eq(Expr::var("sum"), Expr::u32(0)),
Expr::u32(1),
Expr::var("sum"),
),
),
Node::loop_for(
"j",
Expr::u32(0),
Expr::u32(n),
vec![Node::store(
out,
Expr::var("j"),
Expr::div(
Expr::shl(Expr::load(pre_exp, Expr::var("j")), Expr::u32(16)),
Expr::var("sum_safe"),
),
)],
),
],
)];
Program::wrapped(
vec![
BufferDecl::storage(pre_exp, 0, BufferAccess::ReadOnly, DataType::U32).with_count(n),
BufferDecl::storage(out, 1, BufferAccess::ReadWrite, DataType::U32).with_count(n),
],
[256, 1, 1],
vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}],
)
}
#[must_use]
pub fn softmax_cpu(x: &[f64]) -> Vec<f64> {
let mut out = Vec::new();
softmax_cpu_into(x, &mut out);
out
}
pub fn softmax_cpu_into(x: &[f64], out: &mut Vec<f64>) {
out.clear();
if x.is_empty() {
return;
}
let max = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
out.reserve(x.len());
let mut sum = 0.0;
for &value in x {
let exp = (value - max).exp();
sum += exp;
out.push(exp);
}
for value in out.iter_mut() {
*value /= sum;
}
}
#[must_use]
pub fn differentiable_argmax_cpu(x: &[f64], temperature: f64) -> Vec<f64> {
let mut scaled = Vec::new();
let mut out = Vec::new();
differentiable_argmax_cpu_into(x, temperature, &mut scaled, &mut out);
out
}
pub fn differentiable_argmax_cpu_into(
x: &[f64],
temperature: f64,
scaled: &mut Vec<f64>,
out: &mut Vec<f64>,
) {
scaled.clear();
out.clear();
if temperature <= 0.0 || !temperature.is_finite() {
return;
}
scaled.reserve(x.len());
scaled.extend(x.iter().map(|&v| v / temperature));
softmax_cpu_into(scaled, out);
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10 * (1.0 + a.abs() + b.abs())
}
#[test]
fn cpu_softmax_uniform_input_is_uniform_output() {
let x = vec![1.0, 1.0, 1.0, 1.0];
let out = softmax_cpu(&x);
for v in out {
assert!(approx_eq(v, 0.25));
}
}
#[test]
fn cpu_softmax_sums_to_one() {
let x = vec![0.5, 1.0, 1.5, 2.0, 2.5];
let out = softmax_cpu(&x);
let s: f64 = out.iter().sum();
assert!(approx_eq(s, 1.0));
}
#[test]
fn cpu_softmax_monotone_in_input() {
let x = vec![0.0, 1.0, 2.0];
let out = softmax_cpu(&x);
assert!(out[0] < out[1]);
assert!(out[1] < out[2]);
}
#[test]
fn cpu_softmax_handles_large_inputs_no_overflow() {
let x = vec![1000.0, 1000.0, 1000.0];
let out = softmax_cpu(&x);
for v in out {
assert!(v.is_finite());
assert!(approx_eq(v, 1.0 / 3.0));
}
}
#[test]
fn cpu_diff_argmax_low_temp_concentrates() {
let x = vec![1.0, 5.0, 2.0];
let probs = differentiable_argmax_cpu(&x, 0.001);
assert!(probs[1] > 0.99);
assert!(probs[0] < 0.01);
assert!(probs[2] < 0.01);
}
#[test]
fn cpu_diff_argmax_high_temp_uniform() {
let x = vec![1.0, 5.0, 2.0];
let probs = differentiable_argmax_cpu(&x, 1000.0);
for v in probs {
assert!((v - 1.0 / 3.0).abs() < 0.01);
}
}
#[test]
fn cpu_diff_argmax_sums_to_one() {
let x = vec![0.5, 1.0, 1.5, 2.0];
let probs = differentiable_argmax_cpu(&x, 1.0);
let s: f64 = probs.iter().sum();
assert!(approx_eq(s, 1.0));
}
#[test]
fn cpu_diff_argmax_into_reuses_buffers() {
let x = vec![1.0, 5.0, 2.0];
let mut scaled = Vec::with_capacity(8);
let mut out = Vec::with_capacity(8);
let scaled_ptr = scaled.as_ptr();
let out_ptr = out.as_ptr();
differentiable_argmax_cpu_into(&x, 1000.0, &mut scaled, &mut out);
assert_eq!(scaled.as_ptr(), scaled_ptr);
assert_eq!(out.as_ptr(), out_ptr);
let s: f64 = out.iter().sum();
assert!(approx_eq(s, 1.0));
}
#[test]
fn ir_program_buffer_layout() {
let p = softmax_step("e", "out", 32);
assert_eq!(p.workgroup_size, [256, 1, 1]);
let names: Vec<&str> = p.buffers.iter().map(|b| b.name()).collect();
assert_eq!(names, vec!["e", "out"]);
for buf in p.buffers.iter() {
assert_eq!(buf.count(), 32);
}
}
#[test]
fn zero_n_traps() {
let p = softmax_step("e", "out", 0);
assert!(p.stats().trap());
}
}