use crate::error::{PeacoQCError, Result};
use crate::gpu::context::GpuContext;
use realfft::num_complex::Complex;
use realfft::RealFftPlanner;
pub struct KdeContext<'a> {
pub data: &'a [f64],
pub grid: &'a [f64],
pub bandwidth: f64,
pub n: f64,
}
pub fn kde_fft_batched_gpu(
contexts: &[KdeContext],
gpu_ctx: &mut GpuContext,
) -> Result<Vec<Vec<f64>>> {
if contexts.is_empty() {
return Ok(Vec::new());
}
let mut results = Vec::with_capacity(contexts.len());
for ctx in contexts {
let density = kde_fft_single_with_context(ctx, gpu_ctx)?;
results.push(density);
}
Ok(results)
}
fn kde_fft_single_with_context(
ctx: &KdeContext,
gpu_ctx: &mut GpuContext,
) -> Result<Vec<f64>> {
let KdeContext { data, grid, bandwidth, n } = ctx;
let m = grid.len();
if m < 2 {
return Err(PeacoQCError::StatsError("Grid must have at least 2 points".to_string()));
}
let grid_min = grid[0];
let grid_max = grid[m - 1];
let grid_spacing = (grid_max - grid_min) / (m - 1) as f64;
let mut binned = vec![0.0; m];
for &x in *data {
let idx = ((x - grid_min) / grid_spacing).floor() as isize;
if idx >= 0 && (idx as usize) < m {
binned[idx as usize] += 1.0;
}
}
let kernel_center = (m - 1) as f64 / 2.0;
let mut kernel = Vec::with_capacity(m);
for i in 0..m {
let grid_pos = (i as f64 - kernel_center) * grid_spacing;
let u = grid_pos / *bandwidth;
kernel.push(gaussian_kernel(u));
}
let fft_size = (2 * m).next_power_of_two();
let mut planner = RealFftPlanner::<f64>::new();
let r2c = planner.plan_fft_forward(fft_size);
let c2r = planner.plan_fft_inverse(fft_size);
let mut binned_padded = vec![0.0; fft_size];
binned_padded[..m].copy_from_slice(&binned);
let mut binned_spectrum = r2c.make_output_vec();
r2c.process(&mut binned_padded, &mut binned_spectrum)
.map_err(|e| PeacoQCError::StatsError(format!("FFT forward failed: {}", e)))?;
let kernel_spectrum = gpu_ctx.get_or_compute_kernel_spectrum(&kernel, fft_size, *bandwidth)?;
let conv_spectrum = multiply_spectra_gpu_with_context(&binned_spectrum, &kernel_spectrum, gpu_ctx)?;
let mut conv_result = c2r.make_output_vec();
let mut conv_spectrum_mut = conv_spectrum;
c2r.process(&mut conv_spectrum_mut, &mut conv_result)
.map_err(|e| PeacoQCError::StatsError(format!("FFT inverse failed: {}", e)))?;
let kernel_start = (fft_size - m) / 2;
let mut density = Vec::with_capacity(m);
for i in 0..m {
let idx = (kernel_start + i) % fft_size;
density.push(conv_result[idx]);
}
let density: Vec<f64> = density
.iter()
.map(|&val| val / (fft_size as f64 * *n * *bandwidth))
.collect();
Ok(density)
}
fn multiply_spectra_gpu_with_context(
a: &[Complex<f64>],
b: &[Complex<f64>],
gpu_ctx: &GpuContext,
) -> Result<Vec<Complex<f64>>> {
use burn::tensor::{Tensor, TensorData};
let device = gpu_ctx.device();
let n = a.len();
let mut a_real = Vec::with_capacity(n);
let mut a_imag = Vec::with_capacity(n);
let mut b_real = Vec::with_capacity(n);
let mut b_imag = Vec::with_capacity(n);
for &c in a {
a_real.push(c.re);
a_imag.push(c.im);
}
for &c in b {
b_real.push(c.re);
b_imag.push(c.im);
}
let a_re_bytes: Vec<u8> = a_real.iter().flat_map(|x| x.to_le_bytes()).collect();
let a_im_bytes: Vec<u8> = a_imag.iter().flat_map(|x| x.to_le_bytes()).collect();
let b_re_bytes: Vec<u8> = b_real.iter().flat_map(|x| x.to_le_bytes()).collect();
let b_im_bytes: Vec<u8> = b_imag.iter().flat_map(|x| x.to_le_bytes()).collect();
type Backend = burn::backend::wgpu::Wgpu;
let a_re_tensor = Tensor::<Backend, 1, burn::tensor::Float>::from_data(
TensorData::new(a_re_bytes.into(), vec![n]), device
);
let a_im_tensor = Tensor::<Backend, 1, burn::tensor::Float>::from_data(
TensorData::new(a_im_bytes.into(), vec![n]), device
);
let b_re_tensor = Tensor::<Backend, 1, burn::tensor::Float>::from_data(
TensorData::new(b_re_bytes.into(), vec![n]), device
);
let b_im_tensor = Tensor::<Backend, 1, burn::tensor::Float>::from_data(
TensorData::new(b_im_bytes.into(), vec![n]), device
);
let re_result = a_re_tensor.clone().mul(b_re_tensor.clone())
.sub(a_im_tensor.clone().mul(b_im_tensor.clone()));
let im_result = a_re_tensor.clone().mul(b_im_tensor.clone())
.add(a_im_tensor.clone().mul(b_re_tensor.clone()));
let re_data = re_result.to_data();
let im_data = im_result.to_data();
let re_values: Vec<f64> = re_data.as_slice::<f64>().unwrap().to_vec();
let im_values: Vec<f64> = im_data.as_slice::<f64>().unwrap().to_vec();
let mut result = Vec::with_capacity(n);
for i in 0..n {
result.push(Complex::new(re_values[i], im_values[i]));
}
Ok(result)
}
#[inline]
fn gaussian_kernel(u: f64) -> f64 {
const INV_SQRT_2PI: f64 = 0.3989422804014327;
INV_SQRT_2PI * (-0.5 * u * u).exp()
}