#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ConvMode {
Full,
Same,
Valid,
}
pub fn convolve<T: Float>(a: &[T], b: &[T]) -> Vec<T> {
convolve_mode(a, b, ConvMode::Full)
}
pub fn convolve_mode<T: Float>(a: &[T], b: &[T], mode: ConvMode) -> Vec<T> {
convolve_with_mode(a, b, mode)
}
pub fn convolve_with_mode<T: Float>(a: &[T], b: &[T], mode: ConvMode) -> Vec<T> {
if a.is_empty() || b.is_empty() {
return Vec::new();
}
if a.len() < 32 && b.len() < 32 {
return convolve_direct(a, b, mode);
}
let full_len = a.len() + b.len() - 1;
let fft_len = full_len.next_power_of_two();
let mut a_complex = vec![Complex::<T>::zero(); fft_len];
let mut b_complex = vec![Complex::<T>::zero(); fft_len];
for (i, &val) in a.iter().enumerate() {
a_complex[i] = Complex::new(val, T::ZERO);
}
for (i, &val) in b.iter().enumerate() {
b_complex[i] = Complex::new(val, T::ZERO);
}
let Some(fft_plan) = Plan::dft_1d(fft_len, Direction::Forward, Flags::ESTIMATE) else {
return convolve_direct(a, b, mode);
};
let Some(ifft_plan) = Plan::dft_1d(fft_len, Direction::Backward, Flags::ESTIMATE) else {
return convolve_direct(a, b, mode);
};
let mut a_fft = vec![Complex::<T>::zero(); fft_len];
let mut b_fft = vec![Complex::<T>::zero(); fft_len];
fft_plan.execute(&a_complex, &mut a_fft);
fft_plan.execute(&b_complex, &mut b_fft);
let mut product = vec![Complex::<T>::zero(); fft_len];
for i in 0..fft_len {
product[i] = a_fft[i] * b_fft[i];
}
let mut result_complex = vec![Complex::<T>::zero(); fft_len];
ifft_plan.execute(&product, &mut result_complex);
let scale = T::ONE / T::from_usize(fft_len);
let full_result: Vec<T> = result_complex
.iter()
.take(full_len)
.map(|c| c.re * scale)
.collect();
extract_mode(&full_result, a.len(), b.len(), mode)
}
pub fn convolve_complex<T: Float>(a: &[Complex<T>], b: &[Complex<T>]) -> Vec<Complex<T>> {
convolve_complex_mode(a, b, ConvMode::Full)
}
pub fn convolve_complex_mode<T: Float>(
a: &[Complex<T>],
b: &[Complex<T>],
mode: ConvMode,
) -> Vec<Complex<T>> {
if a.is_empty() || b.is_empty() {
return Vec::new();
}
let full_len = a.len() + b.len() - 1;
let fft_len = full_len.next_power_of_two();
let mut a_padded = vec![Complex::<T>::zero(); fft_len];
let mut b_padded = vec![Complex::<T>::zero(); fft_len];
a_padded[..a.len()].copy_from_slice(a);
b_padded[..b.len()].copy_from_slice(b);
let fft_plan = match Plan::dft_1d(fft_len, Direction::Forward, Flags::ESTIMATE) {
Some(p) => p,
None => return convolve_complex_direct(a, b, mode),
};
let ifft_plan = match Plan::dft_1d(fft_len, Direction::Backward, Flags::ESTIMATE) {
Some(p) => p,
None => return convolve_complex_direct(a, b, mode),
};
let mut a_fft = vec![Complex::<T>::zero(); fft_len];
let mut b_fft = vec![Complex::<T>::zero(); fft_len];
fft_plan.execute(&a_padded, &mut a_fft);
fft_plan.execute(&b_padded, &mut b_fft);
for i in 0..fft_len {
a_fft[i] = a_fft[i] * b_fft[i];
}
let mut result = vec![Complex::<T>::zero(); fft_len];
ifft_plan.execute(&a_fft, &mut result);
let scale = T::ONE / T::from_usize(fft_len);
for c in &mut result {
*c = Complex::new(c.re * scale, c.im * scale);
}
let full_result: Vec<Complex<T>> = result.into_iter().take(full_len).collect();
extract_mode_complex(&full_result, a.len(), b.len(), mode)
}
pub fn convolve_circular<T: Float>(a: &[T], b: &[T]) -> Vec<T> {
let n = a.len().max(b.len());
let mut a_padded = vec![T::ZERO; n];
let mut b_padded = vec![T::ZERO; n];
for (i, &val) in a.iter().enumerate() {
a_padded[i] = val;
}
for (i, &val) in b.iter().enumerate() {
b_padded[i] = val;
}
let a_complex: Vec<Complex<T>> = a_padded.iter().map(|&x| Complex::new(x, T::ZERO)).collect();
let b_complex: Vec<Complex<T>> = b_padded.iter().map(|&x| Complex::new(x, T::ZERO)).collect();
let fft_plan = match Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE) {
Some(p) => p,
None => return convolve_circular_direct(&a_padded, &b_padded),
};
let ifft_plan = match Plan::dft_1d(n, Direction::Backward, Flags::ESTIMATE) {
Some(p) => p,
None => return convolve_circular_direct(&a_padded, &b_padded),
};
let mut a_fft = vec![Complex::<T>::zero(); n];
let mut b_fft = vec![Complex::<T>::zero(); n];
fft_plan.execute(&a_complex, &mut a_fft);
fft_plan.execute(&b_complex, &mut b_fft);
for i in 0..n {
a_fft[i] = a_fft[i] * b_fft[i];
}
let mut result = vec![Complex::<T>::zero(); n];
ifft_plan.execute(&a_fft, &mut result);
let scale = T::ONE / T::from_usize(n);
result.iter().map(|c| c.re * scale).collect()
}
pub fn correlate<T: Float>(a: &[T], b: &[T]) -> Vec<T> {
correlate_mode(a, b, ConvMode::Full)
}
pub fn correlate_mode<T: Float>(a: &[T], b: &[T], mode: ConvMode) -> Vec<T> {
if b.is_empty() {
return Vec::new();
}
let b_reversed: Vec<T> = b.iter().rev().copied().collect();
convolve_with_mode(a, &b_reversed, mode)
}
pub fn correlate_complex<T: Float>(a: &[Complex<T>], b: &[Complex<T>]) -> Vec<Complex<T>> {
correlate_complex_mode(a, b, ConvMode::Full)
}
pub fn correlate_complex_mode<T: Float>(
a: &[Complex<T>],
b: &[Complex<T>],
mode: ConvMode,
) -> Vec<Complex<T>> {
if b.is_empty() {
return Vec::new();
}
let b_conj_rev: Vec<Complex<T>> = b.iter().rev().map(|c| c.conj()).collect();
convolve_complex_mode(a, &b_conj_rev, mode)
}
pub fn polynomial_multiply<T: Float>(a: &[T], b: &[T]) -> Vec<T> {
convolve(a, b)
}
pub fn polynomial_multiply_complex<T: Float>(
a: &[Complex<T>],
b: &[Complex<T>],
) -> Vec<Complex<T>> {
convolve_complex(a, b)
}
pub fn polynomial_power<T: Float>(p: &[T], n: u32) -> Vec<T> {
if n == 0 {
return vec![T::ONE];
}
if n == 1 {
return p.to_vec();
}
if p.is_empty() {
return Vec::new();
}
let mut result = vec![T::ONE];
let mut base = p.to_vec();
let mut exp = n;
while exp > 0 {
if exp & 1 == 1 {
result = polynomial_multiply(&result, &base);
}
base = polynomial_multiply(&base, &base);
exp >>= 1;
}
result
}
fn convolve_direct<T: Float>(a: &[T], b: &[T], mode: ConvMode) -> Vec<T> {
let full_len = a.len() + b.len() - 1;
let mut result = vec![T::ZERO; full_len];
for (i, &ai) in a.iter().enumerate() {
for (j, &bj) in b.iter().enumerate() {
result[i + j] = result[i + j] + ai * bj;
}
}
extract_mode(&result, a.len(), b.len(), mode)
}
fn convolve_complex_direct<T: Float>(
a: &[Complex<T>],
b: &[Complex<T>],
mode: ConvMode,
) -> Vec<Complex<T>> {
let full_len = a.len() + b.len() - 1;
let mut result = vec![Complex::<T>::zero(); full_len];
for (i, &ai) in a.iter().enumerate() {
for (j, &bj) in b.iter().enumerate() {
result[i + j] = result[i + j] + ai * bj;
}
}
extract_mode_complex(&result, a.len(), b.len(), mode)
}
fn convolve_circular_direct<T: Float>(a: &[T], b: &[T]) -> Vec<T> {
let n = a.len();
let mut result = vec![T::ZERO; n];
for (i, r) in result.iter_mut().enumerate() {
for j in 0..n {
let b_idx = (n + i - j) % n;
*r = *r + a[j] * b[b_idx];
}
}
result
}
fn extract_mode<T: Clone>(full: &[T], a_len: usize, b_len: usize, mode: ConvMode) -> Vec<T> {
match mode {
ConvMode::Full => full.to_vec(),
ConvMode::Same => {
let start = (b_len - 1) / 2;
let len = a_len.max(b_len);
full[start..start + len].to_vec()
}
ConvMode::Valid => {
let len = a_len.max(b_len) - a_len.min(b_len) + 1;
let start = a_len.min(b_len) - 1;
full[start..start + len].to_vec()
}
}
}
fn extract_mode_complex<T: Float>(
full: &[Complex<T>],
a_len: usize,
b_len: usize,
mode: ConvMode,
) -> Vec<Complex<T>> {
match mode {
ConvMode::Full => full.to_vec(),
ConvMode::Same => {
let start = (b_len - 1) / 2;
let len = a_len.max(b_len);
full[start..start + len].to_vec()
}
ConvMode::Valid => {
let len = a_len.max(b_len) - a_len.min(b_len) + 1;
let start = a_len.min(b_len) - 1;
full[start..start + len].to_vec()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_convolve_simple() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![0.0, 1.0, 0.5];
let result = convolve(&a, &b);
assert_eq!(result.len(), 5);
assert!(approx_eq(result[0], 0.0, 1e-10));
assert!(approx_eq(result[1], 1.0, 1e-10));
assert!(approx_eq(result[2], 2.5, 1e-10));
assert!(approx_eq(result[3], 4.0, 1e-10));
assert!(approx_eq(result[4], 1.5, 1e-10));
}
#[test]
fn test_polynomial_multiply() {
let p1 = vec![1.0, 2.0];
let p2 = vec![3.0, 4.0];
let result = polynomial_multiply(&p1, &p2);
assert_eq!(result.len(), 3);
assert!(approx_eq(result[0], 3.0, 1e-10));
assert!(approx_eq(result[1], 10.0, 1e-10));
assert!(approx_eq(result[2], 8.0, 1e-10));
}
#[test]
fn test_polynomial_power() {
let p = vec![1.0, 1.0];
let result = polynomial_power(&p, 2);
assert_eq!(result.len(), 3);
assert!(approx_eq(result[0], 1.0, 1e-10));
assert!(approx_eq(result[1], 2.0, 1e-10));
assert!(approx_eq(result[2], 1.0, 1e-10));
}
#[test]
fn test_polynomial_power_cubic() {
let p = vec![1.0, 1.0];
let result = polynomial_power(&p, 3);
assert_eq!(result.len(), 4);
assert!(approx_eq(result[0], 1.0, 1e-10));
assert!(approx_eq(result[1], 3.0, 1e-10));
assert!(approx_eq(result[2], 3.0, 1e-10));
assert!(approx_eq(result[3], 1.0, 1e-10));
}
#[test]
fn test_correlate() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![0.0, 1.0, 2.0];
let corr = correlate(&a, &b);
let b_rev = vec![2.0, 1.0, 0.0];
let conv = convolve(&a, &b_rev);
for (c, v) in corr.iter().zip(conv.iter()) {
assert!(approx_eq(*c, *v, 1e-10));
}
}
#[test]
fn test_circular_convolution() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![1.0, 0.0, 0.0, 0.0];
let result = convolve_circular(&a, &b);
for (r, &expected) in result.iter().zip(a.iter()) {
assert!(approx_eq(*r, expected, 1e-10));
}
}
#[test]
fn test_convolve_empty() {
let a: Vec<f64> = vec![];
let b = vec![1.0, 2.0];
let result = convolve(&a, &b);
assert!(result.is_empty());
}
#[test]
fn test_convolve_mode_same() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![1.0, 1.0, 1.0];
let result = convolve_with_mode(&a, &b, ConvMode::Same);
assert_eq!(result.len(), 5);
}
#[test]
fn test_convolve_mode_valid() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![1.0, 1.0, 1.0];
let result = convolve_with_mode(&a, &b, ConvMode::Valid);
assert_eq!(result.len(), 3);
}
}