pub(crate) mod bts;
pub(crate) mod dense;
pub(crate) mod stabilizer;
use cudarc::driver::{DeviceRepr, ValidAsZeroBits};
use crate::error::{PrismError, Result};
use crate::gpu::device::GpuDevice;
use crate::gpu::memory::GpuBuffer;
pub(super) fn launch_err(op: &str, err: impl std::fmt::Display) -> PrismError {
PrismError::BackendUnsupported {
backend: "gpu".to_string(),
operation: format!("{op}: {err}"),
}
}
pub(super) fn launch_limit_err(op: &str, name: &str, value: usize, limit: &str) -> PrismError {
PrismError::BackendUnsupported {
backend: "gpu".to_string(),
operation: format!("{op}: {name}={value} exceeds {limit} kernel limit"),
}
}
pub(super) fn require_i32(op: &str, name: &str, value: usize) -> Result<i32> {
i32::try_from(value).map_err(|_| launch_limit_err(op, name, value, "i32"))
}
pub(super) fn require_u32(op: &str, name: &str, value: usize) -> Result<u32> {
u32::try_from(value).map_err(|_| launch_limit_err(op, name, value, "u32"))
}
pub(super) fn div_ceil_grid(op: &str, name: &str, value: usize, block: u32) -> Result<u32> {
Ok(require_u32(op, name, value)?.div_ceil(block).max(1))
}
#[derive(Default)]
pub(crate) struct LauncherScratch {
pub(crate) f64_a: Option<GpuBuffer<f64>>,
pub(crate) i32_a: Option<GpuBuffer<i32>>,
pub(crate) i32_b: Option<GpuBuffer<i32>>,
pub(crate) i32_c: Option<GpuBuffer<i32>>,
pub(crate) u32_a: Option<GpuBuffer<u32>>,
pub(crate) measure_partials: Option<GpuBuffer<f64>>,
pub(crate) measure_result: Option<GpuBuffer<f64>>,
}
pub(crate) fn ensure_scratch<'a, T: DeviceRepr + ValidAsZeroBits>(
slot: &'a mut Option<GpuBuffer<T>>,
device: &GpuDevice,
host: &[T],
) -> Result<&'a GpuBuffer<T>> {
let needed = host.len().max(1);
let realloc = match slot.as_ref() {
Some(buf) => buf.len() < needed,
None => true,
};
if realloc {
*slot = Some(GpuBuffer::<T>::alloc_zeros(device, needed)?);
}
if !host.is_empty() {
slot.as_mut().unwrap().copy_from_host(device, host)?;
}
Ok(slot.as_ref().unwrap())
}
pub(crate) fn ensure_capacity<'a, T: DeviceRepr + ValidAsZeroBits>(
slot: &'a mut Option<GpuBuffer<T>>,
device: &GpuDevice,
needed: usize,
) -> Result<&'a mut GpuBuffer<T>> {
let needed = needed.max(1);
let realloc = match slot.as_ref() {
Some(buf) => buf.len() < needed,
None => true,
};
if realloc {
*slot = Some(GpuBuffer::<T>::alloc_zeros(device, needed)?);
}
Ok(slot.as_mut().unwrap())
}
pub(crate) fn kernel_source() -> String {
let mut src = dense::kernel_source();
src.push('\n');
src.push_str(&stabilizer::kernel_source());
src.push('\n');
src.push_str(&bts::kernel_source());
src
}
pub(crate) const KERNEL_NAMES: &[&str] = &[
"set_initial_state",
"apply_gate_1q",
"apply_diagonal_1q",
"apply_cx",
"apply_cz",
"apply_swap",
"apply_parity_phase",
"apply_cu",
"apply_cu_phase",
"apply_mcu",
"apply_mcu_phase",
"apply_fused_2q",
"measure_prob_one",
"measure_prob_one_finalize",
"measure_collapse",
"compute_probabilities",
"scale_state",
"apply_multi_fused_diagonal",
"apply_batch_phase",
"apply_batch_rzz",
"apply_diagonal_batch",
"apply_multi_fused_tiled",
"stab_set_initial_tableau",
"stab_apply_word_grouped",
"stab_rowmul_words",
"stab_measure_find_pivot",
"stab_measure_cascade",
"stab_measure_fixup",
"stab_measure_deterministic",
"bts_sample_meas_major",
"bts_popcount_rows",
"bts_count_meas_major_upto8",
"bts_count_shot_major_upto8",
"bts_count_used_slots",
"bts_compact_counts_upto8",
"bts_transpose_meas_to_shot",
"bts_apply_noise_masks_meas_major",
"bts_generate_and_apply_noise_meas_major_by_row",
];
#[cfg(test)]
mod sync_tests {
use super::{kernel_source, KERNEL_NAMES};
use std::collections::BTreeSet;
fn source_entry_points(src: &str) -> BTreeSet<String> {
const MARKER: &str = "__global__ void ";
let mut names = BTreeSet::new();
let mut rest = src;
while let Some(pos) = rest.find(MARKER) {
rest = &rest[pos + MARKER.len()..];
let name: String = rest
.chars()
.take_while(|c| c.is_alphanumeric() || *c == '_')
.collect();
if !name.is_empty() {
names.insert(name);
}
}
names
}
#[test]
fn kernel_names_match_source_entry_points() {
let src = kernel_source();
let in_source = source_entry_points(&src);
let declared: BTreeSet<String> = KERNEL_NAMES.iter().map(|s| s.to_string()).collect();
let missing: Vec<&String> = in_source.difference(&declared).collect();
let orphan: Vec<&String> = declared.difference(&in_source).collect();
assert!(
missing.is_empty(),
"kernels defined in source but absent from KERNEL_NAMES \
(GpuDevice::new would not pre-resolve them): {missing:?}"
);
assert!(
orphan.is_empty(),
"KERNEL_NAMES entries with no matching kernel in source: {orphan:?}"
);
}
}