use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use quantrs2_core::{gate::GateOp, qubit::QubitId};
use crate::builder::Circuit;
use crate::commutation::CommutationAnalyzer;
use crate::dag::{circuit_to_dag, CircuitDag};
#[derive(Debug, Clone)]
pub struct CircuitSlice {
pub id: usize,
pub gate_indices: Vec<usize>,
pub qubits: HashSet<u32>,
pub dependencies: HashSet<usize>,
pub dependents: HashSet<usize>,
pub depth: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SlicingStrategy {
MaxQubits(usize),
MaxGates(usize),
DepthBased(usize),
MinCommunication,
LoadBalanced(usize), ConnectivityBased,
}
#[derive(Debug)]
pub struct SlicingResult {
pub slices: Vec<CircuitSlice>,
pub communication_cost: usize,
pub parallel_depth: usize,
pub schedule: Vec<Vec<usize>>, }
pub struct CircuitSlicer {
commutation_analyzer: CommutationAnalyzer,
}
impl CircuitSlicer {
#[must_use]
pub fn new() -> Self {
Self {
commutation_analyzer: CommutationAnalyzer::new(),
}
}
#[must_use]
pub fn slice_circuit<const N: usize>(
&self,
circuit: &Circuit<N>,
strategy: SlicingStrategy,
) -> SlicingResult {
match strategy {
SlicingStrategy::MaxQubits(max_qubits) => self.slice_by_max_qubits(circuit, max_qubits),
SlicingStrategy::MaxGates(max_gates) => self.slice_by_max_gates(circuit, max_gates),
SlicingStrategy::DepthBased(max_depth) => self.slice_by_depth(circuit, max_depth),
SlicingStrategy::MinCommunication => self.slice_min_communication(circuit),
SlicingStrategy::LoadBalanced(num_processors) => {
self.slice_load_balanced(circuit, num_processors)
}
SlicingStrategy::ConnectivityBased => self.slice_by_connectivity(circuit),
}
}
fn slice_by_max_qubits<const N: usize>(
&self,
circuit: &Circuit<N>,
max_qubits: usize,
) -> SlicingResult {
let mut slices = Vec::new();
let mut current_slice = CircuitSlice {
id: 0,
gate_indices: Vec::new(),
qubits: HashSet::new(),
dependencies: HashSet::new(),
dependents: HashSet::new(),
depth: 0,
};
let mut qubit_last_slice: HashMap<u32, usize> = HashMap::new();
for (gate_idx, gate) in circuit.gates().iter().enumerate() {
let gate_qubits: HashSet<u32> = gate
.qubits()
.iter()
.map(quantrs2_core::QubitId::id)
.collect();
let combined_qubits: HashSet<u32> =
current_slice.qubits.union(&gate_qubits).copied().collect();
if !current_slice.gate_indices.is_empty() && combined_qubits.len() > max_qubits {
let slice_id = slices.len();
current_slice.id = slice_id;
for &qubit in ¤t_slice.qubits {
qubit_last_slice.insert(qubit, slice_id);
}
slices.push(current_slice);
current_slice = CircuitSlice {
id: slice_id + 1,
gate_indices: vec![gate_idx],
qubits: gate_qubits.clone(),
dependencies: HashSet::new(),
dependents: HashSet::new(),
depth: 0,
};
for &qubit in &gate_qubits {
if let Some(&prev_slice) = qubit_last_slice.get(&qubit) {
current_slice.dependencies.insert(prev_slice);
slices[prev_slice].dependents.insert(slice_id + 1);
}
}
} else {
current_slice.gate_indices.push(gate_idx);
current_slice.qubits.extend(gate_qubits);
}
}
if !current_slice.gate_indices.is_empty() {
let slice_id = slices.len();
current_slice.id = slice_id;
slices.push(current_slice);
}
self.calculate_depths_and_schedule(slices)
}
fn slice_by_max_gates<const N: usize>(
&self,
circuit: &Circuit<N>,
max_gates: usize,
) -> SlicingResult {
let mut slices = Vec::new();
let gates = circuit.gates();
for (chunk_idx, chunk) in gates.chunks(max_gates).enumerate() {
let mut slice = CircuitSlice {
id: chunk_idx,
gate_indices: Vec::new(),
qubits: HashSet::new(),
dependencies: HashSet::new(),
dependents: HashSet::new(),
depth: 0,
};
let base_idx = chunk_idx * max_gates;
for (local_idx, gate) in chunk.iter().enumerate() {
slice.gate_indices.push(base_idx + local_idx);
slice
.qubits
.extend(gate.qubits().iter().map(quantrs2_core::QubitId::id));
}
slices.push(slice);
}
self.add_qubit_dependencies(&mut slices, gates);
self.calculate_depths_and_schedule(slices)
}
fn slice_by_depth<const N: usize>(
&self,
circuit: &Circuit<N>,
max_depth: usize,
) -> SlicingResult {
let dag = circuit_to_dag(circuit);
let mut slices = Vec::new();
let max_circuit_depth = dag.max_depth();
for depth_start in (0..=max_circuit_depth).step_by(max_depth) {
let depth_end = (depth_start + max_depth).min(max_circuit_depth + 1);
let mut slice = CircuitSlice {
id: slices.len(),
gate_indices: Vec::new(),
qubits: HashSet::new(),
dependencies: HashSet::new(),
dependents: HashSet::new(),
depth: depth_start / max_depth,
};
for depth in depth_start..depth_end {
for &node_id in &dag.nodes_at_depth(depth) {
slice.gate_indices.push(node_id);
let node = &dag.nodes()[node_id];
slice
.qubits
.extend(node.gate.qubits().iter().map(quantrs2_core::QubitId::id));
}
}
if !slice.gate_indices.is_empty() {
slices.push(slice);
}
}
for i in 1..slices.len() {
slices[i].dependencies.insert(i - 1);
slices[i - 1].dependents.insert(i);
}
self.calculate_depths_and_schedule(slices)
}
fn slice_min_communication<const N: usize>(&self, circuit: &Circuit<N>) -> SlicingResult {
let gates = circuit.gates();
let n_gates = gates.len();
let mut adjacency = vec![vec![0.0; n_gates]; n_gates];
for i in 0..n_gates {
for j in i + 1..n_gates {
let qubits_i: HashSet<u32> = gates[i]
.qubits()
.iter()
.map(quantrs2_core::QubitId::id)
.collect();
let qubits_j: HashSet<u32> = gates[j]
.qubits()
.iter()
.map(quantrs2_core::QubitId::id)
.collect();
let shared_qubits = qubits_i.intersection(&qubits_j).count();
if shared_qubits > 0 {
adjacency[i][j] = shared_qubits as f64;
adjacency[j][i] = shared_qubits as f64;
}
}
}
let num_slices = (n_gates as f64).sqrt().ceil() as usize;
let mut slices = Vec::new();
let mut assigned = vec![false; n_gates];
for slice_id in 0..num_slices {
let mut slice = CircuitSlice {
id: slice_id,
gate_indices: Vec::new(),
qubits: HashSet::new(),
dependencies: HashSet::new(),
dependents: HashSet::new(),
depth: 0,
};
for gate_idx in 0..n_gates {
if !assigned[gate_idx] {
let affinity = slice
.gate_indices
.iter()
.map(|&idx| adjacency[gate_idx][idx])
.sum::<f64>();
if slice.gate_indices.is_empty() || affinity > 0.0 {
slice.gate_indices.push(gate_idx);
slice.qubits.extend(
gates[gate_idx]
.qubits()
.iter()
.map(quantrs2_core::QubitId::id),
);
assigned[gate_idx] = true;
if slice.gate_indices.len() >= n_gates / num_slices {
break;
}
}
}
}
if !slice.gate_indices.is_empty() {
slices.push(slice);
}
}
for gate_idx in 0..n_gates {
if !assigned[gate_idx] {
let mut best_slice = 0;
let mut best_affinity = 0.0;
for (slice_idx, slice) in slices.iter().enumerate() {
let affinity = slice
.gate_indices
.iter()
.map(|&idx| adjacency[gate_idx][idx])
.sum::<f64>();
if affinity > best_affinity {
best_affinity = affinity;
best_slice = slice_idx;
}
}
slices[best_slice].gate_indices.push(gate_idx);
slices[best_slice].qubits.extend(
gates[gate_idx]
.qubits()
.iter()
.map(quantrs2_core::QubitId::id),
);
}
}
self.add_qubit_dependencies(&mut slices, gates);
self.calculate_depths_and_schedule(slices)
}
fn slice_load_balanced<const N: usize>(
&self,
circuit: &Circuit<N>,
num_processors: usize,
) -> SlicingResult {
let gates = circuit.gates();
let gates_per_processor = gates.len().div_ceil(num_processors);
self.slice_by_max_gates(circuit, gates_per_processor)
}
fn slice_by_connectivity<const N: usize>(&self, circuit: &Circuit<N>) -> SlicingResult {
let gates = circuit.gates();
let mut slices: Vec<CircuitSlice> = Vec::new();
let mut gate_to_slice: HashMap<usize, usize> = HashMap::new();
for (gate_idx, gate) in gates.iter().enumerate() {
let gate_qubits: HashSet<u32> = gate
.qubits()
.iter()
.map(quantrs2_core::QubitId::id)
.collect();
let mut connected_slices: Vec<usize> = Vec::new();
for (slice_idx, slice) in slices.iter().enumerate() {
if !slice.qubits.is_disjoint(&gate_qubits) {
connected_slices.push(slice_idx);
}
}
if connected_slices.is_empty() {
let slice_id = slices.len();
let slice = CircuitSlice {
id: slice_id,
gate_indices: vec![gate_idx],
qubits: gate_qubits,
dependencies: HashSet::new(),
dependents: HashSet::new(),
depth: 0,
};
slices.push(slice);
gate_to_slice.insert(gate_idx, slice_id);
} else if connected_slices.len() == 1 {
let slice_idx = connected_slices[0];
slices[slice_idx].gate_indices.push(gate_idx);
slices[slice_idx].qubits.extend(gate_qubits);
gate_to_slice.insert(gate_idx, slice_idx);
} else {
let main_slice = connected_slices[0];
slices[main_slice].gate_indices.push(gate_idx);
slices[main_slice].qubits.extend(gate_qubits);
gate_to_slice.insert(gate_idx, main_slice);
for &slice_idx in connected_slices[1..].iter().rev() {
let slice = slices.remove(slice_idx);
let gate_indices = slice.gate_indices.clone();
slices[main_slice].gate_indices.extend(slice.gate_indices);
slices[main_slice].qubits.extend(slice.qubits);
for &g_idx in &gate_indices {
gate_to_slice.insert(g_idx, main_slice);
}
}
}
}
for (new_id, slice) in slices.iter_mut().enumerate() {
slice.id = new_id;
}
self.add_order_dependencies(&mut slices, gates, &gate_to_slice);
self.calculate_depths_and_schedule(slices)
}
fn add_qubit_dependencies(
&self,
slices: &mut [CircuitSlice],
gates: &[Arc<dyn GateOp + Send + Sync>],
) {
let mut qubit_last_slice: HashMap<u32, usize> = HashMap::new();
for slice in slices.iter_mut() {
for &gate_idx in &slice.gate_indices {
let gate_qubits = gates[gate_idx].qubits();
for qubit in gate_qubits {
if let Some(&prev_slice) = qubit_last_slice.get(&qubit.id()) {
if prev_slice != slice.id {
slice.dependencies.insert(prev_slice);
}
}
}
}
for &qubit in &slice.qubits {
qubit_last_slice.insert(qubit, slice.id);
}
}
for i in 0..slices.len() {
let deps: Vec<usize> = slices[i].dependencies.iter().copied().collect();
for dep in deps {
slices[dep].dependents.insert(i);
}
}
}
fn add_order_dependencies(
&self,
slices: &mut [CircuitSlice],
gates: &[Arc<dyn GateOp + Send + Sync>],
gate_to_slice: &HashMap<usize, usize>,
) {
for (gate_idx, gate) in gates.iter().enumerate() {
let slice_idx = gate_to_slice[&gate_idx];
let gate_qubits: HashSet<u32> = gate
.qubits()
.iter()
.map(quantrs2_core::QubitId::id)
.collect();
for prev_idx in 0..gate_idx {
let prev_slice = gate_to_slice[&prev_idx];
if prev_slice != slice_idx {
let prev_qubits: HashSet<u32> = gates[prev_idx]
.qubits()
.iter()
.map(quantrs2_core::QubitId::id)
.collect();
if !gate_qubits.is_disjoint(&prev_qubits) {
slices[slice_idx].dependencies.insert(prev_slice);
slices[prev_slice].dependents.insert(slice_idx);
}
}
}
}
}
fn calculate_depths_and_schedule(&self, mut slices: Vec<CircuitSlice>) -> SlicingResult {
let mut in_degree: HashMap<usize, usize> = HashMap::new();
for slice in &slices {
in_degree.insert(slice.id, slice.dependencies.len());
}
let mut queue = VecDeque::new();
let mut schedule = Vec::new();
let mut depths = HashMap::new();
for slice in &slices {
if slice.dependencies.is_empty() {
queue.push_back(slice.id);
depths.insert(slice.id, 0);
}
}
while !queue.is_empty() {
let mut current_level = Vec::new();
let level_size = queue.len();
for _ in 0..level_size {
let slice_id = queue
.pop_front()
.expect("queue is not empty (checked in while condition)");
current_level.push(slice_id);
if let Some(slice) = slices.iter().find(|s| s.id == slice_id) {
for &dep_id in &slice.dependents {
if let Some(degree) = in_degree.get_mut(&dep_id) {
*degree -= 1;
if *degree == 0 {
queue.push_back(dep_id);
if let Some(¤t_depth) = depths.get(&slice_id) {
depths.insert(dep_id, current_depth + 1);
}
}
}
}
}
}
schedule.push(current_level);
}
for slice in &mut slices {
slice.depth = depths.get(&slice.id).copied().unwrap_or(0);
}
let communication_cost = self.calculate_communication_cost(&slices);
SlicingResult {
slices,
communication_cost,
parallel_depth: schedule.len(),
schedule,
}
}
fn calculate_communication_cost(&self, slices: &[CircuitSlice]) -> usize {
let mut total_cost = 0;
for slice in slices {
for &dep_id in &slice.dependencies {
if let Some(dep_slice) = slices.iter().find(|s| s.id == dep_id) {
let shared: HashSet<_> = slice.qubits.intersection(&dep_slice.qubits).collect();
total_cost += shared.len();
}
}
}
total_cost
}
}
impl Default for CircuitSlicer {
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> Circuit<N> {
#[must_use]
pub fn slice(&self, strategy: SlicingStrategy) -> SlicingResult {
let slicer = CircuitSlicer::new();
slicer.slice_circuit(self, strategy)
}
}
#[cfg(test)]
mod tests {
use super::*;
use quantrs2_core::gate::multi::CNOT;
use quantrs2_core::gate::single::{Hadamard, PauliX};
#[test]
fn test_slice_by_max_qubits() {
let mut circuit = Circuit::<4>::new();
circuit
.add_gate(Hadamard { target: QubitId(0) })
.expect("failed to add H gate to qubit 0");
circuit
.add_gate(Hadamard { target: QubitId(1) })
.expect("failed to add H gate to qubit 1");
circuit
.add_gate(Hadamard { target: QubitId(2) })
.expect("failed to add H gate to qubit 2");
circuit
.add_gate(Hadamard { target: QubitId(3) })
.expect("failed to add H gate to qubit 3");
circuit
.add_gate(CNOT {
control: QubitId(0),
target: QubitId(1),
})
.expect("failed to add CNOT gate on qubits 0,1");
circuit
.add_gate(CNOT {
control: QubitId(2),
target: QubitId(3),
})
.expect("failed to add CNOT gate on qubits 2,3");
let slicer = CircuitSlicer::new();
let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxQubits(2));
assert!(result.slices.len() >= 2);
for slice in &result.slices {
assert!(slice.qubits.len() <= 2);
}
}
#[test]
fn test_slice_by_max_gates() {
let mut circuit = Circuit::<3>::new();
for i in 0..6 {
circuit
.add_gate(Hadamard {
target: QubitId((i % 3) as u32),
})
.expect("failed to add Hadamard gate in loop");
}
let slicer = CircuitSlicer::new();
let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxGates(2));
assert_eq!(result.slices.len(), 3);
for slice in &result.slices {
assert!(slice.gate_indices.len() <= 2);
}
}
#[test]
fn test_slice_dependencies() {
let mut circuit = Circuit::<2>::new();
circuit
.add_gate(Hadamard { target: QubitId(0) })
.expect("failed to add H gate to qubit 0");
circuit
.add_gate(Hadamard { target: QubitId(1) })
.expect("failed to add H gate to qubit 1");
circuit
.add_gate(CNOT {
control: QubitId(0),
target: QubitId(1),
})
.expect("failed to add CNOT gate on qubits 0,1");
circuit
.add_gate(PauliX { target: QubitId(0) })
.expect("failed to add X gate to qubit 0");
let slicer = CircuitSlicer::new();
let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxGates(2));
let mut has_dependencies = false;
for slice in &result.slices {
if !slice.dependencies.is_empty() {
has_dependencies = true;
break;
}
}
assert!(has_dependencies);
}
#[test]
fn test_parallel_schedule() {
let mut circuit = Circuit::<4>::new();
circuit
.add_gate(Hadamard { target: QubitId(0) })
.expect("failed to add H gate to qubit 0");
circuit
.add_gate(Hadamard { target: QubitId(1) })
.expect("failed to add H gate to qubit 1");
circuit
.add_gate(Hadamard { target: QubitId(2) })
.expect("failed to add H gate to qubit 2");
circuit
.add_gate(Hadamard { target: QubitId(3) })
.expect("failed to add H gate to qubit 3");
let slicer = CircuitSlicer::new();
let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxQubits(1));
assert_eq!(result.parallel_depth, 1);
assert_eq!(result.schedule[0].len(), 4);
}
}