runmat-accelerate 0.5.0

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
use super::*;

impl WgpuProvider {
    pub(super) fn fft_uniform_buffer<T: Pod>(
        &self,
        data: &T,
        label: &'static str,
    ) -> Arc<wgpu::Buffer> {
        let size = std::mem::size_of::<T>() as u64;
        let buffer = self.kernel_resources.uniform_buffer(
            self.device_ref(),
            UniformBufferKey::LenOpParams,
            size,
            label,
        );
        self.queue_ref()
            .write_buffer(buffer.as_ref(), 0, bytes_of(data));
        buffer
    }

    pub(super) fn fft_storage_param_buffer<T: Pod>(
        &self,
        data: &T,
        label: &str,
    ) -> Arc<wgpu::Buffer> {
        let size = std::mem::size_of::<T>() as u64;
        let buffer = Arc::new(self.device_ref().create_buffer(&wgpu::BufferDescriptor {
            label: Some(label),
            size: size.max(1),
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        }));
        self.queue_ref()
            .write_buffer(buffer.as_ref(), 0, bytes_of(data));
        buffer
    }

    pub(super) fn fft_debug_enabled() -> bool {
        std::env::var_os("RUNMAT_FFT_DEBUG").is_some()
    }

    pub(super) fn fft_debug_dump_scalar_buffer(
        &self,
        label: &str,
        buffer: &Arc<wgpu::Buffer>,
        scalar_len: usize,
    ) {
        if !Self::fft_debug_enabled() || scalar_len == 0 {
            return;
        }
        let sample_len = scalar_len.min(32);
        let size_bytes = (sample_len as u64).saturating_mul(self.element_size as u64);
        let staging = self.device_ref().create_buffer(&wgpu::BufferDescriptor {
            label: Some("runmat-fft-debug-staging"),
            size: size_bytes.max(1),
            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });
        let mut encoder =
            self.device_ref()
                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
                    label: Some("runmat-fft-debug-copy"),
                });
        encoder.copy_buffer_to_buffer(buffer.as_ref(), 0, &staging, 0, size_bytes);
        self.submit(encoder);
        let Ok(bytes) = self.map_readback_bytes_sync(staging, size_bytes, "fft-debug") else {
            eprintln!("[fft-debug] {label}: readback failed");
            return;
        };
        match self.precision {
            NumericPrecision::F64 => {
                let vals = bytemuck::cast_slice::<u8, f64>(&bytes);
                eprintln!("[fft-debug] {label} f64 {:?}", vals);
            }
            NumericPrecision::F32 => {
                let vals = bytemuck::cast_slice::<u8, f32>(&bytes);
                eprintln!("[fft-debug] {label} f32 {:?}", vals);
            }
        }
    }

    pub(super) fn fft_stage_buffer_pair(
        &self,
        len: usize,
        label_a: &str,
        label_b: &str,
    ) -> (Arc<wgpu::Buffer>, Arc<wgpu::Buffer>) {
        let a = self.create_storage_buffer(len, label_a);
        let mut b = self.create_storage_buffer(len, label_b);
        if Arc::ptr_eq(&a, &b) {
            let size_bytes = (len.max(1) as u64).saturating_mul(self.element_size as u64);
            b = Arc::new(self.device_ref().create_buffer(&wgpu::BufferDescriptor {
                label: Some(label_b),
                size: size_bytes,
                usage: wgpu::BufferUsages::STORAGE
                    | wgpu::BufferUsages::COPY_SRC
                    | wgpu::BufferUsages::COPY_DST,
                mapped_at_creation: false,
            }));
        }
        if Self::fft_debug_enabled() {
            eprintln!(
                "[fft-debug] stage buffers a={:p} b={:p}",
                Arc::as_ptr(&a),
                Arc::as_ptr(&b)
            );
        }
        (a, b)
    }

    pub(super) fn fft_twiddle_buffer(
        &self,
        len: usize,
        half_only: bool,
        label: &str,
    ) -> Result<Arc<wgpu::Buffer>> {
        let mode = if half_only { 1u8 } else { 0u8 };
        if let Ok(cache) = self.fft_twiddle_cache.lock() {
            if let Some(existing) = cache.get(&(len, mode)) {
                return Ok(existing.clone());
            }
        }

        let count = if half_only { len / 2 } else { len };
        let twiddle_scalar_len = count
            .checked_mul(2)
            .ok_or_else(|| anyhow!("fft_dim: twiddle buffer length overflow"))?;
        let size_bytes =
            (twiddle_scalar_len.max(1) as u64).saturating_mul(self.element_size as u64);
        let twiddle = Arc::new(self.device_ref().create_buffer(&wgpu::BufferDescriptor {
            label: Some(label),
            size: size_bytes,
            usage: wgpu::BufferUsages::STORAGE
                | wgpu::BufferUsages::COPY_DST
                | wgpu::BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        }));
        let tau = std::f64::consts::TAU;
        match self.precision {
            NumericPrecision::F64 => {
                let mut tw = Vec::with_capacity(twiddle_scalar_len);
                for k in 0..count {
                    let angle = -tau * (k as f64) / (len as f64);
                    tw.push(angle.cos());
                    tw.push(angle.sin());
                }
                self.queue_ref()
                    .write_buffer(twiddle.as_ref(), 0, cast_slice(&tw));
            }
            NumericPrecision::F32 => {
                let mut tw = Vec::with_capacity(twiddle_scalar_len);
                for k in 0..count {
                    let angle = -tau * (k as f64) / (len as f64);
                    tw.push(angle.cos() as f32);
                    tw.push(angle.sin() as f32);
                }
                self.queue_ref()
                    .write_buffer(twiddle.as_ref(), 0, cast_slice(&tw));
            }
        }

        if let Ok(mut cache) = self.fft_twiddle_cache.lock() {
            cache.insert((len, mode), twiddle.clone());
        }
        Ok(twiddle)
    }
}