use metaltile::{bench_kernel, kernel};
use crate::bench_types::DType;
const _: DType = DType::F32;
#[bench_kernel(
op="mel_spectrogram",
subop="mel_spectrogram",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
pub fn mel_spectrogram<T>(
audio: Tensor<T>,
window: Tensor<T>,
mel_weight: Tensor<T>,
out: Tensor<T>,
#[constexpr] n_fft: u32,
#[constexpr] n_freq: u32,
#[constexpr] n_mels: u32,
#[constexpr] hop_length: u32,
#[constexpr] log_eps: f32,
) {
let idx = program_id::<0>();
let mel_bin = idx % n_mels;
let frame = idx / n_mels;
let frame_start = frame * hop_length;
let n_fft_f = n_fft.cast::<f32>();
let neg_two_pi_over_n = -6.283185307179586f32 / n_fft_f;
let mel_row = mel_bin * n_freq;
let mut mel_acc = 0.0f32;
for k in range(0u32, n_freq, 1u32) {
let k_f = k.cast::<f32>();
let angle_step = neg_two_pi_over_n * k_f;
let mut re = 0.0f32;
let mut im = 0.0f32;
for t in range(0u32, n_fft, 1u32) {
let sample = load(audio[frame_start + t]).cast::<f32>();
let win = load(window[t]).cast::<f32>();
let xw = sample * win;
let angle = angle_step * t.cast::<f32>();
re = re + xw * cos(angle);
im = im + xw * sin(angle);
}
let power = re * re + im * im;
let w = load(mel_weight[mel_row + k]).cast::<f32>();
mel_acc = mel_acc + w * power;
}
let log_mel = log(mel_acc + log_eps);
store(out[idx], log_mel.cast::<T>());
}
#[bench_kernel(
op="mel_spectrogram",
subop="stft_window",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
pub fn mel_stft_window<T>(
audio: Tensor<T>,
window: Tensor<T>,
mut out_re: Tensor<T>,
mut out_im: Tensor<T>,
#[constexpr] n_fft: u32,
#[constexpr] hop_length: u32,
) {
let idx = program_id::<0>();
let t = idx % n_fft;
let frame = idx / n_fft;
let sample = load(audio[frame * hop_length + t]).cast::<f32>();
let win = load(window[t]).cast::<f32>();
store(out_re[idx], (sample * win).cast::<T>());
store(out_im[idx], 0.0f32.cast::<T>());
}
#[bench_kernel(
op="mel_spectrogram",
subop="filterbank",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
pub fn mel_filterbank<T>(
fft_re: Tensor<T>,
fft_im: Tensor<T>,
mel_weight: Tensor<T>,
out: Tensor<T>,
#[constexpr] n_fft: u32,
#[constexpr] n_freq: u32,
#[constexpr] n_mels: u32,
#[constexpr] log_eps: f32,
) {
let idx = program_id::<0>();
let mel_bin = idx % n_mels;
let frame = idx / n_mels;
let frame_base = frame * n_fft;
let mel_row = mel_bin * n_freq;
let mut mel_acc = 0.0f32;
for k in range(0u32, n_freq, 1u32) {
let re = load(fft_re[frame_base + k]).cast::<f32>();
let im = load(fft_im[frame_base + k]).cast::<f32>();
let power = re * re + im * im;
let w = load(mel_weight[mel_row + k]).cast::<f32>();
mel_acc = mel_acc + w * power;
}
let log_mel = log(mel_acc + log_eps);
store(out[idx], log_mel.cast::<T>());
}