use num_traits::Float;
use crate::scan::chunked_scan;
pub fn semiseparable_matvec<T: Float + Send + Sync>(
a: &[T],
b: &[T],
c: &[T],
x: &[T],
rank: usize,
chunk_size: usize,
) -> Vec<T> {
let n = a.len();
assert_eq!(b.len(), n * rank, "b must be n * rank");
assert_eq!(c.len(), n * rank, "c must be n * rank");
assert_eq!(x.len(), n, "x must be length n");
let mut h = vec![T::zero(); n * rank];
for r in 0..rank {
let b_r: Vec<T> = (0..n).map(|t| b[t * rank + r] * x[t]).collect();
let h_r = chunked_scan(a, &b_r, T::zero(), chunk_size);
for t in 0..n {
h[t * rank + r] = h_r[t];
}
}
(0..n)
.map(|i| (0..rank).fold(T::zero(), |acc, r| acc + c[i * rank + r] * h[i * rank + r]))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn dense_reference(a: &[f64], b: &[f64], c: &[f64], x: &[f64], rank: usize) -> Vec<f64> {
let n = a.len();
let mut y = vec![0.0; n];
for i in 0..n {
for j in 0..=i {
let mut decay = 1.0;
for k in (j + 1)..=i {
decay *= a[k];
}
let mut m_ij = 0.0;
for r in 0..rank {
m_ij += c[i * rank + r] * decay * b[j * rank + r];
}
y[i] += m_ij * x[j];
}
}
y
}
#[test]
fn matches_dense_reference() {
let n = 6;
let rank = 3;
let a: Vec<f64> = vec![0.9, 0.8, 0.95, 0.7, 0.85, 0.6];
let b: Vec<f64> = (0..n * rank).map(|i| (i as f64 * 0.37).sin()).collect();
let c: Vec<f64> = (0..n * rank).map(|i| (i as f64 * 0.53).cos()).collect();
let x: Vec<f64> = (0..n).map(|i| i as f64 + 1.0).collect();
let expected = dense_reference(&a, &b, &c, &x, rank);
for chunk_size in [1, 2, 3, 6, 10] {
let actual = semiseparable_matvec(&a, &b, &c, &x, rank, chunk_size);
for (e, got) in expected.iter().zip(actual.iter()) {
assert!((e - got).abs() < 1e-9, "expected {e}, got {got}");
}
}
}
}