use super::DTypeSupport;
use crate::algorithm::fft::{FftAlgorithms, FftNormalization};
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{BinaryOps, ComplexOps, IndexingOps, ReduceOps, ShapeOps, UtilityOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::Tensor;
const FFT_THRESHOLD: usize = 64;
pub fn convolve_impl<R, C>(
client: &C,
a: &Tensor<R>,
b: &Tensor<R>,
dtype_support: DTypeSupport,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ BinaryOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ UtilityOps<R>
+ ReduceOps<R>
+ FftAlgorithms<R>
+ ComplexOps<R>,
{
let n_a = a.shape()[0];
let n_b = b.shape()[0];
let dtype = a.dtype();
let device = client.device();
if n_a == 0 || n_b == 0 {
return Ok(Tensor::zeros(&[0], dtype, device));
}
if n_a * n_b < FFT_THRESHOLD {
convolve_direct(client, a, b, dtype_support)
} else {
convolve_fft(client, a, b, dtype_support)
}
}
fn convolve_direct<R, C>(
client: &C,
a: &Tensor<R>,
b: &Tensor<R>,
dtype_support: DTypeSupport,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ BinaryOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ UtilityOps<R>
+ ReduceOps<R>,
{
let n_a = a.shape()[0];
let n_b = b.shape()[0];
let dtype = a.dtype();
let device = client.device();
let out_len = n_a + n_b - 1;
let a_col = a.reshape(&[n_a, 1])?;
let b_row = b.reshape(&[1, n_b])?;
let outer = client.mul(&a_col, &b_row)?;
let index_dtype = dtype_support.index_dtype;
let i_indices = client.arange(0.0, n_a as f64, 1.0, index_dtype)?;
let j_indices = client.arange(0.0, n_b as f64, 1.0, index_dtype)?;
let i_col = i_indices.reshape(&[n_a, 1])?;
let j_row = j_indices.reshape(&[1, n_b])?;
let out_indices = client.add(&i_col, &j_row)?;
let outer_flat = outer.reshape(&[n_a * n_b])?;
let indices_flat = out_indices.reshape(&[n_a * n_b])?;
let output = Tensor::zeros(&[out_len], dtype, device);
client.scatter_reduce(
&output,
0,
&indices_flat,
&outer_flat,
crate::ops::ScatterReduceOp::Sum,
true, )
}
fn convolve_fft<R, C>(
client: &C,
a: &Tensor<R>,
b: &Tensor<R>,
dtype_support: DTypeSupport,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ BinaryOps<R>
+ ShapeOps<R>
+ UtilityOps<R>
+ IndexingOps<R>
+ FftAlgorithms<R>
+ ComplexOps<R>,
{
let n_a = a.shape()[0];
let n_b = b.shape()[0];
let dtype = a.dtype();
let out_len = n_a + n_b - 1;
let fft_size = out_len.next_power_of_two();
dtype_support.check(dtype, "convolve_fft")?;
let a_padded = if n_a < fft_size {
let pad_amount = fft_size - n_a;
client.pad(a, &[0, pad_amount], 0.0)?
} else {
a.clone()
};
let b_padded = if n_b < fft_size {
let pad_amount = fft_size - n_b;
client.pad(b, &[0, pad_amount], 0.0)?
} else {
b.clone()
};
let a_fft = client.rfft(&a_padded, FftNormalization::None)?;
let b_fft = client.rfft(&b_padded, FftNormalization::None)?;
let c_fft = complex_mul(client, &a_fft, &b_fft)?;
let c_full = client.irfft(&c_fft, Some(fft_size), FftNormalization::Backward)?;
if out_len < fft_size {
let indices = client.arange(0.0, out_len as f64, 1.0, dtype_support.index_dtype)?;
client.index_select(&c_full, 0, &indices)
} else {
Ok(c_full)
}
}
fn complex_mul<R, C>(client: &C, a: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: BinaryOps<R>,
{
client.mul(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn create_client() -> (CpuClient, CpuDevice) {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
(client, device)
}
#[test]
fn test_convolve_direct_simple() {
let (client, device) = create_client();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);
let c = convolve_direct(&client, &a, &b, DTypeSupport::FULL).unwrap();
let data: Vec<f32> = c.to_vec();
assert_eq!(data.len(), 3);
assert!((data[0] - 1.0).abs() < 1e-6);
assert!((data[1] - 2.0).abs() < 1e-6);
assert!((data[2] - 1.0).abs() < 1e-6);
}
#[test]
fn test_convolve_direct_asymmetric() {
let (client, device) = create_client();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0], &[2], &device);
let c = convolve_direct(&client, &a, &b, DTypeSupport::FULL).unwrap();
let data: Vec<f32> = c.to_vec();
assert_eq!(data.len(), 3);
assert!((data[0] - 3.0).abs() < 1e-6);
assert!((data[1] - 10.0).abs() < 1e-6);
assert!((data[2] - 8.0).abs() < 1e-6);
}
#[test]
fn test_convolve_fft_simple() {
let (client, device) = create_client();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);
let c = convolve_fft(&client, &a, &b, DTypeSupport::FULL).unwrap();
let data: Vec<f32> = c.to_vec();
assert_eq!(data.len(), 3);
assert!((data[0] - 1.0).abs() < 1e-5);
assert!((data[1] - 2.0).abs() < 1e-5);
assert!((data[2] - 1.0).abs() < 1e-5);
}
#[test]
fn test_convolve_impl_selects_direct() {
let (client, device) = create_client();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0], &[2], &device);
let c = convolve_impl(&client, &a, &b, DTypeSupport::FULL).unwrap();
let data: Vec<f32> = c.to_vec();
assert_eq!(data.len(), 3);
assert!((data[0] - 3.0).abs() < 1e-6);
assert!((data[1] - 10.0).abs() < 1e-6);
assert!((data[2] - 8.0).abs() < 1e-6);
}
#[test]
fn test_convolve_impl_selects_fft() {
let (client, device) = create_client();
let a_data: Vec<f32> = (0..10).map(|i| i as f32).collect();
let b_data: Vec<f32> = (0..10).map(|i| i as f32 + 1.0).collect();
let a = Tensor::<CpuRuntime>::from_slice(&a_data, &[10], &device);
let b = Tensor::<CpuRuntime>::from_slice(&b_data, &[10], &device);
let c = convolve_impl(&client, &a, &b, DTypeSupport::FULL).unwrap();
assert_eq!(c.shape()[0], 19);
}
#[test]
fn test_convolve_direct_vs_fft_equivalence() {
let (client, device) = create_client();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[5.0f32, 6.0, 7.0, 8.0], &[4], &device);
let c_direct = convolve_direct(&client, &a, &b, DTypeSupport::FULL).unwrap();
let c_fft = convolve_fft(&client, &a, &b, DTypeSupport::FULL).unwrap();
let direct_data: Vec<f32> = c_direct.to_vec();
let fft_data: Vec<f32> = c_fft.to_vec();
assert_eq!(direct_data.len(), fft_data.len());
for (i, (d, f)) in direct_data.iter().zip(fft_data.iter()).enumerate() {
assert!(
(d - f).abs() < 1e-4,
"Mismatch at index {}: direct={}, fft={}",
i,
d,
f
);
}
}
#[test]
fn test_convolve_f64() {
let (client, device) = create_client();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f64, 2.0, 3.0], &[3], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[4.0f64, 5.0], &[2], &device);
let c = convolve_impl(&client, &a, &b, DTypeSupport::FULL).unwrap();
let data: Vec<f64> = c.to_vec();
assert_eq!(data.len(), 4);
assert!((data[0] - 4.0).abs() < 1e-12);
assert!((data[1] - 13.0).abs() < 1e-12);
assert!((data[2] - 22.0).abs() < 1e-12);
assert!((data[3] - 15.0).abs() < 1e-12);
}
#[test]
fn test_convolve_empty() {
let (client, device) = create_client();
let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
let b = Tensor::<CpuRuntime>::from_slice(&[] as &[f32], &[0], &device);
let c = convolve_impl(&client, &a, &b, DTypeSupport::FULL).unwrap();
assert_eq!(c.shape()[0], 0);
}
}