use crate::dft::problem::Sign;
use crate::kernel::{Complex, Float};
use crate::prelude::*;
use super::ct::{CooleyTukeySolver, CtVariant};
const CACHE_OBLIVIOUS_THRESHOLD: usize = 1024;
const TRANSPOSE_BLOCK_SIZE: usize = 64;
pub struct CacheObliviousSolver<T: Float> {
base_solver: CooleyTukeySolver<T>,
_marker: core::marker::PhantomData<T>,
}
impl<T: Float> Default for CacheObliviousSolver<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Float> CacheObliviousSolver<T> {
#[must_use]
pub fn new() -> Self {
Self {
base_solver: CooleyTukeySolver::new(CtVariant::Dit),
_marker: core::marker::PhantomData,
}
}
#[must_use]
pub fn name(&self) -> &'static str {
"dft-cache-oblivious"
}
#[must_use]
pub fn applicable(n: usize) -> bool {
n.is_power_of_two() && n >= CACHE_OBLIVIOUS_THRESHOLD
}
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let n = input.len();
debug_assert_eq!(n, output.len());
debug_assert!(n.is_power_of_two(), "Size must be power of 2");
if n < CACHE_OBLIVIOUS_THRESHOLD {
self.base_solver.execute(input, output, sign);
return;
}
let (n1, n2) = balanced_factorization(n);
let mut scratch = vec![Complex::<T>::zero(); n];
scratch.copy_from_slice(input);
self.execute_column_ffts(&mut scratch, n1, n2, sign);
apply_twiddle_factors(&mut scratch, n1, n2, n, sign);
self.execute_row_ffts(&scratch, output, n1, n2, sign);
transpose_output(output, n1, n2);
}
pub fn execute_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
let n = data.len();
if n < CACHE_OBLIVIOUS_THRESHOLD {
self.base_solver.execute_inplace(data, sign);
return;
}
let input: Vec<Complex<T>> = data.to_vec();
self.execute(&input, data, sign);
}
fn execute_column_ffts(&self, matrix: &mut [Complex<T>], n1: usize, n2: usize, sign: Sign) {
let mut col_buf = vec![Complex::<T>::zero(); n1];
let mut col_out = vec![Complex::<T>::zero(); n1];
for j in 0..n2 {
for i in 0..n1 {
col_buf[i] = matrix[i * n2 + j];
}
if n1 >= CACHE_OBLIVIOUS_THRESHOLD {
self.execute(&col_buf, &mut col_out, sign);
} else {
self.base_solver.execute(&col_buf, &mut col_out, sign);
}
for i in 0..n1 {
matrix[i * n2 + j] = col_out[i];
}
}
}
fn execute_row_ffts(
&self,
matrix: &[Complex<T>],
output: &mut [Complex<T>],
n1: usize,
n2: usize,
sign: Sign,
) {
let mut row_out = vec![Complex::<T>::zero(); n2];
for i in 0..n1 {
let row_start = i * n2;
let row_end = row_start + n2;
let row_in = &matrix[row_start..row_end];
if n2 >= CACHE_OBLIVIOUS_THRESHOLD {
self.execute(row_in, &mut row_out, sign);
} else {
self.base_solver.execute(row_in, &mut row_out, sign);
}
output[row_start..row_end].copy_from_slice(&row_out);
}
}
}
fn balanced_factorization(n: usize) -> (usize, usize) {
debug_assert!(n.is_power_of_two());
let log_n = n.trailing_zeros();
let log_n1 = log_n / 2;
let n1 = 1usize << log_n1;
let n2 = n / n1;
(n1, n2)
}
fn apply_twiddle_factors<T: Float>(
matrix: &mut [Complex<T>],
n1: usize,
n2: usize,
n: usize,
sign: Sign,
) {
let sign_val = T::from_isize(sign.value() as isize);
let two_pi_over_n = T::TWO_PI / T::from_usize(n);
for i in 0..n1 {
if i == 0 {
continue;
}
let row_start = i * n2;
for j in 1..n2 {
let angle = sign_val * two_pi_over_n * T::from_usize(i) * T::from_usize(j);
let twiddle = Complex::cis(angle);
matrix[row_start + j] = matrix[row_start + j] * twiddle;
}
}
}
fn transpose_output<T: Float>(data: &mut [Complex<T>], n1: usize, n2: usize) {
if n1 == n2 {
transpose_square_blocked(data, n1);
} else {
transpose_rectangular(data, n1, n2);
}
}
fn transpose_square_blocked<T: Float>(data: &mut [Complex<T>], n: usize) {
let block = TRANSPOSE_BLOCK_SIZE.min(n);
let mut bi = 0;
while bi < n {
let bi_end = (bi + block).min(n);
let mut bj = bi;
while bj < n {
let bj_end = (bj + block).min(n);
if bi == bj {
for i in bi..bi_end {
for j in (i + 1)..bj_end {
let idx_ij = i * n + j;
let idx_ji = j * n + i;
data.swap(idx_ij, idx_ji);
}
}
} else {
for i in bi..bi_end {
for j in bj..bj_end {
let idx_ij = i * n + j;
let idx_ji = j * n + i;
data.swap(idx_ij, idx_ji);
}
}
}
bj += block;
}
bi += block;
}
}
fn transpose_rectangular<T: Float>(data: &mut [Complex<T>], n1: usize, n2: usize) {
let total = n1 * n2;
let mut temp = vec![Complex::<T>::zero(); total];
let block = TRANSPOSE_BLOCK_SIZE.min(n1.min(n2));
let mut bi = 0;
while bi < n1 {
let bi_end = (bi + block).min(n1);
let mut bj = 0;
while bj < n2 {
let bj_end = (bj + block).min(n2);
for i in bi..bi_end {
for j in bj..bj_end {
temp[j * n1 + i] = data[i * n2 + j];
}
}
bj += block;
}
bi += block;
}
data[..total].copy_from_slice(&temp);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dft::solvers::direct::DirectSolver;
fn complex_approx_eq<T: Float>(a: Complex<T>, b: Complex<T>, tol: f64) -> bool {
let dr = num_traits::Float::abs(a.re - b.re);
let di = num_traits::Float::abs(a.im - b.im);
dr < T::from_f64(tol) && di < T::from_f64(tol)
}
fn max_abs_error<T: Float>(a: &[Complex<T>], b: &[Complex<T>]) -> f64 {
let mut max_err = 0.0_f64;
for (x, y) in a.iter().zip(b.iter()) {
let dr = num_traits::Float::abs((*x - *y).re);
let di = num_traits::Float::abs((*x - *y).im);
let err_r = num_traits::ToPrimitive::to_f64(&dr).unwrap_or(f64::MAX);
let err_i = num_traits::ToPrimitive::to_f64(&di).unwrap_or(f64::MAX);
if err_r > max_err {
max_err = err_r;
}
if err_i > max_err {
max_err = err_i;
}
}
max_err
}
#[test]
fn test_balanced_factorization() {
assert_eq!(balanced_factorization(1024), (32, 32));
assert_eq!(balanced_factorization(2048), (32, 64));
assert_eq!(balanced_factorization(4096), (64, 64));
assert_eq!(balanced_factorization(8192), (64, 128));
assert_eq!(balanced_factorization(16384), (128, 128));
}
#[test]
fn test_transpose_square() {
let mut data: Vec<Complex<f64>> =
(0..16).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let original = data.clone();
transpose_square_blocked(&mut data, 4);
for i in 0..4 {
for j in 0..4 {
assert!(
complex_approx_eq(data[j * 4 + i], original[i * 4 + j], 1e-15),
"Transpose failed at ({i}, {j})"
);
}
}
}
#[test]
fn test_transpose_rectangular() {
let mut data: Vec<Complex<f64>> = (0..16)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let original = data.clone();
transpose_rectangular(&mut data, 2, 8);
for i in 0..2 {
for j in 0..8 {
assert!(
complex_approx_eq(data[j * 2 + i], original[i * 8 + j], 1e-15),
"Transpose failed at ({i}, {j})"
);
}
}
}
#[test]
fn test_cache_oblivious_vs_direct_1024() {
let n = 1024;
let input: Vec<Complex<f64>> = (0..n)
.map(|i| {
let t = core::f64::consts::TAU * (i as f64) / (n as f64);
Complex::new(t.sin(), t.cos())
})
.collect();
let solver = CacheObliviousSolver::<f64>::new();
let direct = DirectSolver::<f64>::new();
let mut output_co = vec![Complex::zero(); n];
let mut output_direct = vec![Complex::zero(); n];
solver.execute(&input, &mut output_co, Sign::Forward);
direct.execute(&input, &mut output_direct, Sign::Forward);
let err = max_abs_error(&output_co, &output_direct);
assert!(
err < 1e-6,
"Cache-oblivious vs direct error too large: {err} for N={n}"
);
}
#[test]
fn test_cache_oblivious_vs_direct_2048() {
let n = 2048;
let input: Vec<Complex<f64>> = (0..n)
.map(|i| {
let t = core::f64::consts::TAU * (i as f64) / (n as f64);
Complex::new(t.sin(), (2.0 * t).cos())
})
.collect();
let solver = CacheObliviousSolver::<f64>::new();
let direct = DirectSolver::<f64>::new();
let mut output_co = vec![Complex::zero(); n];
let mut output_direct = vec![Complex::zero(); n];
solver.execute(&input, &mut output_co, Sign::Forward);
direct.execute(&input, &mut output_direct, Sign::Forward);
let err = max_abs_error(&output_co, &output_direct);
assert!(
err < 1e-6,
"Cache-oblivious vs direct error too large: {err} for N={n}"
);
}
#[test]
fn test_cache_oblivious_vs_direct_4096() {
let n = 4096;
let input: Vec<Complex<f64>> = (0..n)
.map(|i| {
let t = core::f64::consts::TAU * (i as f64) / (n as f64);
Complex::new(t.sin(), (3.0 * t).cos())
})
.collect();
let solver = CacheObliviousSolver::<f64>::new();
let direct = DirectSolver::<f64>::new();
let mut output_co = vec![Complex::zero(); n];
let mut output_direct = vec![Complex::zero(); n];
solver.execute(&input, &mut output_co, Sign::Forward);
direct.execute(&input, &mut output_direct, Sign::Forward);
let err = max_abs_error(&output_co, &output_direct);
assert!(
err < 1e-5,
"Cache-oblivious vs direct error too large: {err} for N={n}"
);
}
#[test]
fn test_cache_oblivious_vs_ct_8192() {
let n = 8192;
let input: Vec<Complex<f64>> = (0..n)
.map(|i| {
let t = core::f64::consts::TAU * (i as f64) / (n as f64);
Complex::new(t.sin(), (5.0 * t).cos())
})
.collect();
let solver = CacheObliviousSolver::<f64>::new();
let ct_solver = CooleyTukeySolver::<f64>::new(CtVariant::Dit);
let mut output_co = vec![Complex::zero(); n];
let mut output_ct = vec![Complex::zero(); n];
solver.execute(&input, &mut output_co, Sign::Forward);
ct_solver.execute(&input, &mut output_ct, Sign::Forward);
let err = max_abs_error(&output_co, &output_ct);
assert!(
err < 1e-6,
"Cache-oblivious vs CT error too large: {err} for N={n}"
);
}
#[test]
fn test_cache_oblivious_round_trip_1024() {
let n = 1024;
let original: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new(f64::from(i as u32), f64::from(i as u32) * 0.5))
.collect();
let solver = CacheObliviousSolver::<f64>::new();
let mut transformed = vec![Complex::zero(); n];
let mut recovered = vec![Complex::zero(); n];
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let scale = 1.0 / n as f64;
for x in &mut recovered {
*x = Complex::new(x.re * scale, x.im * scale);
}
let err = max_abs_error(&original, &recovered);
assert!(err < 1e-9, "Round-trip error too large: {err} for N={n}");
}
#[test]
fn test_cache_oblivious_round_trip_4096() {
let n = 4096;
let original: Vec<Complex<f64>> = (0..n)
.map(|i| {
let t = core::f64::consts::TAU * (i as f64) / (n as f64);
Complex::new(t.sin(), t.cos())
})
.collect();
let solver = CacheObliviousSolver::<f64>::new();
let mut transformed = vec![Complex::zero(); n];
let mut recovered = vec![Complex::zero(); n];
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let scale = 1.0 / n as f64;
for x in &mut recovered {
*x = Complex::new(x.re * scale, x.im * scale);
}
let err = max_abs_error(&original, &recovered);
assert!(err < 1e-8, "Round-trip error too large: {err} for N={n}");
}
#[test]
fn test_cache_oblivious_round_trip_8192() {
let n = 8192;
let original: Vec<Complex<f64>> = (0..n)
.map(|i| {
let t = core::f64::consts::TAU * (i as f64) / (n as f64);
Complex::new(t.sin(), (7.0 * t).cos())
})
.collect();
let solver = CacheObliviousSolver::<f64>::new();
let mut transformed = vec![Complex::zero(); n];
let mut recovered = vec![Complex::zero(); n];
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let scale = 1.0 / n as f64;
for x in &mut recovered {
*x = Complex::new(x.re * scale, x.im * scale);
}
let err = max_abs_error(&original, &recovered);
assert!(err < 1e-7, "Round-trip error too large: {err} for N={n}");
}
#[test]
fn test_cache_oblivious_inplace_matches_outofplace() {
let n = 1024;
let input: Vec<Complex<f64>> = (0..n)
.map(|i| {
let t = core::f64::consts::TAU * (i as f64) / (n as f64);
Complex::new(t.sin(), t.cos())
})
.collect();
let solver = CacheObliviousSolver::<f64>::new();
let mut out_of_place = vec![Complex::zero(); n];
solver.execute(&input, &mut out_of_place, Sign::Forward);
let mut in_place = input;
solver.execute_inplace(&mut in_place, Sign::Forward);
let err = max_abs_error(&out_of_place, &in_place);
assert!(err < 1e-15, "In-place vs out-of-place error: {err}");
}
#[test]
fn test_cache_oblivious_f32() {
let n = 1024;
let input: Vec<Complex<f32>> = (0..n)
.map(|i| {
let t = core::f32::consts::TAU * (i as f32) / (n as f32);
Complex::new(t.sin(), t.cos())
})
.collect();
let solver = CacheObliviousSolver::<f32>::new();
let ct_solver = CooleyTukeySolver::<f32>::new(CtVariant::Dit);
let mut output_co = vec![Complex::zero(); n];
let mut output_ct = vec![Complex::zero(); n];
solver.execute(&input, &mut output_co, Sign::Forward);
ct_solver.execute(&input, &mut output_ct, Sign::Forward);
let err = max_abs_error(&output_co, &output_ct);
assert!(
err < 1e-2,
"f32 cache-oblivious vs CT error too large: {err}"
);
}
#[test]
fn test_below_threshold_falls_back() {
let n = 512;
let input: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new(f64::from(i as u32), 0.0))
.collect();
let solver = CacheObliviousSolver::<f64>::new();
let ct_solver = CooleyTukeySolver::<f64>::new(CtVariant::Dit);
let mut output_co = vec![Complex::zero(); n];
let mut output_ct = vec![Complex::zero(); n];
solver.execute(&input, &mut output_co, Sign::Forward);
ct_solver.execute(&input, &mut output_ct, Sign::Forward);
let err = max_abs_error(&output_co, &output_ct);
assert!(err < 1e-10, "Below-threshold fallback error: {err}");
}
#[test]
fn test_twiddle_factors_identity() {
let n = 4;
let n1 = 2;
let n2 = 2;
let mut matrix: Vec<Complex<f64>> = vec![Complex::new(1.0, 0.0); n];
apply_twiddle_factors(&mut matrix, n1, n2, n, Sign::Forward);
assert!(complex_approx_eq(matrix[0], Complex::new(1.0, 0.0), 1e-15));
assert!(complex_approx_eq(matrix[1], Complex::new(1.0, 0.0), 1e-15));
assert!(complex_approx_eq(matrix[2], Complex::new(1.0, 0.0), 1e-15));
let expected = Complex::cis(-core::f64::consts::TAU / 4.0);
assert!(
complex_approx_eq(matrix[3], expected, 1e-14),
"Twiddle at (1,1): got {:?}, expected {:?}",
matrix[3],
expected
);
}
}