zilla-muf 0.1.1

Shared structured-matrix and numerical primitives for sparse attention and state space models (SSMs).
Documentation
use num_traits::Float;

/// Numerically stable softmax over a slice.
///
/// Computes `exp(x_i) / sum_j exp(x_j)`, but subtracts `max(x)` from
/// every element first. That shift cancels algebraically (it scales the
/// numerator and denominator by the same constant) yet stops `exp()`
/// from overflowing to `+inf` on large inputs — the difference between a
/// valid distribution and a vector of NaNs. The
/// `handles_large_values_without_overflow` test pins down exactly the
/// failure mode this guards against.
///
/// Returns a vector the same length as `x` that sums to 1. Cost: O(n).
///
/// # Example
///
/// ```
/// use zilla_muf::stable_ops::softmax;
/// let p = softmax(&[1.0_f64, 2.0, 3.0]);
/// assert!((p.iter().sum::<f64>() - 1.0).abs() < 1e-9);
/// assert!(p[0] < p[1] && p[1] < p[2]); // monotone with input
/// ```
pub fn softmax<T: Float>(x: &[T]) -> Vec<T> {
	if x.is_empty() {
		return Vec::new();
	}
	// Shift by the max so the largest exponent becomes exp(0) = 1:
	// nothing can overflow upward, and at least one term stays O(1).
	let max = x.iter().cloned().fold(T::neg_infinity(), T::max);
	let exps: Vec<T> = x.iter().map(|&v| (v - max).exp()).collect(); // all in (0, 1]
	let sum = exps.iter().fold(T::zero(), |a, &b| a + b);            // normalizer
	exps.into_iter().map(|v| v / sum).collect()
}

/// Numerically stable log-sum-exp: `ln(sum_i exp(x_i))`.
///
/// Uses the identity
///   `ln(sum exp(x_i)) = max(x) + ln(sum exp(x_i - max(x)))`
/// so every exponential is taken on a shifted (non-positive) value that
/// can't overflow. This is the scalar normalizer behind log-domain
/// attention and some SSM gating; the naive
/// `x.iter().map(exp).sum().ln()` returns `+inf` for large inputs (see
/// the matching test). Cost: O(n).
///
/// Returns `f64::NEG_INFINITY` for an empty slice (equivalent to `ln(0)`).
///
/// # Example
///
/// ```
/// use zilla_muf::stable_ops::log_sum_exp;
/// let lse = log_sum_exp(&[1.0_f64, 2.0, 3.0]);
/// let expected = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
/// assert!((lse - expected).abs() < 1e-9);
/// ```
pub fn log_sum_exp<T: Float>(x: &[T]) -> T {
	if x.is_empty() {
		return T::neg_infinity(); // ln(0) by convention
	}
	let max = x.iter().cloned().fold(T::neg_infinity(), T::max);
	let sum = x.iter().map(|&v| (v - max).exp()).fold(T::zero(), |a, b| a + b);
	max + sum.ln() // undo the shift: the + max lives in log space
}

#[cfg(test)]
mod tests {
	use super::*;

	/// Naive reference (no max-subtraction trick). Only valid for
	/// small-magnitude inputs where it won't overflow — used as an
	/// independent cross-check, not a restatement of the implementation.
	fn naive_softmax(x: &[f64]) -> Vec<f64> {
		let exps: Vec<f64> = x.iter().map(|v| v.exp()).collect();
		let sum: f64 = exps.iter().sum();
		exps.into_iter().map(|v| v / sum).collect()
	}

	fn naive_log_sum_exp(x: &[f64]) -> f64 {
		x.iter().map(|v| v.exp()).sum::<f64>().ln()
	}

	#[test]
	fn matches_naive_for_small_values() {
		let x = [1.0, 2.0, 3.0];
		let stable = softmax(&x);
		let naive = naive_softmax(&x);
		for (s, n) in stable.iter().zip(naive.iter()) {
			assert!((s - n).abs() < 1e-12, "stable={s}, naive={n}");
		}
	}

	#[test]
	fn sums_to_one() {
		let x = [-3.0, 0.5, 2.0, 7.0];
		let result = softmax(&x);
		let sum: f64 = result.iter().sum();
		assert!((sum - 1.0).abs() < 1e-9);
	}

	#[test]
	fn handles_large_values_without_overflow() {
		// Naive softmax computes exp(1000.0) = inf here, and inf/inf =
		// NaN. The max-subtraction trick must avoid this entirely.
		let x = [1000.0, 1000.0, 1000.0];
		let result = softmax(&x);
		for v in &result {
			assert!(v.is_finite(), "expected finite, got {v}");
			assert!((v - 1.0 / 3.0).abs() < 1e-9);
		}
	}

	#[test]
	fn log_sum_exp_matches_naive_for_small_values() {
		let x = [1.0, 2.0, 3.0];
		let stable = log_sum_exp(&x);
		let naive = naive_log_sum_exp(&x);
		assert!((stable - naive).abs() < 1e-12, "stable={stable}, naive={naive}");
	}

	#[test]
	fn softmax_empty_returns_empty() {
		let result = softmax::<f64>(&[]);
		assert!(result.is_empty());
	}

	#[test]
	fn log_sum_exp_empty_returns_neg_infinity() {
		let result = log_sum_exp::<f64>(&[]);
		assert_eq!(result, f64::NEG_INFINITY);
	}

	#[test]
	fn log_sum_exp_handles_large_values() {
		// Naive ln(sum(exp(x))) computes ln(inf) = inf here.
		let x = [1000.0, 1000.0, 1000.0];
		let result = log_sum_exp(&x);
		assert!(result.is_finite());
		let expected = 1000.0 + 3.0_f64.ln();
		assert!((result - expected).abs() < 1e-9);
	}
}