1use crate::builder::Circuit;
8use crate::dag::{circuit_to_dag, CircuitDag, DagNode};
9use crate::routing::{CouplingMap, RoutedCircuit, RoutingResult};
10use quantrs2_core::{
11 error::{QuantRS2Error, QuantRS2Result},
12 gate::{
13 multi::SWAP,
14 single::{RotationX, RotationY, RotationZ},
15 GateOp,
16 },
17 qubit::QubitId,
18};
19use scirs2_core::random::{seq::SliceRandom, thread_rng, Rng};
20use std::collections::{HashMap, HashSet, VecDeque};
21
22#[derive(Debug, Clone)]
24pub struct SabreConfig {
25 pub max_iterations: usize,
27 pub lookahead_depth: usize,
29 pub decay_factor: f64,
31 pub extended_set_weight: f64,
33 pub max_swaps_per_iteration: usize,
35 pub stochastic: bool,
37}
38
39impl Default for SabreConfig {
40 fn default() -> Self {
41 Self {
42 max_iterations: 1000,
43 lookahead_depth: 20,
44 decay_factor: 0.001,
45 extended_set_weight: 0.5,
46 max_swaps_per_iteration: 10,
47 stochastic: false,
48 }
49 }
50}
51
52impl SabreConfig {
53 #[must_use]
55 pub const fn basic() -> Self {
56 Self {
57 max_iterations: 100,
58 lookahead_depth: 5,
59 decay_factor: 0.01,
60 extended_set_weight: 0.3,
61 max_swaps_per_iteration: 5,
62 stochastic: false,
63 }
64 }
65
66 #[must_use]
68 pub fn stochastic() -> Self {
69 Self {
70 stochastic: true,
71 ..Default::default()
72 }
73 }
74}
75
76pub struct SabreRouter {
78 coupling_map: CouplingMap,
79 config: SabreConfig,
80}
81
82impl SabreRouter {
83 #[must_use]
85 pub const fn new(coupling_map: CouplingMap, config: SabreConfig) -> Self {
86 Self {
87 coupling_map,
88 config,
89 }
90 }
91
92 pub fn route<const N: usize>(&self, circuit: &Circuit<N>) -> QuantRS2Result<RoutedCircuit<N>> {
94 let dag = circuit_to_dag(circuit);
95 let mut logical_to_physical = self.initial_mapping(&dag);
96 let mut physical_to_logical: HashMap<usize, usize> = logical_to_physical
97 .iter()
98 .map(|(&logical, &physical)| (physical, logical))
99 .collect();
100
101 let mut routed_gates = Vec::new();
102 let mut executable = self.find_executable_gates(&dag, &logical_to_physical);
103 let mut remaining_gates: HashSet<usize> = (0..dag.nodes().len()).collect();
104 let mut iteration = 0;
105
106 while !remaining_gates.is_empty() && iteration < self.config.max_iterations {
107 iteration += 1;
108
109 while let Some(gate_id) = executable.pop() {
111 if remaining_gates.contains(&gate_id) {
112 let node = &dag.nodes()[gate_id];
113 let routed_gate = self.map_gate_to_physical(node, &logical_to_physical)?;
114 routed_gates.push(routed_gate);
115 remaining_gates.remove(&gate_id);
116
117 for &succ in &node.successors {
119 if remaining_gates.contains(&succ)
120 && self.is_gate_executable(&dag.nodes()[succ], &logical_to_physical)
121 {
122 executable.push(succ);
123 }
124 }
125 }
126 }
127
128 if !remaining_gates.is_empty() {
130 let swaps = self.find_best_swaps(&dag, &remaining_gates, &logical_to_physical)?;
131
132 if swaps.is_empty() {
133 return Err(QuantRS2Error::RoutingError(
134 "Cannot find valid SWAP operations".to_string(),
135 ));
136 }
137
138 for (p1, p2) in swaps {
140 let swap_gate = Box::new(SWAP {
142 qubit1: QubitId::new(p1 as u32),
143 qubit2: QubitId::new(p2 as u32),
144 }) as Box<dyn GateOp>;
145 routed_gates.push(swap_gate);
146
147 let l1 = physical_to_logical[&p1];
149 let l2 = physical_to_logical[&p2];
150
151 logical_to_physical.insert(l1, p2);
152 logical_to_physical.insert(l2, p1);
153 physical_to_logical.insert(p1, l2);
154 physical_to_logical.insert(p2, l1);
155 }
156
157 executable = self.find_executable_gates_from_remaining(
159 &dag,
160 &remaining_gates,
161 &logical_to_physical,
162 );
163 }
164 }
165
166 if !remaining_gates.is_empty() {
167 return Err(QuantRS2Error::RoutingError(format!(
168 "Routing failed: {} gates remaining after {} iterations",
169 remaining_gates.len(),
170 iteration
171 )));
172 }
173
174 let total_swaps = routed_gates.iter().filter(|g| g.name() == "SWAP").count();
175 let circuit_depth = self.calculate_depth(&routed_gates);
176
177 Ok(RoutedCircuit::new(
178 routed_gates,
179 logical_to_physical,
180 RoutingResult {
181 total_swaps,
182 circuit_depth,
183 routing_overhead: if circuit_depth > 0 {
184 total_swaps as f64 / circuit_depth as f64
185 } else {
186 0.0
187 },
188 },
189 ))
190 }
191
192 fn initial_mapping(&self, dag: &CircuitDag) -> HashMap<usize, usize> {
194 let mut mapping = HashMap::new();
195 let logical_qubits = self.extract_logical_qubits(dag);
196
197 for (i, &logical) in logical_qubits.iter().enumerate() {
199 if i < self.coupling_map.num_qubits() {
200 mapping.insert(logical, i);
201 }
202 }
203
204 mapping
205 }
206
207 fn extract_logical_qubits(&self, dag: &CircuitDag) -> Vec<usize> {
209 let mut qubits = HashSet::new();
210
211 for node in dag.nodes() {
212 for qubit in node.gate.qubits() {
213 qubits.insert(qubit.id() as usize);
214 }
215 }
216
217 let mut qubit_vec: Vec<usize> = qubits.into_iter().collect();
218 qubit_vec.sort_unstable();
219 qubit_vec
220 }
221
222 fn find_executable_gates(
224 &self,
225 dag: &CircuitDag,
226 mapping: &HashMap<usize, usize>,
227 ) -> Vec<usize> {
228 let mut executable = Vec::new();
229
230 for node in dag.nodes() {
231 if node.predecessors.is_empty() && self.is_gate_executable(node, mapping) {
232 executable.push(node.id);
233 }
234 }
235
236 executable
237 }
238
239 fn find_executable_gates_from_remaining(
241 &self,
242 dag: &CircuitDag,
243 remaining: &HashSet<usize>,
244 mapping: &HashMap<usize, usize>,
245 ) -> Vec<usize> {
246 let mut executable = Vec::new();
247
248 for &gate_id in remaining {
249 let node = &dag.nodes()[gate_id];
250
251 let ready = node
253 .predecessors
254 .iter()
255 .all(|&pred| !remaining.contains(&pred));
256
257 if ready && self.is_gate_executable(node, mapping) {
258 executable.push(gate_id);
259 }
260 }
261
262 executable
263 }
264
265 fn is_gate_executable(&self, node: &DagNode, mapping: &HashMap<usize, usize>) -> bool {
267 let qubits = node.gate.qubits();
268
269 if qubits.len() <= 1 {
270 return true; }
272
273 if qubits.len() == 2 {
274 let q1 = qubits[0].id() as usize;
275 let q2 = qubits[1].id() as usize;
276
277 if let (Some(&p1), Some(&p2)) = (mapping.get(&q1), mapping.get(&q2)) {
278 return self.coupling_map.are_connected(p1, p2);
279 }
280 }
281
282 false
283 }
284
285 fn map_gate_to_physical(
287 &self,
288 node: &DagNode,
289 mapping: &HashMap<usize, usize>,
290 ) -> QuantRS2Result<Box<dyn GateOp>> {
291 let qubits = node.gate.qubits();
292 let mut physical_qubits = Vec::new();
293
294 for qubit in qubits {
295 let logical = qubit.id() as usize;
296 if let Some(&physical) = mapping.get(&logical) {
297 physical_qubits.push(QubitId::new(physical as u32));
298 } else {
299 return Err(QuantRS2Error::RoutingError(format!(
300 "Logical qubit {logical} not mapped to physical qubit"
301 )));
302 }
303 }
304
305 self.clone_gate_with_qubits(node.gate.as_ref(), &physical_qubits)
308 }
309
310 fn clone_gate_with_qubits(
312 &self,
313 gate: &dyn GateOp,
314 new_qubits: &[QubitId],
315 ) -> QuantRS2Result<Box<dyn GateOp>> {
316 use quantrs2_core::gate::{multi, single};
317
318 match (gate.name(), new_qubits.len()) {
319 ("H", 1) => Ok(Box::new(single::Hadamard {
320 target: new_qubits[0],
321 })),
322 ("X", 1) => Ok(Box::new(single::PauliX {
323 target: new_qubits[0],
324 })),
325 ("Y", 1) => Ok(Box::new(single::PauliY {
326 target: new_qubits[0],
327 })),
328 ("Z", 1) => Ok(Box::new(single::PauliZ {
329 target: new_qubits[0],
330 })),
331 ("S", 1) => Ok(Box::new(single::Phase {
332 target: new_qubits[0],
333 })),
334 ("T", 1) => Ok(Box::new(single::T {
335 target: new_qubits[0],
336 })),
337 ("CNOT", 2) => Ok(Box::new(multi::CNOT {
338 control: new_qubits[0],
339 target: new_qubits[1],
340 })),
341 ("CZ", 2) => Ok(Box::new(multi::CZ {
342 control: new_qubits[0],
343 target: new_qubits[1],
344 })),
345 ("SWAP", 2) => Ok(Box::new(multi::SWAP {
346 qubit1: new_qubits[0],
347 qubit2: new_qubits[1],
348 })),
349 ("RZ", 1) => {
350 let theta = gate
352 .as_any()
353 .downcast_ref::<RotationZ>()
354 .map(|g| g.theta)
355 .unwrap_or(0.0);
356 Ok(Box::new(RotationZ {
357 target: new_qubits[0],
358 theta,
359 }))
360 }
361 ("RY", 1) => {
362 let theta = gate
363 .as_any()
364 .downcast_ref::<RotationY>()
365 .map(|g| g.theta)
366 .unwrap_or(0.0);
367 Ok(Box::new(RotationY {
368 target: new_qubits[0],
369 theta,
370 }))
371 }
372 ("RX", 1) => {
373 let theta = gate
374 .as_any()
375 .downcast_ref::<RotationX>()
376 .map(|g| g.theta)
377 .unwrap_or(0.0);
378 Ok(Box::new(RotationX {
379 target: new_qubits[0],
380 theta,
381 }))
382 }
383 _ => Err(QuantRS2Error::UnsupportedOperation(format!(
384 "Cannot route gate {} with {} qubits",
385 gate.name(),
386 new_qubits.len()
387 ))),
388 }
389 }
390
391 fn find_best_swaps(
393 &self,
394 dag: &CircuitDag,
395 remaining_gates: &HashSet<usize>,
396 mapping: &HashMap<usize, usize>,
397 ) -> QuantRS2Result<Vec<(usize, usize)>> {
398 let front_layer = self.get_front_layer(dag, remaining_gates);
399 let extended_set = self.get_extended_set(dag, &front_layer);
400
401 let mut swap_scores = HashMap::new();
402
403 for &p1 in &self.get_mapped_physical_qubits(mapping) {
405 for &p2 in self.coupling_map.neighbors(p1) {
406 if p1 < p2 {
407 let score = self.calculate_swap_score(
409 dag,
410 (p1, p2),
411 &front_layer,
412 &extended_set,
413 mapping,
414 );
415 swap_scores.insert((p1, p2), score);
416 }
417 }
418 }
419
420 if swap_scores.is_empty() {
421 return Ok(Vec::new());
422 }
423
424 let mut sorted_swaps: Vec<_> = swap_scores.into_iter().collect();
426
427 if self.config.stochastic {
428 let mut rng = thread_rng();
430 sorted_swaps.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
431 let top_candidates = sorted_swaps.len().min(5);
432
433 if top_candidates > 0 {
434 let idx = rng.random_range(0..top_candidates);
435 Ok(vec![sorted_swaps[idx].0])
436 } else {
437 Ok(Vec::new())
438 }
439 } else {
440 sorted_swaps.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
442 if sorted_swaps.is_empty() {
443 Ok(Vec::new())
444 } else {
445 Ok(vec![sorted_swaps[0].0])
446 }
447 }
448 }
449
450 fn get_front_layer(&self, dag: &CircuitDag, remaining: &HashSet<usize>) -> HashSet<usize> {
452 let mut front_layer = HashSet::new();
453
454 for &gate_id in remaining {
455 let node = &dag.nodes()[gate_id];
456
457 let ready = node
459 .predecessors
460 .iter()
461 .all(|&pred| !remaining.contains(&pred));
462
463 if ready {
464 front_layer.insert(gate_id);
465 }
466 }
467
468 front_layer
469 }
470
471 fn get_extended_set(&self, dag: &CircuitDag, front_layer: &HashSet<usize>) -> HashSet<usize> {
473 let mut extended_set = front_layer.clone();
474 let mut to_visit = VecDeque::new();
475
476 for &gate_id in front_layer {
477 to_visit.push_back((gate_id, 0));
478 }
479
480 while let Some((gate_id, depth)) = to_visit.pop_front() {
481 if depth >= self.config.lookahead_depth {
482 continue;
483 }
484
485 let node = &dag.nodes()[gate_id];
486 for &succ in &node.successors {
487 if extended_set.insert(succ) {
488 to_visit.push_back((succ, depth + 1));
489 }
490 }
491 }
492
493 extended_set
494 }
495
496 fn get_mapped_physical_qubits(&self, mapping: &HashMap<usize, usize>) -> Vec<usize> {
498 mapping.values().copied().collect()
499 }
500
501 fn calculate_swap_score(
503 &self,
504 dag: &CircuitDag,
505 swap: (usize, usize),
506 front_layer: &HashSet<usize>,
507 extended_set: &HashSet<usize>,
508 mapping: &HashMap<usize, usize>,
509 ) -> f64 {
510 let mut temp_mapping = mapping.clone();
512 let (p1, p2) = swap;
513
514 let mut l1_opt = None;
516 let mut l2_opt = None;
517
518 for (&logical, &physical) in mapping {
519 if physical == p1 {
520 l1_opt = Some(logical);
521 } else if physical == p2 {
522 l2_opt = Some(logical);
523 }
524 }
525
526 if let (Some(l1), Some(l2)) = (l1_opt, l2_opt) {
527 temp_mapping.insert(l1, p2);
528 temp_mapping.insert(l2, p1);
529 } else {
530 return -1.0; }
532
533 let front_newly_executable = front_layer
537 .iter()
538 .filter(|&&gate_id| {
539 let node = &dag.nodes()[gate_id];
540 self.is_gate_executable(node, &temp_mapping)
541 })
542 .count() as f64;
543
544 let extended_newly_executable = extended_set
545 .iter()
546 .filter(|&&gate_id| {
547 if front_layer.contains(&gate_id) {
548 return false; }
550 let node = &dag.nodes()[gate_id];
551 self.is_gate_executable(node, &temp_mapping)
552 })
553 .count() as f64;
554
555 let front_size = front_layer.len().max(1) as f64;
558 let raw_score = (front_newly_executable / front_size)
559 + self.config.extended_set_weight * (extended_newly_executable / front_size);
560
561 let decay = 1.0 - self.config.decay_factor;
564 raw_score * decay
565 }
566
567 fn calculate_depth(&self, gates: &[Box<dyn GateOp>]) -> usize {
569 gates.len()
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use quantrs2_core::gate::{multi::CNOT, single::Hadamard};
579
580 #[test]
581 fn test_sabre_basic() {
582 let coupling_map = CouplingMap::linear(3);
583 let config = SabreConfig::basic();
584 let router = SabreRouter::new(coupling_map, config);
585
586 let mut circuit = Circuit::<3>::new();
587 circuit
588 .add_gate(Hadamard { target: QubitId(0) })
589 .expect("add H gate to circuit");
590 circuit
591 .add_gate(CNOT {
592 control: QubitId(0),
593 target: QubitId(2),
594 })
595 .expect("add CNOT gate to circuit");
596
597 let result = router.route(&circuit);
598 assert!(result.is_ok());
599 }
600
601 #[test]
602 fn test_initial_mapping() {
603 let coupling_map = CouplingMap::linear(5);
604 let config = SabreConfig::default();
605 let router = SabreRouter::new(coupling_map, config);
606
607 let mut circuit = Circuit::<3>::new();
608 circuit
609 .add_gate(CNOT {
610 control: QubitId(0),
611 target: QubitId(1),
612 })
613 .expect("add CNOT gate to circuit");
614
615 let dag = circuit_to_dag(&circuit);
616 let mapping = router.initial_mapping(&dag);
617
618 assert!(mapping.contains_key(&0));
619 assert!(mapping.contains_key(&1));
620 }
621}