use crate::half_precision::{bf16, f16};
use crate::{Result, Tensor, TensorError};
use num_complex::Complex;
use oxifft::{Direction, Flags, Plan};
use std::sync::Arc;
#[inline]
fn to_oxifft_complex<T: oxifft::Float>(data: &[Complex<T>]) -> &[oxifft::kernel::Complex<T>] {
unsafe {
std::slice::from_raw_parts(
data.as_ptr() as *const oxifft::kernel::Complex<T>,
data.len(),
)
}
}
#[inline]
fn to_oxifft_complex_mut<T: oxifft::Float>(
data: &mut [Complex<T>],
) -> &mut [oxifft::kernel::Complex<T>] {
unsafe {
std::slice::from_raw_parts_mut(
data.as_mut_ptr() as *mut oxifft::kernel::Complex<T>,
data.len(),
)
}
}
pub fn fft_f16(input: &Tensor<f16>) -> Result<Tensor<Complex<f16>>> {
let shape = input.shape().dims();
if shape.is_empty() {
return Err(TensorError::invalid_shape_simple(
"Empty tensor shape".to_string(),
));
}
let n = shape[shape.len() - 1];
let input_f32 = convert_f16_to_f32_tensor(input)?;
let fft = Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::invalid_shape_simple("Failed to create FFT plan for f16".to_string())
})?;
let output_f32 = execute_optimized_fft_1d(&input_f32, &fft, n)?;
convert_complex_f32_to_f16_tensor(&output_f32, shape)
}
pub fn ifft_f16(input: &Tensor<Complex<f16>>) -> Result<Tensor<Complex<f16>>> {
let shape = input.shape().dims();
if shape.is_empty() {
return Err(TensorError::invalid_shape_simple(
"Empty tensor shape".to_string(),
));
}
let n = shape[shape.len() - 1];
let input_f32 = convert_complex_f16_to_f32_tensor(input)?;
let ifft = Plan::dft_1d(n, Direction::Backward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::invalid_shape_simple("Failed to create IFFT plan for f16".to_string())
})?;
let output_f32 = execute_optimized_ifft_1d(&input_f32, &ifft, n)?;
convert_complex_f32_to_f16_tensor(&output_f32, shape)
}
pub fn fft_bf16(input: &Tensor<bf16>) -> Result<Tensor<Complex<bf16>>> {
let shape = input.shape().dims();
if shape.is_empty() {
return Err(TensorError::invalid_shape_simple(
"Empty tensor shape".to_string(),
));
}
let n = shape[shape.len() - 1];
let input_f32 = convert_bf16_to_f32_tensor(input)?;
let fft = Plan::dft_1d(n, Direction::Forward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::invalid_shape_simple("Failed to create FFT plan for bf16".to_string())
})?;
let output_f32 = execute_optimized_fft_1d(&input_f32, &fft, n)?;
convert_complex_f32_to_bf16_tensor(&output_f32, shape)
}
pub fn ifft_bf16(input: &Tensor<Complex<bf16>>) -> Result<Tensor<Complex<bf16>>> {
let shape = input.shape().dims();
if shape.is_empty() {
return Err(TensorError::invalid_shape_simple(
"Empty tensor shape".to_string(),
));
}
let n = shape[shape.len() - 1];
let input_f32 = convert_complex_bf16_to_f32_tensor(input)?;
let ifft = Plan::dft_1d(n, Direction::Backward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::invalid_shape_simple("Failed to create IFFT plan for bf16".to_string())
})?;
let output_f32 = execute_optimized_ifft_1d(&input_f32, &ifft, n)?;
convert_complex_f32_to_bf16_tensor(&output_f32, shape)
}
pub fn fft2_f16(input: &Tensor<f16>) -> Result<Tensor<Complex<f16>>> {
let shape = input.shape().dims();
if shape.len() < 2 {
return Err(TensorError::invalid_shape_simple(
"2D FFT requires at least 2 dimensions".to_string(),
));
}
let (rows, cols) = (shape[shape.len() - 2], shape[shape.len() - 1]);
let input_f32 = convert_f16_to_f32_tensor(input)?;
let output_f32 = execute_optimized_fft_2d(&input_f32, rows, cols)?;
convert_complex_f32_to_f16_tensor(&output_f32, shape)
}
pub fn ifft2_f16(input: &Tensor<Complex<f16>>) -> Result<Tensor<Complex<f16>>> {
let shape = input.shape().dims();
if shape.len() < 2 {
return Err(TensorError::invalid_shape_simple(
"2D IFFT requires at least 2 dimensions".to_string(),
));
}
let (rows, cols) = (shape[shape.len() - 2], shape[shape.len() - 1]);
let input_f32 = convert_complex_f16_to_f32_tensor(input)?;
let output_f32 = execute_optimized_ifft_2d(&input_f32, rows, cols)?;
convert_complex_f32_to_f16_tensor(&output_f32, shape)
}
pub fn fft2_bf16(input: &Tensor<bf16>) -> Result<Tensor<Complex<bf16>>> {
let shape = input.shape().dims();
if shape.len() < 2 {
return Err(TensorError::invalid_shape_simple(
"2D FFT requires at least 2 dimensions".to_string(),
));
}
let (rows, cols) = (shape[shape.len() - 2], shape[shape.len() - 1]);
let input_f32 = convert_bf16_to_f32_tensor(input)?;
let output_f32 = execute_optimized_fft_2d(&input_f32, rows, cols)?;
convert_complex_f32_to_bf16_tensor(&output_f32, shape)
}
pub fn ifft2_bf16(input: &Tensor<Complex<bf16>>) -> Result<Tensor<Complex<bf16>>> {
let shape = input.shape().dims();
if shape.len() < 2 {
return Err(TensorError::invalid_shape_simple(
"2D IFFT requires at least 2 dimensions".to_string(),
));
}
let (rows, cols) = (shape[shape.len() - 2], shape[shape.len() - 1]);
let input_f32 = convert_complex_bf16_to_f32_tensor(input)?;
let output_f32 = execute_optimized_ifft_2d(&input_f32, rows, cols)?;
convert_complex_f32_to_bf16_tensor(&output_f32, shape)
}
fn convert_f16_to_f32_tensor(input: &Tensor<f16>) -> Result<Tensor<f32>> {
let data: Vec<f32> = input
.data()
.to_vec()
.iter()
.map(|&x| f32::from(x))
.collect();
Tensor::from_data(data, input.shape().dims())
}
fn convert_bf16_to_f32_tensor(input: &Tensor<bf16>) -> Result<Tensor<f32>> {
let data: Vec<f32> = input
.data()
.to_vec()
.iter()
.map(|&x| f32::from(x))
.collect();
Tensor::from_data(data, input.shape().dims())
}
fn convert_complex_f16_to_f32_tensor(input: &Tensor<Complex<f16>>) -> Result<Tensor<Complex<f32>>> {
let data: Vec<Complex<f32>> = input
.data()
.to_vec()
.iter()
.map(|&x| Complex::new(f32::from(x.re), f32::from(x.im)))
.collect();
Tensor::from_data(data, input.shape().dims())
}
fn convert_complex_bf16_to_f32_tensor(
input: &Tensor<Complex<bf16>>,
) -> Result<Tensor<Complex<f32>>> {
let data: Vec<Complex<f32>> = input
.data()
.to_vec()
.iter()
.map(|&x| Complex::new(f32::from(x.re), f32::from(x.im)))
.collect();
Tensor::from_data(data, input.shape().dims())
}
fn convert_complex_f32_to_f16_tensor(
input: &Tensor<Complex<f32>>,
output_shape: &[usize],
) -> Result<Tensor<Complex<f16>>> {
let data: Vec<Complex<f16>> = input
.data()
.to_vec()
.iter()
.map(|&x| Complex::new(f16::from_f32(x.re), f16::from_f32(x.im)))
.collect();
Tensor::from_data(data, output_shape)
}
fn convert_complex_f32_to_bf16_tensor(
input: &Tensor<Complex<f32>>,
output_shape: &[usize],
) -> Result<Tensor<Complex<bf16>>> {
let data: Vec<Complex<bf16>> = input
.data()
.to_vec()
.iter()
.map(|&x| Complex::new(bf16::from_f32(x.re), bf16::from_f32(x.im)))
.collect();
Tensor::from_data(data, output_shape)
}
fn execute_optimized_fft_1d(
input: &Tensor<f32>,
fft: &Plan<f32>,
n: usize,
) -> Result<Tensor<Complex<f32>>> {
let mut input_data: Vec<Complex<f32>> = input
.data()
.to_vec()
.iter()
.map(|&x| Complex::new(x, 0.0))
.collect();
let mut output_data = vec![Complex::new(0.0, 0.0); n];
fft.execute(
to_oxifft_complex(&input_data),
to_oxifft_complex_mut(&mut output_data),
);
Tensor::from_data(output_data, &[n])
}
fn execute_optimized_ifft_1d(
input: &Tensor<Complex<f32>>,
ifft: &Plan<f32>,
n: usize,
) -> Result<Tensor<Complex<f32>>> {
let mut input_data: Vec<Complex<f32>> = input.data().to_vec().to_vec();
let mut output_data = vec![Complex::new(0.0, 0.0); n];
ifft.execute(
to_oxifft_complex(&input_data),
to_oxifft_complex_mut(&mut output_data),
);
let n_inv = 1.0 / (n as f32);
for sample in &mut output_data {
*sample *= n_inv;
}
Tensor::from_data(output_data, &[n])
}
fn execute_optimized_fft_2d(
input: &Tensor<f32>,
rows: usize,
cols: usize,
) -> Result<Tensor<Complex<f32>>> {
let mut data: Vec<Complex<f32>> = input
.data()
.to_vec()
.iter()
.map(|&x| Complex::new(x, 0.0))
.collect();
let fft_cols = Plan::dft_1d(cols, Direction::Forward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::invalid_shape_simple("Failed to create column FFT plan".to_string())
})?;
let fft_rows = Plan::dft_1d(rows, Direction::Forward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::invalid_shape_simple("Failed to create row FFT plan".to_string())
})?;
for row in 0..rows {
let start = row * cols;
let end = start + cols;
let mut row_input = data[start..end].to_vec();
let mut row_output = vec![Complex::new(0.0, 0.0); cols];
fft_cols.execute(
to_oxifft_complex(&row_input),
to_oxifft_complex_mut(&mut row_output),
);
data[start..end].copy_from_slice(&row_output);
}
let mut col_input = vec![Complex::new(0.0, 0.0); rows];
let mut col_output = vec![Complex::new(0.0, 0.0); rows];
for col in 0..cols {
for row in 0..rows {
col_input[row] = data[row * cols + col];
}
fft_rows.execute(
to_oxifft_complex(&col_input),
to_oxifft_complex_mut(&mut col_output),
);
for row in 0..rows {
data[row * cols + col] = col_output[row];
}
}
Tensor::from_data(data, &[rows, cols])
}
fn execute_optimized_ifft_2d(
input: &Tensor<Complex<f32>>,
rows: usize,
cols: usize,
) -> Result<Tensor<Complex<f32>>> {
let mut data: Vec<Complex<f32>> = input.data().to_vec().to_vec();
let ifft_cols = Plan::dft_1d(cols, Direction::Backward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::invalid_shape_simple("Failed to create column IFFT plan".to_string())
})?;
let ifft_rows = Plan::dft_1d(rows, Direction::Backward, Flags::ESTIMATE).ok_or_else(|| {
TensorError::invalid_shape_simple("Failed to create row IFFT plan".to_string())
})?;
let mut col_input = vec![Complex::new(0.0, 0.0); rows];
let mut col_output = vec![Complex::new(0.0, 0.0); rows];
for col in 0..cols {
for row in 0..rows {
col_input[row] = data[row * cols + col];
}
ifft_rows.execute(
to_oxifft_complex(&col_input),
to_oxifft_complex_mut(&mut col_output),
);
for row in 0..rows {
data[row * cols + col] = col_output[row];
}
}
for row in 0..rows {
let start = row * cols;
let end = start + cols;
let mut row_input = data[start..end].to_vec();
let mut row_output = vec![Complex::new(0.0, 0.0); cols];
ifft_cols.execute(
to_oxifft_complex(&row_input),
to_oxifft_complex_mut(&mut row_output),
);
data[start..end].copy_from_slice(&row_output);
}
let norm_factor = 1.0 / ((rows * cols) as f32);
for sample in &mut data {
*sample *= norm_factor;
}
Tensor::from_data(data, &[rows, cols])
}