use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::fft;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug)]
pub struct FftBackward<T: Float> {
input: Tensor<T>,
n: Option<usize>,
}
impl<T: Float> FftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>) -> Self {
Self { input, n }
}
}
impl<T: Float> GradFn<T> for FftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let inv = fft::ifft(grad_output, self.n)?;
let fft_n = grad_output.shape()[grad_output.ndim() - 2];
let scale = T::from(fft_n).unwrap();
let inv_data = inv.data_vec()?;
let scaled: Vec<T> = inv_data.iter().map(|&v| v * scale).collect();
let t = Tensor::from_storage(TensorStorage::cpu(scaled), inv.shape().to_vec(), false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FftBackward"
}
}
#[derive(Debug)]
pub struct IfftBackward<T: Float> {
input: Tensor<T>,
n: Option<usize>,
}
impl<T: Float> IfftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>) -> Self {
Self { input, n }
}
}
impl<T: Float> GradFn<T> for IfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let fwd = fft::fft(grad_output, self.n)?;
let fft_n = grad_output.shape()[grad_output.ndim() - 2];
let scale = T::from(1.0).unwrap() / T::from(fft_n).unwrap();
let fwd_data = fwd.data_vec()?;
let scaled: Vec<T> = fwd_data.iter().map(|&v| v * scale).collect();
let t = Tensor::from_storage(TensorStorage::cpu(scaled), fwd.shape().to_vec(), false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IfftBackward"
}
}
#[derive(Debug)]
pub struct RfftBackward<T: Float> {
input: Tensor<T>,
_n: Option<usize>,
fft_n: usize,
}
impl<T: Float> RfftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>, fft_n: usize) -> Self {
Self {
input,
_n: n,
fft_n,
}
}
}
impl<T: Float> GradFn<T> for RfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let go_shape = grad_output.shape();
if go_shape.len() < 2 || go_shape[go_shape.len() - 1] != 2 {
return Err(crate::error::FerrotorchError::InvalidArgument {
message: format!(
"RfftBackward: grad_output must have trailing complex pair, got {go_shape:?}"
),
});
}
let k = go_shape[go_shape.len() - 2];
let n = self.fft_n;
let batch_shape = &go_shape[..go_shape.len() - 2];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let go_data = grad_output.data_vec()?;
let mut padded = vec![T::from(0.0).unwrap(); batch_size * n * 2];
for b in 0..batch_size {
let src_off = b * k * 2;
let dst_off = b * n * 2;
let copy_pairs = k.min(n);
for kk in 0..copy_pairs {
padded[dst_off + kk * 2] = go_data[src_off + kk * 2];
padded[dst_off + kk * 2 + 1] = go_data[src_off + kk * 2 + 1];
}
}
let mut padded_shape = batch_shape.to_vec();
padded_shape.push(n);
padded_shape.push(2);
let padded_t = Tensor::from_storage(TensorStorage::cpu(padded), padded_shape, false)?;
let inv = fft::ifft(&padded_t, Some(n))?;
let inv_data = inv.data_vec()?;
let scale = T::from(n).unwrap();
let mut grad_x_data = Vec::with_capacity(batch_size * n);
for b in 0..batch_size {
for nn in 0..n {
grad_x_data.push(inv_data[b * n * 2 + nn * 2] * scale);
}
}
let mut out_shape = batch_shape.to_vec();
out_shape.push(n);
let t = Tensor::from_storage(TensorStorage::cpu(grad_x_data), out_shape, false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"RfftBackward"
}
}
#[derive(Debug)]
pub struct IrfftBackward<T: Float> {
input: Tensor<T>,
_n: Option<usize>,
output_n: usize,
}
impl<T: Float> IrfftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>, output_n: usize) -> Self {
Self {
input,
_n: n,
output_n,
}
}
}
impl<T: Float> GradFn<T> for IrfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let n = self.output_n;
let k = n / 2 + 1;
let f = fft::rfft(grad_output, Some(n))?;
let f_shape = f.shape().to_vec();
let f_data = f.data_vec()?;
let total_pairs = f_data.len() / 2;
let batch_size = total_pairs / k;
let inv_n = T::from(1.0).unwrap() / T::from(n).unwrap();
let two = T::from(2.0).unwrap();
let mut out = Vec::with_capacity(f_data.len());
for b in 0..batch_size {
for kk in 0..k {
let interior = kk > 0 && (kk < k - 1 || n % 2 == 1);
let factor = if interior { two * inv_n } else { inv_n };
out.push(f_data[(b * k + kk) * 2] * factor);
out.push(f_data[(b * k + kk) * 2 + 1] * factor);
}
}
let t = Tensor::from_storage(TensorStorage::cpu(out), f_shape, false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IrfftBackward"
}
}
pub fn fft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::fft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(FftBackward::new(input.clone(), n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn ifft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::ifft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(IfftBackward::new(input.clone(), n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn rfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let input_n = *input.shape().last().unwrap();
let fft_n = n.unwrap_or(input_n);
let result = fft::rfft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(RfftBackward::new(input.clone(), n, fft_n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn irfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let shape = input.shape();
let half_n = shape[shape.len() - 2];
let output_n = n.unwrap_or(2 * (half_n - 1));
let result = fft::irfft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(IrfftBackward::new(input.clone(), n, output_n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct FftnBackward<T: Float> {
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
norm_n: usize,
}
impl<T: Float> FftnBackward<T> {
pub fn new(
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
norm_n: usize,
) -> Self {
Self {
input,
s,
axes,
norm_n,
}
}
}
impl<T: Float> GradFn<T> for FftnBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let inv = fft::ifftn(grad_output, self.s.as_deref(), self.axes.as_deref())?;
let scale = T::from(self.norm_n).unwrap();
let inv_data = inv.data_vec()?;
let scaled: Vec<T> = inv_data.iter().map(|&v| v * scale).collect();
let t = Tensor::from_storage(TensorStorage::cpu(scaled), inv.shape().to_vec(), false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FftnBackward"
}
}
#[derive(Debug)]
pub struct IfftnBackward<T: Float> {
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
norm_n: usize,
}
impl<T: Float> IfftnBackward<T> {
pub fn new(
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
norm_n: usize,
) -> Self {
Self {
input,
s,
axes,
norm_n,
}
}
}
impl<T: Float> GradFn<T> for IfftnBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let fwd = fft::fftn(grad_output, self.s.as_deref(), self.axes.as_deref())?;
let scale = T::from(1.0).unwrap() / T::from(self.norm_n).unwrap();
let fwd_data = fwd.data_vec()?;
let scaled: Vec<T> = fwd_data.iter().map(|&v| v * scale).collect();
let t = Tensor::from_storage(TensorStorage::cpu(scaled), fwd.shape().to_vec(), false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IfftnBackward"
}
}
#[derive(Debug)]
pub struct RfftnBackward<T: Float> {
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
out_shape: Vec<usize>,
last_axis_n: usize,
last_axis_logical: usize,
norm_n: usize,
}
impl<T: Float> RfftnBackward<T> {
pub fn new(
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
out_shape: Vec<usize>,
last_axis_n: usize,
last_axis_logical: usize,
norm_n: usize,
) -> Self {
Self {
input,
s,
axes,
out_shape,
last_axis_n,
last_axis_logical,
norm_n,
}
}
}
impl<T: Float> GradFn<T> for RfftnBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let go_shape = grad_output.shape();
if go_shape != self.out_shape.as_slice() {
return Err(crate::error::FerrotorchError::InvalidArgument {
message: format!(
"RfftnBackward: grad_output shape {go_shape:?} does not match \
forward output {:?}",
self.out_shape
),
});
}
let go_data = grad_output.data_vec()?;
let mut padded_shape = self.out_shape.clone();
padded_shape[self.last_axis_logical] = self.last_axis_n;
let padded_total: usize = padded_shape.iter().product();
let mut padded = vec![T::from(0.0).unwrap(); padded_total];
let go_strides = row_major_strides(go_shape);
let pad_strides = row_major_strides(&padded_shape);
let k = go_shape[self.last_axis_logical];
for flat in 0..go_data.len() / 2 {
let mut idx = vec![0usize; go_shape.len() - 1];
let mut rem = flat;
let logical_strides = row_major_strides(&go_shape[..go_shape.len() - 1]);
for d in 0..idx.len() {
idx[d] = rem / logical_strides[d];
rem %= logical_strides[d];
}
let mut src = 0usize;
for d in 0..idx.len() {
src += idx[d] * go_strides[d];
}
let mut dst = 0usize;
for d in 0..idx.len() {
dst += idx[d] * pad_strides[d];
}
padded[dst] = go_data[src];
padded[dst + 1] = go_data[src + 1];
let _ = k; }
let padded_t = Tensor::from_storage(TensorStorage::cpu(padded), padded_shape, false)?;
let inv = fft::ifftn(&padded_t, self.s.as_deref(), self.axes.as_deref())?;
let inv_data = inv.data_vec()?;
let inv_shape = inv.shape().to_vec();
let scale = T::from(self.norm_n).unwrap();
let real_n_pairs = inv_data.len() / 2;
let mut grad_x_data = Vec::with_capacity(real_n_pairs);
for i in 0..real_n_pairs {
grad_x_data.push(inv_data[i * 2] * scale);
}
let mut out_shape = inv_shape;
let _ = out_shape.pop();
let t = Tensor::from_storage(TensorStorage::cpu(grad_x_data), out_shape, false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"RfftnBackward"
}
}
fn row_major_strides(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1usize; shape.len()];
for d in (0..shape.len().saturating_sub(1)).rev() {
strides[d] = strides[d + 1] * shape[d + 1];
}
strides
}
#[derive(Debug)]
pub struct IrfftnBackward<T: Float> {
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
last_axis_n: usize,
last_axis_logical: usize,
norm_n: usize,
}
impl<T: Float> IrfftnBackward<T> {
pub fn new(
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
last_axis_n: usize,
last_axis_logical: usize,
norm_n: usize,
) -> Self {
Self {
input,
s,
axes,
last_axis_n,
last_axis_logical,
norm_n,
}
}
}
impl<T: Float> GradFn<T> for IrfftnBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let f = fft::rfftn(grad_output, self.s.as_deref(), self.axes.as_deref())?;
let f_shape = f.shape().to_vec();
let f_data = f.data_vec()?;
let scale = T::from(1.0).unwrap() / T::from(self.norm_n).unwrap();
let two = T::from(2.0).unwrap();
let k = f_shape[self.last_axis_logical];
let n_last = self.last_axis_n;
let strides_logical = row_major_strides(&f_shape[..f_shape.len() - 1]);
let logical_total: usize = f_shape[..f_shape.len() - 1].iter().product();
let mut out = vec![T::from(0.0).unwrap(); f_data.len()];
for flat in 0..logical_total {
let mut rem = flat;
let mut idx = vec![0usize; strides_logical.len()];
for d in 0..idx.len() {
idx[d] = rem / strides_logical[d];
rem %= strides_logical[d];
}
let kk = idx[self.last_axis_logical];
let is_boundary = kk == 0 || (n_last % 2 == 0 && kk == k - 1);
let factor = if is_boundary { scale } else { two * scale };
let pair_offset = flat * 2;
out[pair_offset] = f_data[pair_offset] * factor;
out[pair_offset + 1] = f_data[pair_offset + 1] * factor;
}
let t = Tensor::from_storage(TensorStorage::cpu(out), f_shape, false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IrfftnBackward"
}
}
#[derive(Debug)]
pub struct HfftBackward<T: Float> {
input: Tensor<T>,
input_n: usize,
output_n: usize,
}
impl<T: Float> HfftBackward<T> {
pub fn new(input: Tensor<T>, input_n: usize, output_n: usize) -> Self {
Self {
input,
input_n,
output_n,
}
}
}
impl<T: Float> GradFn<T> for HfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let n = self.output_n;
let k = self.input_n; let f = fft::rfft(grad_output, Some(n))?;
let f_data = f.data_vec()?;
let f_shape = f.shape().to_vec();
let total_pairs = f_data.len() / 2;
let batch_size = total_pairs / k;
let two = T::from(2.0).unwrap();
let mut out = Vec::with_capacity(f_data.len());
for b in 0..batch_size {
for kk in 0..k {
let is_boundary = kk == 0 || (n % 2 == 0 && kk == k - 1);
let factor = if is_boundary {
T::from(1.0).unwrap()
} else {
two
};
let re = f_data[(b * k + kk) * 2];
let im = f_data[(b * k + kk) * 2 + 1];
out.push(re * factor);
out.push(-im * factor);
}
}
let t = Tensor::from_storage(TensorStorage::cpu(out), f_shape, false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"HfftBackward"
}
}
#[derive(Debug)]
pub struct IhfftBackward<T: Float> {
input: Tensor<T>,
input_n: usize,
}
impl<T: Float> IhfftBackward<T> {
pub fn new(input: Tensor<T>, input_n: usize) -> Self {
Self { input, input_n }
}
}
impl<T: Float> GradFn<T> for IhfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let n = self.input_n;
let go_shape = grad_output.shape();
if go_shape.len() < 2 || go_shape[go_shape.len() - 1] != 2 {
return Err(crate::error::FerrotorchError::InvalidArgument {
message: format!(
"IhfftBackward: grad_output must have trailing complex pair, got {go_shape:?}"
),
});
}
let k = go_shape[go_shape.len() - 2];
let batch_shape = &go_shape[..go_shape.len() - 2];
let batch_size: usize = batch_shape.iter().product::<usize>().max(1);
let go_data = grad_output.data_vec()?;
let mut padded = vec![T::from(0.0).unwrap(); batch_size * n * 2];
for b in 0..batch_size {
let src_off = b * k * 2;
let dst_off = b * n * 2;
let copy_pairs = k.min(n);
for kk in 0..copy_pairs {
let re = go_data[src_off + kk * 2];
let im = go_data[src_off + kk * 2 + 1];
padded[dst_off + kk * 2] = re;
padded[dst_off + kk * 2 + 1] = -im; }
}
let mut padded_shape = batch_shape.to_vec();
padded_shape.push(n);
padded_shape.push(2);
let padded_t = Tensor::from_storage(TensorStorage::cpu(padded), padded_shape, false)?;
let inv = fft::ifft(&padded_t, Some(n))?;
let inv_data = inv.data_vec()?;
let mut grad_x_data = Vec::with_capacity(batch_size * n);
for b in 0..batch_size {
for nn in 0..n {
grad_x_data.push(inv_data[b * n * 2 + nn * 2]);
}
}
let mut out_shape = batch_shape.to_vec();
out_shape.push(n);
let t = Tensor::from_storage(TensorStorage::cpu(grad_x_data), out_shape, false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IhfftBackward"
}
}
fn fftn_norm_n<T: Float>(input: &Tensor<T>, s: Option<&[usize]>, axes: Option<&[isize]>) -> usize {
if let Some(s_slice) = s {
return s_slice.iter().copied().product::<usize>().max(1);
}
let shape = input.shape();
let ndim = shape.len();
if let Some(axes_slice) = axes {
let mut prod: usize = 1;
for &a in axes_slice {
let logical_ndim = ndim.saturating_sub(1);
let resolved = if a < 0 {
(logical_ndim as isize + a) as usize
} else {
a as usize
};
prod = prod.saturating_mul(shape[resolved]);
}
return prod.max(1);
}
if ndim < 2 {
1
} else {
shape[..ndim - 1].iter().product::<usize>().max(1)
}
}
fn rfftn_norm_n<T: Float>(input: &Tensor<T>, s: Option<&[usize]>, axes: Option<&[isize]>) -> usize {
if let Some(s_slice) = s {
return s_slice.iter().copied().product::<usize>().max(1);
}
let shape = input.shape();
let ndim = shape.len();
if let Some(axes_slice) = axes {
let mut prod: usize = 1;
for &a in axes_slice {
let resolved = if a < 0 {
(ndim as isize + a) as usize
} else {
a as usize
};
prod = prod.saturating_mul(shape[resolved]);
}
return prod.max(1);
}
shape.iter().product::<usize>().max(1)
}
pub fn fftn_differentiable<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::fftn(input, s, axes)?;
if is_grad_enabled() && input.requires_grad() {
let norm_n = fftn_norm_n(input, s, axes);
let grad_fn = Arc::new(FftnBackward::new(
input.clone(),
s.map(|v| v.to_vec()),
axes.map(|v| v.to_vec()),
norm_n,
));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn ifftn_differentiable<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::ifftn(input, s, axes)?;
if is_grad_enabled() && input.requires_grad() {
let norm_n = fftn_norm_n(input, s, axes);
let grad_fn = Arc::new(IfftnBackward::new(
input.clone(),
s.map(|v| v.to_vec()),
axes.map(|v| v.to_vec()),
norm_n,
));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn rfftn_differentiable<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let _ = rfftn_norm_n::<T>; let result = fft::rfftn(input, s, axes)?;
if is_grad_enabled() && input.requires_grad() {
let s_back: Vec<usize> = match (s, axes) {
(Some(s_slice), _) => s_slice.to_vec(),
(None, Some(axes_slice)) => {
let shape = input.shape();
axes_slice
.iter()
.map(|&a| {
let resolved = if a < 0 {
(shape.len() as isize + a) as usize
} else {
a as usize
};
shape[resolved]
})
.collect()
}
(None, None) => input.shape().to_vec(),
};
let in_shape = input.shape();
let last_axis_logical = match axes {
Some(axes_slice) => {
let a = *axes_slice.last().unwrap();
if a < 0 {
(in_shape.len() as isize + a) as usize
} else {
a as usize
}
}
None => in_shape.len() - 1,
};
let last_axis_n = s_back
.last()
.copied()
.unwrap_or(in_shape[last_axis_logical]);
let norm_n: usize = s_back.iter().product::<usize>().max(1);
let out_shape = result.shape().to_vec();
let grad_fn = Arc::new(RfftnBackward::new(
input.clone(),
Some(s_back),
axes.map(|v| v.to_vec()),
out_shape,
last_axis_n,
last_axis_logical,
norm_n,
));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn irfftn_differentiable<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::irfftn(input, s, axes)?;
if is_grad_enabled() && input.requires_grad() {
let s_back: Vec<usize> = match s {
Some(s_slice) => s_slice.to_vec(),
None => result.shape().to_vec(),
};
let res_shape = result.shape();
let last_axis_logical_real = match axes {
Some(axes_slice) => {
let a = *axes_slice.last().unwrap();
if a < 0 {
(res_shape.len() as isize + a) as usize
} else {
a as usize
}
}
None => res_shape.len() - 1,
};
let last_axis_logical = last_axis_logical_real;
let last_axis_n = *s_back.last().unwrap_or(&res_shape[last_axis_logical_real]);
let norm_n: usize = s_back.iter().product::<usize>().max(1);
let grad_fn = Arc::new(IrfftnBackward::new(
input.clone(),
Some(s_back),
axes.map(|v| v.to_vec()),
last_axis_n,
last_axis_logical,
norm_n,
));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn hfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let shape = input.shape();
let input_n = shape[shape.len() - 2];
let result = fft::hfft(input, n)?;
let output_n = *result.shape().last().unwrap();
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(HfftBackward::new(input.clone(), input_n, output_n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn ihfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let shape = input.shape();
let input_n = n.unwrap_or(*shape.last().unwrap());
let result = fft::ihfft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(IhfftBackward::new(input.clone(), input_n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn leaf(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
}
fn no_grad_leaf(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn assert_close(actual: &[f64], expected: &[f64], tol: f64) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: {a} vs {e} (diff {})",
(a - e).abs()
);
}
}
#[test]
fn fft_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = fft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "FftBackward");
}
#[test]
fn fft_differentiable_no_grad_when_not_needed() {
let input = no_grad_leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = fft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_none());
}
#[test]
fn fft_backward_identity_check() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = fft_differentiable(&input, None).unwrap();
let grad_out = no_grad_leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[4, 2]);
let grads = result.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
let g = grads[0].as_ref().unwrap();
assert_eq!(g.shape(), &[4, 2]);
let gd = g.data().unwrap();
assert_close(gd, &[4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 1e-10);
}
#[test]
fn ifft_backward_identity_check() {
let input = leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[4, 2]);
let result = ifft_differentiable(&input, None).unwrap();
let grad_out = no_grad_leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let grads = result.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
let g = grads[0].as_ref().unwrap();
let gd = g.data().unwrap();
assert_close(gd, &[0.25, 0.0, 0.25, 0.0, 0.25, 0.0, 0.25, 0.0], 1e-10);
}
#[test]
fn rfft_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[4]);
let result = rfft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "RfftBackward");
}
#[test]
fn irfft_differentiable_attaches_grad_fn() {
let input = leaf(&[10.0, 0.0, -2.0, 2.0, -2.0, 0.0], &[3, 2]);
let result = irfft_differentiable(&input, Some(4)).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "IrfftBackward");
}
#[test]
fn no_grad_context_disables_tracking() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result =
crate::autograd::no_grad::no_grad(|| fft_differentiable(&input, None).unwrap());
assert!(result.grad_fn().is_none());
}
#[test]
fn fftn_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[2, 2, 2]);
let result = fftn_differentiable(&input, None, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "FftnBackward");
}
#[test]
fn ifftn_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[2, 2, 2]);
let result = ifftn_differentiable(&input, None, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "IfftnBackward");
}
#[test]
fn fftn_no_grad_when_not_needed() {
let input = no_grad_leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[2, 2, 2]);
let result = fftn_differentiable(&input, None, None).unwrap();
assert!(result.grad_fn().is_none());
}
#[test]
fn fftn_backward_returns_real_grad_for_impulse() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[2, 2, 2]);
let result = fftn_differentiable(&input, None, None).unwrap();
let grad_out = no_grad_leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[2, 2, 2]);
let grads = result.grad_fn().unwrap().backward(&grad_out).unwrap();
let g = grads[0].as_ref().unwrap();
assert_close(
g.data().unwrap(),
&[4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1e-9,
);
}
#[test]
fn rfftn_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let result = rfftn_differentiable(&input, None, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "RfftnBackward");
}
#[test]
fn irfftn_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[2, 2, 2]);
let result = irfftn_differentiable(&input, Some(&[2, 2]), None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "IrfftnBackward");
}
#[test]
fn hfft_differentiable_attaches_grad_fn() {
let input = leaf(&[10.0, 0.0, -2.0, 2.0, -2.0, 0.0], &[3, 2]);
let result = hfft_differentiable(&input, Some(4)).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "HfftBackward");
}
#[test]
fn ihfft_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[4]);
let result = ihfft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "IhfftBackward");
}
#[test]
fn fftn_norm_n_default_inner_dims() {
let input = no_grad_leaf(&vec![0.0; 2 * 3 * 4 * 2], &[2, 3, 4, 2]);
let n = fftn_norm_n(&input, None, None);
assert_eq!(n, 2 * 3 * 4);
}
#[test]
fn fftn_norm_n_with_explicit_s() {
let input = no_grad_leaf(&[0.0; 8 * 2], &[2, 2, 2, 2]);
let n = fftn_norm_n(&input, Some(&[3, 5]), None);
assert_eq!(n, 15);
}
#[test]
fn fftn_norm_n_with_axes() {
let input = no_grad_leaf(&vec![0.0; 2 * 3 * 4 * 2], &[2, 3, 4, 2]);
let n = fftn_norm_n(&input, None, Some(&[1, 2]));
assert_eq!(n, 12);
}
}