runmat-accelerate 0.4.5

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
use super::*;

#[derive(Clone, Copy)]
pub(crate) enum WindowKind {
    Hann = 0,
    Hamming = 1,
    Blackman = 2,
}

impl WgpuProvider {
    pub(crate) fn window_exec(
        &self,
        kind: WindowKind,
        len: usize,
        periodic: bool,
    ) -> Result<GpuTensorHandle> {
        let (logical_u32, total_u32) = window_lengths(len, periodic)?;
        let shape_vec = vec![len, 1];
        if len == 1 {
            let one = [1.0f64];
            return self.upload(&HostTensorView {
                data: &one,
                shape: &[1, 1],
            });
        }
        let out_buffer = self.create_storage_buffer_checked(len, "runmat-window-out")?;
        if len == 0 {
            return Ok(self.register_existing_buffer(out_buffer, shape_vec, 0));
        }

        let chunk_capacity = (crate::backend::wgpu::config::MAX_DISPATCH_WORKGROUPS as usize)
            * crate::backend::wgpu::config::WORKGROUP_SIZE as usize;
        let mut offset = 0usize;
        while offset < len {
            let chunk_len = (len - offset).min(chunk_capacity).max(1);
            let params = crate::backend::wgpu::params::WindowParams {
                len: logical_u32,
                total: total_u32,
                chunk: chunk_len as u32,
                offset: offset as u32,
                kind: kind as u32,
                _pad0: 0,
                _pad1: 0,
                _pad2: 0,
            };
            let params_buffer = self.uniform_buffer(&params, "runmat-window-params");

            let bind_group = self
                .device_ref()
                .create_bind_group(&wgpu::BindGroupDescriptor {
                    label: Some("runmat-window-bind"),
                    layout: &self.pipelines.window.layout,
                    entries: &[
                        wgpu::BindGroupEntry {
                            binding: 0,
                            resource: out_buffer.as_ref().as_entire_binding(),
                        },
                        wgpu::BindGroupEntry {
                            binding: 1,
                            resource: params_buffer.as_entire_binding(),
                        },
                    ],
                });

            let workgroups = crate::backend::wgpu::dispatch::common::dispatch_size(
                chunk_len as u32,
                crate::backend::wgpu::config::WORKGROUP_SIZE,
            );
            crate::backend::wgpu::dispatch::creation::run(
                self.device_ref(),
                self.queue_ref(),
                &self.pipelines.window.pipeline,
                &bind_group,
                workgroups,
                "runmat-window-encoder",
                "runmat-window-pass",
            );
            offset += chunk_len;
        }

        Ok(self.register_existing_buffer(out_buffer, shape_vec, len))
    }
}

fn window_lengths(len: usize, periodic: bool) -> Result<(u32, u32)> {
    if len > u32::MAX as usize || (periodic && len == u32::MAX as usize) {
        return Err(anyhow!("window: length exceeds GPU limits"));
    }
    let logical_u32 = len as u32;
    let total_u32 = if periodic {
        logical_u32 + 1
    } else {
        logical_u32
    };
    Ok((logical_u32, total_u32))
}

#[cfg(test)]
mod tests {
    use super::window_lengths;

    #[test]
    fn periodic_window_rejects_u32_max_len() {
        let err = window_lengths(u32::MAX as usize, true).expect_err("expected overflow guard");
        assert!(err.to_string().contains("length exceeds GPU limits"));
    }

    #[test]
    fn symmetric_window_allows_u32_max_len() {
        let (logical, total) = window_lengths(u32::MAX as usize, false).expect("symmetric max len");
        assert_eq!(logical, u32::MAX);
        assert_eq!(total, u32::MAX);
    }
}