zilla-muf 0.1.1

Shared structured-matrix and numerical primitives for sparse attention and state space models (SSMs).
Documentation
// fft_conv.rs
use rustfft::{num_complex::Complex, FftNum, FftPlanner};

/// FFT-based causal convolution — computes the exact same quantity as
/// `toeplitz_matvec` (the long-convolution view of an SSM kernel applied
/// to a signal), but in O(n log n) instead of O(n^2) via the standard
/// "pad, FFT, multiply, inverse FFT, truncate" approach.
///
/// `signal` and `kernel` must be the same length n. Pads to the next
/// power of two at or above `2n - 1` to avoid circular-convolution
/// wraparound corrupting the result, then keeps only the first n
/// outputs (the causal portion).
pub fn fft_conv<T: FftNum>(signal: &[T], kernel: &[T]) -> Vec<T> {
	let n = signal.len();
	assert_eq!(kernel.len(), n, "signal and kernel must be the same length");
	if n == 0 {
		return Vec::new();
	}

	// Linear convolution of two length-n sequences has 2n-1 outputs.
	// FFT multiplication is *circular*, so we pad up to at least that
	// length (and to a power of two, which rustfft handles fastest) to
	// stop the tail from wrapping around and polluting early samples.
	let conv_len = 2 * n - 1;
	let fft_len = conv_len.next_power_of_two();

	// Plan both directions once; planners cache twiddle factors internally.
	let mut planner = FftPlanner::<T>::new();
	let fft = planner.plan_fft_forward(fft_len);
	let ifft = planner.plan_fft_inverse(fft_len);

	// Lift the real signal into complex buffers, zero-padded to fft_len.
	let mut signal_buf: Vec<Complex<T>> = signal
		.iter()
		.map(|&v| Complex::new(v, T::zero()))
		.chain(std::iter::repeat(Complex::new(T::zero(), T::zero())))
		.take(fft_len)
		.collect();

	let mut kernel_buf: Vec<Complex<T>> = kernel
		.iter()
		.map(|&v| Complex::new(v, T::zero()))
		.chain(std::iter::repeat(Complex::new(T::zero(), T::zero())))
		.take(fft_len)
		.collect();

	// Convolution theorem: transform both to the frequency domain...
	fft.process(&mut signal_buf);
	fft.process(&mut kernel_buf);

	// ...where convolution becomes pointwise multiplication.
	for (s, k) in signal_buf.iter_mut().zip(kernel_buf.iter()) {
		*s = *s * *k;
	}

	// Back to the time domain to recover the convolution.
	ifft.process(&mut signal_buf);

	// rustfft's inverse is unnormalized, so divide by fft_len; take the
	// real part (imag is rounding noise) and keep only the first n samples
	// — the causal portion that lines up with toeplitz_matvec's output.
	let scale = T::from_usize(fft_len).expect("fft_len fits in T");
	signal_buf.iter().take(n).map(|c| c.re / scale).collect()
}

#[cfg(test)]
mod tests {
	use super::*;
	use crate::structured::toeplitz_matvec;

	#[test]
	fn matches_toeplitz_matvec_reference() {
		let n = 17; // deliberately not a power of two, to exercise padding
		let signal: Vec<f64> = (0..n).map(|i| (i as f64 * 0.23).sin()).collect();
		let kernel: Vec<f64> = (0..n).map(|i| (i as f64 * 0.41).cos() * 0.8_f64.powi(i as i32)).collect();

		let expected = toeplitz_matvec(&kernel, &signal);
		let actual = fft_conv(&signal, &kernel);

		for (e, got) in expected.iter().zip(actual.iter()) {
			assert!((e - got).abs() < 1e-9, "expected {e}, got {got}");
		}
	}

	#[test]
	fn impulse_kernel_passes_signal_through() {
		let n = 5;
		let mut kernel = vec![0.0; n];
		kernel[0] = 1.0;
		let signal: Vec<f64> = (0..n).map(|i| i as f64 + 1.0).collect();
		let result = fft_conv(&signal, &kernel);
		for (r, s) in result.iter().zip(signal.iter()) {
			assert!((r - s).abs() < 1e-9);
		}
	}

	#[test]
	fn empty_input() {
		let result: Vec<f64> = fft_conv(&[], &[]);
		assert!(result.is_empty());
	}
}