use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::FerrotorchResult;
use crate::fft;
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[derive(Debug)]
pub struct FftBackward<T: Float> {
input: Tensor<T>,
n: Option<usize>,
}
impl<T: Float> FftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>) -> Self {
Self { input, n }
}
}
impl<T: Float> GradFn<T> for FftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let inv = fft::ifft(grad_output, self.n)?;
let fft_n = grad_output.shape()[grad_output.ndim() - 2];
let scale = T::from(fft_n).unwrap();
let inv_data = inv.data()?;
let scaled: Vec<T> = inv_data.iter().map(|&v| v * scale).collect();
Some(Tensor::from_storage(
TensorStorage::cpu(scaled),
inv.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FftBackward"
}
}
#[derive(Debug)]
pub struct IfftBackward<T: Float> {
input: Tensor<T>,
n: Option<usize>,
}
impl<T: Float> IfftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>) -> Self {
Self { input, n }
}
}
impl<T: Float> GradFn<T> for IfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
let fwd = fft::fft(grad_output, self.n)?;
let fft_n = grad_output.shape()[grad_output.ndim() - 2];
let scale = T::from(1.0).unwrap() / T::from(fft_n).unwrap();
let fwd_data = fwd.data()?;
let scaled: Vec<T> = fwd_data.iter().map(|&v| v * scale).collect();
Some(Tensor::from_storage(
TensorStorage::cpu(scaled),
fwd.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IfftBackward"
}
}
#[derive(Debug)]
pub struct RfftBackward<T: Float> {
input: Tensor<T>,
_n: Option<usize>,
fft_n: usize,
}
impl<T: Float> RfftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>, fft_n: usize) -> Self {
Self { input, _n: n, fft_n }
}
}
impl<T: Float> GradFn<T> for RfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
Some(fft::irfft(grad_output, Some(self.fft_n))?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"RfftBackward"
}
}
#[derive(Debug)]
pub struct IrfftBackward<T: Float> {
input: Tensor<T>,
_n: Option<usize>,
output_n: usize,
}
impl<T: Float> IrfftBackward<T> {
pub fn new(input: Tensor<T>, n: Option<usize>, output_n: usize) -> Self {
Self { input, _n: n, output_n }
}
}
impl<T: Float> GradFn<T> for IrfftBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let grad_input = if self.input.requires_grad() {
Some(fft::rfft(grad_output, Some(self.output_n))?)
} else {
None
};
Ok(vec![grad_input])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"IrfftBackward"
}
}
pub fn fft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let result = fft::fft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(FftBackward::new(input.clone(), n));
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)
} else {
Ok(result)
}
}
pub fn ifft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let result = fft::ifft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(IfftBackward::new(input.clone(), n));
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)
} else {
Ok(result)
}
}
pub fn rfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let input_n = *input.shape().last().unwrap();
let fft_n = n.unwrap_or(input_n);
let result = fft::rfft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(RfftBackward::new(input.clone(), n, fft_n));
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)
} else {
Ok(result)
}
}
pub fn irfft_differentiable<T: Float>(
input: &Tensor<T>,
n: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
let half_n = shape[shape.len() - 2];
let output_n = n.unwrap_or(2 * (half_n - 1));
let result = fft::irfft(input, n)?;
if is_grad_enabled() && input.requires_grad() {
let grad_fn = Arc::new(IrfftBackward::new(input.clone(), n, output_n));
Tensor::from_operation(
TensorStorage::cpu(result.data()?.to_vec()),
result.shape().to_vec(),
grad_fn,
)
} else {
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn leaf(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
}
fn no_grad_leaf(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn assert_close(actual: &[f64], expected: &[f64], tol: f64) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: {a} vs {e} (diff {})",
(a - e).abs()
);
}
}
#[test]
fn fft_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = fft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "FftBackward");
}
#[test]
fn fft_differentiable_no_grad_when_not_needed() {
let input = no_grad_leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = fft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_none());
}
#[test]
fn fft_backward_identity_check() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = fft_differentiable(&input, None).unwrap();
let grad_out = no_grad_leaf(
&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
&[4, 2],
);
let grads = result.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
let g = grads[0].as_ref().unwrap();
assert_eq!(g.shape(), &[4, 2]);
let gd = g.data().unwrap();
assert_close(gd, &[4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 1e-10);
}
#[test]
fn ifft_backward_identity_check() {
let input = leaf(
&[1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
&[4, 2],
);
let result = ifft_differentiable(&input, None).unwrap();
let grad_out = no_grad_leaf(
&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
&[4, 2],
);
let grads = result.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
let g = grads[0].as_ref().unwrap();
let gd = g.data().unwrap();
assert_close(
gd,
&[0.25, 0.0, 0.25, 0.0, 0.25, 0.0, 0.25, 0.0],
1e-10,
);
}
#[test]
fn rfft_differentiable_attaches_grad_fn() {
let input = leaf(&[1.0, 2.0, 3.0, 4.0], &[4]);
let result = rfft_differentiable(&input, None).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "RfftBackward");
}
#[test]
fn irfft_differentiable_attaches_grad_fn() {
let input = leaf(&[10.0, 0.0, -2.0, 2.0, -2.0, 0.0], &[3, 2]);
let result = irfft_differentiable(&input, Some(4)).unwrap();
assert!(result.grad_fn().is_some());
assert_eq!(result.grad_fn().unwrap().name(), "IrfftBackward");
}
#[test]
fn no_grad_context_disables_tracking() {
let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[4, 2]);
let result = crate::autograd::no_grad::no_grad(|| {
fft_differentiable(&input, None).unwrap()
});
assert!(result.grad_fn().is_none());
}
}