pub mod device;
pub(crate) mod kernels;
pub mod memory;
use std::sync::Arc;
use num_complex::Complex64;
use crate::error::Result;
pub use self::device::GpuDevice;
pub use self::memory::GpuBuffer;
#[derive(Debug)]
pub struct GpuContext {
device: Arc<GpuDevice>,
}
impl GpuContext {
pub fn new(device_id: usize) -> Result<Arc<Self>> {
let device = Arc::new(GpuDevice::new(device_id)?);
Ok(Arc::new(Self { device }))
}
pub fn is_available() -> bool {
GpuDevice::is_available()
}
pub fn vram_bytes(&self) -> Result<usize> {
self.device.vram_bytes()
}
pub fn max_qubits_for_statevector(&self) -> Result<usize> {
self.device.max_qubits_for_statevector()
}
pub(crate) fn device(&self) -> &GpuDevice {
&self.device
}
#[cfg(test)]
pub(crate) fn stub_for_tests() -> Arc<Self> {
Arc::new(Self {
device: Arc::new(GpuDevice::stub_for_tests()),
})
}
}
#[derive(Debug)]
pub struct GpuState {
context: Arc<GpuContext>,
buffer: GpuBuffer<f64>,
num_qubits: usize,
pending_norm: f64,
probs_scratch: std::cell::RefCell<Option<GpuBuffer<f64>>>,
}
impl GpuState {
pub fn new(context: Arc<GpuContext>, num_qubits: usize) -> Result<Self> {
let len = 2usize.checked_shl(num_qubits as u32).ok_or_else(|| {
crate::error::PrismError::InvalidParameter {
message: format!("num_qubits={num_qubits} overflows addressable memory"),
}
})?;
let buffer = GpuBuffer::<f64>::alloc_zeros(context.device(), len)?;
let mut state = Self {
context: context.clone(),
buffer,
num_qubits,
pending_norm: 1.0,
probs_scratch: std::cell::RefCell::new(None),
};
kernels::dense::launch_set_initial_state(&context, &mut state)?;
Ok(state)
}
pub fn num_qubits(&self) -> usize {
self.num_qubits
}
pub fn pending_norm(&self) -> f64 {
self.pending_norm
}
pub fn copy_to_host_raw(&self) -> Result<Vec<f64>> {
let mut host = vec![0.0_f64; self.buffer.len()];
self.buffer.copy_to_host(self.context.device(), &mut host)?;
Ok(host)
}
pub fn export_statevector(&self) -> Result<Vec<Complex64>> {
let raw = self.copy_to_host_raw()?;
let norm = self.pending_norm;
let out = raw
.chunks_exact(2)
.map(|p| Complex64::new(p[0] * norm, p[1] * norm))
.collect();
Ok(out)
}
pub fn probabilities(&self) -> Result<Vec<f64>> {
kernels::dense::launch_compute_probabilities(&self.context, self)
}
pub(crate) fn context(&self) -> &Arc<GpuContext> {
&self.context
}
pub(crate) fn buffer(&self) -> &GpuBuffer<f64> {
&self.buffer
}
pub(crate) fn buffer_mut(&mut self) -> &mut GpuBuffer<f64> {
&mut self.buffer
}
pub(crate) fn set_pending_norm(&mut self, norm: f64) {
self.pending_norm = norm;
}
pub(crate) fn probs_scratch(&self) -> std::cell::RefMut<'_, Option<GpuBuffer<f64>>> {
self.probs_scratch.borrow_mut()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::PrismError;
#[test]
fn stub_context_reports_available_false() {
let ctx = GpuContext::stub_for_tests();
assert!(ctx.device().is_stub());
}
#[test]
fn state_new_on_stub_returns_unsupported() {
let ctx = GpuContext::stub_for_tests();
assert!(matches!(
GpuState::new(ctx, 4).unwrap_err(),
PrismError::BackendUnsupported { .. }
));
}
}