use std::collections::HashMap;
use std::sync::Arc;
use cudarc::driver::{CudaContext, CudaFunction, CudaModule, CudaStream};
use cudarc::nvrtc::{compile_ptx_with_opts, CompileOptions};
use crate::error::{PrismError, Result};
use super::kernels::{kernel_source, KERNEL_NAMES};
#[derive(Debug)]
pub struct GpuDevice {
inner: DeviceInner,
}
#[derive(Debug)]
#[allow(dead_code)]
enum DeviceInner {
Real {
context: Arc<CudaContext>,
stream: Arc<CudaStream>,
module: Arc<CudaModule>,
functions: HashMap<&'static str, CudaFunction>,
},
Stub,
}
impl GpuDevice {
pub fn new(device_id: usize) -> Result<Self> {
let context = CudaContext::new(device_id).map_err(|e| Self::driver_err("init", e))?;
let stream = context.default_stream();
let arch = detect_arch(&context).unwrap_or("compute_60");
let opts = CompileOptions {
arch: Some(arch),
..Default::default()
};
let source = kernel_source();
let ptx =
compile_ptx_with_opts(&source, opts).map_err(|e| PrismError::BackendUnsupported {
backend: "gpu".to_string(),
operation: format!("PTX compilation (arch={arch}): {e}"),
})?;
let module = context
.load_module(ptx)
.map_err(|e| Self::driver_err("load_module", e))?;
let mut functions = HashMap::with_capacity(KERNEL_NAMES.len());
for &name in KERNEL_NAMES {
let func = module
.load_function(name)
.map_err(|e| Self::driver_err(&format!("load_function `{name}`"), e))?;
functions.insert(name, func);
}
Ok(Self {
inner: DeviceInner::Real {
context,
stream,
module,
functions,
},
})
}
pub fn is_available() -> bool {
CudaContext::new(0).is_ok()
}
pub fn vram_bytes(&self) -> Result<usize> {
match &self.inner {
DeviceInner::Real { context, .. } => context
.total_mem()
.map_err(|e| Self::driver_err("vram_bytes", e)),
DeviceInner::Stub => Err(Self::stub_unsupported("vram_bytes")),
}
}
pub fn max_qubits_for_statevector(&self) -> Result<usize> {
let bytes = self.vram_bytes()?;
let elements = bytes / 16;
if elements == 0 {
return Ok(0);
}
Ok(63 - elements.leading_zeros() as usize)
}
#[cfg(test)]
pub(crate) fn stub_for_tests() -> Self {
Self {
inner: DeviceInner::Stub,
}
}
pub(crate) fn stream(&self) -> Result<&Arc<CudaStream>> {
match &self.inner {
DeviceInner::Real { stream, .. } => Ok(stream),
DeviceInner::Stub => Err(Self::stub_unsupported("stream access")),
}
}
pub(crate) fn function(&self, name: &str) -> Result<CudaFunction> {
match &self.inner {
DeviceInner::Real { functions, .. } => {
functions
.get(name)
.cloned()
.ok_or_else(|| PrismError::BackendUnsupported {
backend: "gpu".to_string(),
operation: format!("unknown kernel `{name}` (not in KERNEL_NAMES)"),
})
}
DeviceInner::Stub => Err(Self::stub_unsupported(&format!("function `{name}`"))),
}
}
#[allow(dead_code)]
pub(crate) fn is_stub(&self) -> bool {
matches!(self.inner, DeviceInner::Stub)
}
fn driver_err(op: &str, err: impl std::fmt::Display) -> PrismError {
PrismError::BackendUnsupported {
backend: "gpu".to_string(),
operation: format!("{op}: {err}"),
}
}
fn stub_unsupported(op: &str) -> PrismError {
PrismError::BackendUnsupported {
backend: "gpu".to_string(),
operation: format!("{op} (stub device)"),
}
}
}
fn detect_arch(context: &Arc<CudaContext>) -> Option<&'static str> {
let (major, minor) = context.compute_capability().ok()?;
match (major, minor) {
(6, 0) => Some("compute_60"),
(6, 1) => Some("compute_61"),
(6, 2) => Some("compute_62"),
(7, 0) => Some("compute_70"),
(7, 2) => Some("compute_72"),
(7, 5) => Some("compute_75"),
(8, 0) => Some("compute_80"),
(8, 6) => Some("compute_86"),
(8, 7) => Some("compute_87"),
(8, 9) => Some("compute_89"),
(9, 0) => Some("compute_90"),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stub_reports_as_stub() {
let dev = GpuDevice::stub_for_tests();
assert!(dev.is_stub());
}
#[test]
fn stub_stream_returns_unsupported() {
let dev = GpuDevice::stub_for_tests();
assert!(matches!(
dev.stream().unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
}