#[cfg(test)]
mod tests;
use std::sync::Arc;
use num_complex::Complex64;
use crate::backend::simd;
use crate::backend::statevector::StatevectorBackend;
use crate::backend::{dense_probability_len, dense_statevector_len, Backend};
use crate::circuit::Instruction;
use crate::distributed::DistributedContext;
use crate::error::{PrismError, Result};
const BACKEND_NAME: &str = "distributed_statevector";
pub struct DistributedStatevectorBackend {
context: Arc<DistributedContext>,
inner: StatevectorBackend,
num_qubits: usize,
global_qubits: usize,
recv: Vec<Complex64>,
}
impl DistributedStatevectorBackend {
pub fn new(context: Arc<DistributedContext>, seed: u64) -> Self {
Self {
context,
inner: StatevectorBackend::new(seed),
num_qubits: 0,
global_qubits: 0,
recv: Vec::new(),
}
}
#[inline]
fn local_qubits(&self) -> usize {
self.num_qubits - self.global_qubits
}
#[inline]
fn is_single_rank(&self) -> bool {
self.context.size() == 1
}
#[inline]
fn targets_are_local(&self, targets: &[usize]) -> bool {
let local = self.local_qubits();
targets.iter().all(|&q| q < local)
}
#[inline]
fn global_bit(&self, q: usize) -> usize {
q - self.local_qubits()
}
#[inline]
fn rank_bit_set(&self, q: usize) -> bool {
(self.context.rank() >> self.global_bit(q)) & 1 == 1
}
fn apply_global_1q(&mut self, target: usize, mat: [[Complex64; 2]; 2]) {
let partner = self.context.rank() ^ (1usize << self.global_bit(target));
let len = self.inner.state.len();
if self.recv.len() != len {
self.recv.resize(len, Complex64::new(0.0, 0.0));
}
self.context
.comm()
.sendrecv_c64(partner, &self.inner.state, &mut self.recv);
let (c_self, c_remote) = if self.rank_bit_set(target) {
(mat[1][1], mat[1][0])
} else {
(mat[0][0], mat[0][1])
};
simd::combine_global_half(&mut self.inner.state, &self.recv, c_self, c_remote);
}
fn apply_global_diagonal_1q(&mut self, target: usize, d0: Complex64, d1: Complex64) {
let factor = if self.rank_bit_set(target) { d1 } else { d0 };
simd::scale_complex_slice(&mut self.inner.state, factor);
}
fn unsupported(&self, operation: &str) -> PrismError {
PrismError::BackendUnsupported {
backend: BACKEND_NAME.to_string(),
operation: operation.to_string(),
}
}
}
impl Backend for DistributedStatevectorBackend {
fn name(&self) -> &'static str {
BACKEND_NAME
}
fn supports_fused_gates(&self) -> bool {
self.is_single_rank()
}
fn supports_qft_block(&self) -> bool {
self.is_single_rank() && self.inner.supports_qft_block()
}
fn init(&mut self, num_qubits: usize, num_classical_bits: usize) -> Result<()> {
let size = self.context.size();
if !size.is_power_of_two() {
return Err(PrismError::BackendUnsupported {
backend: BACKEND_NAME.to_string(),
operation: format!("rank count {size} is not a power of two"),
});
}
let p = size.trailing_zeros() as usize;
let min_local = crate::distributed::min_local_qubits();
if size > 1 && num_qubits < p + min_local {
return Err(PrismError::BackendUnsupported {
backend: BACKEND_NAME.to_string(),
operation: format!(
"{num_qubits} qubits across {size} ranks leaves fewer than \
{min_local} local qubits per rank"
),
});
}
self.num_qubits = num_qubits;
self.global_qubits = p;
let local_qubits = num_qubits - p;
self.inner.init(local_qubits, num_classical_bits)?;
if self.context.rank() != 0 {
if let Some(amp) = self.inner.state.get_mut(0) {
*amp = Complex64::new(0.0, 0.0);
}
}
Ok(())
}
fn apply(&mut self, instruction: &Instruction) -> Result<()> {
if self.global_qubits == 0 {
return self.inner.apply(instruction);
}
match instruction {
Instruction::Gate { gate, targets } => {
if self.targets_are_local(targets) {
return self.inner.apply(instruction);
}
if gate.num_qubits() == 1 {
let target = targets[0];
let mat = gate.matrix_2x2();
if gate.is_diagonal_1q() {
self.apply_global_diagonal_1q(target, mat[0][0], mat[1][1]);
} else {
self.apply_global_1q(target, mat);
}
Ok(())
} else {
Err(self.unsupported("multi-qubit gate spanning a global qubit"))
}
}
Instruction::Barrier { .. } => Ok(()),
Instruction::Measure { .. } => Err(self.unsupported("distributed measurement")),
Instruction::Reset { .. } => Err(self.unsupported("distributed reset")),
Instruction::Conditional { .. } => {
Err(self.unsupported("distributed classical conditional"))
}
}
}
fn classical_results(&self) -> &[bool] {
self.inner.classical_results()
}
fn probabilities(&self) -> Result<Vec<f64>> {
let local = self.inner.probabilities()?;
if self.global_qubits == 0 {
return Ok(local);
}
dense_probability_len(BACKEND_NAME, self.num_qubits)?;
Ok(self.context.comm().allgather_f64(&local))
}
fn num_qubits(&self) -> usize {
self.num_qubits
}
fn export_statevector(&self) -> Result<Vec<Complex64>> {
let local = self.inner.export_statevector()?;
if self.global_qubits == 0 {
return Ok(local);
}
dense_statevector_len(BACKEND_NAME, "statevector export", self.num_qubits)?;
Ok(self.context.comm().allgather_c64(&local))
}
}