Skip to main content

zilla_muf/scan/
sequential.rs

1use num_traits::Float;
2
3/// Sequential (recurrent) scan — the correctness baseline that every
4/// other scan in the crate is tested against.
5///
6/// Evaluates the first-order linear recurrence
7///   `h_t = a_t * h_{t-1} + b_t`,   with `h_{-1} = h0`
8/// left-to-right, returning the full state sequence `[h_0, ..., h_{n-1}]`.
9///
10/// - `a`: per-step decay / transition coefficients
11/// - `b`: per-step inputs (must be the same length as `a`)
12/// - `h0`: the initial state fed in before the first step
13///
14/// Setting `a_t = 1` everywhere collapses this into a plain prefix sum
15/// (running total) of `b` — the easiest case to eyeball, which is exactly
16/// what the test below checks.
17///
18/// Cost: O(n), one multiply-add per element. Inherently sequential: each
19/// step needs the previous `h`, so it can't be parallelized as-is.
20/// `chunked_scan` exists to break that dependency for large inputs; this
21/// function is the reference it must match. Generic over the float type
22/// `T` (f32 for speed, f64 for numerical testing).
23///
24/// # Example
25///
26/// ```
27/// use zilla_muf::scan::sequential_scan;
28/// // a = 1 everywhere → plain prefix sum of b
29/// let h = sequential_scan(&[1.0, 1.0, 1.0], &[1.0, 2.0, 3.0], 0.0);
30/// assert_eq!(h, vec![1.0, 3.0, 6.0]);
31/// ```
32pub fn sequential_scan<T: Float>(a: &[T], b: &[T], h0: T) -> Vec<T> {
33	// A length mismatch is a caller bug, not a recoverable runtime state —
34	// fail loudly rather than silently truncating to the shorter slice.
35	assert_eq!(a.len(), b.len(), "a and b must be the same length");
36
37	let mut h = h0;                            // running state, seeded with h0
38	let mut out = Vec::with_capacity(a.len()); // exact size is known up front
39	for i in 0..a.len() {
40		h = a[i] * h + b[i]; // one recurrence step: decay the state, add input
41		out.push(h);         // every intermediate state is part of the output
42	}
43	out
44}
45
46#[cfg(test)]
47mod tests {
48	use super::*;
49
50	#[test]
51	fn cumulative_sum_case() {
52		// a = 1 everywhere -> plain running sum
53		let a = [1.0, 1.0, 1.0, 1.0];
54		let b = [1.0, 2.0, 3.0, 4.0];
55		assert_eq!(sequential_scan(&a, &b, 0.0), vec![1.0, 3.0, 6.0, 10.0]);
56	}
57}