use crate::kernels::*;
use cudarc::cufft::{sys as cufft_sys, CudaFft};
use cudarc::driver::{
CudaContext, CudaFunction, CudaSlice, CudaStream, LaunchConfig, PushKernelArg,
};
#[derive(Clone, Debug)]
pub struct DspConfig {
pub frame_size: usize,
pub hop: usize,
pub gpu_arch: &'static str,
}
impl Default for DspConfig {
fn default() -> Self {
Self {
frame_size: 2048,
hop: 512,
gpu_arch: "sm_86",
}
}
}
pub struct GpuDsp {
pub stream: std::sync::Arc<CudaStream>,
pub config: DspConfig,
pub window_func: CudaFunction,
pub magnitude_func: CudaFunction,
pub median_func: CudaFunction,
pub soft_mask_func: CudaFunction,
pub overlap_add_func: CudaFunction,
}
impl GpuDsp {
pub fn new(config: DspConfig) -> Option<Self> {
let t = std::time::Instant::now();
let ctx = match CudaContext::new(0) {
Ok(c) => c,
Err(e) => {
eprintln!("[moe-gpu-dsp] CUDA init failed: {e}");
return None;
}
};
let stream = ctx.default_stream();
let opts = cudarc::nvrtc::CompileOptions {
arch: Some(config.gpu_arch),
..Default::default()
};
let src = format!(
"{}\n{}\n{}\n{}\n{}",
KERNEL_WINDOW_FRAMES,
KERNEL_MAGNITUDE,
KERNEL_MEDIAN_FILTER,
KERNEL_SOFT_MASK,
KERNEL_OVERLAP_ADD
);
let ptx = match cudarc::nvrtc::compile_ptx_with_opts(&src, opts) {
Ok(p) => p,
Err(e) => {
eprintln!("[moe-gpu-dsp] NVRTC failed: {e}");
return None;
}
};
let module = match ctx.load_module(ptx) {
Ok(m) => m,
Err(e) => {
eprintln!("[moe-gpu-dsp] module load failed: {e}");
return None;
}
};
let f = |name: &str| -> Option<CudaFunction> {
match module.load_function(name) {
Ok(f) => Some(f),
Err(e) => {
eprintln!("[moe-gpu-dsp] {name}: {e}");
None
}
}
};
let name = ctx.name().unwrap_or_else(|_| "unknown".into());
eprintln!(
"[moe-gpu-dsp] {} ready ({:.2}s)",
name,
t.elapsed().as_secs_f64()
);
Some(Self {
stream,
config,
window_func: f("window_frames")?,
magnitude_func: f("magnitude")?,
median_func: f("median_filter")?,
soft_mask_func: f("soft_mask")?,
overlap_add_func: f("overlap_add")?,
})
}
pub fn launch_cfg(total: usize) -> LaunchConfig {
let bs = 256u32;
LaunchConfig {
grid_dim: (((total as u32) + bs - 1) / bs, 1, 1),
block_dim: (bs, 1, 1),
shared_mem_bytes: 0,
}
}
pub fn n_bins(&self) -> usize {
self.config.frame_size / 2 + 1
}
pub fn n_frames(&self, len: usize) -> usize {
if len >= self.config.frame_size {
(len - self.config.frame_size) / self.config.hop + 1
} else {
0
}
}
pub fn window_frames(&self, signal: &[f32]) -> (CudaSlice<f32>, CudaSlice<f32>, usize) {
let len = signal.len();
let nf = self.n_frames(len);
let fs = self.config.frame_size;
let total = nf * fs;
let mut sig = self.stream.alloc_zeros::<f32>(len).unwrap();
self.stream.memcpy_htod(signal, &mut sig).unwrap();
let mut out: CudaSlice<f32> = self.stream.alloc_zeros(total).unwrap();
unsafe {
self.stream
.launch_builder(&self.window_func)
.arg(&sig)
.arg(&mut out)
.arg(&(len as i32))
.arg(&(fs as i32))
.arg(&(self.config.hop as i32))
.arg(&(nf as i32))
.launch(Self::launch_cfg(total))
.unwrap();
}
(sig, out, nf)
}
pub fn batch_fft_r2c(
&self,
windowed: &CudaSlice<f32>,
n_frames: usize,
) -> CudaSlice<cufft_sys::float2> {
let nb = self.n_bins();
let plan = CudaFft::plan_1d(
self.config.frame_size as i32,
cufft_sys::cufftType::CUFFT_R2C,
n_frames as i32,
self.stream.clone(),
)
.expect("[moe-gpu-dsp] R2C plan failed");
let mut out: CudaSlice<cufft_sys::float2> = self.stream.alloc_zeros(n_frames * nb).unwrap();
plan.exec_r2c(windowed, &mut out)
.expect("[moe-gpu-dsp] R2C failed");
out
}
pub fn magnitude(
&self,
complex: &CudaSlice<cufft_sys::float2>,
n_frames: usize,
) -> CudaSlice<f32> {
let nb = self.n_bins();
let total = nb * n_frames;
let complex_floats = unsafe { complex.transmute::<f32>(n_frames * nb * 2).unwrap() };
let mut mag: CudaSlice<f32> = self.stream.alloc_zeros(total).unwrap();
unsafe {
self.stream
.launch_builder(&self.magnitude_func)
.arg(&complex_floats)
.arg(&mut mag)
.arg(&(nb as i32))
.arg(&(n_frames as i32))
.launch(Self::launch_cfg(total))
.unwrap();
}
mag
}
pub fn complex_as_floats(
complex: &CudaSlice<cufft_sys::float2>,
count: usize,
) -> cudarc::driver::CudaView<'_, f32> {
unsafe { complex.transmute::<f32>(count * 2).unwrap() }
}
pub fn median_filter(
&self,
input: &CudaSlice<f32>,
n_rows: usize,
n_cols: usize,
kernel_size: i32,
horizontal: bool,
) -> CudaSlice<f32> {
let total = n_rows * n_cols;
let mut out: CudaSlice<f32> = self.stream.alloc_zeros(total).unwrap();
let h: i32 = if horizontal { 1 } else { 0 };
unsafe {
self.stream
.launch_builder(&self.median_func)
.arg(&(n_rows as i32))
.arg(&(n_cols as i32))
.arg(&kernel_size)
.arg(&h)
.arg(input)
.arg(&mut out)
.launch(Self::launch_cfg(total))
.unwrap();
}
out
}
pub fn soft_mask(
&self,
a_mag: &CudaSlice<f32>,
b_mag: &CudaSlice<f32>,
complex_floats: &CudaSlice<f32>,
n_bins: usize,
n_frames: usize,
) -> CudaSlice<f32> {
let total = n_bins * n_frames;
let mut out: CudaSlice<f32> = self.stream.alloc_zeros(total * 2).unwrap();
unsafe {
self.stream
.launch_builder(&self.soft_mask_func)
.arg(a_mag)
.arg(b_mag)
.arg(complex_floats)
.arg(&mut out)
.arg(&(n_bins as i32))
.arg(&(n_frames as i32))
.launch(Self::launch_cfg(total))
.unwrap();
}
out
}
pub fn batch_ifft_c2r_ola(
&self,
masked_floats: &mut CudaSlice<f32>,
n_frames: usize,
output_len: usize,
) -> Vec<f32> {
let nb = self.n_bins();
let fs = self.config.frame_size;
let hop = self.config.hop;
let total_w = n_frames * fs;
let plan = CudaFft::plan_1d(
fs as i32,
cufft_sys::cufftType::CUFFT_C2R,
n_frames as i32,
self.stream.clone(),
)
.expect("[moe-gpu-dsp] C2R plan failed");
let mut cx = unsafe {
masked_floats
.transmute_mut::<cufft_sys::float2>(n_frames * nb)
.unwrap()
};
let mut istft: CudaSlice<f32> = self.stream.alloc_zeros(total_w).unwrap();
plan.exec_c2r(&mut cx, &mut istft)
.expect("[moe-gpu-dsp] C2R failed");
let mut out: CudaSlice<f32> = self.stream.alloc_zeros(output_len).unwrap();
unsafe {
self.stream
.launch_builder(&self.overlap_add_func)
.arg(&istft)
.arg(&mut out)
.arg(&(fs as i32))
.arg(&(hop as i32))
.arg(&(n_frames as i32))
.arg(&(output_len as i32))
.launch(Self::launch_cfg(total_w))
.unwrap();
}
let mut result = self.stream.clone_dtoh(&out).unwrap();
let scale = 1.0 / fs as f32;
for v in &mut result {
*v *= scale;
}
result
}
}