zilla-muf 0.1.1

Shared structured-matrix and numerical primitives for sparse attention and state space models (SSMs).
Documentation
use num_traits::Float;

#[cfg(feature = "parallel")]
use rayon::prelude::*;

/// Computes one chunk's local scan, assuming a zero incoming state.
/// Returns (cumulative decay per position, local h per position,
/// total chunk decay, final local state) — everything needed for the
/// caller to later stitch chunks together via the carry pass.
fn compute_chunk<T: Float>(a: &[T], b: &[T], start: usize, end: usize) -> (Vec<T>, Vec<T>, T, T) {
	let len = end - start;
	let mut cum_a = Vec::with_capacity(len);   // cum_a[k] = product of a over start..=start+k
	let mut local_h = Vec::with_capacity(len); // local_h[k] = scan output assuming h_in = 0

	let mut running_a = T::one();  // cumulative decay starts at the multiplicative identity
	let mut running_h = T::zero(); // local state starts at 0 — the "zero incoming state"
	for i in start..end {
		running_a = running_a * a[i];        // extend the running decay product
		running_h = a[i] * running_h + b[i]; // same recurrence as sequential_scan, locally
		cum_a.push(running_a);
		local_h.push(running_h);
	}

	// On exit: running_a is the chunk's *total* decay (how much an incoming
	// state is attenuated across the whole chunk), and running_h is the
	// chunk's output state given a zero input state. Both feed the carry pass.
	(cum_a, local_h, running_a, running_h)
}

/// Chunked scan: splits the recurrence h_t = a_t * h_{t-1} + b_t into
/// blocks of `chunk_size`, computes each block locally (assuming a
/// zero incoming state), then does a cheap sequential pass over the
/// O(n / chunk_size) chunk boundaries to propagate the true incoming
/// state and correct each block's output.
///
/// This is the chunking strategy behind Mamba-2's SSD algorithm and
/// chunked linear attention: O(n) work, with only the small per-chunk
/// boundary pass being inherently sequential.
///
/// With the `parallel` feature enabled, the per-chunk local compute and
/// the final per-position correction (both independent across chunks)
/// run via rayon. The carry-propagation pass in between stays
/// sequential regardless — that dependency is structural, not a
/// missed optimization.
///
/// # Example
///
/// ```
/// use zilla_muf::scan::{chunked_scan, sequential_scan};
/// let a = vec![0.5_f64; 8];
/// let b = vec![1.0_f64; 8];
/// let expected = sequential_scan(&a, &b, 0.0);
/// let chunked = chunked_scan(&a, &b, 0.0, 3);
/// for (e, c) in expected.iter().zip(chunked.iter()) {
///     assert!((e - c).abs() < 1e-12);
/// }
/// ```
pub fn chunked_scan<T: Float + Send + Sync>(a: &[T], b: &[T], h0: T, chunk_size: usize) -> Vec<T> {
	assert_eq!(a.len(), b.len(), "a and b must be the same length");
	assert!(chunk_size > 0, "chunk_size must be positive");

	let n = a.len();
	if n == 0 {
		return Vec::new();
	}

	let num_chunks = n.div_ceil(chunk_size);

	// Maps a chunk index k to its [start, end) range, clamping the last
	// chunk to n so an uneven final block is handled correctly.
	let chunk_bounds = |k: usize| (k * chunk_size, ((k + 1) * chunk_size).min(n));

	// Phase 1 — independent local scans, one per chunk. Embarrassingly
	// parallel: no chunk depends on another here, so rayon fans them out.
	#[cfg(feature = "parallel")]
	let chunk_results: Vec<(Vec<T>, Vec<T>, T, T)> = (0..num_chunks)
		.into_par_iter()
		.map(|k| {
			let (start, end) = chunk_bounds(k);
			compute_chunk(a, b, start, end)
		})
		.collect();

	// Serial fallback when the `parallel` feature is off — identical math.
	#[cfg(not(feature = "parallel"))]
	let chunk_results: Vec<(Vec<T>, Vec<T>, T, T)> = (0..num_chunks)
		.map(|k| {
			let (start, end) = chunk_bounds(k);
			compute_chunk(a, b, start, end)
		})
		.collect();

	// Flatten per-chunk results back into flat per-position arrays.
	// This pass is O(n) and sequential, but it's pure data movement
	// (no math), so it's cheap relative to the compute it follows.
	let mut cum_a_local: Vec<T> = Vec::with_capacity(n);    // per-position cumulative decay
	let mut local_h: Vec<T> = Vec::with_capacity(n);        // per-position zero-state output
	let mut chunk_decay: Vec<T> = Vec::with_capacity(num_chunks); // one total decay per chunk
	let mut chunk_state: Vec<T> = Vec::with_capacity(num_chunks); // one output state per chunk
	for (cum_a_chunk, h_chunk, decay, state) in chunk_results {
		cum_a_local.extend(cum_a_chunk);
		local_h.extend(h_chunk);
		chunk_decay.push(decay);
		chunk_state.push(state);
	}

	// The only sequential step: O(num_chunks), propagate true state
	// entering each chunk. Each chunk's incoming state depends on the
	// previous chunk's outgoing state, so this cannot be parallelized
	// without a different (hierarchical) algorithm.
	let mut state_in: Vec<T> = Vec::with_capacity(num_chunks);
	let mut carry = h0;
	for k in 0..num_chunks {
		state_in.push(carry);                              // state entering chunk k
		carry = chunk_decay[k] * carry + chunk_state[k];   // advance: same recurrence, at chunk granularity
	}

	// Correct each position: local value + decayed incoming state.
	// Independent per position, so this parallelizes at the finest
	// granularity rayon can exploit.
	#[cfg(feature = "parallel")]
	let out: Vec<T> = (0..n)
		.into_par_iter()
		.map(|i| {
			let k = i / chunk_size;                        // which chunk position i lives in
			local_h[i] + cum_a_local[i] * state_in[k]      // stitch the true carry back in
		})
		.collect();

	#[cfg(not(feature = "parallel"))]
	let out: Vec<T> = (0..n)
		.map(|i| {
			let k = i / chunk_size;
			local_h[i] + cum_a_local[i] * state_in[k]
		})
		.collect();

	out
}

#[cfg(all(test, feature = "parallel"))]
mod parallel_tests {
	use super::*;
	use crate::scan::sequential_scan;

	/// Only compiled/run under `cargo test --features parallel`. The
	/// default `cargo test` never exercises the rayon code path at all,
	/// so this is the only place that actually proves it's correct.
	///
	/// Uses an epsilon, not assert_eq!: chunking reassociates the
	/// floating-point operations relative to a single linear fold, so
	/// results match sequential_scan to within float rounding error,
	/// not bit-for-bit. (This was already true of plain chunked_scan
	/// before rayon entered the picture — the existing exact-equality
	/// tests above only pass because their inputs, like 0.5 and 1.0,
	/// happen not to expose it. Worth keeping in mind if you add more
	/// assert_eq! scan tests with arbitrary float inputs later.)
	#[test]
	fn parallel_path_matches_sequential() {
		let n = 500;
		let a: Vec<f64> = (0..n).map(|i| 0.5 + 0.4 * (i as f64 * 0.013).sin()).collect();
		let b: Vec<f64> = (0..n).map(|i| (i as f64 * 0.029).cos()).collect();
		let expected = sequential_scan(&a, &b, 0.0);
		for chunk_size in [1, 7, 32, 128, 500, 999] {
			let actual = chunked_scan(&a, &b, 0.0, chunk_size);
			for (e, got) in expected.iter().zip(actual.iter()) {
				assert!((e - got).abs() < 1e-9, "chunk_size={chunk_size}: expected {e}, got {got}");
			}
		}
	}
}

#[cfg(test)]
mod tests {
	use super::*;
	use crate::scan::sequential_scan;

	#[test]
	fn matches_sequential_cumsum() {
		let a = [1.0, 1.0, 1.0, 1.0];
		let b = [1.0, 2.0, 3.0, 4.0];
		let expected = sequential_scan(&a, &b, 0.0);
		for chunk_size in 1..=5 {
			assert_eq!(chunked_scan(&a, &b, 0.0, chunk_size), expected);
		}
	}

	#[test]
	fn matches_sequential_with_decay() {
		let a = [0.5, 0.5, 0.5, 0.5, 0.5];
		let b = [1.0, 1.0, 1.0, 1.0, 1.0];
		let h0 = 0.3;
		let expected = sequential_scan(&a, &b, h0);
		for chunk_size in 1..=6 {
			assert_eq!(chunked_scan(&a, &b, h0, chunk_size), expected);
		}
	}

	#[test]
	fn handles_uneven_chunks() {
		// length 7, chunk_size 3 -> chunks of 3, 3, 1
		let a = [0.9; 7];
		let b: [f64; 7] = [1.0, 0.5, 2.0, 0.0, 1.5, 3.0, 0.2];
		let expected = sequential_scan(&a, &b, 1.0);
		assert_eq!(chunked_scan(&a, &b, 1.0, 3), expected);
	}

	#[test]
	fn empty_input() {
		let a: [f64; 0] = [];
		let b: [f64; 0] = [];
		assert_eq!(chunked_scan(&a, &b, 1.0, 4), Vec::<f64>::new());
	}
}