use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::fft;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug)]
pub struct FftBackward<T: Float> {
input: Tensor<T>,
n: Option<usize>,
}
impl<T: Float> FftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>) -> Self {
Self { input, n }
}
}
impl<T: Float> GradFn<T> for FftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let inv = fft::ifft(grad_output, self.n)?;
let fft_n = grad_output.shape()[grad_output.ndim() - 2];
let scale = T::from(fft_n).unwrap();
let inv_data = inv.data_vec()?;
let scaled: Vec<T> = inv_data.iter().map(|&v| v * scale).collect();
let t = Tensor::from_storage(TensorStorage::cpu(scaled), inv.shape().to_vec(), false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FftBackward"
}
}
#[derive(Debug)]
pub struct IfftBackward<T: Float> {
input: Tensor<T>,
n: Option<usize>,
}
impl<T: Float> IfftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>) -> Self {
Self { input, n }
}
}
impl<T: Float> GradFn<T> for IfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let fwd = fft::fft(grad_output, self.n)?;
let fft_n = grad_output.shape()[grad_output.ndim() - 2];
let scale = T::from(1.0).unwrap() / T::from(fft_n).unwrap();
let fwd_data = fwd.data_vec()?;
let scaled: Vec<T> = fwd_data.iter().map(|&v| v * scale).collect();
let t = Tensor::from_storage(TensorStorage::cpu(scaled), fwd.shape().to_vec(), false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IfftBackward"
}
}
#[derive(Debug)]
pub struct RfftBackward<T: Float> {
input: Tensor<T>,
_n: Option<usize>,
fft_n: usize,
}
impl<T: Float> RfftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>, fft_n: usize) -> Self {
Self {
input,
_n: n,
fft_n,
}
}
}
impl<T: Float> GradFn<T> for RfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
Some(fft::irfft(grad_output, Some(self.fft_n))?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"RfftBackward"
}
}
#[derive(Debug)]
pub struct IrfftBackward<T: Float> {
input: Tensor<T>,
_n: Option<usize>,
output_n: usize,
}
impl<T: Float> IrfftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>, output_n: usize) -> Self {
Self {
input,
_n: n,
output_n,
}
}
}
impl<T: Float> GradFn<T> for IrfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
Some(fft::rfft(grad_output, Some(self.output_n))?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IrfftBackward"
}
}
pub fn fft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::fft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(FftBackward::new(input.clone(), n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn ifft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::ifft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(IfftBackward::new(input.clone(), n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn rfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let input_n = *input.shape().last().unwrap();
let fft_n = n.unwrap_or(input_n);
let result = fft::rfft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(RfftBackward::new(input.clone(), n, fft_n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn irfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let shape = input.shape();
let half_n = shape[shape.len() - 2];
let output_n = n.unwrap_or(2 * (half_n - 1));
let result = fft::irfft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(IrfftBackward::new(input.clone(), n, output_n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
#[derive(Debug)]
pub struct FftnBackward<T: Float> {
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
norm_n: usize,
}
impl<T: Float> FftnBackward<T> {
pub fn new(
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
norm_n: usize,
) -> Self {
Self {
input,
s,
axes,
norm_n,
}
}
}
impl<T: Float> GradFn<T> for FftnBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let inv = fft::ifftn(grad_output, self.s.as_deref(), self.axes.as_deref())?;
let scale = T::from(self.norm_n).unwrap();
let inv_data = inv.data_vec()?;
let scaled: Vec<T> = inv_data.iter().map(|&v| v * scale).collect();
let t = Tensor::from_storage(TensorStorage::cpu(scaled), inv.shape().to_vec(), false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FftnBackward"
}
}
#[derive(Debug)]
pub struct IfftnBackward<T: Float> {
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
norm_n: usize,
}
impl<T: Float> IfftnBackward<T> {
pub fn new(
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
norm_n: usize,
) -> Self {
Self {
input,
s,
axes,
norm_n,
}
}
}
impl<T: Float> GradFn<T> for IfftnBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let device = grad_output.device();
let fwd = fft::fftn(grad_output, self.s.as_deref(), self.axes.as_deref())?;
let scale = T::from(1.0).unwrap() / T::from(self.norm_n).unwrap();
let fwd_data = fwd.data_vec()?;
let scaled: Vec<T> = fwd_data.iter().map(|&v| v * scale).collect();
let t = Tensor::from_storage(TensorStorage::cpu(scaled), fwd.shape().to_vec(), false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IfftnBackward"
}
}
#[derive(Debug)]
pub struct RfftnBackward<T: Float> {
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
}
impl<T: Float> RfftnBackward<T> {
pub fn new(input: Tensor<T>, s: Option<Vec<usize>>, axes: Option<Vec<isize>>) -> Self {
Self { input, s, axes }
}
}
impl<T: Float> GradFn<T> for RfftnBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
Some(fft::irfftn(
grad_output,
self.s.as_deref(),
self.axes.as_deref(),
)?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"RfftnBackward"
}
}
#[derive(Debug)]
pub struct IrfftnBackward<T: Float> {
input: Tensor<T>,
s: Option<Vec<usize>>,
axes: Option<Vec<isize>>,
}
impl<T: Float> IrfftnBackward<T> {
pub fn new(input: Tensor<T>, s: Option<Vec<usize>>, axes: Option<Vec<isize>>) -> Self {
Self { input, s, axes }
}
}
impl<T: Float> GradFn<T> for IrfftnBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
Some(fft::rfftn(
grad_output,
self.s.as_deref(),
self.axes.as_deref(),
)?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IrfftnBackward"
}
}
#[derive(Debug)]
pub struct HfftBackward<T: Float> {
input: Tensor<T>,
input_n: usize,
}
impl<T: Float> HfftBackward<T> {
pub fn new(input: Tensor<T>, input_n: usize) -> Self {
Self { input, input_n }
}
}
impl<T: Float> GradFn<T> for HfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let n_forward = 2 * (self.input_n - 1);
Some(fft::ihfft(grad_output, Some(n_forward))?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"HfftBackward"
}
}
#[derive(Debug)]
pub struct IhfftBackward<T: Float> {
input: Tensor<T>,
input_n: usize,
}
impl<T: Float> IhfftBackward<T> {
pub fn new(input: Tensor<T>, input_n: usize) -> Self {
Self { input, input_n }
}
}
impl<T: Float> GradFn<T> for IhfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
Some(fft::hfft(grad_output, Some(self.input_n))?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IhfftBackward"
}
}
fn fftn_norm_n<T: Float>(input: &Tensor<T>, s: Option<&[usize]>, axes: Option<&[isize]>) -> usize {
if let Some(s_slice) = s {
return s_slice.iter().copied().product::<usize>().max(1);
}
let shape = input.shape();
let ndim = shape.len();
if let Some(axes_slice) = axes {
let mut prod: usize = 1;
for &a in axes_slice {
let logical_ndim = ndim.saturating_sub(1);
let resolved = if a < 0 {
(logical_ndim as isize + a) as usize
} else {
a as usize
};
prod = prod.saturating_mul(shape[resolved]);
}
return prod.max(1);
}
if ndim < 2 {
1
} else {
shape[..ndim - 1].iter().product::<usize>().max(1)
}
}
fn rfftn_norm_n<T: Float>(input: &Tensor<T>, s: Option<&[usize]>, axes: Option<&[isize]>) -> usize {
if let Some(s_slice) = s {
return s_slice.iter().copied().product::<usize>().max(1);
}
let shape = input.shape();
let ndim = shape.len();
if let Some(axes_slice) = axes {
let mut prod: usize = 1;
for &a in axes_slice {
let resolved = if a < 0 {
(ndim as isize + a) as usize
} else {
a as usize
};
prod = prod.saturating_mul(shape[resolved]);
}
return prod.max(1);
}
shape.iter().product::<usize>().max(1)
}
pub fn fftn_differentiable<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::fftn(input, s, axes)?;
if is_grad_enabled() && input.requires_grad() {
let norm_n = fftn_norm_n(input, s, axes);
let grad_fn = Arc::new(FftnBackward::new(
input.clone(),
s.map(|v| v.to_vec()),
axes.map(|v| v.to_vec()),
norm_n,
));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn ifftn_differentiable<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::ifftn(input, s, axes)?;
if is_grad_enabled() && input.requires_grad() {
let norm_n = fftn_norm_n(input, s, axes);
let grad_fn = Arc::new(IfftnBackward::new(
input.clone(),
s.map(|v| v.to_vec()),
axes.map(|v| v.to_vec()),
norm_n,
));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn rfftn_differentiable<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let _ = rfftn_norm_n::<T>; let result = fft::rfftn(input, s, axes)?;
if is_grad_enabled() && input.requires_grad() {
let s_back: Vec<usize> = match (s, axes) {
(Some(s_slice), _) => s_slice.to_vec(),
(None, Some(axes_slice)) => {
let shape = input.shape();
axes_slice
.iter()
.map(|&a| {
let resolved = if a < 0 {
(shape.len() as isize + a) as usize
} else {
a as usize
};
shape[resolved]
})
.collect()
}
(None, None) => input.shape().to_vec(),
};
let grad_fn = Arc::new(RfftnBackward::new(
input.clone(),
Some(s_back),
axes.map(|v| v.to_vec()),
));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn irfftn_differentiable<T: Float>(
input: &Tensor<T>,
s: Option<&[usize]>,
axes: Option<&[isize]>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let result = fft::irfftn(input, s, axes)?;
if is_grad_enabled() && input.requires_grad() {
let s_back: Vec<usize> = match s {
Some(s_slice) => s_slice.to_vec(),
None => result.shape().to_vec(),
};
let grad_fn = Arc::new(IrfftnBackward::new(
input.clone(),
Some(s_back),
axes.map(|v| v.to_vec()),
));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn hfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let shape = input.shape();
let input_n = shape[shape.len() - 2];
let _ = n; let result = fft::hfft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(HfftBackward::new(input.clone(), input_n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
pub fn ihfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let device = input.device();
let shape = input.shape();
let input_n = *shape.last().unwrap();
let _ = n;
let result = fft::ihfft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(IhfftBackward::new(input.clone(), input_n));
let storage = TensorStorage::on_device(result.data_vec()?, device)?;
Tensor::from_operation(storage, result.shape().to_vec(), grad_fn)
} else {
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn leaf(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
}
fn no_grad_leaf(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn assert_close(actual: &[f64], expected: &[f64], tol: f64) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: {a} vs {e} (diff {})",
(a - e).abs()
);
}
}
#[test]
fn fft_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = fft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "FftBackward");
}
#[test]
fn fft_differentiable_no_grad_when_not_needed() {
let input = no_grad_leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = fft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_none());
}
#[test]
fn fft_backward_identity_check() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = fft_differentiable(&input, None).unwrap();
let grad_out = no_grad_leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[4, 2]);
let grads = result.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
let g = grads[0].as_ref().unwrap();
assert_eq!(g.shape(), &[4, 2]);
let gd = g.data().unwrap();
assert_close(gd, &[4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 1e-10);
}
#[test]
fn ifft_backward_identity_check() {
let input = leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[4, 2]);
let result = ifft_differentiable(&input, None).unwrap();
let grad_out = no_grad_leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let grads = result.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
let g = grads[0].as_ref().unwrap();
let gd = g.data().unwrap();
assert_close(gd, &[0.25, 0.0, 0.25, 0.0, 0.25, 0.0, 0.25, 0.0], 1e-10);
}
#[test]
fn rfft_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[4]);
let result = rfft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "RfftBackward");
}
#[test]
fn irfft_differentiable_attaches_grad_fn() {
let input = leaf(&[10.0, 0.0, -2.0, 2.0, -2.0, 0.0], &[3, 2]);
let result = irfft_differentiable(&input, Some(4)).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "IrfftBackward");
}
#[test]
fn no_grad_context_disables_tracking() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result =
crate::autograd::no_grad::no_grad(|| fft_differentiable(&input, None).unwrap());
assert!(result.grad_fn().is_none());
}
#[test]
fn fftn_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[2, 2, 2]);
let result = fftn_differentiable(&input, None, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "FftnBackward");
}
#[test]
fn ifftn_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[2, 2, 2]);
let result = ifftn_differentiable(&input, None, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "IfftnBackward");
}
#[test]
fn fftn_no_grad_when_not_needed() {
let input = no_grad_leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[2, 2, 2]);
let result = fftn_differentiable(&input, None, None).unwrap();
assert!(result.grad_fn().is_none());
}
#[test]
fn fftn_backward_returns_real_grad_for_impulse() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[2, 2, 2]);
let result = fftn_differentiable(&input, None, None).unwrap();
let grad_out = no_grad_leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[2, 2, 2]);
let grads = result.grad_fn().unwrap().backward(&grad_out).unwrap();
let g = grads[0].as_ref().unwrap();
assert_close(
g.data().unwrap(),
&[4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1e-9,
);
}
#[test]
fn rfftn_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let result = rfftn_differentiable(&input, None, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "RfftnBackward");
}
#[test]
fn irfftn_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0], &[2, 2, 2]);
let result = irfftn_differentiable(&input, Some(&[2, 2]), None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "IrfftnBackward");
}
#[test]
fn hfft_differentiable_attaches_grad_fn() {
let input = leaf(&[10.0, 0.0, -2.0, 2.0, -2.0, 0.0], &[3, 2]);
let result = hfft_differentiable(&input, Some(4)).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "HfftBackward");
}
#[test]
fn ihfft_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[4]);
let result = ihfft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "IhfftBackward");
}
#[test]
fn fftn_norm_n_default_inner_dims() {
let input = no_grad_leaf(&vec![0.0; 2 * 3 * 4 * 2], &[2, 3, 4, 2]);
let n = fftn_norm_n(&input, None, None);
assert_eq!(n, 2 * 3 * 4);
}
#[test]
fn fftn_norm_n_with_explicit_s() {
let input = no_grad_leaf(&[0.0; 8 * 2], &[2, 2, 2, 2]);
let n = fftn_norm_n(&input, Some(&[3, 5]), None);
assert_eq!(n, 15);
}
#[test]
fn fftn_norm_n_with_axes() {
let input = no_grad_leaf(&vec![0.0; 2 * 3 * 4 * 2], &[2, 3, 4, 2]);
let n = fftn_norm_n(&input, None, Some(&[1, 2]));
assert_eq!(n, 12);
}
}