pub const KERNEL_WINDOW_FRAMES: &str = r#"
extern "C" __global__ void window_frames(
const float* __restrict__ signal,
float* __restrict__ output,
const int signal_len,
const int frame_size,
const int hop,
const int n_frames
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = n_frames * frame_size;
if (idx >= total) return;
int frame = idx / frame_size;
int sample = idx % frame_size;
int src = frame * hop + sample;
float val = (src < signal_len) ? signal[src] : 0.0f;
float w = 0.5f * (1.0f - cosf(2.0f * 3.14159265359f * (float)sample / (float)(frame_size - 1)));
output[idx] = val * w;
}
"#;
pub const KERNEL_MAGNITUDE: &str = r#"
extern "C" __global__ void magnitude(
const float* __restrict__ complex_data,
float* __restrict__ mag_out,
const int n_bins,
const int n_frames
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = n_bins * n_frames;
if (idx >= total) return;
int frame = idx / n_bins;
int bin = idx % n_bins;
int src = (frame * n_bins + bin) * 2;
float re = complex_data[src];
float im = complex_data[src + 1];
int dst = bin * n_frames + frame;
mag_out[dst] = sqrtf(re * re + im * im);
}
"#;
pub const KERNEL_MEDIAN_FILTER: &str = r#"
extern "C" __global__ void median_filter(
const int n_rows,
const int n_cols,
const int kernel_size,
const int horizontal,
const float* __restrict__ input,
float* __restrict__ output
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = n_rows * n_cols;
if (idx >= total) return;
int row = idx / n_cols;
int col = idx % n_cols;
int half = kernel_size / 2;
float window[31];
int count = 0;
int center, axis_len;
if (horizontal) {
center = col;
axis_len = n_cols;
} else {
center = row;
axis_len = n_rows;
}
int start = (center >= half) ? center - half : 0;
int end = (center + half + 1 < axis_len) ? center + half + 1 : axis_len;
for (int i = start; i < end; i++) {
int src_idx;
if (horizontal) {
src_idx = row * n_cols + i;
} else {
src_idx = i * n_cols + col;
}
float val = input[src_idx];
int j = count;
while (j > 0 && window[j - 1] > val) {
window[j] = window[j - 1];
j--;
}
window[j] = val;
count++;
}
output[idx] = window[count / 2];
}
"#;
pub const KERNEL_SOFT_MASK: &str = r#"
extern "C" __global__ void soft_mask(
const float* __restrict__ a_mag,
const float* __restrict__ b_mag,
const float* __restrict__ complex_in,
float* __restrict__ complex_out,
const int n_bins,
const int n_frames
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = n_bins * n_frames;
if (idx >= total) return;
int frame = idx / n_bins;
int bin = idx % n_bins;
int mag_idx = bin * n_frames + frame;
float aa = a_mag[mag_idx] * a_mag[mag_idx];
float bb = b_mag[mag_idx] * b_mag[mag_idx];
float mask = aa / (aa + bb + 1e-10f);
int c_idx = (frame * n_bins + bin) * 2;
complex_out[c_idx] = complex_in[c_idx] * mask;
complex_out[c_idx + 1] = complex_in[c_idx + 1] * mask;
}
"#;
pub const KERNEL_OVERLAP_ADD: &str = r#"
extern "C" __global__ void overlap_add(
const float* __restrict__ frames,
float* __restrict__ output,
const int frame_size,
const int hop,
const int n_frames,
const int output_len
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total = n_frames * frame_size;
if (idx >= total) return;
int frame = idx / frame_size;
int sample = idx % frame_size;
int dst = frame * hop + sample;
if (dst >= output_len) return;
float w = 0.5f * (1.0f - cosf(2.0f * 3.14159265359f * (float)sample / (float)(frame_size - 1)));
atomicAdd(&output[dst], frames[idx] * w);
}
"#;