1use std::collections::{HashMap, HashSet, VecDeque};
8use std::sync::Arc;
9
10use quantrs2_core::{gate::GateOp, qubit::QubitId};
11
12use crate::builder::Circuit;
13use crate::commutation::CommutationAnalyzer;
14use crate::dag::{circuit_to_dag, CircuitDag};
15
16#[derive(Debug, Clone)]
18pub struct CircuitSlice {
19 pub id: usize,
21 pub gate_indices: Vec<usize>,
23 pub qubits: HashSet<u32>,
25 pub dependencies: HashSet<usize>,
27 pub dependents: HashSet<usize>,
29 pub depth: usize,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum SlicingStrategy {
36 MaxQubits(usize),
38 MaxGates(usize),
40 DepthBased(usize),
42 MinCommunication,
44 LoadBalanced(usize), ConnectivityBased,
48}
49
50#[derive(Debug)]
52pub struct SlicingResult {
53 pub slices: Vec<CircuitSlice>,
55 pub communication_cost: usize,
57 pub parallel_depth: usize,
59 pub schedule: Vec<Vec<usize>>, }
62
63pub struct CircuitSlicer {
65 commutation_analyzer: CommutationAnalyzer,
67}
68
69impl CircuitSlicer {
70 #[must_use]
72 pub fn new() -> Self {
73 Self {
74 commutation_analyzer: CommutationAnalyzer::new(),
75 }
76 }
77
78 #[must_use]
80 pub fn slice_circuit<const N: usize>(
81 &self,
82 circuit: &Circuit<N>,
83 strategy: SlicingStrategy,
84 ) -> SlicingResult {
85 match strategy {
86 SlicingStrategy::MaxQubits(max_qubits) => self.slice_by_max_qubits(circuit, max_qubits),
87 SlicingStrategy::MaxGates(max_gates) => self.slice_by_max_gates(circuit, max_gates),
88 SlicingStrategy::DepthBased(max_depth) => self.slice_by_depth(circuit, max_depth),
89 SlicingStrategy::MinCommunication => self.slice_min_communication(circuit),
90 SlicingStrategy::LoadBalanced(num_processors) => {
91 self.slice_load_balanced(circuit, num_processors)
92 }
93 SlicingStrategy::ConnectivityBased => self.slice_by_connectivity(circuit),
94 }
95 }
96
97 fn slice_by_max_qubits<const N: usize>(
99 &self,
100 circuit: &Circuit<N>,
101 max_qubits: usize,
102 ) -> SlicingResult {
103 let mut slices = Vec::new();
104 let mut current_slice = CircuitSlice {
105 id: 0,
106 gate_indices: Vec::new(),
107 qubits: HashSet::new(),
108 dependencies: HashSet::new(),
109 dependents: HashSet::new(),
110 depth: 0,
111 };
112
113 let mut qubit_last_slice: HashMap<u32, usize> = HashMap::new();
115
116 for (gate_idx, gate) in circuit.gates().iter().enumerate() {
117 let gate_qubits: HashSet<u32> = gate
118 .qubits()
119 .iter()
120 .map(quantrs2_core::QubitId::id)
121 .collect();
122
123 let combined_qubits: HashSet<u32> =
125 current_slice.qubits.union(&gate_qubits).copied().collect();
126
127 if !current_slice.gate_indices.is_empty() && combined_qubits.len() > max_qubits {
128 let slice_id = slices.len();
130 current_slice.id = slice_id;
131
132 for &qubit in ¤t_slice.qubits {
134 qubit_last_slice.insert(qubit, slice_id);
135 }
136
137 slices.push(current_slice);
138
139 current_slice = CircuitSlice {
141 id: slice_id + 1,
142 gate_indices: vec![gate_idx],
143 qubits: gate_qubits.clone(),
144 dependencies: HashSet::new(),
145 dependents: HashSet::new(),
146 depth: 0,
147 };
148
149 for &qubit in &gate_qubits {
151 if let Some(&prev_slice) = qubit_last_slice.get(&qubit) {
152 current_slice.dependencies.insert(prev_slice);
153 slices[prev_slice].dependents.insert(slice_id + 1);
154 }
155 }
156 } else {
157 current_slice.gate_indices.push(gate_idx);
159 current_slice.qubits.extend(gate_qubits);
160 }
161 }
162
163 if !current_slice.gate_indices.is_empty() {
165 let slice_id = slices.len();
166 current_slice.id = slice_id;
167 slices.push(current_slice);
168 }
169
170 self.calculate_depths_and_schedule(slices)
172 }
173
174 fn slice_by_max_gates<const N: usize>(
176 &self,
177 circuit: &Circuit<N>,
178 max_gates: usize,
179 ) -> SlicingResult {
180 let mut slices = Vec::new();
181 let gates = circuit.gates();
182
183 for (chunk_idx, chunk) in gates.chunks(max_gates).enumerate() {
185 let mut slice = CircuitSlice {
186 id: chunk_idx,
187 gate_indices: Vec::new(),
188 qubits: HashSet::new(),
189 dependencies: HashSet::new(),
190 dependents: HashSet::new(),
191 depth: 0,
192 };
193
194 let base_idx = chunk_idx * max_gates;
195 for (local_idx, gate) in chunk.iter().enumerate() {
196 slice.gate_indices.push(base_idx + local_idx);
197 slice
198 .qubits
199 .extend(gate.qubits().iter().map(quantrs2_core::QubitId::id));
200 }
201
202 slices.push(slice);
203 }
204
205 self.add_qubit_dependencies(&mut slices, gates);
207
208 self.calculate_depths_and_schedule(slices)
210 }
211
212 fn slice_by_depth<const N: usize>(
214 &self,
215 circuit: &Circuit<N>,
216 max_depth: usize,
217 ) -> SlicingResult {
218 let dag = circuit_to_dag(circuit);
219 let mut slices = Vec::new();
220
221 let max_circuit_depth = dag.max_depth();
223 for depth_start in (0..=max_circuit_depth).step_by(max_depth) {
224 let depth_end = (depth_start + max_depth).min(max_circuit_depth + 1);
225
226 let mut slice = CircuitSlice {
227 id: slices.len(),
228 gate_indices: Vec::new(),
229 qubits: HashSet::new(),
230 dependencies: HashSet::new(),
231 dependents: HashSet::new(),
232 depth: depth_start / max_depth,
233 };
234
235 for depth in depth_start..depth_end {
237 for &node_id in &dag.nodes_at_depth(depth) {
238 slice.gate_indices.push(node_id);
239 let node = &dag.nodes()[node_id];
240 slice
241 .qubits
242 .extend(node.gate.qubits().iter().map(quantrs2_core::QubitId::id));
243 }
244 }
245
246 if !slice.gate_indices.is_empty() {
247 slices.push(slice);
248 }
249 }
250
251 for i in 1..slices.len() {
253 slices[i].dependencies.insert(i - 1);
254 slices[i - 1].dependents.insert(i);
255 }
256
257 self.calculate_depths_and_schedule(slices)
258 }
259
260 fn slice_min_communication<const N: usize>(&self, circuit: &Circuit<N>) -> SlicingResult {
262 let gates = circuit.gates();
264 let n_gates = gates.len();
265
266 let mut adjacency = vec![vec![0.0; n_gates]; n_gates];
268
269 for i in 0..n_gates {
270 for j in i + 1..n_gates {
271 let qubits_i: HashSet<u32> = gates[i]
272 .qubits()
273 .iter()
274 .map(quantrs2_core::QubitId::id)
275 .collect();
276 let qubits_j: HashSet<u32> = gates[j]
277 .qubits()
278 .iter()
279 .map(quantrs2_core::QubitId::id)
280 .collect();
281
282 let shared_qubits = qubits_i.intersection(&qubits_j).count();
283 if shared_qubits > 0 {
284 adjacency[i][j] = shared_qubits as f64;
285 adjacency[j][i] = shared_qubits as f64;
286 }
287 }
288 }
289
290 let num_slices = (n_gates as f64).sqrt().ceil() as usize;
292 let mut slices = Vec::new();
293 let mut assigned = vec![false; n_gates];
294
295 for slice_id in 0..num_slices {
297 let mut slice = CircuitSlice {
298 id: slice_id,
299 gate_indices: Vec::new(),
300 qubits: HashSet::new(),
301 dependencies: HashSet::new(),
302 dependents: HashSet::new(),
303 depth: 0,
304 };
305
306 for gate_idx in 0..n_gates {
308 if !assigned[gate_idx] {
309 let affinity = slice
311 .gate_indices
312 .iter()
313 .map(|&idx| adjacency[gate_idx][idx])
314 .sum::<f64>();
315
316 if slice.gate_indices.is_empty() || affinity > 0.0 {
318 slice.gate_indices.push(gate_idx);
319 slice.qubits.extend(
320 gates[gate_idx]
321 .qubits()
322 .iter()
323 .map(quantrs2_core::QubitId::id),
324 );
325 assigned[gate_idx] = true;
326
327 if slice.gate_indices.len() >= n_gates / num_slices {
329 break;
330 }
331 }
332 }
333 }
334
335 if !slice.gate_indices.is_empty() {
336 slices.push(slice);
337 }
338 }
339
340 for gate_idx in 0..n_gates {
342 if !assigned[gate_idx] {
343 let mut best_slice = 0;
345 let mut best_affinity = 0.0;
346
347 for (slice_idx, slice) in slices.iter().enumerate() {
348 let affinity = slice
349 .gate_indices
350 .iter()
351 .map(|&idx| adjacency[gate_idx][idx])
352 .sum::<f64>();
353
354 if affinity > best_affinity {
355 best_affinity = affinity;
356 best_slice = slice_idx;
357 }
358 }
359
360 slices[best_slice].gate_indices.push(gate_idx);
361 slices[best_slice].qubits.extend(
362 gates[gate_idx]
363 .qubits()
364 .iter()
365 .map(quantrs2_core::QubitId::id),
366 );
367 }
368 }
369
370 self.add_qubit_dependencies(&mut slices, gates);
372
373 self.calculate_depths_and_schedule(slices)
374 }
375
376 fn slice_load_balanced<const N: usize>(
378 &self,
379 circuit: &Circuit<N>,
380 num_processors: usize,
381 ) -> SlicingResult {
382 let gates = circuit.gates();
383 let gates_per_processor = gates.len().div_ceil(num_processors);
384
385 self.slice_by_max_gates(circuit, gates_per_processor)
387 }
388
389 fn slice_by_connectivity<const N: usize>(&self, circuit: &Circuit<N>) -> SlicingResult {
391 let gates = circuit.gates();
393 let mut slices: Vec<CircuitSlice> = Vec::new();
394 let mut gate_to_slice: HashMap<usize, usize> = HashMap::new();
395
396 for (gate_idx, gate) in gates.iter().enumerate() {
397 let gate_qubits: HashSet<u32> = gate
398 .qubits()
399 .iter()
400 .map(quantrs2_core::QubitId::id)
401 .collect();
402
403 let mut connected_slices: Vec<usize> = Vec::new();
405 for (slice_idx, slice) in slices.iter().enumerate() {
406 if !slice.qubits.is_disjoint(&gate_qubits) {
407 connected_slices.push(slice_idx);
408 }
409 }
410
411 if connected_slices.is_empty() {
412 let slice_id = slices.len();
414 let slice = CircuitSlice {
415 id: slice_id,
416 gate_indices: vec![gate_idx],
417 qubits: gate_qubits,
418 dependencies: HashSet::new(),
419 dependents: HashSet::new(),
420 depth: 0,
421 };
422 slices.push(slice);
423 gate_to_slice.insert(gate_idx, slice_id);
424 } else if connected_slices.len() == 1 {
425 let slice_idx = connected_slices[0];
427 slices[slice_idx].gate_indices.push(gate_idx);
428 slices[slice_idx].qubits.extend(gate_qubits);
429 gate_to_slice.insert(gate_idx, slice_idx);
430 } else {
431 let main_slice = connected_slices[0];
433 slices[main_slice].gate_indices.push(gate_idx);
434 slices[main_slice].qubits.extend(gate_qubits);
435 gate_to_slice.insert(gate_idx, main_slice);
436
437 for &slice_idx in connected_slices[1..].iter().rev() {
439 let slice = slices.remove(slice_idx);
440 let gate_indices = slice.gate_indices.clone();
441 slices[main_slice].gate_indices.extend(slice.gate_indices);
442 slices[main_slice].qubits.extend(slice.qubits);
443
444 for &g_idx in &gate_indices {
446 gate_to_slice.insert(g_idx, main_slice);
447 }
448 }
449 }
450 }
451
452 for (new_id, slice) in slices.iter_mut().enumerate() {
454 slice.id = new_id;
455 }
456
457 self.add_order_dependencies(&mut slices, gates, &gate_to_slice);
459
460 self.calculate_depths_and_schedule(slices)
461 }
462
463 fn add_qubit_dependencies(
465 &self,
466 slices: &mut [CircuitSlice],
467 gates: &[Arc<dyn GateOp + Send + Sync>],
468 ) {
469 let mut qubit_last_slice: HashMap<u32, usize> = HashMap::new();
470
471 for slice in slices.iter_mut() {
472 for &gate_idx in &slice.gate_indices {
473 let gate_qubits = gates[gate_idx].qubits();
474
475 for qubit in gate_qubits {
477 if let Some(&prev_slice) = qubit_last_slice.get(&qubit.id()) {
478 if prev_slice != slice.id {
479 slice.dependencies.insert(prev_slice);
480 }
481 }
482 }
483 }
484
485 for &qubit in &slice.qubits {
487 qubit_last_slice.insert(qubit, slice.id);
488 }
489 }
490
491 for i in 0..slices.len() {
493 let deps: Vec<usize> = slices[i].dependencies.iter().copied().collect();
494 for dep in deps {
495 slices[dep].dependents.insert(i);
496 }
497 }
498 }
499
500 fn add_order_dependencies(
502 &self,
503 slices: &mut [CircuitSlice],
504 gates: &[Arc<dyn GateOp + Send + Sync>],
505 gate_to_slice: &HashMap<usize, usize>,
506 ) {
507 for (gate_idx, gate) in gates.iter().enumerate() {
508 let slice_idx = gate_to_slice[&gate_idx];
509 let gate_qubits: HashSet<u32> = gate
510 .qubits()
511 .iter()
512 .map(quantrs2_core::QubitId::id)
513 .collect();
514
515 for prev_idx in 0..gate_idx {
517 let prev_slice = gate_to_slice[&prev_idx];
518 if prev_slice != slice_idx {
519 let prev_qubits: HashSet<u32> = gates[prev_idx]
520 .qubits()
521 .iter()
522 .map(quantrs2_core::QubitId::id)
523 .collect();
524
525 if !gate_qubits.is_disjoint(&prev_qubits) {
526 slices[slice_idx].dependencies.insert(prev_slice);
527 slices[prev_slice].dependents.insert(slice_idx);
528 }
529 }
530 }
531 }
532 }
533
534 fn calculate_depths_and_schedule(&self, mut slices: Vec<CircuitSlice>) -> SlicingResult {
536 let mut in_degree: HashMap<usize, usize> = HashMap::new();
538 for slice in &slices {
539 in_degree.insert(slice.id, slice.dependencies.len());
540 }
541
542 let mut queue = VecDeque::new();
543 let mut schedule = Vec::new();
544 let mut depths = HashMap::new();
545
546 for slice in &slices {
548 if slice.dependencies.is_empty() {
549 queue.push_back(slice.id);
550 depths.insert(slice.id, 0);
551 }
552 }
553
554 while !queue.is_empty() {
556 let mut current_level = Vec::new();
557 let level_size = queue.len();
558
559 for _ in 0..level_size {
560 let slice_id = queue
561 .pop_front()
562 .expect("queue is not empty (checked in while condition)");
563 current_level.push(slice_id);
564
565 if let Some(slice) = slices.iter().find(|s| s.id == slice_id) {
567 for &dep_id in &slice.dependents {
568 if let Some(degree) = in_degree.get_mut(&dep_id) {
569 *degree -= 1;
570
571 if *degree == 0 {
572 queue.push_back(dep_id);
573 if let Some(¤t_depth) = depths.get(&slice_id) {
574 depths.insert(dep_id, current_depth + 1);
575 }
576 }
577 }
578 }
579 }
580 }
581
582 schedule.push(current_level);
583 }
584
585 for slice in &mut slices {
587 slice.depth = depths.get(&slice.id).copied().unwrap_or(0);
588 }
589
590 let communication_cost = self.calculate_communication_cost(&slices);
592
593 SlicingResult {
594 slices,
595 communication_cost,
596 parallel_depth: schedule.len(),
597 schedule,
598 }
599 }
600
601 fn calculate_communication_cost(&self, slices: &[CircuitSlice]) -> usize {
603 let mut total_cost = 0;
604
605 for slice in slices {
606 for &dep_id in &slice.dependencies {
607 if let Some(dep_slice) = slices.iter().find(|s| s.id == dep_id) {
608 let shared: HashSet<_> = slice.qubits.intersection(&dep_slice.qubits).collect();
610 total_cost += shared.len();
611 }
612 }
613 }
614
615 total_cost
616 }
617}
618
619impl Default for CircuitSlicer {
620 fn default() -> Self {
621 Self::new()
622 }
623}
624
625impl<const N: usize> Circuit<N> {
627 #[must_use]
629 pub fn slice(&self, strategy: SlicingStrategy) -> SlicingResult {
630 let slicer = CircuitSlicer::new();
631 slicer.slice_circuit(self, strategy)
632 }
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use quantrs2_core::gate::multi::CNOT;
639 use quantrs2_core::gate::single::{Hadamard, PauliX};
640
641 #[test]
642 fn test_slice_by_max_qubits() {
643 let mut circuit = Circuit::<4>::new();
644
645 circuit
647 .add_gate(Hadamard { target: QubitId(0) })
648 .expect("failed to add H gate to qubit 0");
649 circuit
650 .add_gate(Hadamard { target: QubitId(1) })
651 .expect("failed to add H gate to qubit 1");
652 circuit
653 .add_gate(Hadamard { target: QubitId(2) })
654 .expect("failed to add H gate to qubit 2");
655 circuit
656 .add_gate(Hadamard { target: QubitId(3) })
657 .expect("failed to add H gate to qubit 3");
658 circuit
659 .add_gate(CNOT {
660 control: QubitId(0),
661 target: QubitId(1),
662 })
663 .expect("failed to add CNOT gate on qubits 0,1");
664 circuit
665 .add_gate(CNOT {
666 control: QubitId(2),
667 target: QubitId(3),
668 })
669 .expect("failed to add CNOT gate on qubits 2,3");
670
671 let slicer = CircuitSlicer::new();
672 let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxQubits(2));
673
674 assert!(result.slices.len() >= 2);
676
677 for slice in &result.slices {
679 assert!(slice.qubits.len() <= 2);
680 }
681 }
682
683 #[test]
684 fn test_slice_by_max_gates() {
685 let mut circuit = Circuit::<3>::new();
686
687 for i in 0..6 {
689 circuit
690 .add_gate(Hadamard {
691 target: QubitId((i % 3) as u32),
692 })
693 .expect("failed to add Hadamard gate in loop");
694 }
695
696 let slicer = CircuitSlicer::new();
697 let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxGates(2));
698
699 assert_eq!(result.slices.len(), 3);
701
702 for slice in &result.slices {
704 assert!(slice.gate_indices.len() <= 2);
705 }
706 }
707
708 #[test]
709 fn test_slice_dependencies() {
710 let mut circuit = Circuit::<2>::new();
711
712 circuit
714 .add_gate(Hadamard { target: QubitId(0) })
715 .expect("failed to add H gate to qubit 0");
716 circuit
717 .add_gate(Hadamard { target: QubitId(1) })
718 .expect("failed to add H gate to qubit 1");
719 circuit
720 .add_gate(CNOT {
721 control: QubitId(0),
722 target: QubitId(1),
723 })
724 .expect("failed to add CNOT gate on qubits 0,1");
725 circuit
726 .add_gate(PauliX { target: QubitId(0) })
727 .expect("failed to add X gate to qubit 0");
728
729 let slicer = CircuitSlicer::new();
730 let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxGates(2));
731
732 let mut has_dependencies = false;
734 for slice in &result.slices {
735 if !slice.dependencies.is_empty() {
736 has_dependencies = true;
737 break;
738 }
739 }
740 assert!(has_dependencies);
741 }
742
743 #[test]
744 fn test_parallel_schedule() {
745 let mut circuit = Circuit::<4>::new();
746
747 circuit
749 .add_gate(Hadamard { target: QubitId(0) })
750 .expect("failed to add H gate to qubit 0");
751 circuit
752 .add_gate(Hadamard { target: QubitId(1) })
753 .expect("failed to add H gate to qubit 1");
754 circuit
755 .add_gate(Hadamard { target: QubitId(2) })
756 .expect("failed to add H gate to qubit 2");
757 circuit
758 .add_gate(Hadamard { target: QubitId(3) })
759 .expect("failed to add H gate to qubit 3");
760
761 let slicer = CircuitSlicer::new();
762 let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxQubits(1));
763
764 assert_eq!(result.parallel_depth, 1);
766 assert_eq!(result.schedule[0].len(), 4);
767 }
768}