#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use crate::api::{Direction, Flags, Plan};
use crate::kernel::{Complex, Float};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Dual<T: Float> {
pub value: T,
pub deriv: T,
}
impl<T: Float> Dual<T> {
pub fn new(value: T, deriv: T) -> Self {
Self { value, deriv }
}
pub fn constant(value: T) -> Self {
Self {
value,
deriv: T::ZERO,
}
}
pub fn variable(value: T) -> Self {
Self {
value,
deriv: T::ONE,
}
}
}
impl<T: Float> core::ops::Add for Dual<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self::new(self.value + rhs.value, self.deriv + rhs.deriv)
}
}
impl<T: Float> core::ops::Sub for Dual<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self::new(self.value - rhs.value, self.deriv - rhs.deriv)
}
}
impl<T: Float> core::ops::Mul for Dual<T> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self::new(
self.value * rhs.value,
self.value * rhs.deriv + self.deriv * rhs.value,
)
}
}
impl<T: Float> core::ops::Div for Dual<T> {
type Output = Self;
fn div(self, rhs: Self) -> Self {
let val = self.value / rhs.value;
let deriv = (self.deriv * rhs.value - self.value * rhs.deriv) / (rhs.value * rhs.value);
Self::new(val, deriv)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DualComplex<T: Float> {
pub value: Complex<T>,
pub deriv: Complex<T>,
}
impl<T: Float> DualComplex<T> {
pub fn new(re: T, im: T, dre: T, dim: T) -> Self {
Self {
value: Complex::new(re, im),
deriv: Complex::new(dre, dim),
}
}
pub fn from_complex(value: Complex<T>, deriv: Complex<T>) -> Self {
Self { value, deriv }
}
pub fn constant(value: Complex<T>) -> Self {
Self {
value,
deriv: Complex::zero(),
}
}
pub fn variable(value: Complex<T>) -> Self {
Self {
value,
deriv: Complex::new(T::ONE, T::ZERO),
}
}
pub fn zero() -> Self {
Self {
value: Complex::zero(),
deriv: Complex::zero(),
}
}
}
impl<T: Float> core::ops::Add for DualComplex<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self::from_complex(self.value + rhs.value, self.deriv + rhs.deriv)
}
}
impl<T: Float> core::ops::Sub for DualComplex<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self::from_complex(self.value - rhs.value, self.deriv - rhs.deriv)
}
}
impl<T: Float> core::ops::Mul for DualComplex<T> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self::from_complex(
self.value * rhs.value,
self.value * rhs.deriv + self.deriv * rhs.value,
)
}
}
impl<T: Float> core::ops::Mul<Complex<T>> for DualComplex<T> {
type Output = Self;
fn mul(self, rhs: Complex<T>) -> Self {
Self::from_complex(self.value * rhs, self.deriv * rhs)
}
}
pub struct DiffFftPlan<T: Float> {
fwd_plan: Plan<T>,
inv_plan: Plan<T>,
size: usize,
}
impl<T: Float> DiffFftPlan<T> {
pub fn new(size: usize) -> Option<Self> {
let fwd_plan = Plan::dft_1d(size, Direction::Forward, Flags::MEASURE)?;
let inv_plan = Plan::dft_1d(size, Direction::Backward, Flags::MEASURE)?;
Some(Self {
fwd_plan,
inv_plan,
size,
})
}
pub fn forward(&self, input: &[Complex<T>], output: &mut [Complex<T>]) {
self.fwd_plan.execute(input, output);
}
pub fn inverse(&self, input: &[Complex<T>], output: &mut [Complex<T>]) {
self.inv_plan.execute(input, output);
let scale = T::ONE / T::from_usize(self.size);
for c in output.iter_mut() {
*c = Complex::new(c.re * scale, c.im * scale);
}
}
pub fn forward_dual(&self, input: &[DualComplex<T>]) -> (Vec<Complex<T>>, Vec<Complex<T>>) {
let n = input.len();
let values: Vec<Complex<T>> = input.iter().map(|d| d.value).collect();
let tangents: Vec<Complex<T>> = input.iter().map(|d| d.deriv).collect();
let mut y = vec![Complex::<T>::zero(); n];
self.forward(&values, &mut y);
let mut dy = vec![Complex::<T>::zero(); n];
self.forward(&tangents, &mut dy);
(y, dy)
}
pub fn backward(&self, grad_output: &[Complex<T>]) -> Vec<Complex<T>> {
let n = grad_output.len();
let mut grad_input = vec![Complex::<T>::zero(); n];
self.inv_plan.execute(grad_output, &mut grad_input);
grad_input
}
pub fn backward_inverse(&self, grad_output: &[Complex<T>]) -> Vec<Complex<T>> {
let n = grad_output.len();
let mut grad_input = vec![Complex::<T>::zero(); n];
self.forward(grad_output, &mut grad_input);
let scale = T::ONE / T::from_usize(n);
for c in &mut grad_input {
*c = Complex::new(c.re * scale, c.im * scale);
}
grad_input
}
pub fn size(&self) -> usize {
self.size
}
}
pub fn fft_dual<T: Float>(input: &[DualComplex<T>]) -> Option<(Vec<Complex<T>>, Vec<Complex<T>>)> {
let plan = DiffFftPlan::new(input.len())?;
Some(plan.forward_dual(input))
}
pub fn grad_fft<T: Float>(grad_output: &[Complex<T>]) -> Option<Vec<Complex<T>>> {
let plan = DiffFftPlan::new(grad_output.len())?;
Some(plan.backward(grad_output))
}
pub fn grad_ifft<T: Float>(grad_output: &[Complex<T>]) -> Option<Vec<Complex<T>>> {
let plan = DiffFftPlan::new(grad_output.len())?;
Some(plan.backward_inverse(grad_output))
}
pub fn vjp_fft<T: Float>(v: &[Complex<T>]) -> Option<Vec<Complex<T>>> {
grad_fft(v)
}
pub fn jvp_fft<T: Float>(v: &[Complex<T>]) -> Option<Vec<Complex<T>>> {
use crate::api::fft;
Some(fft(v))
}
pub fn fft_jacobian<T: Float>(n: usize) -> Vec<Vec<Complex<T>>> {
let two_pi = T::from_f64(2.0 * core::f64::consts::PI);
let n_t = T::from_usize(n);
(0..n)
.map(|k| {
(0..n)
.map(|j| {
let angle = -two_pi * T::from_usize(k) * T::from_usize(j) / n_t;
Complex::new(Float::cos(angle), Float::sin(angle))
})
.collect()
})
.collect()
}
pub mod real {
use super::*;
pub fn grad_rfft<T: Float>(grad_output: &[Complex<T>], n: usize) -> Option<Vec<T>> {
let fft_plan = Plan::<T>::dft_1d(n, Direction::Backward, Flags::ESTIMATE)?;
let mut full_grad = vec![Complex::<T>::zero(); n];
for (i, &g) in grad_output.iter().enumerate() {
full_grad[i] = g;
}
for i in 1..n / 2 {
full_grad[n - i] = grad_output[i].conj();
}
let mut result = vec![Complex::<T>::zero(); n];
fft_plan.execute(&full_grad, &mut result);
let scale = T::ONE / T::from_usize(n);
Some(result.iter().map(|c| c.re * scale).collect())
}
pub fn grad_irfft<T: Float>(grad_output: &[T], n_output: usize) -> Option<Vec<Complex<T>>> {
let fft_plan = Plan::<T>::dft_1d(n_output, Direction::Forward, Flags::ESTIMATE)?;
let complex_grad: Vec<Complex<T>> = grad_output
.iter()
.map(|&r| Complex::new(r, T::ZERO))
.collect();
let mut result = vec![Complex::<T>::zero(); n_output];
fft_plan.execute(&complex_grad, &mut result);
let n_freq = n_output / 2 + 1;
let scale = T::ONE / T::from_usize(n_output);
Some(
result
.into_iter()
.take(n_freq)
.map(|c| Complex::new(c.re * scale, c.im * scale))
.collect(),
)
}
}
pub mod fft2d {
use super::*;
pub fn grad_fft2d<T: Float>(
grad_output: &[Complex<T>],
rows: usize,
cols: usize,
) -> Option<Vec<Complex<T>>> {
if grad_output.len() != rows * cols {
return None;
}
let row_plan = DiffFftPlan::new(cols)?;
let col_plan = DiffFftPlan::new(rows)?;
let mut temp = vec![Complex::<T>::zero(); rows * cols];
for c in 0..cols {
let col: Vec<Complex<T>> = (0..rows).map(|r| grad_output[r * cols + c]).collect();
let grad_col = col_plan.backward(&col);
for (r, &g) in grad_col.iter().enumerate() {
temp[r * cols + c] = g;
}
}
let mut result = vec![Complex::<T>::zero(); rows * cols];
for r in 0..rows {
let row: Vec<Complex<T>> = (0..cols).map(|c| temp[r * cols + c]).collect();
let grad_row = row_plan.backward(&row);
for (c, &g) in grad_row.iter().enumerate() {
result[r * cols + c] = g;
}
}
Some(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_dual_arithmetic() {
let a = Dual::new(2.0, 1.0);
let b = Dual::new(3.0, 0.0);
let sum = a + b;
assert!(approx_eq(sum.value, 5.0, 1e-10));
assert!(approx_eq(sum.deriv, 1.0, 1e-10));
let prod = a * b;
assert!(approx_eq(prod.value, 6.0, 1e-10));
assert!(approx_eq(prod.deriv, 3.0, 1e-10)); }
#[test]
fn test_dual_complex_arithmetic() {
let a = DualComplex::new(1.0, 0.0, 1.0, 0.0);
let b = DualComplex::new(0.0, 1.0, 0.0, 0.0);
let sum = a + b;
assert!(approx_eq(sum.value.re, 1.0, 1e-10));
assert!(approx_eq(sum.value.im, 1.0, 1e-10));
assert!(approx_eq(sum.deriv.re, 1.0, 1e-10));
assert!(approx_eq(sum.deriv.im, 0.0, 1e-10));
}
#[test]
fn test_fft_forward_mode() {
let n = 8;
let input: Vec<DualComplex<f64>> = (0..n)
.map(|k| DualComplex::new(1.0, 0.0, if k == 0 { 1.0 } else { 0.0 }, 0.0))
.collect();
let result = fft_dual(&input);
assert!(result.is_some());
let (y, dy) = result.expect("fft_dual failed");
assert!(approx_eq(y[0].re, n as f64, 1e-10));
for i in 1..n {
assert!(approx_eq(y[i].re, 0.0, 1e-10));
assert!(approx_eq(y[i].im, 0.0, 1e-10));
}
for i in 0..n {
assert!(approx_eq(dy[i].re, 1.0, 1e-10));
assert!(approx_eq(dy[i].im, 0.0, 1e-10));
}
}
#[test]
fn test_fft_backward_mode() {
let n = 8;
let x: Vec<Complex<f64>> = (0..n)
.map(|k| Complex::new((k as f64).cos(), (k as f64).sin()))
.collect();
let v: Vec<Complex<f64>> = (0..n)
.map(|k| Complex::new(((k + 1) as f64).sin(), ((k + 1) as f64).cos()))
.collect();
let plan = DiffFftPlan::new(n).expect("Plan creation failed");
let mut y = vec![Complex::<f64>::zero(); n];
plan.forward(&x, &mut y);
let grad_x = plan.backward(&v);
let inner_vy: Complex<f64> = v
.iter()
.zip(y.iter())
.map(|(&a, &b)| a.conj() * b)
.fold(Complex::zero(), |acc, x| acc + x);
let inner_gx: Complex<f64> = grad_x
.iter()
.zip(x.iter())
.map(|(&a, &b)| a.conj() * b)
.fold(Complex::zero(), |acc, x| acc + x);
assert!(
approx_eq(inner_vy.re, inner_gx.re, 1e-8),
"Adjoint property failed: {} != {}",
inner_vy.re,
inner_gx.re
);
}
#[test]
fn test_fft_jacobian_small() {
let n = 4;
let jac = fft_jacobian::<f64>(n);
assert_eq!(jac.len(), n);
for row in &jac {
assert_eq!(row.len(), n);
}
for j in 0..n {
assert!(approx_eq(jac[0][j].re, 1.0, 1e-10));
assert!(approx_eq(jac[0][j].im, 0.0, 1e-10));
}
}
#[test]
fn test_vjp_jvp_consistency() {
let n = 8;
let u: Vec<Complex<f64>> = (0..n)
.map(|k| Complex::new(f64::from(k) * 0.1, 0.0))
.collect();
let v: Vec<Complex<f64>> = (0..n)
.map(|k| Complex::new(0.0, f64::from(k) * 0.1))
.collect();
let jvp_u = jvp_fft(&u).expect("JVP failed");
let vjp_v = vjp_fft(&v).expect("VJP failed");
let inner1: Complex<f64> = v
.iter()
.zip(jvp_u.iter())
.map(|(&a, &b)| a.conj() * b)
.fold(Complex::zero(), |acc, x| acc + x);
let inner2: Complex<f64> = vjp_v
.iter()
.zip(u.iter())
.map(|(&a, &b)| a.conj() * b)
.fold(Complex::zero(), |acc, x| acc + x);
assert!(
approx_eq(inner1.re, inner2.re, 1e-8),
"VJP/JVP consistency failed"
);
}
}