use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="vocoder",
subop="istft",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
pub fn vocoder_istft<T>(
spec_re: Tensor<T>,
spec_im: Tensor<T>,
window: Tensor<T>,
out: Tensor<T>,
#[constexpr] n_frames: u32,
#[constexpr] n_fft: u32,
#[constexpr] n_freq: u32,
#[constexpr] hop_length: u32,
) {
let t = program_id::<0>();
let n_fft_f = n_fft.cast::<f32>();
let inv_n = 1.0f32 / n_fft_f;
let two_pi_over_n = 6.283185307179586f32 / n_fft_f;
let nyquist = n_fft / 2u32; let f_hi_raw = t / hop_length;
let f_hi = select(f_hi_raw < n_frames, f_hi_raw, n_frames - 1u32);
let has_lo = t + 1u32 > n_fft;
let lo_num = select(has_lo, t + 1u32 - n_fft, 0u32);
let f_lo = select(has_lo, (lo_num + hop_length - 1u32) / hop_length, 0u32);
let mut num = 0.0f32;
let mut den = 0.0f32;
for f in range(f_lo, f_hi + 1u32, 1u32) {
let tau = t - f * hop_length; let tau_f = tau.cast::<f32>();
let angle_step = two_pi_over_n * tau_f;
let row = f * n_freq;
let mut sample = 0.0f32;
for k in range(0u32, n_freq, 1u32) {
let re = load(spec_re[row + k]).cast::<f32>();
let im = load(spec_im[row + k]).cast::<f32>();
let angle = angle_step * k.cast::<f32>();
let contrib = re * cos(angle) - im * sin(angle);
let is_unpaired = (k == 0u32) | (k == nyquist);
let weight_k = select(is_unpaired, 1.0f32, 2.0f32);
sample = sample + weight_k * contrib;
}
sample = sample * inv_n;
let win = load(window[tau]).cast::<f32>();
num = num + sample * win;
den = den + win * win;
}
let safe_den = select(den > 1e-8f32, den, 1.0f32);
let out_val = select(den > 1e-8f32, num / safe_den, 0.0f32);
store(out[t], out_val.cast::<T>());
}