zilla-muf 0.1.1

Shared structured-matrix and numerical primitives for sparse attention and state space models (SSMs).
Documentation
use num_complex::Complex;
use num_traits::Float;

/// Cauchy kernel evaluation — original S4's (DPLR) kernel-generation
/// step. Given complex evaluation points `nodes` (z_k, typically roots
/// of unity), complex `poles` (lambda_i, the diagonalized state
/// matrix's eigenvalues), and complex `weights` (w_i, the per-pole
/// residues, e.g. C_i * B_i), computes for each node:
///   out[k] = sum_i weights[i] / (nodes[k] - poles[i])
///
/// This is the implicit Cauchy matrix M[k][i] = 1/(nodes[k] - poles[i])
/// applied to `weights`. Cost: O(num_nodes * num_poles).
///
/// Singular (and returns +-infinity/NaN) if a node exactly equals a
/// pole — that's a property of the Cauchy kernel itself, not a bug;
/// callers evaluating at roots of unity should ensure poles don't sit
/// exactly on the unit circle at a sampled angle.
pub fn cauchy_matvec<T: Float>(nodes: &[Complex<T>], poles: &[Complex<T>], weights: &[Complex<T>]) -> Vec<Complex<T>> {
	assert_eq!(poles.len(), weights.len(), "poles and weights must be the same length");

	nodes
		.iter()
		.map(|&z| {
			// One output per evaluation node z: sum each pole's
			// contribution w_i / (z - lambda_i). Diverges if z lands
			// exactly on a pole (z == lambda_i) — documented above as
			// inherent to the Cauchy kernel, not a defect to guard here.
			poles
				.iter()
				.zip(weights.iter())
				.fold(Complex::new(T::zero(), T::zero()), |acc, (&lambda, &w)| {
					acc + w / (z - lambda)
				})
		})
		.collect()
}

/// Vandermonde kernel evaluation — S4D's (diagonal-only) direct
/// kernel-generation step. Given complex `poles` (lambda_i, the
/// diagonalized state matrix's eigenvalues) and complex `weights`
/// (w_i, the per-pole residues, e.g. C_i * B_i), computes the SSM's
/// convolution kernel directly:
///   kernel[t] = sum_i weights[i] * poles[i]^t,   for t = 0..length
///
/// This is the implicit Vandermonde matrix V[t][i] = poles[i]^t applied
/// to `weights`. Cost: O(length * num_poles), computed via incremental
/// multiplication rather than repeated exponentiation.
///
/// Poles typically arrive in conjugate pairs so the summed kernel is
/// real; this function stays complex-valued and leaves that reduction
/// to the caller (e.g. take `.re`, or sum conjugate pairs first).
pub fn vandermonde_matvec<T: Float>(poles: &[Complex<T>], weights: &[Complex<T>], length: usize) -> Vec<Complex<T>> {
	assert_eq!(poles.len(), weights.len(), "poles and weights must be the same length");

	// running_powers[i] holds poles[i]^t for the current t, advanced one
	// power per outer iteration. Seeding at poles^0 = 1 and multiplying
	// in place keeps this O(length * poles) instead of recomputing a
	// fresh power each step (which would be O(length^2 * poles)).
	let mut running_powers: Vec<Complex<T>> = vec![Complex::new(T::one(), T::zero()); poles.len()];
	let mut out = Vec::with_capacity(length);

	for _ in 0..length {
		// kernel[t] = sum_i weights[i] * poles[i]^t, using the powers as they stand now.
		let kernel_t = running_powers
			.iter()
			.zip(weights.iter())
			.fold(Complex::new(T::zero(), T::zero()), |acc, (&p, &w)| acc + w * p);
		out.push(kernel_t);

		// Advance every pole's power for the next t: poles[i]^t -> poles[i]^{t+1}.
		for (p, &lambda) in running_powers.iter_mut().zip(poles.iter()) {
			*p = *p * lambda;
		}
	}

	out
}

#[cfg(test)]
mod tests {
	use super::*;

	#[test]
	fn cauchy_matches_dense_reference() {
		// Independent oracle: same formula, but built as an explicit
		// dense matrix first, summed in the opposite loop order.
		fn dense_reference(nodes: &[Complex<f64>], poles: &[Complex<f64>], weights: &[Complex<f64>]) -> Vec<Complex<f64>> {
			let mut out = vec![Complex::new(0.0, 0.0); nodes.len()];
			for i in 0..poles.len() {
				for k in 0..nodes.len() {
					out[k] += weights[i] / (nodes[k] - poles[i]);
				}
			}
			out
		}

		let l = 8;
		let nodes: Vec<Complex<f64>> = (0..l)
			.map(|k| Complex::from_polar(1.0, 2.0 * std::f64::consts::PI * k as f64 / l as f64))
			.collect();
		let poles: Vec<Complex<f64>> = vec![
			Complex::new(-0.5, 0.3),
			Complex::new(-0.5, -0.3),
			Complex::new(-0.2, 0.7),
			Complex::new(-0.2, -0.7),
		];
		let weights: Vec<Complex<f64>> = vec![
			Complex::new(1.0, 0.2),
			Complex::new(1.0, -0.2),
			Complex::new(0.5, 0.1),
			Complex::new(0.5, -0.1),
		];

		let expected = dense_reference(&nodes, &poles, &weights);
		let actual = cauchy_matvec(&nodes, &poles, &weights);

		for (e, got) in expected.iter().zip(actual.iter()) {
			assert!((e - got).norm() < 1e-9, "expected {e}, got {got}");
		}
	}

	#[test]
	fn vandermonde_matches_direct_powu() {
		// Independent oracle: Complex::powu (repeated squaring) instead
		// of the implementation's incremental multiplication.
		let poles: Vec<Complex<f64>> = vec![
			Complex::new(0.9, 0.1),
			Complex::new(0.9, -0.1),
			Complex::new(0.7, 0.4),
		];
		let weights: Vec<Complex<f64>> = vec![
			Complex::new(1.0, 0.0),
			Complex::new(1.0, 0.0),
			Complex::new(0.5, -0.2),
		];
		let length = 12;

		let expected: Vec<Complex<f64>> = (0..length as u32)
			.map(|t| {
				poles
					.iter()
					.zip(weights.iter())
					.fold(Complex::new(0.0, 0.0), |acc, (&p, &w)| acc + w * p.powu(t))
			})
			.collect();

		let actual = vandermonde_matvec(&poles, &weights, length);

		for (e, got) in expected.iter().zip(actual.iter()) {
			assert!((e - got).norm() < 1e-9, "expected {e}, got {got}");
		}
	}

	#[test]
	fn vandermonde_at_t_zero_is_sum_of_weights() {
		// poles^0 == 1 for all poles, so kernel[0] = sum(weights).
		let poles: Vec<Complex<f64>> = vec![Complex::new(0.5, 0.5), Complex::new(-0.3, 0.2)];
		let weights: Vec<Complex<f64>> = vec![Complex::new(2.0, 0.0), Complex::new(1.0, 1.0)];
		let result = vandermonde_matvec(&poles, &weights, 3);
		let expected_t0 = weights[0] + weights[1];
		assert!((result[0] - expected_t0).norm() < 1e-12);
	}
}