use crate::infer::GraphExt as _;
use crate::{DType, Graph, NodeId, Op, Shape, fft::FftNorm};
impl Graph {
pub fn pad_last_axis_to_pow2(&mut self, x: NodeId) -> NodeId {
let shape = self.shape(x).clone();
let rank = shape.rank();
let last = rank - 1;
let n = shape.dim(last).unwrap_static();
let n_pad = crate::fft::next_pow2(n);
if n_pad == n {
return x;
}
let pad_len = n_pad - n;
let mut pad_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
pad_dims[last] = pad_len;
let pad_shape = Shape::new(&pad_dims, shape.dtype());
let zeros = self.zeros_tensor(&pad_shape);
self.concat_(vec![x, zeros], last)
}
pub fn split_spectrum(&mut self, spectrum: NodeId) -> (NodeId, NodeId) {
let shape = self.shape(spectrum).clone();
let meta = crate::fft::fft_meta(&shape);
let last = shape.rank() - 1;
let n = meta.n_complex;
let re = self.narrow_(spectrum, last, 0, n);
let im = self.narrow_(spectrum, last, n, n);
(re, im)
}
pub fn fft_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId) {
assert_eq!(
self.shape(x).dtype(),
DType::F32,
"fft_real: requires F32 real input"
);
let padded = self.pad_last_axis_to_pow2(x);
let shape = self.shape(padded).clone();
let rank = shape.rank();
let last = rank - 1;
let n = shape.dim(last).unwrap_static();
let mut im_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
im_dims[last] = n;
let im_shape = Shape::new(&im_dims, DType::F32);
let zero_im = self.zeros_tensor(&im_shape);
let block = self.concat_(vec![padded, zero_im], last);
let spectrum = self.fft_norm(block, false, norm);
self.split_spectrum(spectrum)
}
pub fn fft_batch_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId) {
self.fft_real(x, norm)
}
pub fn rfft(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId) {
let (re, im) = self.fft_real(x, norm);
let rank = self.shape(re).rank();
let last = rank - 1;
let n = self.shape(re).dim(last).unwrap_static();
let half = n / 2 + 1;
(
self.narrow_(re, last, 0, half),
self.narrow_(im, last, 0, half),
)
}
pub fn irfft(&mut self, re_half: NodeId, im_half: NodeId, n: usize, norm: FftNorm) -> NodeId {
assert_eq!(
*self.shape(re_half),
*self.shape(im_half),
"irfft: re/im shape mismatch"
);
let n_pad = crate::fft::next_pow2(n);
let half = n_pad / 2 + 1;
let rank = self.shape(re_half).rank();
let last = rank - 1;
assert_eq!(
self.shape(re_half).dim(last).unwrap_static(),
half,
"irfft: expected half-spectrum length {half}, got {}",
self.shape(re_half).dim(last).unwrap_static()
);
let (re_full, im_full) = if half > 2 {
let mirror_len = half - 2;
let mirror_re = self.narrow_(re_half, last, 1, mirror_len);
let mirror_im = self.narrow_(im_half, last, 1, mirror_len);
let mirror_re_rev = self.reverse_last_axis(mirror_re);
let mirror_im_rev = self.reverse_last_axis(mirror_im);
let neg = self.scalar_f32(-1.0);
let mirror_im_neg = self.mul(mirror_im_rev, neg);
(
self.concat_(vec![re_half, mirror_re_rev], last),
self.concat_(vec![im_half, mirror_im_neg], last),
)
} else {
(re_half, im_half)
};
let recovered = self.ifft_spectrum(re_full, im_full, norm);
self.narrow_(recovered, last, 0, n)
}
pub fn stft(&mut self, x: NodeId, frame_len: usize, hop: usize, norm: FftNorm) -> NodeId {
assert!(
frame_len > 0 && hop > 0,
"stft: frame_len and hop must be positive"
);
let shape = self.shape(x).clone();
let rank = shape.rank();
let last = rank - 1;
let t = shape.dim(last).unwrap_static();
assert!(
t >= frame_len,
"stft: signal length {t} < frame_len {frame_len}"
);
let n_frames = 1 + (t - frame_len) / hop;
let mut frames = Vec::with_capacity(n_frames);
for f in 0..n_frames {
let start = f * hop;
let frame = self.narrow_(x, last, start, frame_len);
let (re, im) = self.rfft(frame, norm);
let block = self.concat_(vec![re, im], last);
frames.push(block);
}
if frames.len() == 1 {
let f = frames[0];
let mut dims: Vec<i64> = self
.shape(f)
.dims()
.iter()
.map(|d| d.unwrap_static() as i64)
.collect();
dims.insert(0, 1);
return self.reshape_(f, dims);
}
let mut rows = Vec::new();
for f in frames {
let mut dims: Vec<i64> = self
.shape(f)
.dims()
.iter()
.map(|d| d.unwrap_static() as i64)
.collect();
dims.insert(0, 1);
rows.push(self.reshape_(f, dims));
}
self.concat_(rows, 0)
}
pub fn fft_conv1d(&mut self, a: NodeId, b: NodeId, n_fft: usize, norm: FftNorm) -> NodeId {
let n_fft = n_fft.max(crate::fft::next_pow2(
self.shape(a).dim(self.shape(a).rank() - 1).unwrap_static()
+ self.shape(b).dim(self.shape(b).rank() - 1).unwrap_static()
- 1,
));
let pad_a = self.pad_axis_to_len(a, n_fft);
let pad_b = self.pad_axis_to_len(b, n_fft);
let (a_re, a_im) = self.rfft(pad_a, norm);
let (b_re, b_im) = self.rfft(pad_b, norm);
let ar_br = self.mul(a_re, b_re);
let ai_bi = self.mul(a_im, b_im);
let prod_re = self.sub(ar_br, ai_bi);
let ar_bi = self.mul(a_re, b_im);
let ai_br = self.mul(a_im, b_re);
let prod_im = self.add(ar_bi, ai_br);
let out_len = self.shape(a).dim(self.shape(a).rank() - 1).unwrap_static()
+ self.shape(b).dim(self.shape(b).rank() - 1).unwrap_static()
- 1;
self.irfft(prod_re, prod_im, out_len.max(1), norm)
}
pub fn fftfreq_tensor(&mut self, n: usize) -> NodeId {
let xs = crate::fft::fftfreq(n);
let mut bytes = Vec::with_capacity(n * 8);
for x in &xs {
bytes.extend_from_slice(&x.to_le_bytes());
}
self.add_node(
Op::Constant { data: bytes },
vec![],
Shape::new(&[n], DType::F64),
)
}
pub fn rfftfreq_tensor(&mut self, n: usize) -> NodeId {
let xs = crate::fft::rfftfreq(n);
let half = xs.len();
let mut bytes = Vec::with_capacity(half * 8);
for x in &xs {
bytes.extend_from_slice(&x.to_le_bytes());
}
self.add_node(
Op::Constant { data: bytes },
vec![],
Shape::new(&[half], DType::F64),
)
}
pub fn psd_real(&mut self, x: NodeId, norm: FftNorm) -> NodeId {
let (re, im) = self.rfft(x, norm);
self.psd(re, im)
}
pub fn ifft_spectrum(&mut self, re: NodeId, im: NodeId, norm: FftNorm) -> NodeId {
let re_shape = self.shape(re).clone();
assert_eq!(
re_shape,
*self.shape(im),
"ifft_spectrum: re/im shape mismatch"
);
let rank = re_shape.rank();
let last = rank - 1;
let n = re_shape.dim(last).unwrap_static();
let block = self.concat_(vec![re, im], last);
let full = self.fft_norm(block, true, norm);
self.narrow_(full, last, 0, n)
}
pub fn psd(&mut self, re: NodeId, im: NodeId) -> NodeId {
let n = self
.shape(re)
.dim(self.shape(re).rank() - 1)
.unwrap_static();
let re2 = self.mul(re, re);
let im2 = self.mul(im, im);
let power = self.add(re2, im2);
let inv_n = self.scalar_f32(1.0 / n as f32);
self.mul(power, inv_n)
}
fn reverse_last_axis(&mut self, x: NodeId) -> NodeId {
let shape = self.shape(x).clone();
let rank = shape.rank();
let last = rank - 1;
let len = shape.dim(last).unwrap_static();
if len <= 1 {
return x;
}
let prefix_elems: usize = shape
.dims()
.iter()
.take(last)
.map(|d| d.unwrap_static())
.product();
let mut idx_bytes = Vec::with_capacity(prefix_elems * len * 4);
for _ in 0..prefix_elems.max(1) {
for i in (0..len).rev() {
idx_bytes.extend_from_slice(&(i as i32).to_le_bytes());
}
}
let idx_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
let idx = self.add_node(
Op::Constant { data: idx_bytes },
vec![],
Shape::new(&idx_dims, DType::I32),
);
self.gather_(x, idx, last)
}
fn pad_axis_to_len(&mut self, x: NodeId, len: usize) -> NodeId {
let shape = self.shape(x).clone();
let last = shape.rank() - 1;
let n = shape.dim(last).unwrap_static();
if n >= len {
return self.narrow_(x, last, 0, len);
}
let pad_len = len - n;
let mut pad_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
pad_dims[last] = pad_len;
let zeros = self.zeros_tensor(&Shape::new(&pad_dims, shape.dtype()));
self.concat_(vec![x, zeros], last)
}
fn zeros_tensor(&mut self, shape: &Shape) -> NodeId {
let n = shape.num_elements().unwrap();
let bytes = vec![0u8; n * shape.dtype().size_bytes()];
self.add_node(Op::Constant { data: bytes }, vec![], shape.clone())
}
fn scalar_f32(&mut self, v: f32) -> NodeId {
self.add_node(
Op::Constant {
data: v.to_le_bytes().to_vec(),
},
vec![],
Shape::scalar(DType::F32),
)
}
}