1use std::collections::{HashMap, HashSet, VecDeque};
8use std::sync::Arc;
9
10use quantrs2_core::{gate::GateOp, qubit::QubitId};
11
12use crate::builder::Circuit;
13use crate::commutation::CommutationAnalyzer;
14use crate::dag::{circuit_to_dag, CircuitDag};
15
16#[derive(Debug, Clone)]
18pub struct CircuitSlice {
19 pub id: usize,
21 pub gate_indices: Vec<usize>,
23 pub qubits: HashSet<u32>,
25 pub dependencies: HashSet<usize>,
27 pub dependents: HashSet<usize>,
29 pub depth: usize,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq)]
35pub enum SlicingStrategy {
36 MaxQubits(usize),
38 MaxGates(usize),
40 DepthBased(usize),
42 MinCommunication,
44 LoadBalanced(usize), ConnectivityBased,
48}
49
50#[derive(Debug)]
52pub struct SlicingResult {
53 pub slices: Vec<CircuitSlice>,
55 pub communication_cost: usize,
57 pub parallel_depth: usize,
59 pub schedule: Vec<Vec<usize>>, }
62
63pub struct CircuitSlicer {
65 commutation_analyzer: CommutationAnalyzer,
67}
68
69impl CircuitSlicer {
70 pub fn new() -> Self {
72 Self {
73 commutation_analyzer: CommutationAnalyzer::new(),
74 }
75 }
76
77 pub fn slice_circuit<const N: usize>(
79 &self,
80 circuit: &Circuit<N>,
81 strategy: SlicingStrategy,
82 ) -> SlicingResult {
83 match strategy {
84 SlicingStrategy::MaxQubits(max_qubits) => self.slice_by_max_qubits(circuit, max_qubits),
85 SlicingStrategy::MaxGates(max_gates) => self.slice_by_max_gates(circuit, max_gates),
86 SlicingStrategy::DepthBased(max_depth) => self.slice_by_depth(circuit, max_depth),
87 SlicingStrategy::MinCommunication => self.slice_min_communication(circuit),
88 SlicingStrategy::LoadBalanced(num_processors) => {
89 self.slice_load_balanced(circuit, num_processors)
90 }
91 SlicingStrategy::ConnectivityBased => self.slice_by_connectivity(circuit),
92 }
93 }
94
95 fn slice_by_max_qubits<const N: usize>(
97 &self,
98 circuit: &Circuit<N>,
99 max_qubits: usize,
100 ) -> SlicingResult {
101 let mut slices = Vec::new();
102 let mut current_slice = CircuitSlice {
103 id: 0,
104 gate_indices: Vec::new(),
105 qubits: HashSet::new(),
106 dependencies: HashSet::new(),
107 dependents: HashSet::new(),
108 depth: 0,
109 };
110
111 let mut qubit_last_slice: HashMap<u32, usize> = HashMap::new();
113
114 for (gate_idx, gate) in circuit.gates().iter().enumerate() {
115 let gate_qubits: HashSet<u32> = gate.qubits().iter().map(|q| q.id()).collect();
116
117 let combined_qubits: HashSet<u32> =
119 current_slice.qubits.union(&gate_qubits).cloned().collect();
120
121 if !current_slice.gate_indices.is_empty() && combined_qubits.len() > max_qubits {
122 let slice_id = slices.len();
124 current_slice.id = slice_id;
125
126 for &qubit in ¤t_slice.qubits {
128 qubit_last_slice.insert(qubit, slice_id);
129 }
130
131 slices.push(current_slice);
132
133 current_slice = CircuitSlice {
135 id: slice_id + 1,
136 gate_indices: vec![gate_idx],
137 qubits: gate_qubits.clone(),
138 dependencies: HashSet::new(),
139 dependents: HashSet::new(),
140 depth: 0,
141 };
142
143 for &qubit in &gate_qubits {
145 if let Some(&prev_slice) = qubit_last_slice.get(&qubit) {
146 current_slice.dependencies.insert(prev_slice);
147 slices[prev_slice].dependents.insert(slice_id + 1);
148 }
149 }
150 } else {
151 current_slice.gate_indices.push(gate_idx);
153 current_slice.qubits.extend(gate_qubits);
154 }
155 }
156
157 if !current_slice.gate_indices.is_empty() {
159 let slice_id = slices.len();
160 current_slice.id = slice_id;
161 slices.push(current_slice);
162 }
163
164 self.calculate_depths_and_schedule(slices)
166 }
167
168 fn slice_by_max_gates<const N: usize>(
170 &self,
171 circuit: &Circuit<N>,
172 max_gates: usize,
173 ) -> SlicingResult {
174 let mut slices = Vec::new();
175 let gates = circuit.gates();
176
177 for (chunk_idx, chunk) in gates.chunks(max_gates).enumerate() {
179 let mut slice = CircuitSlice {
180 id: chunk_idx,
181 gate_indices: Vec::new(),
182 qubits: HashSet::new(),
183 dependencies: HashSet::new(),
184 dependents: HashSet::new(),
185 depth: 0,
186 };
187
188 let base_idx = chunk_idx * max_gates;
189 for (local_idx, gate) in chunk.iter().enumerate() {
190 slice.gate_indices.push(base_idx + local_idx);
191 slice.qubits.extend(gate.qubits().iter().map(|q| q.id()));
192 }
193
194 slices.push(slice);
195 }
196
197 self.add_qubit_dependencies(&mut slices, gates);
199
200 self.calculate_depths_and_schedule(slices)
202 }
203
204 fn slice_by_depth<const N: usize>(
206 &self,
207 circuit: &Circuit<N>,
208 max_depth: usize,
209 ) -> SlicingResult {
210 let dag = circuit_to_dag(circuit);
211 let mut slices = Vec::new();
212
213 let max_circuit_depth = dag.max_depth();
215 for depth_start in (0..=max_circuit_depth).step_by(max_depth) {
216 let depth_end = (depth_start + max_depth).min(max_circuit_depth + 1);
217
218 let mut slice = CircuitSlice {
219 id: slices.len(),
220 gate_indices: Vec::new(),
221 qubits: HashSet::new(),
222 dependencies: HashSet::new(),
223 dependents: HashSet::new(),
224 depth: depth_start / max_depth,
225 };
226
227 for depth in depth_start..depth_end {
229 for &node_id in &dag.nodes_at_depth(depth) {
230 slice.gate_indices.push(node_id);
231 let node = &dag.nodes()[node_id];
232 slice
233 .qubits
234 .extend(node.gate.qubits().iter().map(|q| q.id()));
235 }
236 }
237
238 if !slice.gate_indices.is_empty() {
239 slices.push(slice);
240 }
241 }
242
243 for i in 1..slices.len() {
245 slices[i].dependencies.insert(i - 1);
246 slices[i - 1].dependents.insert(i);
247 }
248
249 self.calculate_depths_and_schedule(slices)
250 }
251
252 fn slice_min_communication<const N: usize>(&self, circuit: &Circuit<N>) -> SlicingResult {
254 let gates = circuit.gates();
256 let n_gates = gates.len();
257
258 let mut adjacency = vec![vec![0.0; n_gates]; n_gates];
260
261 for i in 0..n_gates {
262 for j in i + 1..n_gates {
263 let qubits_i: HashSet<u32> = gates[i].qubits().iter().map(|q| q.id()).collect();
264 let qubits_j: HashSet<u32> = gates[j].qubits().iter().map(|q| q.id()).collect();
265
266 let shared_qubits = qubits_i.intersection(&qubits_j).count();
267 if shared_qubits > 0 {
268 adjacency[i][j] = shared_qubits as f64;
269 adjacency[j][i] = shared_qubits as f64;
270 }
271 }
272 }
273
274 let num_slices = (n_gates as f64).sqrt().ceil() as usize;
276 let mut slices = Vec::new();
277 let mut assigned = vec![false; n_gates];
278
279 for slice_id in 0..num_slices {
281 let mut slice = CircuitSlice {
282 id: slice_id,
283 gate_indices: Vec::new(),
284 qubits: HashSet::new(),
285 dependencies: HashSet::new(),
286 dependents: HashSet::new(),
287 depth: 0,
288 };
289
290 for gate_idx in 0..n_gates {
292 if !assigned[gate_idx] {
293 let affinity = slice
295 .gate_indices
296 .iter()
297 .map(|&idx| adjacency[gate_idx][idx])
298 .sum::<f64>();
299
300 if slice.gate_indices.is_empty() || affinity > 0.0 {
302 slice.gate_indices.push(gate_idx);
303 slice
304 .qubits
305 .extend(gates[gate_idx].qubits().iter().map(|q| q.id()));
306 assigned[gate_idx] = true;
307
308 if slice.gate_indices.len() >= n_gates / num_slices {
310 break;
311 }
312 }
313 }
314 }
315
316 if !slice.gate_indices.is_empty() {
317 slices.push(slice);
318 }
319 }
320
321 for gate_idx in 0..n_gates {
323 if !assigned[gate_idx] {
324 let mut best_slice = 0;
326 let mut best_affinity = 0.0;
327
328 for (slice_idx, slice) in slices.iter().enumerate() {
329 let affinity = slice
330 .gate_indices
331 .iter()
332 .map(|&idx| adjacency[gate_idx][idx])
333 .sum::<f64>();
334
335 if affinity > best_affinity {
336 best_affinity = affinity;
337 best_slice = slice_idx;
338 }
339 }
340
341 slices[best_slice].gate_indices.push(gate_idx);
342 slices[best_slice]
343 .qubits
344 .extend(gates[gate_idx].qubits().iter().map(|q| q.id()));
345 }
346 }
347
348 self.add_qubit_dependencies(&mut slices, gates);
350
351 self.calculate_depths_and_schedule(slices)
352 }
353
354 fn slice_load_balanced<const N: usize>(
356 &self,
357 circuit: &Circuit<N>,
358 num_processors: usize,
359 ) -> SlicingResult {
360 let gates = circuit.gates();
361 let gates_per_processor = (gates.len() + num_processors - 1) / num_processors;
362
363 self.slice_by_max_gates(circuit, gates_per_processor)
365 }
366
367 fn slice_by_connectivity<const N: usize>(&self, circuit: &Circuit<N>) -> SlicingResult {
369 let gates = circuit.gates();
371 let mut slices: Vec<CircuitSlice> = Vec::new();
372 let mut gate_to_slice: HashMap<usize, usize> = HashMap::new();
373
374 for (gate_idx, gate) in gates.iter().enumerate() {
375 let gate_qubits: HashSet<u32> = gate.qubits().iter().map(|q| q.id()).collect();
376
377 let mut connected_slices: Vec<usize> = Vec::new();
379 for (slice_idx, slice) in slices.iter().enumerate() {
380 if !slice.qubits.is_disjoint(&gate_qubits) {
381 connected_slices.push(slice_idx);
382 }
383 }
384
385 if connected_slices.is_empty() {
386 let slice_id = slices.len();
388 let slice = CircuitSlice {
389 id: slice_id,
390 gate_indices: vec![gate_idx],
391 qubits: gate_qubits,
392 dependencies: HashSet::new(),
393 dependents: HashSet::new(),
394 depth: 0,
395 };
396 slices.push(slice);
397 gate_to_slice.insert(gate_idx, slice_id);
398 } else if connected_slices.len() == 1 {
399 let slice_idx = connected_slices[0];
401 slices[slice_idx].gate_indices.push(gate_idx);
402 slices[slice_idx].qubits.extend(gate_qubits);
403 gate_to_slice.insert(gate_idx, slice_idx);
404 } else {
405 let main_slice = connected_slices[0];
407 slices[main_slice].gate_indices.push(gate_idx);
408 slices[main_slice].qubits.extend(gate_qubits);
409 gate_to_slice.insert(gate_idx, main_slice);
410
411 for &slice_idx in connected_slices[1..].iter().rev() {
413 let slice = slices.remove(slice_idx);
414 let gate_indices = slice.gate_indices.clone();
415 slices[main_slice].gate_indices.extend(slice.gate_indices);
416 slices[main_slice].qubits.extend(slice.qubits);
417
418 for &g_idx in &gate_indices {
420 gate_to_slice.insert(g_idx, main_slice);
421 }
422 }
423 }
424 }
425
426 for (new_id, slice) in slices.iter_mut().enumerate() {
428 slice.id = new_id;
429 }
430
431 self.add_order_dependencies(&mut slices, gates, &gate_to_slice);
433
434 self.calculate_depths_and_schedule(slices)
435 }
436
437 fn add_qubit_dependencies(
439 &self,
440 slices: &mut [CircuitSlice],
441 gates: &[Arc<dyn GateOp + Send + Sync>],
442 ) {
443 let mut qubit_last_slice: HashMap<u32, usize> = HashMap::new();
444
445 for slice in slices.iter_mut() {
446 for &gate_idx in &slice.gate_indices {
447 let gate_qubits = gates[gate_idx].qubits();
448
449 for qubit in gate_qubits {
451 if let Some(&prev_slice) = qubit_last_slice.get(&qubit.id()) {
452 if prev_slice != slice.id {
453 slice.dependencies.insert(prev_slice);
454 }
455 }
456 }
457 }
458
459 for &qubit in &slice.qubits {
461 qubit_last_slice.insert(qubit, slice.id);
462 }
463 }
464
465 for i in 0..slices.len() {
467 let deps: Vec<usize> = slices[i].dependencies.iter().cloned().collect();
468 for dep in deps {
469 slices[dep].dependents.insert(i);
470 }
471 }
472 }
473
474 fn add_order_dependencies(
476 &self,
477 slices: &mut [CircuitSlice],
478 gates: &[Arc<dyn GateOp + Send + Sync>],
479 gate_to_slice: &HashMap<usize, usize>,
480 ) {
481 for (gate_idx, gate) in gates.iter().enumerate() {
482 let slice_idx = gate_to_slice[&gate_idx];
483 let gate_qubits: HashSet<u32> = gate.qubits().iter().map(|q| q.id()).collect();
484
485 for prev_idx in 0..gate_idx {
487 let prev_slice = gate_to_slice[&prev_idx];
488 if prev_slice != slice_idx {
489 let prev_qubits: HashSet<u32> =
490 gates[prev_idx].qubits().iter().map(|q| q.id()).collect();
491
492 if !gate_qubits.is_disjoint(&prev_qubits) {
493 slices[slice_idx].dependencies.insert(prev_slice);
494 slices[prev_slice].dependents.insert(slice_idx);
495 }
496 }
497 }
498 }
499 }
500
501 fn calculate_depths_and_schedule(&self, mut slices: Vec<CircuitSlice>) -> SlicingResult {
503 let mut in_degree: HashMap<usize, usize> = HashMap::new();
505 for slice in &slices {
506 in_degree.insert(slice.id, slice.dependencies.len());
507 }
508
509 let mut queue = VecDeque::new();
510 let mut schedule = Vec::new();
511 let mut depths = HashMap::new();
512
513 for slice in &slices {
515 if slice.dependencies.is_empty() {
516 queue.push_back(slice.id);
517 depths.insert(slice.id, 0);
518 }
519 }
520
521 while !queue.is_empty() {
523 let mut current_level = Vec::new();
524 let level_size = queue.len();
525
526 for _ in 0..level_size {
527 let slice_id = queue.pop_front().unwrap();
528 current_level.push(slice_id);
529
530 if let Some(slice) = slices.iter().find(|s| s.id == slice_id) {
532 for &dep_id in &slice.dependents {
533 *in_degree.get_mut(&dep_id).unwrap() -= 1;
534
535 if in_degree[&dep_id] == 0 {
536 queue.push_back(dep_id);
537 depths.insert(dep_id, depths[&slice_id] + 1);
538 }
539 }
540 }
541 }
542
543 schedule.push(current_level);
544 }
545
546 for slice in &mut slices {
548 slice.depth = depths.get(&slice.id).copied().unwrap_or(0);
549 }
550
551 let communication_cost = self.calculate_communication_cost(&slices);
553
554 SlicingResult {
555 slices,
556 communication_cost,
557 parallel_depth: schedule.len(),
558 schedule,
559 }
560 }
561
562 fn calculate_communication_cost(&self, slices: &[CircuitSlice]) -> usize {
564 let mut total_cost = 0;
565
566 for slice in slices {
567 for &dep_id in &slice.dependencies {
568 if let Some(dep_slice) = slices.iter().find(|s| s.id == dep_id) {
569 let shared: HashSet<_> = slice.qubits.intersection(&dep_slice.qubits).collect();
571 total_cost += shared.len();
572 }
573 }
574 }
575
576 total_cost
577 }
578}
579
580impl Default for CircuitSlicer {
581 fn default() -> Self {
582 Self::new()
583 }
584}
585
586impl<const N: usize> Circuit<N> {
588 pub fn slice(&self, strategy: SlicingStrategy) -> SlicingResult {
590 let slicer = CircuitSlicer::new();
591 slicer.slice_circuit(self, strategy)
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598 use quantrs2_core::gate::multi::CNOT;
599 use quantrs2_core::gate::single::{Hadamard, PauliX};
600
601 #[test]
602 fn test_slice_by_max_qubits() {
603 let mut circuit = Circuit::<4>::new();
604
605 circuit.add_gate(Hadamard { target: QubitId(0) }).unwrap();
607 circuit.add_gate(Hadamard { target: QubitId(1) }).unwrap();
608 circuit.add_gate(Hadamard { target: QubitId(2) }).unwrap();
609 circuit.add_gate(Hadamard { target: QubitId(3) }).unwrap();
610 circuit
611 .add_gate(CNOT {
612 control: QubitId(0),
613 target: QubitId(1),
614 })
615 .unwrap();
616 circuit
617 .add_gate(CNOT {
618 control: QubitId(2),
619 target: QubitId(3),
620 })
621 .unwrap();
622
623 let slicer = CircuitSlicer::new();
624 let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxQubits(2));
625
626 assert!(result.slices.len() >= 2);
628
629 for slice in &result.slices {
631 assert!(slice.qubits.len() <= 2);
632 }
633 }
634
635 #[test]
636 fn test_slice_by_max_gates() {
637 let mut circuit = Circuit::<3>::new();
638
639 for i in 0..6 {
641 circuit
642 .add_gate(Hadamard {
643 target: QubitId((i % 3) as u32),
644 })
645 .unwrap();
646 }
647
648 let slicer = CircuitSlicer::new();
649 let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxGates(2));
650
651 assert_eq!(result.slices.len(), 3);
653
654 for slice in &result.slices {
656 assert!(slice.gate_indices.len() <= 2);
657 }
658 }
659
660 #[test]
661 fn test_slice_dependencies() {
662 let mut circuit = Circuit::<2>::new();
663
664 circuit.add_gate(Hadamard { target: QubitId(0) }).unwrap();
666 circuit.add_gate(Hadamard { target: QubitId(1) }).unwrap();
667 circuit
668 .add_gate(CNOT {
669 control: QubitId(0),
670 target: QubitId(1),
671 })
672 .unwrap();
673 circuit.add_gate(PauliX { target: QubitId(0) }).unwrap();
674
675 let slicer = CircuitSlicer::new();
676 let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxGates(2));
677
678 let mut has_dependencies = false;
680 for slice in &result.slices {
681 if !slice.dependencies.is_empty() {
682 has_dependencies = true;
683 break;
684 }
685 }
686 assert!(has_dependencies);
687 }
688
689 #[test]
690 fn test_parallel_schedule() {
691 let mut circuit = Circuit::<4>::new();
692
693 circuit.add_gate(Hadamard { target: QubitId(0) }).unwrap();
695 circuit.add_gate(Hadamard { target: QubitId(1) }).unwrap();
696 circuit.add_gate(Hadamard { target: QubitId(2) }).unwrap();
697 circuit.add_gate(Hadamard { target: QubitId(3) }).unwrap();
698
699 let slicer = CircuitSlicer::new();
700 let result = slicer.slice_circuit(&circuit, SlicingStrategy::MaxQubits(1));
701
702 assert_eq!(result.parallel_depth, 1);
704 assert_eq!(result.schedule[0].len(), 4);
705 }
706}