zilla-muf 0.1.0

Shared structured-matrix and numerical primitives for sparse attention and state space models (SSMs).
Documentation
use num_traits::Float;

pub fn softmax<T: Float>(x: &[T]) -> Vec<T> {
	let max = x.iter().cloned().fold(T::neg_infinity(), T::max);
	let exps: Vec<T> = x.iter().map(|&v| (v - max).exp()).collect();
	let sum = exps.iter().fold(T::zero(), |a, &b| a + b);
	exps.into_iter().map(|v| v / sum).collect()
}

pub fn log_sum_exp<T: Float>(x: &[T]) -> T {
	let max = x.iter().cloned().fold(T::neg_infinity(), T::max);
	let sum = x.iter().map(|&v| (v - max).exp()).fold(T::zero(), |a, b| a + b);
	max + sum.ln()
}

#[cfg(test)]
mod tests {
	use super::*;

	/// Naive reference (no max-subtraction trick). Only valid for
	/// small-magnitude inputs where it won't overflow — used as an
	/// independent cross-check, not a restatement of the implementation.
	fn naive_softmax(x: &[f64]) -> Vec<f64> {
		let exps: Vec<f64> = x.iter().map(|v| v.exp()).collect();
		let sum: f64 = exps.iter().sum();
		exps.into_iter().map(|v| v / sum).collect()
	}

	fn naive_log_sum_exp(x: &[f64]) -> f64 {
		x.iter().map(|v| v.exp()).sum::<f64>().ln()
	}

	#[test]
	fn matches_naive_for_small_values() {
		let x = [1.0, 2.0, 3.0];
		let stable = softmax(&x);
		let naive = naive_softmax(&x);
		for (s, n) in stable.iter().zip(naive.iter()) {
			assert!((s - n).abs() < 1e-12, "stable={s}, naive={n}");
		}
	}

	#[test]
	fn sums_to_one() {
		let x = [-3.0, 0.5, 2.0, 7.0];
		let result = softmax(&x);
		let sum: f64 = result.iter().sum();
		assert!((sum - 1.0).abs() < 1e-9);
	}

	#[test]
	fn handles_large_values_without_overflow() {
		// Naive softmax computes exp(1000.0) = inf here, and inf/inf =
		// NaN. The max-subtraction trick must avoid this entirely.
		let x = [1000.0, 1000.0, 1000.0];
		let result = softmax(&x);
		for v in &result {
			assert!(v.is_finite(), "expected finite, got {v}");
			assert!((v - 1.0 / 3.0).abs() < 1e-9);
		}
	}

	#[test]
	fn log_sum_exp_matches_naive_for_small_values() {
		let x = [1.0, 2.0, 3.0];
		let stable = log_sum_exp(&x);
		let naive = naive_log_sum_exp(&x);
		assert!((stable - naive).abs() < 1e-12, "stable={stable}, naive={naive}");
	}

	#[test]
	fn log_sum_exp_handles_large_values() {
		// Naive ln(sum(exp(x))) computes ln(inf) = inf here.
		let x = [1000.0, 1000.0, 1000.0];
		let result = log_sum_exp(&x);
		assert!(result.is_finite());
		let expected = 1000.0 + 3.0_f64.ln();
		assert!((result - expected).abs() < 1e-9);
	}
}