#[cfg(test)]
mod tests;
use std::borrow::Cow;
use std::sync::Arc;
use num_complex::Complex64;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::backend::simd;
use crate::backend::statevector::StatevectorBackend;
use crate::backend::{dense_probability_len, dense_statevector_len, measurement_inv_norm, Backend};
use crate::circuit::{smallvec, Instruction, SmallVec};
use crate::distributed::DistributedContext;
use crate::error::{PrismError, Result};
use crate::gates::{DiagEntry, Gate};
const BACKEND_NAME: &str = "distributed_statevector";
#[inline]
fn is_diagonal_2x2(mat: &[[Complex64; 2]; 2]) -> bool {
mat[0][1].norm() < 1e-12 && mat[1][0].norm() < 1e-12
}
fn for_each_gate_qubit(gate: &Gate, targets: &[usize], mut f: impl FnMut(usize)) {
for &q in targets {
f(q);
}
match gate {
Gate::BatchPhase(data) => {
for &(target, _) in &data.phases {
f(target);
}
}
Gate::BatchRzz(data) => {
for &(q0, q1, _) in &data.edges {
f(q0);
f(q1);
}
}
Gate::MultiFused(data) => {
for &(q, _) in &data.gates {
f(q);
}
}
Gate::Multi2q(data) => {
for &(q0, q1, _) in &data.gates {
f(q0);
f(q1);
}
}
Gate::DiagonalBatch(data) => {
for entry in &data.entries {
match *entry {
DiagEntry::Phase1q { qubit, .. } => f(qubit),
DiagEntry::Phase2q { q0, q1, .. } | DiagEntry::Parity2q { q0, q1, .. } => {
f(q0);
f(q1);
}
}
}
}
_ => {}
}
}
fn required_local_qubits(gate: &Gate, targets: &[usize]) -> SmallVec<[usize; 8]> {
let mut req: SmallVec<[usize; 8]> = SmallVec::new();
fn push(req: &mut SmallVec<[usize; 8]>, q: usize) {
if !req.contains(&q) {
req.push(q);
}
}
match gate {
Gate::Cx => push(&mut req, targets[1]),
Gate::Cz
| Gate::Swap
| Gate::Rzz(_)
| Gate::BatchPhase(_)
| Gate::BatchRzz(_)
| Gate::DiagonalBatch(_) => {}
Gate::Cu(_) | Gate::Mcu(_) => {
if gate.controlled_phase().is_none() {
let target = match gate {
Gate::Mcu(data) => targets[data.num_controls as usize],
_ => targets[1],
};
push(&mut req, target);
}
}
Gate::Fused2q(_) => {
push(&mut req, targets[0]);
push(&mut req, targets[1]);
}
Gate::Multi2q(data) => {
for &(q0, q1, _) in &data.gates {
push(&mut req, q0);
push(&mut req, q1);
}
}
Gate::MultiFused(data) => {
for &(q, ref mat) in &data.gates {
if !is_diagonal_2x2(mat) {
push(&mut req, q);
}
}
}
g if g.num_qubits() == 1 => {
if !g.is_diagonal_1q() {
push(&mut req, targets[0]);
}
}
_ => {}
}
req
}
pub struct DistributedStatevectorBackend {
context: Arc<DistributedContext>,
inner: StatevectorBackend,
num_qubits: usize,
global_qubits: usize,
recv: Vec<Complex64>,
seed: u64,
exchange_chunk: usize,
exchange_messages: u64,
exchange_amplitudes: u64,
meas_rng: ChaCha8Rng,
qubit_map: Vec<usize>,
phys_map: Vec<usize>,
map_identity: bool,
relabel: bool,
last_used: Vec<u64>,
tick: u64,
pack: Vec<Complex64>,
}
impl DistributedStatevectorBackend {
pub fn new(context: Arc<DistributedContext>, seed: u64) -> Self {
Self {
context,
inner: StatevectorBackend::new(seed),
num_qubits: 0,
global_qubits: 0,
recv: Vec::new(),
seed,
exchange_chunk: crate::distributed::exchange_chunk(),
exchange_messages: 0,
exchange_amplitudes: 0,
meas_rng: ChaCha8Rng::seed_from_u64(seed),
qubit_map: Vec::new(),
phys_map: Vec::new(),
map_identity: true,
relabel: crate::distributed::relabel_enabled(),
last_used: Vec::new(),
tick: 0,
pack: Vec::new(),
}
}
#[cfg(test)]
pub(crate) fn set_exchange_chunk(&mut self, chunk: usize) {
self.exchange_chunk = chunk.max(1);
}
pub fn set_relabel(&mut self, enabled: bool) {
self.relabel = enabled;
}
pub fn exchange_messages(&self) -> u64 {
self.exchange_messages
}
pub fn exchange_amplitudes(&self) -> u64 {
self.exchange_amplitudes
}
#[inline]
fn count_exchange(&mut self, amplitudes: usize) {
self.exchange_messages += 1;
self.exchange_amplitudes += amplitudes as u64;
}
#[inline]
fn local_qubits(&self) -> usize {
self.num_qubits - self.global_qubits
}
#[inline]
fn is_single_rank(&self) -> bool {
self.context.size() == 1
}
#[inline]
fn global_bit(&self, q: usize) -> usize {
q - self.local_qubits()
}
#[inline]
fn rank_bit_set(&self, q: usize) -> bool {
(self.context.rank() >> self.global_bit(q)) & 1 == 1
}
fn touch_instruction(&mut self, gate: &Gate, targets: &[usize]) {
self.tick += 1;
let tick = self.tick;
for_each_gate_qubit(gate, targets, |q| self.last_used[q] = tick);
}
fn refresh_map_identity(&mut self) {
self.map_identity = self.qubit_map.iter().enumerate().all(|(q, &p)| q == p);
}
fn swap_circuit_qubits(&mut self, a: usize, b: usize) {
if a == b {
return;
}
let pa = self.qubit_map[a];
let pb = self.qubit_map[b];
self.qubit_map.swap(a, b);
self.phys_map.swap(pa, pb);
self.refresh_map_identity();
}
fn pick_victim(&self) -> Option<usize> {
let local = self.local_qubits();
let mut best: Option<(u64, usize)> = None;
for pos in 0..local {
let used = self.last_used[self.phys_map[pos]];
if used == self.tick {
continue;
}
match best {
Some((b, _)) if used >= b => {}
_ => best = Some((used, pos)),
}
}
best.map(|(_, pos)| pos)
}
fn make_local(&mut self, req: &[usize]) {
for &q in req {
let pos = self.qubit_map[q];
if pos < self.local_qubits() {
continue;
}
let Some(victim) = self.pick_victim() else {
return;
};
self.relabel_swap(victim, pos);
}
}
fn relabel_swap(&mut self, local_pos: usize, global_pos: usize) {
let partner = self.context.rank() ^ (1usize << self.global_bit(global_pos));
let gbit = self.rank_bit_set(global_pos);
let stride = 1usize << local_pos;
let fixed = if gbit { 0 } else { stride };
let moving = self.inner.state.len() / 2;
let chunk = self.exchange_chunk.min(moving).max(1);
if self.pack.len() != chunk {
self.pack.resize(chunk, Complex64::new(0.0, 0.0));
}
if self.recv.len() != chunk {
self.recv.resize(chunk, Complex64::new(0.0, 0.0));
}
let index_of =
|flat: usize| ((flat >> local_pos) << (local_pos + 1)) | fixed | (flat & (stride - 1));
let mut off = 0;
while off < moving {
let count = (off + chunk).min(moving) - off;
for (k, slot) in self.pack[..count].iter_mut().enumerate() {
*slot = self.inner.state[index_of(off + k)];
}
self.count_exchange(count);
self.context
.comm()
.sendrecv_c64(partner, &self.pack[..count], &mut self.recv[..count]);
for (k, &) in self.recv[..count].iter().enumerate() {
self.inner.state[index_of(off + k)] = amp;
}
off += count;
}
let local_q = self.phys_map[local_pos];
let global_q = self.phys_map[global_pos];
self.qubit_map[local_q] = global_pos;
self.qubit_map[global_q] = local_pos;
self.phys_map.swap(local_pos, global_pos);
self.refresh_map_identity();
}
fn swap_physical_positions(&mut self, a: usize, b: usize) {
debug_assert!(
a < b,
"positions must be ordered: branch selection assumes a < b"
);
let local = self.local_qubits();
if b < local {
self.inner
.apply(&Instruction::Gate {
gate: Gate::Swap,
targets: smallvec![a, b],
})
.expect("local SWAP cannot fail");
} else if a < local {
self.relabel_swap(a, b);
return;
} else {
let ga = self.global_bit(a);
let gb = self.global_bit(b);
let rank = self.context.rank();
if ((rank >> ga) ^ (rank >> gb)) & 1 == 1 {
let partner = rank ^ ((1usize << ga) | (1usize << gb));
let len = self.inner.state.len();
let chunk = self.exchange_chunk.min(len).max(1);
if self.recv.len() != chunk {
self.recv.resize(chunk, Complex64::new(0.0, 0.0));
}
let mut off = 0;
while off < len {
let end = (off + chunk).min(len);
self.count_exchange(end - off);
let recv = &mut self.recv[..end - off];
self.context
.comm()
.sendrecv_c64(partner, &self.inner.state[off..end], recv);
self.inner.state[off..end].copy_from_slice(recv);
off = end;
}
}
}
let qa = self.phys_map[a];
let qb = self.phys_map[b];
self.qubit_map[qa] = b;
self.qubit_map[qb] = a;
self.phys_map.swap(a, b);
self.refresh_map_identity();
}
fn restore_identity_map(&mut self) {
while !self.map_identity {
let Some(pos) = (0..self.num_qubits).find(|&p| self.phys_map[p] != p) else {
break;
};
let src = self.qubit_map[pos];
self.swap_physical_positions(pos, src);
}
}
fn to_physical<'g>(
&self,
gate: &'g Gate,
targets: &[usize],
) -> (Cow<'g, Gate>, SmallVec<[usize; 4]>) {
if self.map_identity {
return (Cow::Borrowed(gate), targets.into());
}
let ptargets: SmallVec<[usize; 4]> = targets.iter().map(|&q| self.qubit_map[q]).collect();
let pgate = match gate {
Gate::MultiFused(data) => {
let mut data = data.clone();
for entry in &mut data.gates {
entry.0 = self.qubit_map[entry.0];
}
Cow::Owned(Gate::MultiFused(data))
}
Gate::Multi2q(data) => {
let mut data = data.clone();
for entry in &mut data.gates {
entry.0 = self.qubit_map[entry.0];
entry.1 = self.qubit_map[entry.1];
}
Cow::Owned(Gate::Multi2q(data))
}
Gate::BatchPhase(data) => {
let mut data = data.clone();
for entry in &mut data.phases {
entry.0 = self.qubit_map[entry.0];
}
Cow::Owned(Gate::BatchPhase(data))
}
Gate::BatchRzz(data) => {
let mut data = data.clone();
for entry in &mut data.edges {
entry.0 = self.qubit_map[entry.0];
entry.1 = self.qubit_map[entry.1];
}
Cow::Owned(Gate::BatchRzz(data))
}
Gate::DiagonalBatch(data) => {
let mut data = data.clone();
for entry in &mut data.entries {
match entry {
DiagEntry::Phase1q { qubit, .. } => *qubit = self.qubit_map[*qubit],
DiagEntry::Phase2q { q0, q1, .. } | DiagEntry::Parity2q { q0, q1, .. } => {
*q0 = self.qubit_map[*q0];
*q1 = self.qubit_map[*q1];
}
}
}
Cow::Owned(Gate::DiagonalBatch(data))
}
_ => Cow::Borrowed(gate),
};
(pgate, ptargets)
}
fn instruction_qubits_local(&self, gate: &Gate, targets: &[usize]) -> bool {
let local = self.local_qubits();
let mut all = true;
for_each_gate_qubit(gate, targets, |q| all &= q < local);
all
}
fn unpermuted<T: Copy + Default>(&self, phys: Vec<T>) -> Vec<T> {
if self.map_identity {
return phys;
}
let mut out = vec![T::default(); phys.len()];
for (c, slot) in out.iter_mut().enumerate() {
let mut p = 0usize;
for (q, &pos) in self.qubit_map.iter().enumerate() {
p |= ((c >> q) & 1) << pos;
}
*slot = phys[p];
}
out
}
fn apply_global_1q(&mut self, target: usize, mat: [[Complex64; 2]; 2]) {
let partner = self.context.rank() ^ (1usize << self.global_bit(target));
let (c_self, c_remote) = if self.rank_bit_set(target) {
(mat[1][1], mat[1][0])
} else {
(mat[0][0], mat[0][1])
};
let len = self.inner.state.len();
let chunk = self.exchange_chunk.min(len).max(1);
if self.recv.len() != chunk {
self.recv.resize(chunk, Complex64::new(0.0, 0.0));
}
let mut off = 0;
while off < len {
let end = (off + chunk).min(len);
self.count_exchange(end - off);
let recv = &mut self.recv[..end - off];
self.context
.comm()
.sendrecv_c64(partner, &self.inner.state[off..end], recv);
simd::combine_global_half(&mut self.inner.state[off..end], recv, c_self, c_remote);
off = end;
}
}
fn apply_global_diagonal_1q(&mut self, target: usize, d0: Complex64, d1: Complex64) {
let factor = if self.rank_bit_set(target) { d1 } else { d0 };
simd::scale_complex_slice(&mut self.inner.state, factor);
}
fn apply_local_controlled_1q(
&mut self,
local_controls: &[usize],
target: usize,
mat: [[Complex64; 2]; 2],
) {
let gate = match local_controls.len() {
0 => {
self.inner
.apply_1q_matrix(target, &mat)
.expect("local 1q matrix");
return;
}
1 => Gate::cu(mat),
n => Gate::mcu(mat, n as u8),
};
let mut targets: SmallVec<[usize; 4]> = local_controls.iter().copied().collect();
targets.push(target);
self.inner
.apply(&Instruction::Gate { gate, targets })
.expect("local controlled 1q");
}
fn apply_global_controlled_1q(
&mut self,
local_controls: &[usize],
target: usize,
mat: [[Complex64; 2]; 2],
) {
let partner = self.context.rank() ^ (1usize << self.global_bit(target));
let len = self.inner.state.len();
if self.recv.len() != len {
self.recv.resize(len, Complex64::new(0.0, 0.0));
}
self.count_exchange(len);
self.context
.comm()
.sendrecv_c64(partner, &self.inner.state, &mut self.recv);
let (c_self, c_remote) = if self.rank_bit_set(target) {
(mat[1][1], mat[1][0])
} else {
(mat[0][0], mat[0][1])
};
if local_controls.is_empty() {
simd::combine_global_half(&mut self.inner.state, &self.recv, c_self, c_remote);
return;
}
let ctrl_mask: usize = local_controls.iter().map(|&c| 1usize << c).sum();
for (i, amp) in self.inner.state.iter_mut().enumerate() {
if i & ctrl_mask == ctrl_mask {
*amp = c_self * *amp + c_remote * self.recv[i];
}
}
}
fn apply_controlled_dist(
&mut self,
controls: &[usize],
target: usize,
mat: [[Complex64; 2]; 2],
) {
let local = self.local_qubits();
let mut local_controls: SmallVec<[usize; 4]> = SmallVec::new();
for &c in controls {
if c < local {
local_controls.push(c);
} else if !self.rank_bit_set(c) {
return;
}
}
if target < local {
self.apply_local_controlled_1q(&local_controls, target, mat);
} else {
self.apply_global_controlled_1q(&local_controls, target, mat);
}
}
fn apply_controlled_phase_dist(&mut self, qubits: &[usize], phase: Complex64) {
let local = self.local_qubits();
let mut local_qubits: SmallVec<[usize; 8]> = SmallVec::new();
for &q in qubits {
if q < local {
local_qubits.push(q);
} else if !self.rank_bit_set(q) {
return;
}
}
self.apply_local_corner_phase(&local_qubits, phase);
}
fn apply_local_corner_phase(&mut self, local_qubits: &[usize], phase: Complex64) {
let z = Complex64::new(0.0, 0.0);
let one = Complex64::new(1.0, 0.0);
match local_qubits.len() {
0 => simd::scale_complex_slice(&mut self.inner.state, phase),
1 => self
.inner
.apply_1q_matrix(local_qubits[0], &[[one, z], [z, phase]])
.expect("local diagonal phase"),
n => {
let mat = [[one, z], [z, phase]];
let gate = if n == 2 {
Gate::cu(mat)
} else {
Gate::mcu(mat, (n - 1) as u8)
};
self.inner
.apply(&Instruction::Gate {
gate,
targets: local_qubits.iter().copied().collect(),
})
.expect("local controlled phase");
}
}
}
fn apply_rzz_dist(&mut self, q0: usize, q1: usize, theta: f64) {
let phase_same = Complex64::from_polar(1.0, -theta / 2.0);
let phase_diff = Complex64::from_polar(1.0, theta / 2.0);
self.apply_rzz_phases_dist(q0, q1, phase_same, phase_diff);
}
fn apply_rzz_phases_dist(
&mut self,
q0: usize,
q1: usize,
phase_same: Complex64,
phase_diff: Complex64,
) {
let local = self.local_qubits();
match (q0 < local, q1 < local) {
(true, true) => {
use crate::gates::{DiagEntry, DiagonalBatchData};
let entry = DiagEntry::Parity2q {
q0,
q1,
same: phase_same,
diff: phase_diff,
};
self.inner
.apply(&Instruction::Gate {
gate: Gate::DiagonalBatch(Box::new(DiagonalBatchData {
entries: vec![entry],
})),
targets: smallvec![q0, q1],
})
.expect("local parity diagonal");
}
(false, false) => {
let parity =
((self.rank_bit_set(q0) as usize) ^ (self.rank_bit_set(q1) as usize)) & 1;
let factor = [phase_same, phase_diff][parity];
simd::scale_complex_slice(&mut self.inner.state, factor);
}
(true, false) | (false, true) => {
let (local_q, global_q) = if q0 < local { (q0, q1) } else { (q1, q0) };
let gbit = self.rank_bit_set(global_q) as usize;
let d0 = [phase_same, phase_diff][gbit];
let d1 = [phase_same, phase_diff][gbit ^ 1];
let z = Complex64::new(0.0, 0.0);
self.inner
.apply_1q_matrix(local_q, &[[d0, z], [z, d1]])
.expect("local parity residual");
}
}
}
fn apply_swap_dist(&mut self, a: usize, b: usize) {
let local = self.local_qubits();
match (a < local, b < local) {
(true, true) => {
self.inner
.apply(&Instruction::Gate {
gate: Gate::Swap,
targets: smallvec![a, b],
})
.expect("local swap");
}
(false, false) => {
if self.rank_bit_set(a) == self.rank_bit_set(b) {
return;
}
let partner = self.context.rank()
^ (1usize << self.global_bit(a))
^ (1usize << self.global_bit(b));
let len = self.inner.state.len();
if self.recv.len() != len {
self.recv.resize(len, Complex64::new(0.0, 0.0));
}
self.count_exchange(len);
self.context
.comm()
.sendrecv_c64(partner, &self.inner.state, &mut self.recv);
self.inner.state.copy_from_slice(&self.recv);
}
(true, false) | (false, true) => {
let (local_q, global_q) = if a < local { (a, b) } else { (b, a) };
let partner = self.context.rank() ^ (1usize << self.global_bit(global_q));
let len = self.inner.state.len();
if self.recv.len() != len {
self.recv.resize(len, Complex64::new(0.0, 0.0));
}
self.count_exchange(len);
self.context
.comm()
.sendrecv_c64(partner, &self.inner.state, &mut self.recv);
let global_bit = self.rank_bit_set(global_q);
let half = 1usize << local_q;
let len = self.inner.state.len();
for i in 0..len {
let local_bit = (i >> local_q) & 1 == 1;
if local_bit != global_bit {
let partner_idx = i ^ half;
self.inner.state[i] = self.recv[partner_idx];
}
}
}
}
}
fn apply_2q_dist(&mut self, q0: usize, q1: usize, mat: &[[Complex64; 4]; 4]) {
let local = self.local_qubits();
match (q0 < local, q1 < local) {
(true, true) => self.apply_local_fused_2q(q0, q1, mat),
(true, false) | (false, true) => self.apply_2q_one_global(q0, q1, mat),
(false, false) => self.apply_2q_two_global(q0, q1, mat),
}
}
fn apply_local_fused_2q(&mut self, q0: usize, q1: usize, mat: &[[Complex64; 4]; 4]) {
self.inner
.apply(&Instruction::Gate {
gate: Gate::Fused2q(Box::new(*mat)),
targets: smallvec![q0, q1],
})
.expect("local fused 2q");
}
fn apply_2q_one_global(&mut self, q0: usize, q1: usize, mat: &[[Complex64; 4]; 4]) {
let local = self.local_qubits();
let (local_q, global_q, global_is_q0) = if q0 < local {
(q0, q1, false)
} else {
(q1, q0, true)
};
let partner = self.context.rank() ^ (1usize << self.global_bit(global_q));
let len = self.inner.state.len();
if self.recv.len() != len {
self.recv.resize(len, Complex64::new(0.0, 0.0));
}
self.count_exchange(len);
self.context
.comm()
.sendrecv_c64(partner, &self.inner.state, &mut self.recv);
let g = self.rank_bit_set(global_q) as usize;
let half = 1usize << local_q;
let basis = |gbit: usize, lbit: usize| -> usize {
if global_is_q0 {
(gbit << 1) | lbit
} else {
(lbit << 1) | gbit
}
};
let local_snapshot = self.inner.state.clone();
for i in 0..len {
let l = (i >> local_q) & 1;
let row = basis(g, l);
let sib0 = i & !half; let sib1 = i | half; let mut acc = mat[row][basis(g, 0)] * local_snapshot[sib0];
acc += mat[row][basis(g, 1)] * local_snapshot[sib1];
acc += mat[row][basis(1 - g, 0)] * self.recv[sib0];
acc += mat[row][basis(1 - g, 1)] * self.recv[sib1];
self.inner.state[i] = acc;
}
}
fn apply_2q_two_global(&mut self, q0: usize, q1: usize, mat: &[[Complex64; 4]; 4]) {
let b0 = self.global_bit(q0);
let b1 = self.global_bit(q1);
let rank = self.context.rank();
let g0 = (rank >> b0) & 1;
let g1 = (rank >> b1) & 1;
let rank_basis = (g0 << 1) | g1;
let len = self.inner.state.len();
let mut partners: [Option<Vec<Complex64>>; 4] = [None, None, None, None];
for (c, slot) in partners.iter_mut().enumerate() {
if c == rank_basis {
continue;
}
let c0 = (c >> 1) & 1;
let c1 = c & 1;
let partner = (rank & !(1 << b0) & !(1 << b1)) | (c0 << b0) | (c1 << b1);
let mut buf = vec![Complex64::new(0.0, 0.0); len];
self.exchange_messages += 1;
self.exchange_amplitudes += len as u64;
self.context
.comm()
.sendrecv_c64(partner, &self.inner.state, &mut buf);
*slot = Some(buf);
}
let mut out = vec![Complex64::new(0.0, 0.0); len];
let self_coeff = mat[rank_basis][rank_basis];
for i in 0..len {
let mut acc = self_coeff * self.inner.state[i];
for (c, slot) in partners.iter().enumerate() {
if let Some(slice) = slot {
acc += mat[rank_basis][c] * slice[i];
}
}
out[i] = acc;
}
self.inner.state = out;
}
fn apply_global_multi_qubit(&mut self, gate: &Gate, targets: &[usize]) -> Result<()> {
match gate {
Gate::Cx => {
self.apply_controlled_dist(&targets[..1], targets[1], Gate::X.matrix_2x2());
Ok(())
}
Gate::Cz => {
self.apply_controlled_phase_dist(
&[targets[0], targets[1]],
-Complex64::new(1.0, 0.0),
);
Ok(())
}
Gate::Swap => {
self.apply_swap_dist(targets[0], targets[1]);
Ok(())
}
Gate::Rzz(theta) => {
self.apply_rzz_dist(targets[0], targets[1], *theta);
Ok(())
}
Gate::Cu(mat) => {
if let Some(phase) = gate.controlled_phase() {
self.apply_controlled_phase_dist(&[targets[0], targets[1]], phase);
} else {
self.apply_controlled_dist(&targets[..1], targets[1], **mat);
}
Ok(())
}
Gate::Mcu(data) => {
let num_ctrl = data.num_controls as usize;
let controls = &targets[..num_ctrl];
let target = targets[num_ctrl];
if let Some(phase) = gate.controlled_phase() {
let mut corner: Vec<usize> = controls.to_vec();
corner.push(target);
self.apply_controlled_phase_dist(&corner, phase);
} else {
self.apply_controlled_dist(controls, target, data.mat);
}
Ok(())
}
Gate::Fused2q(mat) => {
self.apply_2q_dist(targets[0], targets[1], mat);
Ok(())
}
Gate::Multi2q(data) => {
for &(q0, q1, ref mat) in &data.gates {
self.apply_2q_dist(q0, q1, mat);
}
Ok(())
}
Gate::MultiFused(data) => {
for &(q, ref mat) in &data.gates {
if q < self.local_qubits() {
self.inner.apply_1q_matrix(q, mat).expect("local 1q matrix");
} else if is_diagonal_2x2(mat) {
self.apply_global_diagonal_1q(q, mat[0][0], mat[1][1]);
} else {
self.apply_global_1q(q, *mat);
}
}
Ok(())
}
Gate::BatchPhase(data) => {
let control = targets[0];
for &(target, phase) in &data.phases {
self.apply_controlled_phase_dist(&[control, target], phase);
}
Ok(())
}
Gate::BatchRzz(data) => {
for &(q0, q1, theta) in &data.edges {
self.apply_rzz_dist(q0, q1, theta);
}
Ok(())
}
Gate::DiagonalBatch(data) => {
for entry in &data.entries {
self.apply_diag_entry_dist(entry);
}
Ok(())
}
_ => Err(self.unsupported("gate spanning a global qubit")),
}
}
fn apply_diag_entry_dist(&mut self, entry: &crate::gates::DiagEntry) {
use crate::gates::DiagEntry;
match *entry {
DiagEntry::Phase1q { qubit, d0, d1 } => {
if qubit < self.local_qubits() {
self.inner
.apply_1q_matrix(
qubit,
&[
[d0, Complex64::new(0.0, 0.0)],
[Complex64::new(0.0, 0.0), d1],
],
)
.expect("local diagonal 1q");
} else {
self.apply_global_diagonal_1q(qubit, d0, d1);
}
}
DiagEntry::Phase2q { q0, q1, phase } => {
self.apply_controlled_phase_dist(&[q0, q1], phase);
}
DiagEntry::Parity2q {
q0, q1, same, diff, ..
} => {
self.apply_rzz_phases_dist(q0, q1, same, diff);
}
}
}
fn prob_outcome_global(&self, qubit: usize, outcome: bool) -> f64 {
let norm_sq = self.inner.pending_norm * self.inner.pending_norm;
let local_prob = if qubit < self.local_qubits() {
let half = 1usize << qubit;
let block_size = half << 1;
let mut acc = 0.0f64;
for block in self.inner.state.chunks(block_size) {
let (lo, hi) = block.split_at(half);
acc += simd::norm_sqr_sum(if outcome { hi } else { lo });
}
acc
} else if self.rank_bit_set(qubit) == outcome {
simd::norm_sqr_sum(&self.inner.state)
} else {
0.0
};
self.context.comm().allreduce_sum_f64(local_prob) * norm_sq
}
fn prob_one_global(&self, qubit: usize) -> f64 {
self.prob_outcome_global(qubit, true)
}
fn measure_dist(&mut self, qubit: usize, classical_bit: usize) {
let qubit = self.physical_qubit(qubit);
let prob_one = self.prob_one_global(qubit);
let outcome = self.meas_rng.random::<f64>() < prob_one;
self.inner.classical_bits[classical_bit] = outcome;
self.collapse(qubit, outcome);
self.inner.pending_norm *= measurement_inv_norm(outcome, prob_one);
}
#[inline]
fn physical_qubit(&self, qubit: usize) -> usize {
if self.map_identity {
qubit
} else {
self.qubit_map[qubit]
}
}
pub fn sample_state_indices(&mut self, num_shots: usize, seed: u64) -> Result<Vec<u64>> {
if self.num_qubits > 53 {
return Err(self.unsupported(
"shot sampling above 53 qubits: index transport is exact only below 2^53",
));
}
if num_shots == 0 {
return Ok(Vec::new());
}
self.restore_identity_map();
debug_assert!(self.map_identity);
let mut local_cdf = self.inner.probabilities()?;
let mut acc = 0.0f64;
for p in local_cdf.iter_mut() {
acc += *p;
*p = acc;
}
let masses = self.context.comm().allgather_f64(&[acc]);
let mut rank_cdf = Vec::with_capacity(masses.len());
let mut total = 0.0f64;
for &m in &masses {
total += m;
rank_cdf.push(total);
}
if let Some(last) = rank_cdf.last_mut() {
*last = 1.0;
}
let rank = self.context.rank();
let local_qubits = self.local_qubits();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut indices = vec![0.0f64; num_shots];
for slot in indices.iter_mut() {
let r: f64 = rng.random();
let owner = rank_cdf.partition_point(|&c| c < r);
if owner != rank {
continue;
}
let residual = if owner == 0 {
r
} else {
r - rank_cdf[owner - 1]
};
let local_idx = crate::sim::shots::sample_from_cdf(&local_cdf, residual);
*slot = (((rank as u64) << local_qubits) | local_idx as u64) as f64;
}
const REDUCE_CHUNK: usize = 1 << 20;
for chunk in indices.chunks_mut(REDUCE_CHUNK) {
self.context.comm().allreduce_sum_f64_slice(chunk);
}
Ok(indices.into_iter().map(|v| v as u64).collect())
}
fn collapse(&mut self, qubit: usize, outcome: bool) {
let zero = Complex64::new(0.0, 0.0);
if qubit < self.local_qubits() {
let half = 1usize << qubit;
let block_size = half << 1;
for block in self.inner.state.chunks_mut(block_size) {
let (lo, hi) = block.split_at_mut(half);
if outcome {
simd::zero_slice(lo);
} else {
simd::zero_slice(hi);
}
}
} else if self.rank_bit_set(qubit) != outcome {
for amp in self.inner.state.iter_mut() {
*amp = zero;
}
}
}
fn reset_dist(&mut self, qubit: usize) {
let qubit = self.physical_qubit(qubit);
let prob_zero = self.prob_outcome_global(qubit, false);
if prob_zero > 0.0 {
self.collapse(qubit, false);
self.inner.pending_norm *= 1.0 / prob_zero.sqrt();
} else {
simd::zero_slice(&mut self.inner.state);
if self.context.rank() == 0 {
if let Some(amp) = self.inner.state.get_mut(0) {
*amp = Complex64::new(1.0, 0.0);
}
}
self.inner.pending_norm = 1.0;
}
}
fn apply_gate(&mut self, gate: &Gate, targets: &[usize]) -> Result<()> {
if self.global_qubits == 0 {
return self.inner.apply(&Instruction::Gate {
gate: gate.clone(),
targets: targets.into(),
});
}
if self.relabel {
self.touch_instruction(gate, targets);
if matches!(gate, Gate::Swap) {
self.swap_circuit_qubits(targets[0], targets[1]);
return Ok(());
}
let req = required_local_qubits(gate, targets);
if !req.is_empty() {
self.make_local(&req);
}
}
if matches!(gate, Gate::QftBlock { .. }) && !self.map_identity {
return Err(self.unsupported("QftBlock with a permuted qubit map"));
}
let (pgate, ptargets) = self.to_physical(gate, targets);
if self.instruction_qubits_local(&pgate, &ptargets) {
return self.inner.apply(&Instruction::Gate {
gate: pgate.into_owned(),
targets: ptargets,
});
}
let pgate = pgate.as_ref();
if pgate.num_qubits() == 1 {
let target = ptargets[0];
let mat = pgate.matrix_2x2();
if pgate.is_diagonal_1q() {
self.apply_global_diagonal_1q(target, mat[0][0], mat[1][1]);
} else {
self.apply_global_1q(target, mat);
}
return Ok(());
}
self.apply_global_multi_qubit(pgate, &ptargets)
}
fn unsupported(&self, operation: &str) -> PrismError {
PrismError::BackendUnsupported {
backend: BACKEND_NAME.to_string(),
operation: operation.to_string(),
}
}
}
impl Backend for DistributedStatevectorBackend {
fn name(&self) -> &'static str {
BACKEND_NAME
}
fn supports_fused_gates(&self) -> bool {
true
}
fn supports_qft_block(&self) -> bool {
self.is_single_rank() && self.inner.supports_qft_block()
}
fn init(&mut self, num_qubits: usize, num_classical_bits: usize) -> Result<()> {
let size = self.context.size();
if !size.is_power_of_two() {
return Err(PrismError::BackendUnsupported {
backend: BACKEND_NAME.to_string(),
operation: format!("rank count {size} is not a power of two"),
});
}
let p = size.trailing_zeros() as usize;
let min_local = crate::distributed::min_local_qubits();
if size > 1 && num_qubits < p + min_local {
return Err(PrismError::BackendUnsupported {
backend: BACKEND_NAME.to_string(),
operation: format!(
"{num_qubits} qubits across {size} ranks leaves fewer than \
{min_local} local qubits per rank"
),
});
}
self.num_qubits = num_qubits;
self.global_qubits = p;
self.meas_rng = ChaCha8Rng::seed_from_u64(self.seed);
self.exchange_messages = 0;
self.exchange_amplitudes = 0;
self.qubit_map = (0..num_qubits).collect();
self.phys_map = (0..num_qubits).collect();
self.map_identity = true;
self.last_used = vec![0; num_qubits];
self.tick = 0;
let local_qubits = num_qubits - p;
self.inner.init(local_qubits, num_classical_bits)?;
if self.context.rank() != 0 {
if let Some(amp) = self.inner.state.get_mut(0) {
*amp = Complex64::new(0.0, 0.0);
}
}
Ok(())
}
fn apply(&mut self, instruction: &Instruction) -> Result<()> {
match instruction {
Instruction::Measure {
qubit,
classical_bit,
} => {
self.measure_dist(*qubit, *classical_bit);
Ok(())
}
Instruction::Reset { qubit } => {
self.reset_dist(*qubit);
Ok(())
}
Instruction::Barrier { .. } => Ok(()),
Instruction::Conditional {
condition,
gate,
targets,
} => {
if condition.evaluate(self.inner.classical_results()) {
self.apply_gate(gate, targets)
} else {
Ok(())
}
}
Instruction::Gate { gate, targets } => self.apply_gate(gate, targets),
}
}
fn classical_results(&self) -> &[bool] {
self.inner.classical_results()
}
fn probabilities(&self) -> Result<Vec<f64>> {
let local = self.inner.probabilities()?;
if self.global_qubits == 0 {
return Ok(local);
}
dense_probability_len(BACKEND_NAME, self.num_qubits)?;
let gathered = self.context.comm().allgather_f64(&local);
Ok(self.unpermuted(gathered))
}
fn num_qubits(&self) -> usize {
self.num_qubits
}
fn export_statevector(&self) -> Result<Vec<Complex64>> {
let local = self.inner.export_statevector()?;
if self.global_qubits == 0 {
return Ok(local);
}
dense_statevector_len(BACKEND_NAME, "statevector export", self.num_qubits)?;
let gathered = self.context.comm().allgather_c64(&local);
Ok(self.unpermuted(gathered))
}
fn qubit_probability(&self, qubit: usize) -> Result<f64> {
Ok(self.prob_one_global(self.physical_qubit(qubit)))
}
fn reset(&mut self, qubit: usize) -> Result<()> {
self.reset_dist(qubit);
Ok(())
}
}