zilla-muf 0.1.1

Shared structured-matrix and numerical primitives for sparse attention and state space models (SSMs).
Documentation
use num_traits::Float;

/// Sequential (recurrent) scan — the correctness baseline that every
/// other scan in the crate is tested against.
///
/// Evaluates the first-order linear recurrence
///   `h_t = a_t * h_{t-1} + b_t`,   with `h_{-1} = h0`
/// left-to-right, returning the full state sequence `[h_0, ..., h_{n-1}]`.
///
/// - `a`: per-step decay / transition coefficients
/// - `b`: per-step inputs (must be the same length as `a`)
/// - `h0`: the initial state fed in before the first step
///
/// Setting `a_t = 1` everywhere collapses this into a plain prefix sum
/// (running total) of `b` — the easiest case to eyeball, which is exactly
/// what the test below checks.
///
/// Cost: O(n), one multiply-add per element. Inherently sequential: each
/// step needs the previous `h`, so it can't be parallelized as-is.
/// `chunked_scan` exists to break that dependency for large inputs; this
/// function is the reference it must match. Generic over the float type
/// `T` (f32 for speed, f64 for numerical testing).
///
/// # Example
///
/// ```
/// use zilla_muf::scan::sequential_scan;
/// // a = 1 everywhere → plain prefix sum of b
/// let h = sequential_scan(&[1.0, 1.0, 1.0], &[1.0, 2.0, 3.0], 0.0);
/// assert_eq!(h, vec![1.0, 3.0, 6.0]);
/// ```
pub fn sequential_scan<T: Float>(a: &[T], b: &[T], h0: T) -> Vec<T> {
	// A length mismatch is a caller bug, not a recoverable runtime state —
	// fail loudly rather than silently truncating to the shorter slice.
	assert_eq!(a.len(), b.len(), "a and b must be the same length");

	let mut h = h0;                            // running state, seeded with h0
	let mut out = Vec::with_capacity(a.len()); // exact size is known up front
	for i in 0..a.len() {
		h = a[i] * h + b[i]; // one recurrence step: decay the state, add input
		out.push(h);         // every intermediate state is part of the output
	}
	out
}

#[cfg(test)]
mod tests {
	use super::*;

	#[test]
	fn cumulative_sum_case() {
		// a = 1 everywhere -> plain running sum
		let a = [1.0, 1.0, 1.0, 1.0];
		let b = [1.0, 2.0, 3.0, 4.0];
		assert_eq!(sequential_scan(&a, &b, 0.0), vec![1.0, 3.0, 6.0, 10.0]);
	}
}