use burn::backend::wgpu::WgpuDevice;
use crate::error::Result;
use realfft::num_complex::Complex;
pub struct GpuContext {
device: WgpuDevice,
cached_kernel_spectrum: Option<(Vec<Complex<f64>>, usize, f64)>,
}
impl GpuContext {
pub fn new() -> Result<Self> {
let device = WgpuDevice::default();
Ok(Self {
device,
cached_kernel_spectrum: None,
})
}
pub fn device(&self) -> &WgpuDevice {
&self.device
}
pub fn get_or_compute_kernel_spectrum(
&mut self,
kernel: &[f64],
fft_size: usize,
bandwidth: f64,
) -> Result<Vec<Complex<f64>>> {
if let Some((cached_spectrum, cached_size, cached_bw)) = &self.cached_kernel_spectrum {
if *cached_size == fft_size && (cached_bw - bandwidth).abs() < 1e-10 {
return Ok(cached_spectrum.clone());
}
}
use realfft::RealFftPlanner;
let mut planner = RealFftPlanner::<f64>::new();
let r2c = planner.plan_fft_forward(fft_size);
let m = kernel.len();
let mut kernel_padded = vec![0.0; fft_size];
let kernel_start = (fft_size - m) / 2;
let first_half = (m + 1) / 2;
kernel_padded[kernel_start..kernel_start + first_half].copy_from_slice(&kernel[m / 2..]);
let second_half = m / 2;
if second_half > 0 {
kernel_padded[..second_half].copy_from_slice(&kernel[..second_half]);
}
let mut kernel_spectrum = r2c.make_output_vec();
r2c.process(&mut kernel_padded, &mut kernel_spectrum)
.map_err(|e| crate::error::PeacoQCError::StatsError(format!("FFT forward failed: {}", e)))?;
self.cached_kernel_spectrum = Some((kernel_spectrum.clone(), fft_size, bandwidth));
Ok(kernel_spectrum)
}
pub fn clear_kernel_cache(&mut self) {
self.cached_kernel_spectrum = None;
}
}