pub mod device;
pub(crate) mod kernels;
pub mod memory;
use std::sync::Arc;
use num_complex::Complex64;
use crate::error::Result;
pub use self::device::GpuDevice;
pub use self::memory::GpuBuffer;
pub const MIN_QUBITS_DEFAULT: usize = 14;
pub fn is_available() -> bool {
GpuContext::is_available()
}
pub fn min_qubits() -> usize {
static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
if let Ok(val) = std::env::var("PRISM_GPU_MIN_QUBITS") {
if let Ok(n) = val.parse::<usize>() {
return n;
}
}
MIN_QUBITS_DEFAULT
})
}
pub const BTS_MIN_SHOTS_DEFAULT: usize = 131_072;
pub const BTS_MIN_RANK_DEFAULT: usize = 4;
pub const BTS_MIN_WEIGHT_FACTOR_DEFAULT: usize = 2;
pub(crate) fn bts_min_shots() -> usize {
static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
if let Ok(val) = std::env::var("PRISM_GPU_BTS_MIN_SHOTS") {
if let Ok(n) = val.parse::<usize>() {
return n;
}
}
BTS_MIN_SHOTS_DEFAULT
})
}
pub(crate) fn bts_min_rank() -> usize {
static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
if let Ok(val) = std::env::var("PRISM_GPU_BTS_MIN_RANK") {
if let Ok(n) = val.parse::<usize>() {
return n.max(1);
}
}
BTS_MIN_RANK_DEFAULT
})
}
pub(crate) fn bts_min_weight_factor() -> usize {
static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
if let Ok(val) = std::env::var("PRISM_GPU_BTS_MIN_WEIGHT_FACTOR") {
if let Ok(n) = val.parse::<usize>() {
return n.max(1);
}
}
BTS_MIN_WEIGHT_FACTOR_DEFAULT
})
}
pub const STABILIZER_MIN_QUBITS_DEFAULT: usize = 100_000;
pub fn stabilizer_min_qubits() -> usize {
static CACHED: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CACHED.get_or_init(|| {
if let Ok(val) = std::env::var("PRISM_STABILIZER_GPU_MIN_QUBITS") {
if let Ok(n) = val.parse::<usize>() {
return n;
}
}
STABILIZER_MIN_QUBITS_DEFAULT
})
}
pub struct GpuContext {
device: Arc<GpuDevice>,
launcher_scratch: std::sync::Mutex<kernels::LauncherScratch>,
}
impl std::fmt::Debug for GpuContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuContext")
.field("device", &self.device)
.finish_non_exhaustive()
}
}
impl GpuContext {
pub fn new(device_id: usize) -> Result<Arc<Self>> {
let device = Arc::new(GpuDevice::new(device_id)?);
Ok(Arc::new(Self {
device,
launcher_scratch: std::sync::Mutex::new(kernels::LauncherScratch::default()),
}))
}
pub fn is_available() -> bool {
GpuDevice::is_available()
}
pub fn vram_bytes(&self) -> Result<usize> {
self.device.vram_bytes()
}
pub fn vram_available(&self) -> Result<usize> {
self.device.vram_available()
}
pub fn max_qubits_for_statevector(&self) -> Result<usize> {
self.device.max_qubits_for_statevector()
}
pub fn fits_statevector(&self, num_qubits: usize) -> Result<bool> {
if num_qubits >= usize::BITS as usize - 4 {
return Ok(false);
}
let amplitude_bytes = (1usize << num_qubits).checked_mul(16).ok_or_else(|| {
crate::error::PrismError::InvalidParameter {
message: format!("num_qubits={num_qubits} overflows usize"),
}
})?;
let available = self.vram_available()?;
Ok(amplitude_bytes <= available)
}
pub(crate) fn device(&self) -> &GpuDevice {
&self.device
}
pub(crate) fn launcher_scratch(&self) -> std::sync::MutexGuard<'_, kernels::LauncherScratch> {
self.launcher_scratch
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[cfg(test)]
pub(crate) fn stub_for_tests() -> Arc<Self> {
Arc::new(Self {
device: Arc::new(GpuDevice::stub_for_tests()),
launcher_scratch: std::sync::Mutex::new(kernels::LauncherScratch::default()),
})
}
}
#[derive(Debug)]
pub struct GpuState {
context: Arc<GpuContext>,
buffer: GpuBuffer<f64>,
num_qubits: usize,
pending_norm: f64,
probs_scratch: std::cell::RefCell<Option<GpuBuffer<f64>>>,
}
impl GpuState {
pub fn new(context: Arc<GpuContext>, num_qubits: usize) -> Result<Self> {
let len = 2usize.checked_shl(num_qubits as u32).ok_or_else(|| {
crate::error::PrismError::InvalidParameter {
message: format!("num_qubits={num_qubits} overflows addressable memory"),
}
})?;
let buffer = GpuBuffer::<f64>::alloc_zeros(context.device(), len)?;
let mut state = Self {
context: context.clone(),
buffer,
num_qubits,
pending_norm: 1.0,
probs_scratch: std::cell::RefCell::new(None),
};
kernels::dense::launch_set_initial_state(&context, &mut state)?;
Ok(state)
}
pub fn num_qubits(&self) -> usize {
self.num_qubits
}
pub fn pending_norm(&self) -> f64 {
self.pending_norm
}
pub fn copy_to_host_raw(&self) -> Result<Vec<f64>> {
let mut host = vec![0.0_f64; self.buffer.len()];
self.buffer.copy_to_host(self.context.device(), &mut host)?;
Ok(host)
}
pub fn export_statevector(&self) -> Result<Vec<Complex64>> {
let raw = self.copy_to_host_raw()?;
let norm = self.pending_norm;
let out = raw
.chunks_exact(2)
.map(|p| Complex64::new(p[0] * norm, p[1] * norm))
.collect();
Ok(out)
}
pub fn probabilities(&self) -> Result<Vec<f64>> {
kernels::dense::launch_compute_probabilities(&self.context, self)
}
pub(crate) fn context(&self) -> &Arc<GpuContext> {
&self.context
}
pub(crate) fn buffer(&self) -> &GpuBuffer<f64> {
&self.buffer
}
pub(crate) fn buffer_mut(&mut self) -> &mut GpuBuffer<f64> {
&mut self.buffer
}
pub(crate) fn set_pending_norm(&mut self, norm: f64) {
self.pending_norm = norm;
}
pub(crate) fn probs_scratch(&self) -> std::cell::RefMut<'_, Option<GpuBuffer<f64>>> {
self.probs_scratch.borrow_mut()
}
}
#[derive(Debug)]
pub struct GpuTableau {
#[allow(dead_code)]
context: Arc<GpuContext>,
xz: GpuBuffer<u64>,
#[allow(dead_code)]
phase: GpuBuffer<u8>,
measure_pivot: GpuBuffer<i32>,
measure_outcome: GpuBuffer<u8>,
num_qubits: usize,
num_words: usize,
}
impl GpuTableau {
pub fn new(context: Arc<GpuContext>, num_qubits: usize) -> Result<Self> {
let num_words = num_qubits.div_ceil(64);
let total_rows = 2 * num_qubits + 1;
let xz_len = total_rows * 2 * num_words.max(1);
let phase_len = total_rows;
let xz = GpuBuffer::<u64>::alloc_zeros(context.device(), xz_len)?;
let phase = GpuBuffer::<u8>::alloc_zeros(context.device(), phase_len)?;
let measure_pivot = GpuBuffer::<i32>::alloc_zeros(context.device(), 1)?;
let measure_outcome = GpuBuffer::<u8>::alloc_zeros(context.device(), 1)?;
let mut tableau = Self {
context: context.clone(),
xz,
phase,
measure_pivot,
measure_outcome,
num_qubits,
num_words,
};
kernels::stabilizer::launch_set_initial_tableau(&context, &mut tableau)?;
Ok(tableau)
}
pub fn num_qubits(&self) -> usize {
self.num_qubits
}
pub fn num_words(&self) -> usize {
self.num_words
}
pub(crate) fn xz_mut(&mut self) -> &mut GpuBuffer<u64> {
&mut self.xz
}
pub(crate) fn xz_phase_mut(&mut self) -> (&mut GpuBuffer<u64>, &mut GpuBuffer<u8>) {
(&mut self.xz, &mut self.phase)
}
pub(crate) fn xz_pivot_mut(&mut self) -> (&mut GpuBuffer<u64>, &mut GpuBuffer<i32>) {
(&mut self.xz, &mut self.measure_pivot)
}
pub(crate) fn xz_phase_outcome_mut(
&mut self,
) -> (&mut GpuBuffer<u64>, &mut GpuBuffer<u8>, &mut GpuBuffer<u8>) {
(&mut self.xz, &mut self.phase, &mut self.measure_outcome)
}
pub(crate) fn total_rows(&self) -> usize {
2 * self.num_qubits + 1
}
pub fn copy_to_host(&self) -> Result<(Vec<u64>, Vec<bool>)> {
let device = self.context.device();
let mut xz = vec![0u64; self.xz.len()];
self.xz.copy_to_host(device, &mut xz)?;
let mut phase_bytes = vec![0u8; self.phase.len()];
self.phase.copy_to_host(device, &mut phase_bytes)?;
let phase = phase_bytes.iter().map(|&b| b != 0).collect();
Ok((xz, phase))
}
pub fn copy_from_host(&mut self, xz: &[u64], phase: &[bool]) -> Result<()> {
let device = self.context.device();
self.xz.copy_from_host(device, xz)?;
let phase_bytes: Vec<u8> = phase.iter().map(|&b| u8::from(b)).collect();
self.phase.copy_from_host(device, &phase_bytes)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::PrismError;
#[test]
fn stub_context_reports_available_false() {
let ctx = GpuContext::stub_for_tests();
assert!(ctx.device().is_stub());
}
#[test]
fn state_new_on_stub_returns_unsupported() {
let ctx = GpuContext::stub_for_tests();
assert!(matches!(
GpuState::new(ctx, 4).unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
#[test]
fn tableau_new_on_stub_returns_unsupported() {
let ctx = GpuContext::stub_for_tests();
assert!(matches!(
GpuTableau::new(ctx, 4).unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
#[test]
fn min_qubits_default_when_env_unset() {
let n = min_qubits();
assert!(
(1..=32).contains(&n),
"implausible gpu crossover threshold: {n}"
);
}
#[test]
fn stub_vram_available_rejects_cleanly() {
let ctx = GpuContext::stub_for_tests();
assert!(matches!(
ctx.vram_available().unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
#[test]
fn stub_fits_statevector_rejects_cleanly() {
let ctx = GpuContext::stub_for_tests();
assert!(matches!(
ctx.fits_statevector(4).unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
#[test]
fn fits_statevector_rejects_overflowing_qubit_counts() {
let ctx = GpuContext::stub_for_tests();
assert!(!ctx.fits_statevector(128).unwrap());
}
}