quantrs2_circuit/
slicing.rs

1//! Circuit slicing for parallel execution.
2//!
3//! This module provides functionality to slice quantum circuits into
4//! smaller subcircuits that can be executed in parallel or distributed
5//! across multiple quantum processors.
6
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::sync::Arc;
9
10use quantrs2_core::{gate::GateOp, qubit::QubitId};
11
12use crate::builder::Circuit;
13use crate::commutation::CommutationAnalyzer;
14use crate::dag::{circuit_to_dag, CircuitDag};
15
16/// A slice of a circuit that can be executed independently
17#[derive(Debug, Clone)]
18pub struct CircuitSlice {
19    /// Unique identifier for this slice
20    pub id: usize,
21    /// Gates in this slice (indices into original circuit)
22    pub gate_indices: Vec<usize>,
23    /// Qubits used in this slice
24    pub qubits: HashSet<u32>,
25    /// Dependencies on other slices (slice IDs)
26    pub dependencies: HashSet<usize>,
27    /// Slices that depend on this one
28    pub dependents: HashSet<usize>,
29    /// Depth of this slice in the dependency graph
30    pub depth: usize,
31}
32
33/// Strategy for slicing circuits
34#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum SlicingStrategy {
36    /// Slice by maximum number of qubits per slice
37    MaxQubits(usize),
38    /// Slice by maximum number of gates per slice
39    MaxGates(usize),
40    /// Slice by circuit depth
41    DepthBased(usize),
42    /// Slice to minimize communication between slices
43    MinCommunication,
44    /// Slice for load balancing across processors
45    LoadBalanced(usize), // number of processors
46    /// Custom slicing based on qubit connectivity
47    ConnectivityBased,
48}
49
50/// Result of circuit slicing
51#[derive(Debug)]
52pub struct SlicingResult {
53    /// The slices
54    pub slices: Vec<CircuitSlice>,
55    /// Communication cost between slices (number of qubits)
56    pub communication_cost: usize,
57    /// Maximum parallel depth
58    pub parallel_depth: usize,
59    /// Slice scheduling order
60    pub schedule: Vec<Vec<usize>>, // Groups of slices that can run in parallel
61}
62
63/// Circuit slicer
64pub struct CircuitSlicer {
65    /// Commutation analyzer for optimization
66    commutation_analyzer: CommutationAnalyzer,
67}
68
69impl CircuitSlicer {
70    /// Create a new circuit slicer
71    pub fn new() -> Self {
72        Self {
73            commutation_analyzer: CommutationAnalyzer::new(),
74        }
75    }
76
77    /// Slice a circuit according to the given strategy
78    pub fn slice_circuit<const N: usize>(
79        &self,
80        circuit: &Circuit<N>,
81        strategy: SlicingStrategy,
82    ) -> SlicingResult {
83        match strategy {
84            SlicingStrategy::MaxQubits(max_qubits) => self.slice_by_max_qubits(circuit, max_qubits),
85            SlicingStrategy::MaxGates(max_gates) => self.slice_by_max_gates(circuit, max_gates),
86            SlicingStrategy::DepthBased(max_depth) => self.slice_by_depth(circuit, max_depth),
87            SlicingStrategy::MinCommunication => self.slice_min_communication(circuit),
88            SlicingStrategy::LoadBalanced(num_processors) => {
89                self.slice_load_balanced(circuit, num_processors)
90            }
91            SlicingStrategy::ConnectivityBased => self.slice_by_connectivity(circuit),
92        }
93    }
94
95    /// Slice circuit limiting qubits per slice
96    fn slice_by_max_qubits<const N: usize>(
97        &self,
98        circuit: &Circuit<N>,
99        max_qubits: usize,
100    ) -> SlicingResult {
101        let mut slices = Vec::new();
102        let mut current_slice = CircuitSlice {
103            id: 0,
104            gate_indices: Vec::new(),
105            qubits: HashSet::new(),
106            dependencies: HashSet::new(),
107            dependents: HashSet::new(),
108            depth: 0,
109        };
110
111        // Track which slice each qubit was last used in
112        let mut qubit_last_slice: HashMap<u32, usize> = HashMap::new();
113
114        for (gate_idx, gate) in circuit.gates().iter().enumerate() {
115            let gate_qubits: HashSet<u32> = gate.qubits().iter().map(|q| q.id()).collect();
116
117            // Check if adding this gate would exceed qubit limit
118            let combined_qubits: HashSet<u32> =
119                current_slice.qubits.union(&gate_qubits).cloned().collect();
120
121            if !current_slice.gate_indices.is_empty() && combined_qubits.len() > max_qubits {
122                // Need to start a new slice
123                let slice_id = slices.len();
124                current_slice.id = slice_id;
125
126                // Update dependencies based on qubit usage
127                for &qubit in &current_slice.qubits {
128                    qubit_last_slice.insert(qubit, slice_id);
129                }
130
131                slices.push(current_slice);
132
133                // Start new slice
134                current_slice = CircuitSlice {
135                    id: slice_id + 1,
136                    gate_indices: vec![gate_idx],
137                    qubits: gate_qubits.clone(),
138                    dependencies: HashSet::new(),
139                    dependents: HashSet::new(),
140                    depth: 0,
141                };
142
143                // Add dependencies from previous slices
144                for &qubit in &gate_qubits {
145                    if let Some(&prev_slice) = qubit_last_slice.get(&qubit) {
146                        current_slice.dependencies.insert(prev_slice);
147                        slices[prev_slice].dependents.insert(slice_id + 1);
148                    }
149                }
150            } else {
151                // Add gate to current slice
152                current_slice.gate_indices.push(gate_idx);
153                current_slice.qubits.extend(gate_qubits);
154            }
155        }
156
157        // Don't forget the last slice
158        if !current_slice.gate_indices.is_empty() {
159            let slice_id = slices.len();
160            current_slice.id = slice_id;
161            slices.push(current_slice);
162        }
163
164        // Calculate depths and schedule
165        self.calculate_depths_and_schedule(slices)
166    }
167
168    /// Slice circuit limiting gates per slice
169    fn slice_by_max_gates<const N: usize>(
170        &self,
171        circuit: &Circuit<N>,
172        max_gates: usize,
173    ) -> SlicingResult {
174        let mut slices = Vec::new();
175        let gates = circuit.gates();
176
177        // Simple slicing by gate count
178        for (chunk_idx, chunk) in gates.chunks(max_gates).enumerate() {
179            let mut slice = CircuitSlice {
180                id: chunk_idx,
181                gate_indices: Vec::new(),
182                qubits: HashSet::new(),
183                dependencies: HashSet::new(),
184                dependents: HashSet::new(),
185                depth: 0,
186            };
187
188            let base_idx = chunk_idx * max_gates;
189            for (local_idx, gate) in chunk.iter().enumerate() {
190                slice.gate_indices.push(base_idx + local_idx);
191                slice.qubits.extend(gate.qubits().iter().map(|q| q.id()));
192            }
193
194            slices.push(slice);
195        }
196
197        // Add dependencies based on qubit usage
198        self.add_qubit_dependencies(&mut slices, gates);
199
200        // Calculate depths and schedule
201        self.calculate_depths_and_schedule(slices)
202    }
203
204    /// Slice circuit by depth levels
205    fn slice_by_depth<const N: usize>(
206        &self,
207        circuit: &Circuit<N>,
208        max_depth: usize,
209    ) -> SlicingResult {
210        let dag = circuit_to_dag(circuit);
211        let mut slices = Vec::new();
212
213        // Group gates by depth levels
214        let max_circuit_depth = dag.max_depth();
215        for depth_start in (0..=max_circuit_depth).step_by(max_depth) {
216            let depth_end = (depth_start + max_depth).min(max_circuit_depth + 1);
217
218            let mut slice = CircuitSlice {
219                id: slices.len(),
220                gate_indices: Vec::new(),
221                qubits: HashSet::new(),
222                dependencies: HashSet::new(),
223                dependents: HashSet::new(),
224                depth: depth_start / max_depth,
225            };
226
227            // Collect all gates in this depth range
228            for depth in depth_start..depth_end {
229                for &node_id in &dag.nodes_at_depth(depth) {
230                    slice.gate_indices.push(node_id);
231                    let node = &dag.nodes()[node_id];
232                    slice
233                        .qubits
234                        .extend(node.gate.qubits().iter().map(|q| q.id()));
235                }
236            }
237
238            if !slice.gate_indices.is_empty() {
239                slices.push(slice);
240            }
241        }
242
243        // Dependencies are implicit in depth-based slicing
244        for i in 1..slices.len() {
245            slices[i].dependencies.insert(i - 1);
246            slices[i - 1].dependents.insert(i);
247        }
248
249        self.calculate_depths_and_schedule(slices)
250    }
251
252    /// Slice to minimize communication between slices
253    fn slice_min_communication<const N: usize>(&self, circuit: &Circuit<N>) -> SlicingResult {
254        // Use spectral clustering approach
255        let gates = circuit.gates();
256        let n_gates = gates.len();
257
258        // Build adjacency matrix based on qubit sharing
259        let mut adjacency = vec![vec![0.0; n_gates]; n_gates];
260
261        for i in 0..n_gates {
262            for j in i + 1..n_gates {
263                let qubits_i: HashSet<u32> = gates[i].qubits().iter().map(|q| q.id()).collect();
264                let qubits_j: HashSet<u32> = gates[j].qubits().iter().map(|q| q.id()).collect();
265
266                let shared_qubits = qubits_i.intersection(&qubits_j).count();
267                if shared_qubits > 0 {
268                    adjacency[i][j] = shared_qubits as f64;
269                    adjacency[j][i] = shared_qubits as f64;
270                }
271            }
272        }
273
274        // Simple clustering: greedy approach
275        let num_slices = (n_gates as f64).sqrt().ceil() as usize;
276        let mut slices = Vec::new();
277        let mut assigned = vec![false; n_gates];
278
279        // Create initial clusters
280        for slice_id in 0..num_slices {
281            let mut slice = CircuitSlice {
282                id: slice_id,
283                gate_indices: Vec::new(),
284                qubits: HashSet::new(),
285                dependencies: HashSet::new(),
286                dependents: HashSet::new(),
287                depth: 0,
288            };
289
290            // Find unassigned gate with highest connectivity to slice
291            for gate_idx in 0..n_gates {
292                if !assigned[gate_idx] {
293                    // Compute affinity to current slice
294                    let affinity = slice
295                        .gate_indices
296                        .iter()
297                        .map(|&idx| adjacency[gate_idx][idx])
298                        .sum::<f64>();
299
300                    // Add to slice if first gate or has affinity
301                    if slice.gate_indices.is_empty() || affinity > 0.0 {
302                        slice.gate_indices.push(gate_idx);
303                        slice
304                            .qubits
305                            .extend(gates[gate_idx].qubits().iter().map(|q| q.id()));
306                        assigned[gate_idx] = true;
307
308                        // Limit slice size
309                        if slice.gate_indices.len() >= n_gates / num_slices {
310                            break;
311                        }
312                    }
313                }
314            }
315
316            if !slice.gate_indices.is_empty() {
317                slices.push(slice);
318            }
319        }
320
321        // Assign remaining gates
322        for gate_idx in 0..n_gates {
323            if !assigned[gate_idx] {
324                // Add to slice with highest affinity
325                let mut best_slice = 0;
326                let mut best_affinity = 0.0;
327
328                for (slice_idx, slice) in slices.iter().enumerate() {
329                    let affinity = slice
330                        .gate_indices
331                        .iter()
332                        .map(|&idx| adjacency[gate_idx][idx])
333                        .sum::<f64>();
334
335                    if affinity > best_affinity {
336                        best_affinity = affinity;
337                        best_slice = slice_idx;
338                    }
339                }
340
341                slices[best_slice].gate_indices.push(gate_idx);
342                slices[best_slice]
343                    .qubits
344                    .extend(gates[gate_idx].qubits().iter().map(|q| q.id()));
345            }
346        }
347
348        // Add dependencies
349        self.add_qubit_dependencies(&mut slices, gates);
350
351        self.calculate_depths_and_schedule(slices)
352    }
353
354    /// Slice for load balancing across processors
355    fn slice_load_balanced<const N: usize>(
356        &self,
357        circuit: &Circuit<N>,
358        num_processors: usize,
359    ) -> SlicingResult {
360        let gates = circuit.gates();
361        let gates_per_processor = (gates.len() + num_processors - 1) / num_processors;
362
363        // Use max gates strategy with balanced load
364        self.slice_by_max_gates(circuit, gates_per_processor)
365    }
366
367    /// Slice based on qubit connectivity
368    fn slice_by_connectivity<const N: usize>(&self, circuit: &Circuit<N>) -> SlicingResult {
369        // Group gates by connected components of qubits
370        let gates = circuit.gates();
371        let mut slices: Vec<CircuitSlice> = Vec::new();
372        let mut gate_to_slice: HashMap<usize, usize> = HashMap::new();
373
374        for (gate_idx, gate) in gates.iter().enumerate() {
375            let gate_qubits: HashSet<u32> = gate.qubits().iter().map(|q| q.id()).collect();
376
377            // Find slices that share qubits with this gate
378            let mut connected_slices: Vec<usize> = Vec::new();
379            for (slice_idx, slice) in slices.iter().enumerate() {
380                if !slice.qubits.is_disjoint(&gate_qubits) {
381                    connected_slices.push(slice_idx);
382                }
383            }
384
385            if connected_slices.is_empty() {
386                // Create new slice
387                let slice_id = slices.len();
388                let slice = CircuitSlice {
389                    id: slice_id,
390                    gate_indices: vec![gate_idx],
391                    qubits: gate_qubits,
392                    dependencies: HashSet::new(),
393                    dependents: HashSet::new(),
394                    depth: 0,
395                };
396                slices.push(slice);
397                gate_to_slice.insert(gate_idx, slice_id);
398            } else if connected_slices.len() == 1 {
399                // Add to existing slice
400                let slice_idx = connected_slices[0];
401                slices[slice_idx].gate_indices.push(gate_idx);
402                slices[slice_idx].qubits.extend(gate_qubits);
403                gate_to_slice.insert(gate_idx, slice_idx);
404            } else {
405                // Merge slices
406                let main_slice = connected_slices[0];
407                slices[main_slice].gate_indices.push(gate_idx);
408                slices[main_slice].qubits.extend(gate_qubits);
409                gate_to_slice.insert(gate_idx, main_slice);
410
411                // Merge other slices into main
412                for &slice_idx in connected_slices[1..].iter().rev() {
413                    let slice = slices.remove(slice_idx);
414                    let gate_indices = slice.gate_indices.clone();
415                    slices[main_slice].gate_indices.extend(slice.gate_indices);
416                    slices[main_slice].qubits.extend(slice.qubits);
417
418                    // Update gate mappings
419                    for &g_idx in &gate_indices {
420                        gate_to_slice.insert(g_idx, main_slice);
421                    }
422                }
423            }
424        }
425
426        // Renumber slices
427        for (new_id, slice) in slices.iter_mut().enumerate() {
428            slice.id = new_id;
429        }
430
431        // Add dependencies based on gate order
432        self.add_order_dependencies(&mut slices, gates, &gate_to_slice);
433
434        self.calculate_depths_and_schedule(slices)
435    }
436
437    /// Add dependencies based on qubit usage
438    fn add_qubit_dependencies(
439        &self,
440        slices: &mut [CircuitSlice],
441        gates: &[Arc<dyn GateOp + Send + Sync>],
442    ) {
443        let mut qubit_last_slice: HashMap<u32, usize> = HashMap::new();
444
445        for slice in slices.iter_mut() {
446            for &gate_idx in &slice.gate_indices {
447                let gate_qubits = gates[gate_idx].qubits();
448
449                // Check dependencies
450                for qubit in gate_qubits {
451                    if let Some(&prev_slice) = qubit_last_slice.get(&qubit.id()) {
452                        if prev_slice != slice.id {
453                            slice.dependencies.insert(prev_slice);
454                        }
455                    }
456                }
457            }
458
459            // Update last slice for qubits
460            for &qubit in &slice.qubits {
461                qubit_last_slice.insert(qubit, slice.id);
462            }
463        }
464
465        // Add dependent relationships
466        for i in 0..slices.len() {
467            let deps: Vec<usize> = slices[i].dependencies.iter().cloned().collect();
468            for dep in deps {
469                slices[dep].dependents.insert(i);
470            }
471        }
472    }
473
474    /// Add dependencies based on gate ordering
475    fn add_order_dependencies(
476        &self,
477        slices: &mut [CircuitSlice],
478        gates: &[Arc<dyn GateOp + Send + Sync>],
479        gate_to_slice: &HashMap<usize, usize>,
480    ) {
481        for (gate_idx, gate) in gates.iter().enumerate() {
482            let slice_idx = gate_to_slice[&gate_idx];
483            let gate_qubits: HashSet<u32> = gate.qubits().iter().map(|q| q.id()).collect();
484
485            // Look for earlier gates on same qubits
486            for prev_idx in 0..gate_idx {
487                let prev_slice = gate_to_slice[&prev_idx];
488                if prev_slice != slice_idx {
489                    let prev_qubits: HashSet<u32> =
490                        gates[prev_idx].qubits().iter().map(|q| q.id()).collect();
491
492                    if !gate_qubits.is_disjoint(&prev_qubits) {
493                        slices[slice_idx].dependencies.insert(prev_slice);
494                        slices[prev_slice].dependents.insert(slice_idx);
495                    }
496                }
497            }
498        }
499    }
500
501    /// Calculate slice depths and parallel schedule
502    fn calculate_depths_and_schedule(&self, mut slices: Vec<CircuitSlice>) -> SlicingResult {
503        // Calculate depths using topological sort
504        let mut in_degree: HashMap<usize, usize> = HashMap::new();
505        for slice in &slices {
506            in_degree.insert(slice.id, slice.dependencies.len());
507        }
508
509        let mut queue = VecDeque::new();
510        let mut schedule = Vec::new();
511        let mut depths = HashMap::new();
512
513        // Initialize with slices having no dependencies
514        for slice in &slices {
515            if slice.dependencies.is_empty() {
516                queue.push_back(slice.id);
517                depths.insert(slice.id, 0);
518            }
519        }
520
521        // Process slices level by level
522        while !queue.is_empty() {
523            let mut current_level = Vec::new();
524            let level_size = queue.len();
525
526            for _ in 0..level_size {
527                let slice_id = queue.pop_front().unwrap();
528                current_level.push(slice_id);
529
530                // Update dependents
531                if let Some(slice) = slices.iter().find(|s| s.id == slice_id) {
532                    for &dep_id in &slice.dependents {
533                        *in_degree.get_mut(&dep_id).unwrap() -= 1;
534
535                        if in_degree[&dep_id] == 0 {
536                            queue.push_back(dep_id);
537                            depths.insert(dep_id, depths[&slice_id] + 1);
538                        }
539                    }
540                }
541            }
542
543            schedule.push(current_level);
544        }
545
546        // Update slice depths
547        for slice in &mut slices {
548            slice.depth = depths.get(&slice.id).copied().unwrap_or(0);
549        }
550
551        // Calculate communication cost
552        let communication_cost = self.calculate_communication_cost(&slices);
553
554        SlicingResult {
555            slices,
556            communication_cost,
557            parallel_depth: schedule.len(),
558            schedule,
559        }
560    }
561
562    /// Calculate total communication cost between slices
563    fn calculate_communication_cost(&self, slices: &[CircuitSlice]) -> usize {
564        let mut total_cost = 0;
565
566        for slice in slices {
567            for &dep_id in &slice.dependencies {
568                if let Some(dep_slice) = slices.iter().find(|s| s.id == dep_id) {
569                    // Count shared qubits
570                    let shared: HashSet<_> = slice.qubits.intersection(&dep_slice.qubits).collect();
571                    total_cost += shared.len();
572                }
573            }
574        }
575
576        total_cost
577    }
578}
579
580impl Default for CircuitSlicer {
581    fn default() -> Self {
582        Self::new()
583    }
584}
585
586/// Extension trait for circuit slicing
587impl<const N: usize> Circuit<N> {
588    /// Slice this circuit using the given strategy
589    pub fn slice(&self, strategy: SlicingStrategy) -> SlicingResult {
590        let slicer = CircuitSlicer::new();
591        slicer.slice_circuit(self, strategy)
592    }
593}
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598    use quantrs2_core::gate::multi::CNOT;
599    use quantrs2_core::gate::single::{Hadamard, PauliX};
600
601    #[test]
602    fn test_slice_by_max_qubits() {
603        let mut circuit = Circuit::<4>::new();
604
605        // Create a circuit that uses all 4 qubits
606        circuit.add_gate(Hadamard { target: QubitId(0) }).unwrap();
607        circuit.add_gate(Hadamard { target: QubitId(1) }).unwrap();
608        circuit.add_gate(Hadamard { target: QubitId(2) }).unwrap();
609        circuit.add_gate(Hadamard { target: QubitId(3) }).unwrap();
610        circuit
611            .add_gate(CNOT {
612                control: QubitId(0),
613                target: QubitId(1),
614            })
615            .unwrap();
616        circuit
617            .add_gate(CNOT {
618                control: QubitId(2),
619                target: QubitId(3),
620            })
621            .unwrap();
622
623        let slicer = CircuitSlicer::new();
624        let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxQubits(2));
625
626        // Should create multiple slices
627        assert!(result.slices.len() >= 2);
628
629        // Each slice should use at most 2 qubits
630        for slice in &result.slices {
631            assert!(slice.qubits.len() <= 2);
632        }
633    }
634
635    #[test]
636    fn test_slice_by_max_gates() {
637        let mut circuit = Circuit::<3>::new();
638
639        // Add 6 gates
640        for i in 0..6 {
641            circuit
642                .add_gate(Hadamard {
643                    target: QubitId((i % 3) as u32),
644                })
645                .unwrap();
646        }
647
648        let slicer = CircuitSlicer::new();
649        let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxGates(2));
650
651        // Should create 3 slices
652        assert_eq!(result.slices.len(), 3);
653
654        // Each slice should have at most 2 gates
655        for slice in &result.slices {
656            assert!(slice.gate_indices.len() <= 2);
657        }
658    }
659
660    #[test]
661    fn test_slice_dependencies() {
662        let mut circuit = Circuit::<2>::new();
663
664        // Create dependent gates
665        circuit.add_gate(Hadamard { target: QubitId(0) }).unwrap();
666        circuit.add_gate(Hadamard { target: QubitId(1) }).unwrap();
667        circuit
668            .add_gate(CNOT {
669                control: QubitId(0),
670                target: QubitId(1),
671            })
672            .unwrap();
673        circuit.add_gate(PauliX { target: QubitId(0) }).unwrap();
674
675        let slicer = CircuitSlicer::new();
676        let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxGates(2));
677
678        // Check dependencies exist
679        let mut has_dependencies = false;
680        for slice in &result.slices {
681            if !slice.dependencies.is_empty() {
682                has_dependencies = true;
683                break;
684            }
685        }
686        assert!(has_dependencies);
687    }
688
689    #[test]
690    fn test_parallel_schedule() {
691        let mut circuit = Circuit::<4>::new();
692
693        // Create gates that can be parallel
694        circuit.add_gate(Hadamard { target: QubitId(0) }).unwrap();
695        circuit.add_gate(Hadamard { target: QubitId(1) }).unwrap();
696        circuit.add_gate(Hadamard { target: QubitId(2) }).unwrap();
697        circuit.add_gate(Hadamard { target: QubitId(3) }).unwrap();
698
699        let slicer = CircuitSlicer::new();
700        let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxQubits(1));
701
702        // All H gates can be executed in parallel
703        assert_eq!(result.parallel_depth, 1);
704        assert_eq!(result.schedule[0].len(), 4);
705    }
706}