use crate::error::{FFTError, FFTResult};
use crate::oxifft_plan_cache;
use oxifft::{Complex as OxiComplex, Direction, Flags};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::Complex;
pub fn rfft_oxifft(input: &ArrayView1<f64>) -> FFTResult<Array1<Complex<f64>>> {
let n = input.len();
let input_vec: Vec<f64> = input.to_vec();
let output_len = n / 2 + 1;
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); output_len];
oxifft_plan_cache::execute_r2c(&input_vec, &mut output)?;
let result: Vec<Complex<f64>> = output.iter().map(|c| Complex::new(c.re, c.im)).collect();
Ok(Array1::from_vec(result))
}
pub fn fft_oxifft(input: &ArrayView1<Complex<f64>>) -> FFTResult<Array1<Complex<f64>>> {
let n = input.len();
let input_oxifft: Vec<OxiComplex<f64>> =
input.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); n];
oxifft_plan_cache::execute_c2c(&input_oxifft, &mut output, Direction::Forward)?;
let result: Vec<Complex<f64>> = output.iter().map(|c| Complex::new(c.re, c.im)).collect();
Ok(Array1::from_vec(result))
}
pub fn ifft_oxifft(input: &ArrayView1<Complex<f64>>) -> FFTResult<Array1<Complex<f64>>> {
let n = input.len();
let input_oxifft: Vec<OxiComplex<f64>> =
input.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); n];
oxifft_plan_cache::execute_c2c(&input_oxifft, &mut output, Direction::Backward)?;
let scale = 1.0 / (n as f64);
let result: Vec<Complex<f64>> = output
.iter()
.map(|c| Complex::new(c.re * scale, c.im * scale))
.collect();
Ok(Array1::from_vec(result))
}
pub fn irfft_oxifft(input: &ArrayView1<Complex<f64>>, n: usize) -> FFTResult<Array1<f64>> {
let input_oxifft: Vec<OxiComplex<f64>> =
input.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<f64> = vec![0.0; n];
oxifft_plan_cache::execute_c2r(&input_oxifft, &mut output, n)?;
let scale = 1.0 / (n as f64);
let result: Vec<f64> = output.iter().map(|&x| x * scale).collect();
Ok(Array1::from_vec(result))
}
pub fn fft2_oxifft(input: &ArrayView2<Complex<f64>>) -> FFTResult<Array2<Complex<f64>>> {
let (rows, cols) = input.dim();
let n = rows * cols;
let input_oxifft: Vec<OxiComplex<f64>> =
input.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); n];
oxifft_plan_cache::execute_c2c_2d(&input_oxifft, &mut output, rows, cols, Direction::Forward)?;
let result: Vec<Complex<f64>> = output.iter().map(|c| Complex::new(c.re, c.im)).collect();
Array2::from_shape_vec((rows, cols), result)
.map_err(|e| FFTError::ComputationError(format!("Failed to reshape result: {:?}", e)))
}
pub fn ifft2_oxifft(input: &ArrayView2<Complex<f64>>) -> FFTResult<Array2<Complex<f64>>> {
let (rows, cols) = input.dim();
let n = rows * cols;
let input_oxifft: Vec<OxiComplex<f64>> =
input.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); n];
oxifft_plan_cache::execute_c2c_2d(&input_oxifft, &mut output, rows, cols, Direction::Backward)?;
let scale = 1.0 / (n as f64);
let result: Vec<Complex<f64>> = output
.iter()
.map(|c| Complex::new(c.re * scale, c.im * scale))
.collect();
Array2::from_shape_vec((rows, cols), result)
.map_err(|e| FFTError::ComputationError(format!("Failed to reshape result: {:?}", e)))
}
pub fn rfft2_oxifft(input: &ArrayView2<f64>) -> FFTResult<Array2<Complex<f64>>> {
let (rows, cols) = input.dim();
let input_vec: Vec<f64> = input.iter().cloned().collect();
let out_cols = cols / 2 + 1;
let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); rows * out_cols];
oxifft_plan_cache::execute_r2c_2d(&input_vec, &mut output, rows, cols)?;
let result: Vec<Complex<f64>> = output.iter().map(|c| Complex::new(c.re, c.im)).collect();
Array2::from_shape_vec((rows, out_cols), result)
.map_err(|e| FFTError::ComputationError(format!("Failed to reshape result: {:?}", e)))
}
pub fn irfft2_oxifft(
input: &ArrayView2<Complex<f64>>,
shape: (usize, usize),
) -> FFTResult<Array2<f64>> {
let (rows, cols) = shape;
let (in_rows, in_cols) = input.dim();
if in_rows != rows || in_cols != cols / 2 + 1 {
return Err(FFTError::ValueError(format!(
"Input shape ({}, {}) doesn't match expected ({}, {}) for output shape ({}, {})",
in_rows,
in_cols,
rows,
cols / 2 + 1,
rows,
cols
)));
}
let input_oxifft: Vec<OxiComplex<f64>> =
input.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
let mut output: Vec<f64> = vec![0.0; rows * cols];
oxifft_plan_cache::execute_c2r_2d(&input_oxifft, &mut output, rows, cols)?;
let scale = 1.0 / ((rows * cols) as f64);
let result: Vec<f64> = output.iter().map(|&x| x * scale).collect();
Array2::from_shape_vec((rows, cols), result)
.map_err(|e| FFTError::ComputationError(format!("Failed to reshape result: {:?}", e)))
}
pub fn dct2_oxifft(input: &ArrayView1<f64>) -> FFTResult<Array1<f64>> {
let n = input.len();
let input_vec: Vec<f64> = input.to_vec();
let mut output: Vec<f64> = vec![0.0; n];
oxifft_plan_cache::execute_dct2(&input_vec, &mut output)?;
Ok(Array1::from_vec(output))
}
pub fn idct2_oxifft(input: &ArrayView1<f64>) -> FFTResult<Array1<f64>> {
let n = input.len();
let input_vec: Vec<f64> = input.to_vec();
let mut output: Vec<f64> = vec![0.0; n];
oxifft_plan_cache::execute_idct2(&input_vec, &mut output)?;
let scale = 1.0 / (2.0 * n as f64);
let result: Vec<f64> = output.iter().map(|&x| x * scale).collect();
Ok(Array1::from_vec(result))
}
pub fn dst2_oxifft(input: &ArrayView1<f64>) -> FFTResult<Array1<f64>> {
let n = input.len();
let input_vec: Vec<f64> = input.to_vec();
let mut output: Vec<f64> = vec![0.0; n];
oxifft_plan_cache::execute_dst2(&input_vec, &mut output)?;
Ok(Array1::from_vec(output))
}
pub fn idst2_oxifft(input: &ArrayView1<f64>) -> FFTResult<Array1<f64>> {
let n = input.len();
let input_vec: Vec<f64> = input.to_vec();
let mut output: Vec<f64> = vec![0.0; n];
oxifft_plan_cache::execute_idst2(&input_vec, &mut output)?;
let scale = 1.0 / (2.0 * n as f64);
let result: Vec<f64> = output.iter().map(|&x| x * scale).collect();
Ok(Array1::from_vec(result))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_fft_oxifft_basic() {
let input = array![
Complex::new(1.0, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0)
];
let result = fft_oxifft(&input.view()).expect("FFT failed");
assert_eq!(result.len(), 4);
assert_relative_eq!(result[0].re, 10.0, epsilon = 1e-10);
assert_relative_eq!(result[0].im, 0.0, epsilon = 1e-10);
}
#[test]
fn test_ifft_oxifft_roundtrip() {
let input = array![
Complex::new(1.0, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0)
];
let fft_result = fft_oxifft(&input.view()).expect("FFT failed");
let ifft_result = ifft_oxifft(&fft_result.view()).expect("IFFT failed");
for (orig, recovered) in input.iter().zip(ifft_result.iter()) {
assert_relative_eq!(orig.re, recovered.re, epsilon = 1e-10);
assert_relative_eq!(orig.im, recovered.im, epsilon = 1e-10);
}
}
#[test]
fn test_rfft_oxifft_basic() {
let input = array![1.0, 2.0, 3.0, 4.0];
let result = rfft_oxifft(&input.view()).expect("RFFT failed");
assert_eq!(result.len(), 3);
}
#[test]
fn test_irfft_oxifft_roundtrip() {
let input = array![1.0, 2.0, 3.0, 4.0];
let rfft_result = rfft_oxifft(&input.view()).expect("RFFT failed");
let irfft_result = irfft_oxifft(&rfft_result.view(), 4).expect("IRFFT failed");
for (orig, recovered) in input.iter().zip(irfft_result.iter()) {
assert_relative_eq!(*orig, *recovered, epsilon = 1e-10);
}
}
}