use crate::prelude::*;
use cudarc::driver::{CudaSlice, CudaStream, CudaView, CudaViewMut, DevicePtr, DevicePtrMut};
use std::sync::Arc;
pub type LibXCCudaInput<'a> = HashMap<String, CudaView<'a, f64>>;
pub type LibXCCudaOutputMut<'a> = HashMap<String, CudaViewMut<'a, f64>>;
fn cuda_view_ptr(view: &CudaView<f64>, stream: &Arc<CudaStream>) -> *const f64 {
let (ptr, _sync) = view.device_ptr(stream);
ptr as *const f64
}
fn cuda_view_mut_ptr(view: &mut CudaViewMut<f64>, stream: &Arc<CudaStream>) -> *mut f64 {
let (ptr, _sync) = view.device_ptr_mut(stream);
ptr as *mut f64
}
fn require_cuda_input_ptr(
input: &LibXCCudaInput,
key: &str,
npoints: usize,
expected_dim: i32,
stream: &Arc<CudaStream>,
) -> Result<*const f64, LibXCError> {
let view = input.get(key).ok_or_else(|| {
LibXCError::ComputeError(format!("{key}: required CUDA input not provided"))
})?;
let expected = npoints * (expected_dim as usize);
if view.len() != expected {
return Err(LibXCError::ComputeError(format!(
"{key}: expected size {expected}, got {}",
view.len()
)));
}
Ok(cuda_view_ptr(view, stream))
}
fn conditional_cuda_input_ptr(
input: &LibXCCudaInput,
key: &str,
npoints: usize,
expected_dim: i32,
required: bool,
stream: &Arc<CudaStream>,
) -> Result<*const f64, LibXCError> {
match (input.get(key), required) {
(Some(view), true) => {
let expected = npoints * (expected_dim as usize);
if view.len() != expected {
return Err(LibXCError::ComputeError(format!(
"{key}: expected size {expected}, got {}",
view.len()
)));
}
Ok(cuda_view_ptr(view, stream))
},
(None, true) => {
Err(LibXCError::ComputeError(format!("{key}: required CUDA input not provided")))
},
(_, false) => Ok(std::ptr::null()),
}
}
fn validate_cuda_output_ptr(
output: &mut LibXCCudaOutputMut,
key: &str,
npoints: usize,
expected_dim: i32,
stream: &Arc<CudaStream>,
) -> Result<*mut f64, LibXCError> {
match output.get_mut(key) {
Some(view) => {
let expected = npoints * (expected_dim as usize);
if view.len() != expected {
return Err(LibXCError::ComputeError(format!(
"{key}: expected size {expected}, got {}",
view.len()
)));
}
Ok(cuda_view_mut_ptr(view, stream))
},
None => Ok(std::ptr::null_mut()),
}
}
fn validate_cuda_output_ptrs(
output: &mut LibXCCudaOutputMut,
labels: &[&'static str],
npoints: usize,
dim: &ffi::xc_dimensions,
stream: &Arc<CudaStream>,
) -> Result<HashMap<&'static str, *mut f64>, LibXCError> {
let mut ptrs = HashMap::new();
for &label in labels {
let d = crate::layout_handling::get_dim(dim, label);
let ptr = validate_cuda_output_ptr(output, label, npoints, d, stream)?;
if !ptr.is_null() {
ptrs.insert(label, ptr);
}
}
Ok(ptrs)
}
fn guard_on_device(func: &LibXCFunctional) -> Result<(), LibXCError> {
if !func.is_on_device() {
Err(LibXCError::ComputeError(
"functional was not initialized for GPU; use from_identifier_with_device with OnDevice"
.into(),
))
} else {
Ok(())
}
}
impl LibXCFunctional {
fn cuda_lda_prepare(
&self,
input: &LibXCCudaInput,
deriv_flags: impl Into<LibXCDerivativeFlags>,
stream: &Arc<CudaStream>,
) -> Result<(usize, *const f64, LibXCOutputLayout), LibXCError> {
let flags = deriv_flags.into();
self.validate_flags(flags)?;
guard_on_device(self)?;
let view = input.get("rho").ok_or_else(|| {
LibXCError::ComputeError("rho: required CUDA input not provided".into())
})?;
let nspin = self.spin() as usize;
if view.len() % nspin != 0 {
return Err(LibXCError::ComputeError(
"rho input has invalid shape: size not divisible by nspin".into(),
));
}
let npoints = view.len() / nspin;
let rho_ptr = cuda_view_ptr(view, stream);
let layout = self.lda_output_layout(npoints, flags);
Ok((npoints, rho_ptr, layout))
}
pub fn cuda_compute_lda(
&self,
stream: &Arc<CudaStream>,
input: &LibXCCudaInput,
flags: impl Into<LibXCDerivativeFlags>,
) -> Result<(CudaSlice<f64>, LibXCOutputLayout), LibXCError> {
let (npoints, rho_ptr, layout) = self.cuda_lda_prepare(input, flags, stream)?;
let mut buffer = stream
.alloc_zeros::<f64>(layout.total_size)
.map_err(|e| LibXCError::CudaError(format!("CUDA allocation failed: {e}")))?;
{
let (output_base, _sync) = buffer.device_ptr_mut(stream);
unsafe {
xc_lda_call(self.ptr, npoints, rho_ptr, output_base as *mut f64, &layout);
}
}
Ok((buffer, layout))
}
pub fn cuda_compute_lda_with_unsliced_output(
&self,
input: &LibXCCudaInput,
output: &mut CudaSlice<f64>,
deriv_flags: impl Into<LibXCDerivativeFlags>,
) -> Result<LibXCOutputLayout, LibXCError> {
let stream = output.stream().clone();
let (npoints, rho_ptr, layout) = self.cuda_lda_prepare(input, deriv_flags, &stream)?;
if output.len() < layout.total_size {
return Err(LibXCError::ComputeError(format!(
"output buffer has too small size: expected {}, got {}",
layout.total_size,
output.len()
)));
}
let (output_base, _sync) = output.device_ptr_mut(&stream);
unsafe {
xc_lda_call(self.ptr, npoints, rho_ptr, output_base as *mut f64, &layout);
}
Ok(layout)
}
pub fn cuda_compute_lda_with_output(
&self,
stream: &Arc<CudaStream>,
input: &LibXCCudaInput,
output: &mut LibXCCudaOutputMut,
) -> Result<(), LibXCError> {
let view = input.get("rho").ok_or_else(|| {
LibXCError::ComputeError("rho: required CUDA input not provided".into())
})?;
guard_on_device(self)?;
let nspin = self.spin() as usize;
if view.len() % nspin != 0 {
return Err(LibXCError::ComputeError(
"rho input has invalid shape: size not divisible by nspin".into(),
));
}
let npoints = view.len() / nspin;
let rho_ptr = cuda_view_ptr(view, stream);
let dim = self.dim();
let ptrs = validate_cuda_output_ptrs(output, &LDA_OUTPUT_LABELS, npoints, dim, stream)?;
unsafe {
xc_lda_call_with_output(self.ptr, npoints, rho_ptr, &ptrs);
}
Ok(())
}
fn cuda_gga_prepare(
&self,
input: &LibXCCudaInput,
deriv_flags: impl Into<LibXCDerivativeFlags>,
stream: &Arc<CudaStream>,
) -> Result<(usize, *const f64, *const f64, LibXCOutputLayout), LibXCError> {
let flags = deriv_flags.into();
self.validate_flags(flags)?;
guard_on_device(self)?;
let view = input.get("rho").ok_or_else(|| {
LibXCError::ComputeError("rho: required CUDA input not provided".into())
})?;
let nspin = self.spin() as usize;
if view.len() % nspin != 0 {
return Err(LibXCError::ComputeError(
"rho input has invalid shape: size not divisible by nspin".into(),
));
}
let npoints = view.len() / nspin;
let dim = self.dim();
let rho_ptr = cuda_view_ptr(view, stream);
let sigma_ptr = require_cuda_input_ptr(input, "sigma", npoints, dim.sigma, stream)?;
let layout = self.gga_output_layout(npoints, flags);
Ok((npoints, rho_ptr, sigma_ptr, layout))
}
pub fn cuda_compute_gga(
&self,
stream: &Arc<CudaStream>,
input: &LibXCCudaInput,
flags: impl Into<LibXCDerivativeFlags>,
) -> Result<(CudaSlice<f64>, LibXCOutputLayout), LibXCError> {
let (npoints, rho_ptr, sigma_ptr, layout) = self.cuda_gga_prepare(input, flags, stream)?;
let mut buffer = stream
.alloc_zeros::<f64>(layout.total_size)
.map_err(|e| LibXCError::CudaError(format!("CUDA allocation failed: {e}")))?;
{
let (output_base, _sync) = buffer.device_ptr_mut(stream);
unsafe {
xc_gga_call(
self.ptr,
npoints,
rho_ptr,
sigma_ptr,
output_base as *mut f64,
&layout,
);
}
}
Ok((buffer, layout))
}
pub fn cuda_compute_gga_with_unsliced_output(
&self,
input: &LibXCCudaInput,
output: &mut CudaSlice<f64>,
deriv_flags: impl Into<LibXCDerivativeFlags>,
) -> Result<LibXCOutputLayout, LibXCError> {
let stream = output.stream().clone();
let (npoints, rho_ptr, sigma_ptr, layout) =
self.cuda_gga_prepare(input, deriv_flags, &stream)?;
if output.len() < layout.total_size {
return Err(LibXCError::ComputeError(format!(
"output buffer has too small size: expected {}, got {}",
layout.total_size,
output.len()
)));
}
let (output_base, _sync) = output.device_ptr_mut(&stream);
unsafe {
xc_gga_call(self.ptr, npoints, rho_ptr, sigma_ptr, output_base as *mut f64, &layout);
}
Ok(layout)
}
pub fn cuda_compute_gga_with_output(
&self,
stream: &Arc<CudaStream>,
input: &LibXCCudaInput,
output: &mut LibXCCudaOutputMut,
) -> Result<(), LibXCError> {
let view = input.get("rho").ok_or_else(|| {
LibXCError::ComputeError("rho: required CUDA input not provided".into())
})?;
guard_on_device(self)?;
let nspin = self.spin() as usize;
if view.len() % nspin != 0 {
return Err(LibXCError::ComputeError(
"rho input has invalid shape: size not divisible by nspin".into(),
));
}
let npoints = view.len() / nspin;
let rho_ptr = cuda_view_ptr(view, stream);
let dim = self.dim();
let sigma_ptr = require_cuda_input_ptr(input, "sigma", npoints, dim.sigma, stream)?;
let ptrs = validate_cuda_output_ptrs(output, &GGA_OUTPUT_LABELS, npoints, dim, stream)?;
unsafe {
xc_gga_call_with_output(self.ptr, npoints, rho_ptr, sigma_ptr, &ptrs);
}
Ok(())
}
fn cuda_mgga_prepare(
&self,
input: &LibXCCudaInput,
deriv_flags: impl Into<LibXCDerivativeFlags>,
stream: &Arc<CudaStream>,
) -> Result<
(usize, *const f64, *const f64, *const f64, *const f64, LibXCOutputLayout),
LibXCError,
> {
let flags = deriv_flags.into();
self.validate_flags(flags)?;
guard_on_device(self)?;
let view = input.get("rho").ok_or_else(|| {
LibXCError::ComputeError("rho: required CUDA input not provided".into())
})?;
let nspin = self.spin() as usize;
if view.len() % nspin != 0 {
return Err(LibXCError::ComputeError(
"rho input has invalid shape: size not divisible by nspin".into(),
));
}
let npoints = view.len() / nspin;
let dim = self.dim();
let needs_lapl = self.needs_laplacian();
let needs_tau = self.needs_tau();
let rho_ptr = cuda_view_ptr(view, stream);
let sigma_ptr = require_cuda_input_ptr(input, "sigma", npoints, dim.sigma, stream)?;
let lapl_ptr =
conditional_cuda_input_ptr(input, "lapl", npoints, dim.lapl, needs_lapl, stream)?;
let tau_ptr =
conditional_cuda_input_ptr(input, "tau", npoints, dim.tau, needs_tau, stream)?;
let layout = self.mgga_output_layout(npoints, flags);
Ok((npoints, rho_ptr, sigma_ptr, lapl_ptr, tau_ptr, layout))
}
pub fn cuda_compute_mgga(
&self,
stream: &Arc<CudaStream>,
input: &LibXCCudaInput,
flags: impl Into<LibXCDerivativeFlags>,
) -> Result<(CudaSlice<f64>, LibXCOutputLayout), LibXCError> {
let (npoints, rho_ptr, sigma_ptr, lapl_ptr, tau_ptr, layout) =
self.cuda_mgga_prepare(input, flags, stream)?;
let mut buffer = stream
.alloc_zeros::<f64>(layout.total_size)
.map_err(|e| LibXCError::CudaError(format!("CUDA allocation failed: {e}")))?;
{
let (output_base, _sync) = buffer.device_ptr_mut(stream);
unsafe {
xc_mgga_call(
self.ptr,
npoints,
rho_ptr,
sigma_ptr,
lapl_ptr,
tau_ptr,
output_base as *mut f64,
&layout,
);
}
}
Ok((buffer, layout))
}
pub fn cuda_compute_mgga_with_unsliced_output(
&self,
input: &LibXCCudaInput,
output: &mut CudaSlice<f64>,
deriv_flags: impl Into<LibXCDerivativeFlags>,
) -> Result<LibXCOutputLayout, LibXCError> {
let stream = output.stream().clone();
let (npoints, rho_ptr, sigma_ptr, lapl_ptr, tau_ptr, layout) =
self.cuda_mgga_prepare(input, deriv_flags, &stream)?;
if output.len() < layout.total_size {
return Err(LibXCError::ComputeError(format!(
"output buffer has too small size: expected {}, got {}",
layout.total_size,
output.len()
)));
}
let (output_base, _sync) = output.device_ptr_mut(&stream);
unsafe {
xc_mgga_call(
self.ptr,
npoints,
rho_ptr,
sigma_ptr,
lapl_ptr,
tau_ptr,
output_base as *mut f64,
&layout,
);
}
Ok(layout)
}
pub fn cuda_compute_mgga_with_output(
&self,
stream: &Arc<CudaStream>,
input: &LibXCCudaInput,
output: &mut LibXCCudaOutputMut,
) -> Result<(), LibXCError> {
let view = input.get("rho").ok_or_else(|| {
LibXCError::ComputeError("rho: required CUDA input not provided".into())
})?;
guard_on_device(self)?;
let nspin = self.spin() as usize;
if view.len() % nspin != 0 {
return Err(LibXCError::ComputeError(
"rho input has invalid shape: size not divisible by nspin".into(),
));
}
let npoints = view.len() / nspin;
let rho_ptr = cuda_view_ptr(view, stream);
let dim = self.dim();
let needs_lapl = self.needs_laplacian();
let needs_tau = self.needs_tau();
let sigma_ptr = require_cuda_input_ptr(input, "sigma", npoints, dim.sigma, stream)?;
let lapl_ptr =
conditional_cuda_input_ptr(input, "lapl", npoints, dim.lapl, needs_lapl, stream)?;
let tau_ptr =
conditional_cuda_input_ptr(input, "tau", npoints, dim.tau, needs_tau, stream)?;
let ptrs = validate_cuda_output_ptrs(output, &MGGA_OUTPUT_LABELS, npoints, dim, stream)?;
unsafe {
xc_mgga_call_with_output(
self.ptr, npoints, rho_ptr, sigma_ptr, lapl_ptr, tau_ptr, &ptrs,
);
}
Ok(())
}
pub fn cuda_compute_xc(
&self,
stream: &Arc<CudaStream>,
input: &LibXCCudaInput,
flags: impl Into<LibXCDerivativeFlags>,
) -> Result<(CudaSlice<f64>, LibXCOutputLayout), LibXCError> {
use crate::prelude::libxc_enum_items::*;
match self.family() {
LDA | HybLDA => self.cuda_compute_lda(stream, input, flags),
GGA | HybGGA => self.cuda_compute_gga(stream, input, flags),
MGGA | HybMGGA => self.cuda_compute_mgga(stream, input, flags),
OEP | LCA => Err(LibXCError::ComputeError(
"cuda_compute_xc: OEP/LCA family is not supported".into(),
)),
}
}
pub fn cuda_compute_xc_with_unsliced_output(
&self,
input: &LibXCCudaInput,
output: &mut CudaSlice<f64>,
deriv_flags: impl Into<LibXCDerivativeFlags>,
) -> Result<LibXCOutputLayout, LibXCError> {
use crate::prelude::libxc_enum_items::*;
match self.family() {
LDA | HybLDA => self.cuda_compute_lda_with_unsliced_output(input, output, deriv_flags),
GGA | HybGGA => self.cuda_compute_gga_with_unsliced_output(input, output, deriv_flags),
MGGA | HybMGGA => {
self.cuda_compute_mgga_with_unsliced_output(input, output, deriv_flags)
},
OEP | LCA => Err(LibXCError::ComputeError(
"cuda_compute_xc: OEP/LCA family is not supported".into(),
)),
}
}
pub fn cuda_compute_xc_with_output(
&self,
stream: &Arc<CudaStream>,
input: &LibXCCudaInput,
output: &mut LibXCCudaOutputMut,
) -> Result<(), LibXCError> {
use crate::prelude::libxc_enum_items::*;
match self.family() {
LDA | HybLDA => self.cuda_compute_lda_with_output(stream, input, output),
GGA | HybGGA => self.cuda_compute_gga_with_output(stream, input, output),
MGGA | HybMGGA => self.cuda_compute_mgga_with_output(stream, input, output),
OEP | LCA => Err(LibXCError::ComputeError(
"cuda_compute_xc: OEP/LCA family is not supported".into(),
)),
}
}
}