use crate::error::Error::RocFFT;
use crate::error::Result;
use crate::hip::{DeviceMemory, Stream};
use crate::rocfft::{
description::PlanDescription,
error,
execution::ExecutionInfo,
plan::{ArrayType, PlacementType, Plan, Precision, TransformType},
};
pub fn get_real_forward_output_length(input_lengths: &[usize]) -> Vec<usize> {
let mut output_lengths = input_lengths.to_vec();
if !output_lengths.is_empty() {
output_lengths[0] = (input_lengths[0] / 2) + 1;
}
output_lengths
}
pub unsafe fn complex_forward_transform<T>(
input: &DeviceMemory<T>,
output: Option<&DeviceMemory<T>>,
lengths: &[usize],
precision: Precision,
stream: Option<&Stream>,
) -> Result<()> {
let placement = match output {
Some(_) => PlacementType::NotInPlace,
None => PlacementType::InPlace,
};
let dimensions = lengths.len();
if dimensions < 1 || dimensions > 3 {
return Err(RocFFT(error::Error::InvalidDimensions));
}
let mut plan = Plan::new(
placement,
TransformType::ComplexForward,
precision,
dimensions,
lengths,
1, None,
)?;
let mut exec_info = match stream {
Some(s) => {
let mut info = ExecutionInfo::new()?;
unsafe { info.set_stream(s.as_raw() as *mut std::ffi::c_void) }?;
Some(info)
}
None => None,
};
let input_ptr = [input.as_ptr()];
let output_ptrs = match output {
Some(out) => vec![out.as_ptr()],
None => vec![],
};
plan.execute(&input_ptr, &output_ptrs, exec_info.as_mut())?;
Ok(())
}
pub fn complex_inverse_transform<T>(
input: &DeviceMemory<T>,
output: Option<&DeviceMemory<T>>,
lengths: &[usize],
precision: Precision,
scale: bool,
stream: Option<&Stream>,
) -> Result<()> {
let placement = match output {
Some(_) => PlacementType::NotInPlace,
None => PlacementType::InPlace,
};
let dimensions = lengths.len();
if dimensions < 1 || dimensions > 3 {
return Err(RocFFT(error::Error::InvalidDimensions));
}
let total_elements: usize = lengths.iter().product();
let scale_factor = if scale {
1.0 / total_elements as f64
} else {
1.0
};
let description = if scale {
let mut desc = PlanDescription::new()?;
desc.set_scale_factor(scale_factor)?;
Some(desc)
} else {
None
};
let mut plan = Plan::new(
placement,
TransformType::ComplexInverse,
precision,
dimensions,
lengths,
1, description.as_ref(),
)?;
let mut exec_info = match stream {
Some(s) => unsafe {
let mut info = ExecutionInfo::new()?;
info.set_stream(s.as_raw() as *mut std::ffi::c_void)?;
Some(info)
},
None => None,
};
let input_ptr = [input.as_ptr()];
let output_ptrs = match output {
Some(out) => vec![out.as_ptr()],
None => vec![],
};
plan.execute(&input_ptr, &output_ptrs, exec_info.as_mut())?;
Ok(())
}
pub fn real_forward_transform<T, U>(
input: &DeviceMemory<T>, output: &DeviceMemory<U>, lengths: &[usize],
precision: Precision,
stream: Option<&Stream>,
) -> Result<()> {
let dimensions = lengths.len();
if dimensions < 1 || dimensions > 3 {
return Err(RocFFT(error::Error::InvalidDimensions));
}
let mut plan = Plan::new(
PlacementType::NotInPlace,
TransformType::RealForward,
precision,
dimensions,
lengths,
1, None,
)?;
let mut exec_info = match stream {
Some(s) => unsafe {
let mut info = ExecutionInfo::new()?;
info.set_stream(s.as_raw() as *mut std::ffi::c_void)?;
Some(info)
},
None => None,
};
let input_ptr = [input.as_ptr()];
let output_ptr = [output.as_ptr()];
plan.execute(&input_ptr, &output_ptr, exec_info.as_mut())?;
Ok(())
}
pub fn complex_to_real_transform<T, U>(
input: &DeviceMemory<T>, output: &DeviceMemory<U>, lengths: &[usize],
precision: Precision,
scale: bool,
stream: Option<&Stream>,
) -> Result<()> {
let dimensions = lengths.len();
if dimensions < 1 || dimensions > 3 {
return Err(RocFFT(error::Error::InvalidDimensions));
}
let total_elements: usize = lengths.iter().product();
let scale_factor = if scale {
1.0 / total_elements as f64
} else {
1.0
};
let description = if scale {
let mut desc = PlanDescription::new()?;
desc.set_scale_factor(scale_factor)?;
Some(desc)
} else {
None
};
let mut plan = Plan::new(
PlacementType::NotInPlace,
TransformType::RealInverse,
precision,
dimensions,
lengths, 1, description.as_ref(),
)?;
let mut exec_info = match stream {
Some(s) => unsafe {
let mut info = ExecutionInfo::new()?;
info.set_stream(s.as_raw() as *mut std::ffi::c_void)?;
Some(info)
},
None => None,
};
let input_ptr = [input.as_ptr()];
let output_ptr = [output.as_ptr()];
plan.execute(&input_ptr, &output_ptr, exec_info.as_mut())?;
Ok(())
}
pub fn create_2d_fft_plan_with_strides(
width: usize,
height: usize,
in_row_stride: usize,
in_col_stride: usize,
transform_type: TransformType,
precision: Precision,
placement: PlacementType,
) -> Result<Plan> {
let lengths = vec![width, height];
let dimensions = 2;
let mut description = PlanDescription::new()?;
let in_strides = vec![in_row_stride, in_col_stride];
let (in_array_type, out_array_type) = match transform_type {
TransformType::RealForward => (ArrayType::Real, ArrayType::ComplexInterleaved),
TransformType::RealInverse => (ArrayType::ComplexInterleaved, ArrayType::Real),
_ => (ArrayType::ComplexInterleaved, ArrayType::ComplexInterleaved),
};
let in_distance = height * in_col_stride;
let out_strides = in_strides.clone();
let out_distance = in_distance;
description.set_data_layout(
in_array_type,
out_array_type,
Some(&[0]), Some(&[0]), Some(&in_strides),
in_distance,
Some(&out_strides),
out_distance,
)?;
let plan = Plan::new(
placement,
transform_type,
precision,
dimensions,
&lengths,
1, Some(&description),
)?;
Ok(plan)
}
pub fn fft_convolution_1d<T>(
signal: &DeviceMemory<T>,
kernel: &DeviceMemory<T>,
output: &mut DeviceMemory<T>,
precision: Precision,
stream: Option<&Stream>,
) -> Result<()>
where
T: Copy
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Neg<Output = T>
+ std::ops::Add<Output = T>,
{
let signal_size = signal.count();
let kernel_size = kernel.count();
if signal_size < kernel_size {
return Err(RocFFT(error::Error::InvalidArgValue));
}
let padded_size = signal_size + kernel_size - 1;
let lengths = vec![padded_size];
let mut padded_signal = DeviceMemory::<T>::new(padded_size * 2)?; let mut padded_kernel = DeviceMemory::<T>::new(padded_size * 2)?; let mut fft_result = DeviceMemory::<T>::new(padded_size * 2)?;
let mut forward_plan = Plan::new(
PlacementType::NotInPlace,
TransformType::ComplexForward,
precision,
1, &lengths,
1, None,
)?;
let mut desc = PlanDescription::new()?;
desc.set_scale_factor(1.0 / padded_size as f64)?;
let mut inverse_plan = Plan::new(
PlacementType::NotInPlace,
TransformType::ComplexInverse,
precision,
1, &lengths,
1, Some(&desc),
)?;
let mut exec_info = match stream {
Some(s) => unsafe {
let mut info = ExecutionInfo::new()?;
info.set_stream(s.as_raw() as *mut std::ffi::c_void)?;
Some(info)
},
None => None,
};
let mut host_padded_signal = vec![T::default(); padded_size * 2];
let mut host_padded_kernel = vec![T::default(); padded_size * 2];
let mut host_signal = vec![T::default(); signal_size];
signal.copy_to_host(&mut host_signal)?;
for i in 0..signal_size {
host_padded_signal[i * 2] = host_signal[i]; }
let mut host_kernel = vec![T::default(); kernel_size];
kernel.copy_to_host(&mut host_kernel)?;
for i in 0..kernel_size {
host_padded_kernel[i * 2] = host_kernel[i]; }
padded_signal.copy_from_host(&host_padded_signal)?;
padded_kernel.copy_from_host(&host_padded_kernel)?;
let input_ptr = [padded_signal.as_ptr()];
let kernel_ptr = [padded_kernel.as_ptr()];
let result_ptr = [fft_result.as_ptr()];
forward_plan.execute(&input_ptr, &result_ptr, exec_info.as_mut())?;
let mut host_fft_signal = vec![T::default(); padded_size * 2];
fft_result.copy_to_host(&mut host_fft_signal)?;
forward_plan.execute(&kernel_ptr, &result_ptr, exec_info.as_mut())?;
let mut host_fft_kernel = vec![T::default(); padded_size * 2];
fft_result.copy_to_host(&mut host_fft_kernel)?;
let mut host_mult_result = vec![T::default(); padded_size * 2];
for i in 0..padded_size {
let idx = i * 2;
let s_real = host_fft_signal[idx];
let s_imag = host_fft_signal[idx + 1];
let k_real = host_fft_kernel[idx];
let k_imag = host_fft_kernel[idx + 1];
host_mult_result[idx] = multiply_add(s_real, k_real, multiply_neg(s_imag, k_imag)); host_mult_result[idx + 1] = multiply_add(s_real, k_imag, multiply(s_imag, k_real)); }
fft_result.copy_from_host(&host_mult_result)?;
let ifft_result = DeviceMemory::<T>::new(padded_size * 2)?;
let ifft_ptr = [ifft_result.as_ptr()];
inverse_plan.execute(&result_ptr, &ifft_ptr, exec_info.as_mut())?;
let mut host_ifft_result = vec![T::default(); padded_size * 2];
ifft_result.copy_to_host(&mut host_ifft_result)?;
let mut host_output = vec![T::default(); signal_size];
for i in 0..signal_size {
host_output[i] = host_ifft_result[i * 2]; }
output.copy_from_host(&host_output)?;
Ok(())
}
fn multiply<T: std::ops::Mul<Output = T>>(a: T, b: T) -> T {
a * b
}
fn multiply_neg<T: std::ops::Mul<Output = T> + std::ops::Neg<Output = T>>(a: T, b: T) -> T {
-multiply(a, b)
}
fn multiply_add<T: std::ops::Mul<Output = T> + std::ops::Add<Output = T>>(a: T, b: T, c: T) -> T {
multiply(a, b) + c
}