1use crate::builder::Circuit;
8use crate::dag::{circuit_to_dag, CircuitDag, DagNode};
9use crate::routing::{CouplingMap, RoutedCircuit, RoutingResult};
10use quantrs2_core::{
11 error::{QuantRS2Error, QuantRS2Result},
12 gate::{multi::SWAP, GateOp},
13 qubit::QubitId,
14};
15use scirs2_core::random::{seq::SliceRandom, thread_rng, Rng};
16use std::collections::{HashMap, HashSet, VecDeque};
17
18#[derive(Debug, Clone)]
20pub struct SabreConfig {
21 pub max_iterations: usize,
23 pub lookahead_depth: usize,
25 pub decay_factor: f64,
27 pub extended_set_weight: f64,
29 pub max_swaps_per_iteration: usize,
31 pub stochastic: bool,
33}
34
35impl Default for SabreConfig {
36 fn default() -> Self {
37 Self {
38 max_iterations: 1000,
39 lookahead_depth: 20,
40 decay_factor: 0.001,
41 extended_set_weight: 0.5,
42 max_swaps_per_iteration: 10,
43 stochastic: false,
44 }
45 }
46}
47
48impl SabreConfig {
49 #[must_use]
51 pub const fn basic() -> Self {
52 Self {
53 max_iterations: 100,
54 lookahead_depth: 5,
55 decay_factor: 0.01,
56 extended_set_weight: 0.3,
57 max_swaps_per_iteration: 5,
58 stochastic: false,
59 }
60 }
61
62 #[must_use]
64 pub fn stochastic() -> Self {
65 Self {
66 stochastic: true,
67 ..Default::default()
68 }
69 }
70}
71
72pub struct SabreRouter {
74 coupling_map: CouplingMap,
75 config: SabreConfig,
76}
77
78impl SabreRouter {
79 #[must_use]
81 pub const fn new(coupling_map: CouplingMap, config: SabreConfig) -> Self {
82 Self {
83 coupling_map,
84 config,
85 }
86 }
87
88 pub fn route<const N: usize>(&self, circuit: &Circuit<N>) -> QuantRS2Result<RoutedCircuit<N>> {
90 let dag = circuit_to_dag(circuit);
91 let mut logical_to_physical = self.initial_mapping(&dag);
92 let mut physical_to_logical: HashMap<usize, usize> = logical_to_physical
93 .iter()
94 .map(|(&logical, &physical)| (physical, logical))
95 .collect();
96
97 let mut routed_gates = Vec::new();
98 let mut executable = self.find_executable_gates(&dag, &logical_to_physical);
99 let mut remaining_gates: HashSet<usize> = (0..dag.nodes().len()).collect();
100 let mut iteration = 0;
101
102 while !remaining_gates.is_empty() && iteration < self.config.max_iterations {
103 iteration += 1;
104
105 while let Some(gate_id) = executable.pop() {
107 if remaining_gates.contains(&gate_id) {
108 let node = &dag.nodes()[gate_id];
109 let routed_gate = self.map_gate_to_physical(node, &logical_to_physical)?;
110 routed_gates.push(routed_gate);
111 remaining_gates.remove(&gate_id);
112
113 for &succ in &node.successors {
115 if remaining_gates.contains(&succ)
116 && self.is_gate_executable(&dag.nodes()[succ], &logical_to_physical)
117 {
118 executable.push(succ);
119 }
120 }
121 }
122 }
123
124 if !remaining_gates.is_empty() {
126 let swaps = self.find_best_swaps(&dag, &remaining_gates, &logical_to_physical)?;
127
128 if swaps.is_empty() {
129 return Err(QuantRS2Error::RoutingError(
130 "Cannot find valid SWAP operations".to_string(),
131 ));
132 }
133
134 for (p1, p2) in swaps {
136 let swap_gate = Box::new(SWAP {
138 qubit1: QubitId::new(p1 as u32),
139 qubit2: QubitId::new(p2 as u32),
140 }) as Box<dyn GateOp>;
141 routed_gates.push(swap_gate);
142
143 let l1 = physical_to_logical[&p1];
145 let l2 = physical_to_logical[&p2];
146
147 logical_to_physical.insert(l1, p2);
148 logical_to_physical.insert(l2, p1);
149 physical_to_logical.insert(p1, l2);
150 physical_to_logical.insert(p2, l1);
151 }
152
153 executable = self.find_executable_gates_from_remaining(
155 &dag,
156 &remaining_gates,
157 &logical_to_physical,
158 );
159 }
160 }
161
162 if !remaining_gates.is_empty() {
163 return Err(QuantRS2Error::RoutingError(format!(
164 "Routing failed: {} gates remaining after {} iterations",
165 remaining_gates.len(),
166 iteration
167 )));
168 }
169
170 let total_swaps = routed_gates.iter().filter(|g| g.name() == "SWAP").count();
171 let circuit_depth = self.calculate_depth(&routed_gates);
172
173 Ok(RoutedCircuit::new(
174 routed_gates,
175 logical_to_physical,
176 RoutingResult {
177 total_swaps,
178 circuit_depth,
179 routing_overhead: if circuit_depth > 0 {
180 total_swaps as f64 / circuit_depth as f64
181 } else {
182 0.0
183 },
184 },
185 ))
186 }
187
188 fn initial_mapping(&self, dag: &CircuitDag) -> HashMap<usize, usize> {
190 let mut mapping = HashMap::new();
191 let logical_qubits = self.extract_logical_qubits(dag);
192
193 for (i, &logical) in logical_qubits.iter().enumerate() {
195 if i < self.coupling_map.num_qubits() {
196 mapping.insert(logical, i);
197 }
198 }
199
200 mapping
201 }
202
203 fn extract_logical_qubits(&self, dag: &CircuitDag) -> Vec<usize> {
205 let mut qubits = HashSet::new();
206
207 for node in dag.nodes() {
208 for qubit in node.gate.qubits() {
209 qubits.insert(qubit.id() as usize);
210 }
211 }
212
213 let mut qubit_vec: Vec<usize> = qubits.into_iter().collect();
214 qubit_vec.sort_unstable();
215 qubit_vec
216 }
217
218 fn find_executable_gates(
220 &self,
221 dag: &CircuitDag,
222 mapping: &HashMap<usize, usize>,
223 ) -> Vec<usize> {
224 let mut executable = Vec::new();
225
226 for node in dag.nodes() {
227 if node.predecessors.is_empty() && self.is_gate_executable(node, mapping) {
228 executable.push(node.id);
229 }
230 }
231
232 executable
233 }
234
235 fn find_executable_gates_from_remaining(
237 &self,
238 dag: &CircuitDag,
239 remaining: &HashSet<usize>,
240 mapping: &HashMap<usize, usize>,
241 ) -> Vec<usize> {
242 let mut executable = Vec::new();
243
244 for &gate_id in remaining {
245 let node = &dag.nodes()[gate_id];
246
247 let ready = node
249 .predecessors
250 .iter()
251 .all(|&pred| !remaining.contains(&pred));
252
253 if ready && self.is_gate_executable(node, mapping) {
254 executable.push(gate_id);
255 }
256 }
257
258 executable
259 }
260
261 fn is_gate_executable(&self, node: &DagNode, mapping: &HashMap<usize, usize>) -> bool {
263 let qubits = node.gate.qubits();
264
265 if qubits.len() <= 1 {
266 return true; }
268
269 if qubits.len() == 2 {
270 let q1 = qubits[0].id() as usize;
271 let q2 = qubits[1].id() as usize;
272
273 if let (Some(&p1), Some(&p2)) = (mapping.get(&q1), mapping.get(&q2)) {
274 return self.coupling_map.are_connected(p1, p2);
275 }
276 }
277
278 false
279 }
280
281 fn map_gate_to_physical(
283 &self,
284 node: &DagNode,
285 mapping: &HashMap<usize, usize>,
286 ) -> QuantRS2Result<Box<dyn GateOp>> {
287 let qubits = node.gate.qubits();
288 let mut physical_qubits = Vec::new();
289
290 for qubit in qubits {
291 let logical = qubit.id() as usize;
292 if let Some(&physical) = mapping.get(&logical) {
293 physical_qubits.push(QubitId::new(physical as u32));
294 } else {
295 return Err(QuantRS2Error::RoutingError(format!(
296 "Logical qubit {logical} not mapped to physical qubit"
297 )));
298 }
299 }
300
301 self.clone_gate_with_qubits(node.gate.as_ref(), &physical_qubits)
304 }
305
306 fn clone_gate_with_qubits(
308 &self,
309 gate: &dyn GateOp,
310 new_qubits: &[QubitId],
311 ) -> QuantRS2Result<Box<dyn GateOp>> {
312 use quantrs2_core::gate::{multi, single};
313
314 match (gate.name(), new_qubits.len()) {
315 ("H", 1) => Ok(Box::new(single::Hadamard {
316 target: new_qubits[0],
317 })),
318 ("X", 1) => Ok(Box::new(single::PauliX {
319 target: new_qubits[0],
320 })),
321 ("Y", 1) => Ok(Box::new(single::PauliY {
322 target: new_qubits[0],
323 })),
324 ("Z", 1) => Ok(Box::new(single::PauliZ {
325 target: new_qubits[0],
326 })),
327 ("S", 1) => Ok(Box::new(single::Phase {
328 target: new_qubits[0],
329 })),
330 ("T", 1) => Ok(Box::new(single::T {
331 target: new_qubits[0],
332 })),
333 ("CNOT", 2) => Ok(Box::new(multi::CNOT {
334 control: new_qubits[0],
335 target: new_qubits[1],
336 })),
337 ("CZ", 2) => Ok(Box::new(multi::CZ {
338 control: new_qubits[0],
339 target: new_qubits[1],
340 })),
341 ("SWAP", 2) => Ok(Box::new(multi::SWAP {
342 qubit1: new_qubits[0],
343 qubit2: new_qubits[1],
344 })),
345 _ => Err(QuantRS2Error::UnsupportedOperation(format!(
346 "Cannot route gate {} with {} qubits",
347 gate.name(),
348 new_qubits.len()
349 ))),
350 }
351 }
352
353 fn find_best_swaps(
355 &self,
356 dag: &CircuitDag,
357 remaining_gates: &HashSet<usize>,
358 mapping: &HashMap<usize, usize>,
359 ) -> QuantRS2Result<Vec<(usize, usize)>> {
360 let front_layer = self.get_front_layer(dag, remaining_gates);
361 let extended_set = self.get_extended_set(dag, &front_layer);
362
363 let mut swap_scores = HashMap::new();
364
365 for &p1 in &self.get_mapped_physical_qubits(mapping) {
367 for &p2 in self.coupling_map.neighbors(p1) {
368 if p1 < p2 {
369 let score =
371 self.calculate_swap_score((p1, p2), &front_layer, &extended_set, mapping);
372 swap_scores.insert((p1, p2), score);
373 }
374 }
375 }
376
377 if swap_scores.is_empty() {
378 return Ok(Vec::new());
379 }
380
381 let mut sorted_swaps: Vec<_> = swap_scores.into_iter().collect();
383
384 if self.config.stochastic {
385 let mut rng = thread_rng();
387 sorted_swaps.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
388 let top_candidates = sorted_swaps.len().min(5);
389
390 if top_candidates > 0 {
391 let idx = rng.gen_range(0..top_candidates);
392 Ok(vec![sorted_swaps[idx].0])
393 } else {
394 Ok(Vec::new())
395 }
396 } else {
397 sorted_swaps.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
399 if sorted_swaps.is_empty() {
400 Ok(Vec::new())
401 } else {
402 Ok(vec![sorted_swaps[0].0])
403 }
404 }
405 }
406
407 fn get_front_layer(&self, dag: &CircuitDag, remaining: &HashSet<usize>) -> HashSet<usize> {
409 let mut front_layer = HashSet::new();
410
411 for &gate_id in remaining {
412 let node = &dag.nodes()[gate_id];
413
414 let ready = node
416 .predecessors
417 .iter()
418 .all(|&pred| !remaining.contains(&pred));
419
420 if ready {
421 front_layer.insert(gate_id);
422 }
423 }
424
425 front_layer
426 }
427
428 fn get_extended_set(&self, dag: &CircuitDag, front_layer: &HashSet<usize>) -> HashSet<usize> {
430 let mut extended_set = front_layer.clone();
431 let mut to_visit = VecDeque::new();
432
433 for &gate_id in front_layer {
434 to_visit.push_back((gate_id, 0));
435 }
436
437 while let Some((gate_id, depth)) = to_visit.pop_front() {
438 if depth >= self.config.lookahead_depth {
439 continue;
440 }
441
442 let node = &dag.nodes()[gate_id];
443 for &succ in &node.successors {
444 if extended_set.insert(succ) {
445 to_visit.push_back((succ, depth + 1));
446 }
447 }
448 }
449
450 extended_set
451 }
452
453 fn get_mapped_physical_qubits(&self, mapping: &HashMap<usize, usize>) -> Vec<usize> {
455 mapping.values().copied().collect()
456 }
457
458 fn calculate_swap_score(
460 &self,
461 swap: (usize, usize),
462 front_layer: &HashSet<usize>,
463 extended_set: &HashSet<usize>,
464 mapping: &HashMap<usize, usize>,
465 ) -> f64 {
466 let mut temp_mapping = mapping.clone();
468 let (p1, p2) = swap;
469
470 let mut l1_opt = None;
472 let mut l2_opt = None;
473
474 for (&logical, &physical) in mapping {
475 if physical == p1 {
476 l1_opt = Some(logical);
477 } else if physical == p2 {
478 l2_opt = Some(logical);
479 }
480 }
481
482 if let (Some(l1), Some(l2)) = (l1_opt, l2_opt) {
483 temp_mapping.insert(l1, p2);
484 temp_mapping.insert(l2, p1);
485 } else {
486 return -1.0; }
488
489 0.0
495 }
496
497 fn calculate_depth(&self, gates: &[Box<dyn GateOp>]) -> usize {
499 gates.len()
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use quantrs2_core::gate::{multi::CNOT, single::Hadamard};
509
510 #[test]
511 fn test_sabre_basic() {
512 let coupling_map = CouplingMap::linear(3);
513 let config = SabreConfig::basic();
514 let router = SabreRouter::new(coupling_map, config);
515
516 let mut circuit = Circuit::<3>::new();
517 circuit
518 .add_gate(Hadamard { target: QubitId(0) })
519 .expect("add H gate to circuit");
520 circuit
521 .add_gate(CNOT {
522 control: QubitId(0),
523 target: QubitId(2),
524 })
525 .expect("add CNOT gate to circuit");
526
527 let result = router.route(&circuit);
528 assert!(result.is_ok());
529 }
530
531 #[test]
532 fn test_initial_mapping() {
533 let coupling_map = CouplingMap::linear(5);
534 let config = SabreConfig::default();
535 let router = SabreRouter::new(coupling_map, config);
536
537 let mut circuit = Circuit::<3>::new();
538 circuit
539 .add_gate(CNOT {
540 control: QubitId(0),
541 target: QubitId(1),
542 })
543 .expect("add CNOT gate to circuit");
544
545 let dag = circuit_to_dag(&circuit);
546 let mapping = router.initial_mapping(&dag);
547
548 assert!(mapping.contains_key(&0));
549 assert!(mapping.contains_key(&1));
550 }
551}