use std::{ffi::c_void, sync::Arc};
use cudarc::driver::{CudaFunction, CudaSlice, CudaStream, DevicePtr, LaunchConfig, PushKernelArg};
use lin_alg::f32::{Vec3, vec3s_to_dev};
use crate::{
PmeRecip,
fft::{destroy_plan, exec_forward, exec_inverse},
self_energy,
};
pub(crate) struct GpuData {
pub planner_gpu: *mut c_void,
pub gpu_tables: GpuTables,
pub kernels: Kernels,
}
pub(crate) struct Kernels {
pub kernel_spread: CudaFunction,
pub kernel_ghat: CudaFunction,
pub kernel_gather: CudaFunction,
pub kernel_half_spectrum: CudaFunction,
}
pub(crate) struct GpuTables {
pub kx: CudaSlice<f32>,
pub ky: CudaSlice<f32>,
pub kz: CudaSlice<f32>,
pub bx: CudaSlice<f32>,
pub by: CudaSlice<f32>,
pub bz: CudaSlice<f32>,
}
impl GpuTables {
pub(crate) fn new(
k: (&Vec<f32>, &Vec<f32>, &Vec<f32>),
bmod2: (&Vec<f32>, &Vec<f32>, &Vec<f32>),
stream: &Arc<CudaStream>,
) -> Self {
Self {
kx: stream.clone_htod(k.0).unwrap(),
ky: stream.clone_htod(k.1).unwrap(),
kz: stream.clone_htod(k.2).unwrap(),
bx: stream.clone_htod(bmod2.0).unwrap(),
by: stream.clone_htod(bmod2.1).unwrap(),
bz: stream.clone_htod(bmod2.2).unwrap(),
}
}
}
impl PmeRecip {
pub fn forces_gpu(
&mut self,
stream: &Arc<CudaStream>,
posits: &[Vec3],
q: &[f32],
) -> (Vec<Vec3>, f32) {
let Some(data) = &mut self.gpu_data else {
panic!("Error: Computing forces on GPU without having initialized on GPU");
};
assert_eq!(posits.len(), q.len());
let (nx, ny, nz) = self.plan_dims;
let n_real = nx * ny * nz;
let nzc = nz / 2 + 1;
let n_cplx = nx * ny * nzc;
let complex_len = n_cplx * 2;
let pos_dev = vec3s_to_dev(stream, posits);
let q_dev = stream.clone_htod(q).unwrap();
let mut rho_real_dev: CudaSlice<f32> = stream.alloc_zeros(n_real).unwrap();
let mut rho_dev: CudaSlice<f32> = stream.alloc_zeros(complex_len).unwrap();
spread_charges(
stream,
&data.kernels.kernel_spread,
&pos_dev,
&q_dev,
&mut rho_real_dev,
posits.len() as u32,
self.plan_dims,
self.box_dims,
);
unsafe {
exec_forward(
data.planner_gpu,
cuda_slice_to_ptr_mut(&rho_real_dev, stream),
cuda_slice_to_ptr_mut(&rho_dev, stream),
);
}
let mut exk_dev: CudaSlice<f32> = stream.alloc_zeros(complex_len).unwrap();
let mut eyk_dev: CudaSlice<f32> = stream.alloc_zeros(complex_len).unwrap();
let mut ezk_dev: CudaSlice<f32> = stream.alloc_zeros(complex_len).unwrap();
let ekx_ptr = cuda_slice_to_ptr_mut(&exk_dev, stream);
let eky_ptr = cuda_slice_to_ptr_mut(&eyk_dev, stream);
let ekz_ptr = cuda_slice_to_ptr_mut(&ezk_dev, stream);
apply_ghat_and_grad(
stream,
&data.kernels.kernel_ghat,
&rho_dev,
&mut exk_dev,
&mut eyk_dev,
&mut ezk_dev,
&data.gpu_tables,
self.plan_dims,
self.vol,
self.alpha,
);
let mut out_partial_gpu: CudaSlice<f64> = stream.alloc_zeros(n_cplx).unwrap();
energy_half_spectrum(
stream,
&data.kernels.kernel_half_spectrum,
&mut rho_dev,
&mut out_partial_gpu,
&data.gpu_tables,
self.plan_dims,
self.vol,
self.alpha,
);
let energy: f64 = stream
.clone_dtoh(&out_partial_gpu)
.unwrap()
.into_iter()
.sum();
let energy = (energy + self_energy(q, self.alpha)) as f32;
let ex_dev: CudaSlice<f32> = stream.alloc_zeros(n_real).unwrap();
let ey_dev: CudaSlice<f32> = stream.alloc_zeros(n_real).unwrap();
let ez_dev: CudaSlice<f32> = stream.alloc_zeros(n_real).unwrap();
let ex_ptr = cuda_slice_to_ptr_mut(&ex_dev, stream);
let ey_ptr = cuda_slice_to_ptr_mut(&ey_dev, stream);
let ez_ptr = cuda_slice_to_ptr_mut(&ez_dev, stream);
unsafe {
exec_inverse(
data.planner_gpu,
ekx_ptr,
eky_ptr,
ekz_ptr,
ex_ptr,
ey_ptr,
ez_ptr,
);
}
let n_atoms = posits.len();
let mut out_f_gpu: CudaSlice<f32> = stream.alloc_zeros(3 * n_atoms).unwrap();
gather_forces_to_atoms(
stream,
&data.kernels.kernel_gather,
&pos_dev,
&q_dev,
&ex_dev,
&ey_dev,
&ez_dev,
&mut out_f_gpu,
self.plan_dims,
self.box_dims,
);
let f_host: Vec<f32> = stream.clone_dtoh(&out_f_gpu).unwrap();
let mut f = Vec::with_capacity(posits.len());
for i in 0..posits.len() {
f.push(-Vec3 {
x: f_host[i * 3 + 0],
y: f_host[i * 3 + 1],
z: f_host[i * 3 + 2],
});
}
(f, energy)
}
}
impl Drop for PmeRecip {
fn drop(&mut self) {
let Some(data) = &mut self.gpu_data else {
return;
};
unsafe {
if !data.planner_gpu.is_null() {
destroy_plan(data.planner_gpu);
data.planner_gpu = std::ptr::null_mut();
}
}
}
}
pub(crate) fn _cuda_slice_to_ptr<T>(buf: &CudaSlice<T>, stream: &Arc<CudaStream>) -> *const c_void {
let (p, _) = buf.device_ptr(stream);
p as *const c_void
}
pub(crate) fn cuda_slice_to_ptr_mut<T>(
buf: &CudaSlice<T>,
stream: &Arc<CudaStream>,
) -> *mut c_void {
let (p, _) = buf.device_ptr(stream);
p as *mut c_void
}
fn spread_charges(
stream: &Arc<CudaStream>,
kernel: &CudaFunction,
pos_dev: &CudaSlice<f32>,
q_dev: &CudaSlice<f32>,
rho_dev: &mut CudaSlice<f32>, n_posits: u32,
plan_dims: (usize, usize, usize),
box_dims: (f32, f32, f32),
) {
let (nx, ny, nz) = plan_dims;
let nx_i = nx as i32;
let ny_i = ny as i32;
let nz_i = nz as i32;
let (lx, ly, lz) = box_dims;
let n_atoms_i = n_posits as i32;
let cfg = launch_cfg(n_posits, 256);
let mut launch_args = stream.launch_builder(kernel);
launch_args.arg(pos_dev);
launch_args.arg(q_dev);
launch_args.arg(rho_dev);
launch_args.arg(&n_atoms_i);
launch_args.arg(&nx_i);
launch_args.arg(&ny_i);
launch_args.arg(&nz_i);
launch_args.arg(&lx);
launch_args.arg(&ly);
launch_args.arg(&lz);
unsafe { launch_args.launch(cfg) }.unwrap();
}
fn apply_ghat_and_grad(
stream: &Arc<CudaStream>,
kernel: &CudaFunction,
rho_dev: &CudaSlice<f32>, ekx_dev: &mut CudaSlice<f32>,
eky_dev: &mut CudaSlice<f32>,
ekz_dev: &mut CudaSlice<f32>,
tables: &GpuTables,
plan_dims: (usize, usize, usize),
vol: f32,
alpha: f32,
) {
let (nx, ny, nz) = plan_dims;
let nx_i = nx as i32;
let ny_i = ny as i32;
let nz_i = nz as i32;
let n = nx * ny * (nz / 2 + 1);
let n_real = (nx * ny * nz) as i32;
let cfg = launch_cfg(n as u32, 256);
let mut launch_args = stream.launch_builder(kernel);
launch_args.arg(rho_dev);
launch_args.arg(ekx_dev);
launch_args.arg(eky_dev);
launch_args.arg(ekz_dev);
launch_args.arg(&tables.kx);
launch_args.arg(&tables.ky);
launch_args.arg(&tables.kz);
launch_args.arg(&tables.bx);
launch_args.arg(&tables.by);
launch_args.arg(&tables.bz);
launch_args.arg(&nx_i);
launch_args.arg(&ny_i);
launch_args.arg(&nz_i);
launch_args.arg(&vol);
launch_args.arg(&alpha);
launch_args.arg(&n_real);
unsafe { launch_args.launch(cfg) }.unwrap();
}
fn energy_half_spectrum(
stream: &Arc<CudaStream>,
kernel: &CudaFunction,
rho_cplx_gpu: &mut CudaSlice<f32>,
out_partial_gpu: &mut CudaSlice<f64>,
tables: &GpuTables,
plan_dims: (usize, usize, usize),
vol: f32,
alpha: f32,
) {
let (nx, ny, nz) = plan_dims;
let nx_i = nx as i32;
let ny_i = ny as i32;
let nz_i = nz as i32;
let n = (nx * ny * (nz / 2 + 1)) as i32;
let cfg = launch_cfg(n as u32, 256);
let mut launch_args = stream.launch_builder(kernel);
launch_args.arg(rho_cplx_gpu);
launch_args.arg(&tables.kx);
launch_args.arg(&tables.ky);
launch_args.arg(&tables.kz);
launch_args.arg(&tables.bx);
launch_args.arg(&tables.by);
launch_args.arg(&tables.bz);
launch_args.arg(&nx_i);
launch_args.arg(&ny_i);
launch_args.arg(&nz_i);
launch_args.arg(&vol);
launch_args.arg(&alpha);
launch_args.arg(out_partial_gpu);
unsafe { launch_args.launch(cfg) }.unwrap();
}
fn gather_forces_to_atoms(
stream: &Arc<CudaStream>,
kernel: &CudaFunction,
pos_gpu: &CudaSlice<f32>,
q_gpu: &CudaSlice<f32>,
ex_gpu: &CudaSlice<f32>,
ey_gpu: &CudaSlice<f32>,
ez_gpu: &CudaSlice<f32>,
out_partial_gpu: &mut CudaSlice<f32>,
plan_dims: (usize, usize, usize),
box_dims: (f32, f32, f32),
) {
let (nx, ny, nz) = plan_dims;
let nx_i = nx as i32;
let ny_i = ny as i32;
let nz_i = nz as i32;
let (lx, ly, lz) = box_dims;
let n = pos_gpu.len() / 3;
let n_u32 = n as u32;
let n_i32 = n as i32;
let cfg = launch_cfg(n_u32, 256);
let mut launch_args = stream.launch_builder(kernel);
launch_args.arg(pos_gpu);
launch_args.arg(ex_gpu);
launch_args.arg(ey_gpu);
launch_args.arg(ez_gpu);
launch_args.arg(q_gpu);
launch_args.arg(out_partial_gpu);
launch_args.arg(&n_i32);
launch_args.arg(&nx_i);
launch_args.arg(&ny_i);
launch_args.arg(&nz_i);
launch_args.arg(&lx);
launch_args.arg(&ly);
launch_args.arg(&lz);
unsafe { launch_args.launch(cfg) }.unwrap();
}
fn launch_cfg(n: u32, block: u32) -> LaunchConfig {
let grid = (n + block - 1) / block; LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
}
}