pub mod device;
pub(crate) mod kernels;
pub mod memory;
use std::sync::Arc;
use num_complex::Complex64;
use crate::error::Result;
pub use self::device::GpuDevice;
pub use self::memory::GpuBuffer;
pub const MIN_QUBITS_DEFAULT: usize = 14;
pub fn is_available() -> bool {
GpuContext::is_available()
}
pub fn min_qubits() -> usize {
static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
if let Ok(val) = std::env::var("PRISM_GPU_MIN_QUBITS") {
if let Ok(n) = val.parse::<usize>() {
return n;
}
}
MIN_QUBITS_DEFAULT
})
}
#[derive(Debug)]
pub struct GpuContext {
device: Arc<GpuDevice>,
}
impl GpuContext {
pub fn new(device_id: usize) -> Result<Arc<Self>> {
let device = Arc::new(GpuDevice::new(device_id)?);
Ok(Arc::new(Self { device }))
}
pub fn is_available() -> bool {
GpuDevice::is_available()
}
pub fn vram_bytes(&self) -> Result<usize> {
self.device.vram_bytes()
}
pub fn vram_available(&self) -> Result<usize> {
self.device.vram_available()
}
pub fn max_qubits_for_statevector(&self) -> Result<usize> {
self.device.max_qubits_for_statevector()
}
pub fn fits_statevector(&self, num_qubits: usize) -> Result<bool> {
if num_qubits >= usize::BITS as usize - 4 {
return Ok(false);
}
let amplitude_bytes = (1usize << num_qubits).checked_mul(16).ok_or_else(|| {
crate::error::PrismError::InvalidParameter {
message: format!("num_qubits={num_qubits} overflows usize"),
}
})?;
let available = self.vram_available()?;
Ok(amplitude_bytes <= available)
}
pub(crate) fn device(&self) -> &GpuDevice {
&self.device
}
#[cfg(test)]
pub(crate) fn stub_for_tests() -> Arc<Self> {
Arc::new(Self {
device: Arc::new(GpuDevice::stub_for_tests()),
})
}
}
#[derive(Debug)]
pub struct GpuState {
context: Arc<GpuContext>,
buffer: GpuBuffer<f64>,
num_qubits: usize,
pending_norm: f64,
probs_scratch: std::cell::RefCell<Option<GpuBuffer<f64>>>,
}
impl GpuState {
pub fn new(context: Arc<GpuContext>, num_qubits: usize) -> Result<Self> {
let len = 2usize.checked_shl(num_qubits as u32).ok_or_else(|| {
crate::error::PrismError::InvalidParameter {
message: format!("num_qubits={num_qubits} overflows addressable memory"),
}
})?;
let buffer = GpuBuffer::<f64>::alloc_zeros(context.device(), len)?;
let mut state = Self {
context: context.clone(),
buffer,
num_qubits,
pending_norm: 1.0,
probs_scratch: std::cell::RefCell::new(None),
};
kernels::dense::launch_set_initial_state(&context, &mut state)?;
Ok(state)
}
pub fn num_qubits(&self) -> usize {
self.num_qubits
}
pub fn pending_norm(&self) -> f64 {
self.pending_norm
}
pub fn copy_to_host_raw(&self) -> Result<Vec<f64>> {
let mut host = vec![0.0_f64; self.buffer.len()];
self.buffer.copy_to_host(self.context.device(), &mut host)?;
Ok(host)
}
pub fn export_statevector(&self) -> Result<Vec<Complex64>> {
let raw = self.copy_to_host_raw()?;
let norm = self.pending_norm;
let out = raw
.chunks_exact(2)
.map(|p| Complex64::new(p[0] * norm, p[1] * norm))
.collect();
Ok(out)
}
pub fn probabilities(&self) -> Result<Vec<f64>> {
kernels::dense::launch_compute_probabilities(&self.context, self)
}
pub(crate) fn context(&self) -> &Arc<GpuContext> {
&self.context
}
pub(crate) fn buffer(&self) -> &GpuBuffer<f64> {
&self.buffer
}
pub(crate) fn buffer_mut(&mut self) -> &mut GpuBuffer<f64> {
&mut self.buffer
}
pub(crate) fn set_pending_norm(&mut self, norm: f64) {
self.pending_norm = norm;
}
pub(crate) fn probs_scratch(&self) -> std::cell::RefMut<'_, Option<GpuBuffer<f64>>> {
self.probs_scratch.borrow_mut()
}
}
#[derive(Debug)]
pub struct GpuTableau {
#[allow(dead_code)]
context: Arc<GpuContext>,
xz: GpuBuffer<u64>,
#[allow(dead_code)]
phase: GpuBuffer<u8>,
num_qubits: usize,
num_words: usize,
}
impl GpuTableau {
pub fn new(context: Arc<GpuContext>, num_qubits: usize) -> Result<Self> {
let num_words = num_qubits.div_ceil(64);
let total_rows = 2 * num_qubits + 1;
let xz_len = total_rows * 2 * num_words.max(1);
let phase_len = total_rows;
let xz = GpuBuffer::<u64>::alloc_zeros(context.device(), xz_len)?;
let phase = GpuBuffer::<u8>::alloc_zeros(context.device(), phase_len)?;
let mut tableau = Self {
context: context.clone(),
xz,
phase,
num_qubits,
num_words,
};
kernels::stabilizer::launch_set_initial_tableau(&context, &mut tableau)?;
Ok(tableau)
}
pub fn num_qubits(&self) -> usize {
self.num_qubits
}
pub fn num_words(&self) -> usize {
self.num_words
}
pub(crate) fn xz_mut(&mut self) -> &mut GpuBuffer<u64> {
&mut self.xz
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::PrismError;
#[test]
fn stub_context_reports_available_false() {
let ctx = GpuContext::stub_for_tests();
assert!(ctx.device().is_stub());
}
#[test]
fn state_new_on_stub_returns_unsupported() {
let ctx = GpuContext::stub_for_tests();
assert!(matches!(
GpuState::new(ctx, 4).unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
#[test]
fn tableau_new_on_stub_returns_unsupported() {
let ctx = GpuContext::stub_for_tests();
assert!(matches!(
GpuTableau::new(ctx, 4).unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
#[test]
fn min_qubits_default_when_env_unset() {
let n = min_qubits();
assert!(
(1..=32).contains(&n),
"implausible gpu crossover threshold: {n}"
);
}
#[test]
fn stub_vram_available_rejects_cleanly() {
let ctx = GpuContext::stub_for_tests();
assert!(matches!(
ctx.vram_available().unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
#[test]
fn stub_fits_statevector_rejects_cleanly() {
let ctx = GpuContext::stub_for_tests();
assert!(matches!(
ctx.fits_statevector(4).unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
#[test]
fn fits_statevector_rejects_overflowing_qubit_counts() {
let ctx = GpuContext::stub_for_tests();
assert!(!ctx.fits_statevector(128).unwrap());
}
}