1use crate::infer::GraphExt as _;
24use crate::{DType, Graph, NodeId, Op, Shape, fft::FftNorm};
25
26impl Graph {
27 pub fn pad_last_axis_to_pow2(&mut self, x: NodeId) -> NodeId {
29 let shape = self.shape(x).clone();
30 let rank = shape.rank();
31 let last = rank - 1;
32 let n = shape.dim(last).unwrap_static();
33 let n_pad = crate::fft::next_pow2(n);
34 if n_pad == n {
35 return x;
36 }
37 let pad_len = n_pad - n;
38 let mut pad_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
39 pad_dims[last] = pad_len;
40 let pad_shape = Shape::new(&pad_dims, shape.dtype());
41 let zeros = self.zeros_tensor(&pad_shape);
42 self.concat_(vec![x, zeros], last)
43 }
44
45 pub fn split_spectrum(&mut self, spectrum: NodeId) -> (NodeId, NodeId) {
47 let shape = self.shape(spectrum).clone();
48 let meta = crate::fft::fft_meta(&shape);
49 let last = shape.rank() - 1;
50 let n = meta.n_complex;
51 let re = self.narrow_(spectrum, last, 0, n);
52 let im = self.narrow_(spectrum, last, n, n);
53 (re, im)
54 }
55
56 pub fn fft_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId) {
58 assert_eq!(
59 self.shape(x).dtype(),
60 DType::F32,
61 "fft_real: requires F32 real input"
62 );
63 let padded = self.pad_last_axis_to_pow2(x);
64 let shape = self.shape(padded).clone();
65 let rank = shape.rank();
66 let last = rank - 1;
67 let n = shape.dim(last).unwrap_static();
68 let mut im_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
69 im_dims[last] = n;
70 let im_shape = Shape::new(&im_dims, DType::F32);
71 let zero_im = self.zeros_tensor(&im_shape);
72 let block = self.concat_(vec![padded, zero_im], last);
73 let spectrum = self.fft_norm(block, false, norm);
74 self.split_spectrum(spectrum)
75 }
76
77 pub fn fft_batch_real(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId) {
80 self.fft_real(x, norm)
81 }
82
83 pub fn rfft(&mut self, x: NodeId, norm: FftNorm) -> (NodeId, NodeId) {
88 let (re, im) = self.fft_real(x, norm);
89 let rank = self.shape(re).rank();
90 let last = rank - 1;
91 let n = self.shape(re).dim(last).unwrap_static();
92 let half = n / 2 + 1;
93 (
94 self.narrow_(re, last, 0, half),
95 self.narrow_(im, last, 0, half),
96 )
97 }
98
99 pub fn irfft(&mut self, re_half: NodeId, im_half: NodeId, n: usize, norm: FftNorm) -> NodeId {
104 assert_eq!(
105 *self.shape(re_half),
106 *self.shape(im_half),
107 "irfft: re/im shape mismatch"
108 );
109 let n_pad = crate::fft::next_pow2(n);
110 let half = n_pad / 2 + 1;
111 let rank = self.shape(re_half).rank();
112 let last = rank - 1;
113 assert_eq!(
114 self.shape(re_half).dim(last).unwrap_static(),
115 half,
116 "irfft: expected half-spectrum length {half}, got {}",
117 self.shape(re_half).dim(last).unwrap_static()
118 );
119 let (re_full, im_full) = if half > 2 {
120 let mirror_len = half - 2;
121 let mirror_re = self.narrow_(re_half, last, 1, mirror_len);
122 let mirror_im = self.narrow_(im_half, last, 1, mirror_len);
123 let mirror_re_rev = self.reverse_last_axis(mirror_re);
124 let mirror_im_rev = self.reverse_last_axis(mirror_im);
125 let neg = self.scalar_f32(-1.0);
126 let mirror_im_neg = self.mul(mirror_im_rev, neg);
127 (
128 self.concat_(vec![re_half, mirror_re_rev], last),
129 self.concat_(vec![im_half, mirror_im_neg], last),
130 )
131 } else {
132 (re_half, im_half)
133 };
134 let recovered = self.ifft_spectrum(re_full, im_full, norm);
135 self.narrow_(recovered, last, 0, n)
136 }
137
138 pub fn stft(&mut self, x: NodeId, frame_len: usize, hop: usize, norm: FftNorm) -> NodeId {
142 assert!(
143 frame_len > 0 && hop > 0,
144 "stft: frame_len and hop must be positive"
145 );
146 let shape = self.shape(x).clone();
147 let rank = shape.rank();
148 let last = rank - 1;
149 let t = shape.dim(last).unwrap_static();
150 assert!(
151 t >= frame_len,
152 "stft: signal length {t} < frame_len {frame_len}"
153 );
154 let n_frames = 1 + (t - frame_len) / hop;
155 let mut frames = Vec::with_capacity(n_frames);
156 for f in 0..n_frames {
157 let start = f * hop;
158 let frame = self.narrow_(x, last, start, frame_len);
159 let (re, im) = self.rfft(frame, norm);
160 let block = self.concat_(vec![re, im], last);
161 frames.push(block);
162 }
163 if frames.len() == 1 {
164 let f = frames[0];
165 let mut dims: Vec<i64> = self
166 .shape(f)
167 .dims()
168 .iter()
169 .map(|d| d.unwrap_static() as i64)
170 .collect();
171 dims.insert(0, 1);
172 return self.reshape_(f, dims);
173 }
174 let mut rows = Vec::new();
175 for f in frames {
176 let mut dims: Vec<i64> = self
177 .shape(f)
178 .dims()
179 .iter()
180 .map(|d| d.unwrap_static() as i64)
181 .collect();
182 dims.insert(0, 1);
183 rows.push(self.reshape_(f, dims));
184 }
185 self.concat_(rows, 0)
186 }
187
188 pub fn fft_conv1d(&mut self, a: NodeId, b: NodeId, n_fft: usize, norm: FftNorm) -> NodeId {
193 let n_fft = n_fft.max(crate::fft::next_pow2(
194 self.shape(a).dim(self.shape(a).rank() - 1).unwrap_static()
195 + self.shape(b).dim(self.shape(b).rank() - 1).unwrap_static()
196 - 1,
197 ));
198 let pad_a = self.pad_axis_to_len(a, n_fft);
199 let pad_b = self.pad_axis_to_len(b, n_fft);
200 let (a_re, a_im) = self.rfft(pad_a, norm);
201 let (b_re, b_im) = self.rfft(pad_b, norm);
202 let ar_br = self.mul(a_re, b_re);
203 let ai_bi = self.mul(a_im, b_im);
204 let prod_re = self.sub(ar_br, ai_bi);
205 let ar_bi = self.mul(a_re, b_im);
206 let ai_br = self.mul(a_im, b_re);
207 let prod_im = self.add(ar_bi, ai_br);
208 let out_len = self.shape(a).dim(self.shape(a).rank() - 1).unwrap_static()
209 + self.shape(b).dim(self.shape(b).rank() - 1).unwrap_static()
210 - 1;
211 self.irfft(prod_re, prod_im, out_len.max(1), norm)
212 }
213
214 pub fn fftfreq_tensor(&mut self, n: usize) -> NodeId {
216 let xs = crate::fft::fftfreq(n);
217 let mut bytes = Vec::with_capacity(n * 8);
218 for x in &xs {
219 bytes.extend_from_slice(&x.to_le_bytes());
220 }
221 self.add_node(
222 Op::Constant { data: bytes },
223 vec![],
224 Shape::new(&[n], DType::F64),
225 )
226 }
227
228 pub fn rfftfreq_tensor(&mut self, n: usize) -> NodeId {
230 let xs = crate::fft::rfftfreq(n);
231 let half = xs.len();
232 let mut bytes = Vec::with_capacity(half * 8);
233 for x in &xs {
234 bytes.extend_from_slice(&x.to_le_bytes());
235 }
236 self.add_node(
237 Op::Constant { data: bytes },
238 vec![],
239 Shape::new(&[half], DType::F64),
240 )
241 }
242
243 pub fn psd_real(&mut self, x: NodeId, norm: FftNorm) -> NodeId {
245 let (re, im) = self.rfft(x, norm);
246 self.psd(re, im)
247 }
248
249 pub fn ifft_spectrum(&mut self, re: NodeId, im: NodeId, norm: FftNorm) -> NodeId {
251 let re_shape = self.shape(re).clone();
252 assert_eq!(
253 re_shape,
254 *self.shape(im),
255 "ifft_spectrum: re/im shape mismatch"
256 );
257 let rank = re_shape.rank();
258 let last = rank - 1;
259 let n = re_shape.dim(last).unwrap_static();
260 let block = self.concat_(vec![re, im], last);
261 let full = self.fft_norm(block, true, norm);
262 self.narrow_(full, last, 0, n)
263 }
264
265 pub fn psd(&mut self, re: NodeId, im: NodeId) -> NodeId {
267 let n = self
268 .shape(re)
269 .dim(self.shape(re).rank() - 1)
270 .unwrap_static();
271 let re2 = self.mul(re, re);
272 let im2 = self.mul(im, im);
273 let power = self.add(re2, im2);
274 let inv_n = self.scalar_f32(1.0 / n as f32);
275 self.mul(power, inv_n)
276 }
277
278 fn reverse_last_axis(&mut self, x: NodeId) -> NodeId {
279 let shape = self.shape(x).clone();
280 let rank = shape.rank();
281 let last = rank - 1;
282 let len = shape.dim(last).unwrap_static();
283 if len <= 1 {
284 return x;
285 }
286 let prefix_elems: usize = shape
287 .dims()
288 .iter()
289 .take(last)
290 .map(|d| d.unwrap_static())
291 .product();
292 let mut idx_bytes = Vec::with_capacity(prefix_elems * len * 4);
293 for _ in 0..prefix_elems.max(1) {
294 for i in (0..len).rev() {
295 idx_bytes.extend_from_slice(&(i as i32).to_le_bytes());
296 }
297 }
298 let idx_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
299 let idx = self.add_node(
300 Op::Constant { data: idx_bytes },
301 vec![],
302 Shape::new(&idx_dims, DType::I32),
303 );
304 self.gather_(x, idx, last)
305 }
306
307 fn pad_axis_to_len(&mut self, x: NodeId, len: usize) -> NodeId {
308 let shape = self.shape(x).clone();
309 let last = shape.rank() - 1;
310 let n = shape.dim(last).unwrap_static();
311 if n >= len {
312 return self.narrow_(x, last, 0, len);
313 }
314 let pad_len = len - n;
315 let mut pad_dims: Vec<usize> = shape.dims().iter().map(|d| d.unwrap_static()).collect();
316 pad_dims[last] = pad_len;
317 let zeros = self.zeros_tensor(&Shape::new(&pad_dims, shape.dtype()));
318 self.concat_(vec![x, zeros], last)
319 }
320
321 fn zeros_tensor(&mut self, shape: &Shape) -> NodeId {
322 let n = shape.num_elements().unwrap();
323 let bytes = vec![0u8; n * shape.dtype().size_bytes()];
324 self.add_node(Op::Constant { data: bytes }, vec![], shape.clone())
325 }
326
327 fn scalar_f32(&mut self, v: f32) -> NodeId {
328 self.add_node(
329 Op::Constant {
330 data: v.to_le_bytes().to_vec(),
331 },
332 vec![],
333 Shape::scalar(DType::F32),
334 )
335 }
336}