zilla-muf 0.1.1

Shared structured-matrix and numerical primitives for sparse attention and state space models (SSMs).
Documentation
// src/structured/semiseparable.rs
use num_traits::Float;
use crate::scan::chunked_scan;

/// Semiseparable matrix-vector product, diagonal/selective parameterization
/// (Mamba-2 / SSD style). Each row i computes:
///   y_i = C_i . h_i,  where h_i = a_i * h_{i-1} + B_i * x_i
/// — i.e. the matvec IS the SSM recurrence; this function just exposes it
/// as "apply this implicit matrix to a vector" for callers that think in
/// those terms (e.g. an attention-equivalent API).
///
/// - `a`: length n — scalar decay per timestep
/// - `b`: length n * rank, row-major — B_j vectors (state dim = rank)
/// - `c`: length n * rank, row-major — C_i vectors
/// - `x`: length n — the input sequence
///
/// Cost: O(n * rank), reusing `chunked_scan` for the actual recurrence.
pub fn semiseparable_matvec<T: Float + Send + Sync>(
	a: &[T],
	b: &[T],
	c: &[T],
	x: &[T],
	rank: usize,
	chunk_size: usize,
) -> Vec<T> {
	let n = a.len();
	// Shape contract: B and C are n x rank row-major; a and x are length n.
	assert_eq!(b.len(), n * rank, "b must be n * rank");
	assert_eq!(c.len(), n * rank, "c must be n * rank");
	assert_eq!(x.len(), n, "x must be length n");

	// State trajectory h[t][r], stored row-major (t major, r minor).
	let mut h = vec![T::zero(); n * rank];
	// The state recurrence is independent across the `rank` channels, so
	// run one scalar scan per channel. Channel r is driven by input
	// (B_t[r] * x_t); the shared decay `a` couples timesteps within a channel.
	for r in 0..rank {
		let b_r: Vec<T> = (0..n).map(|t| b[t * rank + r] * x[t]).collect(); // gather channel r's input
		let h_r = chunked_scan(a, &b_r, T::zero(), chunk_size);             // reuse the scan kernel
		for t in 0..n {
			h[t * rank + r] = h_r[t]; // scatter channel r back into the interleaved buffer
		}
	}

	// Output projection: y_i = C_i · h_i, a dot product across the rank dim.
	(0..n)
		.map(|i| (0..rank).fold(T::zero(), |acc, r| acc + c[i * rank + r] * h[i * rank + r]))
		.collect()
}

#[cfg(test)]
mod tests {
	use super::*;

	/// Builds the dense n x n semiseparable matrix directly and does a
	/// naive O(n^2 * rank) matvec — a correctness oracle independent of
	/// chunked_scan, so this test actually proves something.
	fn dense_reference(a: &[f64], b: &[f64], c: &[f64], x: &[f64], rank: usize) -> Vec<f64> {
		let n = a.len();
		let mut y = vec![0.0; n];
		for i in 0..n {
			for j in 0..=i {
				let mut decay = 1.0;
				for k in (j + 1)..=i {
					decay *= a[k];
				}
				let mut m_ij = 0.0;
				for r in 0..rank {
					m_ij += c[i * rank + r] * decay * b[j * rank + r];
				}
				y[i] += m_ij * x[j];
			}
		}
		y
	}

	#[test]
	fn matches_dense_reference() {
		let n = 6;
		let rank = 3;
		let a: Vec<f64> = vec![0.9, 0.8, 0.95, 0.7, 0.85, 0.6];
		let b: Vec<f64> = (0..n * rank).map(|i| (i as f64 * 0.37).sin()).collect();
		let c: Vec<f64> = (0..n * rank).map(|i| (i as f64 * 0.53).cos()).collect();
		let x: Vec<f64> = (0..n).map(|i| i as f64 + 1.0).collect();

		let expected = dense_reference(&a, &b, &c, &x, rank);
		for chunk_size in [1, 2, 3, 6, 10] {
			let actual = semiseparable_matvec(&a, &b, &c, &x, rank, chunk_size);
			for (e, got) in expected.iter().zip(actual.iter()) {
				assert!((e - got).abs() < 1e-9, "expected {e}, got {got}");
			}
		}
	}
}