1use crate::builder::Circuit;
8use crate::dag::{circuit_to_dag, CircuitDag, DagNode};
9use crate::routing::{CouplingMap, RoutedCircuit, RoutingResult};
10use quantrs2_core::{
11 error::{QuantRS2Error, QuantRS2Result},
12 gate::{multi::SWAP, GateOp},
13 qubit::QubitId,
14};
15use scirs2_core::random::{seq::SliceRandom, thread_rng, Rng};
16use std::collections::{HashMap, HashSet, VecDeque};
17
18#[derive(Debug, Clone)]
20pub struct SabreConfig {
21 pub max_iterations: usize,
23 pub lookahead_depth: usize,
25 pub decay_factor: f64,
27 pub extended_set_weight: f64,
29 pub max_swaps_per_iteration: usize,
31 pub stochastic: bool,
33}
34
35impl Default for SabreConfig {
36 fn default() -> Self {
37 Self {
38 max_iterations: 1000,
39 lookahead_depth: 20,
40 decay_factor: 0.001,
41 extended_set_weight: 0.5,
42 max_swaps_per_iteration: 10,
43 stochastic: false,
44 }
45 }
46}
47
48impl SabreConfig {
49 #[must_use]
51 pub const fn basic() -> Self {
52 Self {
53 max_iterations: 100,
54 lookahead_depth: 5,
55 decay_factor: 0.01,
56 extended_set_weight: 0.3,
57 max_swaps_per_iteration: 5,
58 stochastic: false,
59 }
60 }
61
62 #[must_use]
64 pub fn stochastic() -> Self {
65 Self {
66 stochastic: true,
67 ..Default::default()
68 }
69 }
70}
71
72pub struct SabreRouter {
74 coupling_map: CouplingMap,
75 config: SabreConfig,
76}
77
78impl SabreRouter {
79 #[must_use]
81 pub const fn new(coupling_map: CouplingMap, config: SabreConfig) -> Self {
82 Self {
83 coupling_map,
84 config,
85 }
86 }
87
88 pub fn route<const N: usize>(&self, circuit: &Circuit<N>) -> QuantRS2Result<RoutedCircuit<N>> {
90 let dag = circuit_to_dag(circuit);
91 let mut logical_to_physical = self.initial_mapping(&dag);
92 let mut physical_to_logical: HashMap<usize, usize> = logical_to_physical
93 .iter()
94 .map(|(&logical, &physical)| (physical, logical))
95 .collect();
96
97 let mut routed_gates = Vec::new();
98 let mut executable = self.find_executable_gates(&dag, &logical_to_physical);
99 let mut remaining_gates: HashSet<usize> = (0..dag.nodes().len()).collect();
100 let mut iteration = 0;
101
102 while !remaining_gates.is_empty() && iteration < self.config.max_iterations {
103 iteration += 1;
104
105 while let Some(gate_id) = executable.pop() {
107 if remaining_gates.contains(&gate_id) {
108 let node = &dag.nodes()[gate_id];
109 let routed_gate = self.map_gate_to_physical(node, &logical_to_physical)?;
110 routed_gates.push(routed_gate);
111 remaining_gates.remove(&gate_id);
112
113 for &succ in &node.successors {
115 if remaining_gates.contains(&succ)
116 && self.is_gate_executable(&dag.nodes()[succ], &logical_to_physical)
117 {
118 executable.push(succ);
119 }
120 }
121 }
122 }
123
124 if !remaining_gates.is_empty() {
126 let swaps = self.find_best_swaps(&dag, &remaining_gates, &logical_to_physical)?;
127
128 if swaps.is_empty() {
129 return Err(QuantRS2Error::RoutingError(
130 "Cannot find valid SWAP operations".to_string(),
131 ));
132 }
133
134 for (p1, p2) in swaps {
136 let swap_gate = Box::new(SWAP {
138 qubit1: QubitId::new(p1 as u32),
139 qubit2: QubitId::new(p2 as u32),
140 }) as Box<dyn GateOp>;
141 routed_gates.push(swap_gate);
142
143 let l1 = physical_to_logical[&p1];
145 let l2 = physical_to_logical[&p2];
146
147 logical_to_physical.insert(l1, p2);
148 logical_to_physical.insert(l2, p1);
149 physical_to_logical.insert(p1, l2);
150 physical_to_logical.insert(p2, l1);
151 }
152
153 executable = self.find_executable_gates_from_remaining(
155 &dag,
156 &remaining_gates,
157 &logical_to_physical,
158 );
159 }
160 }
161
162 if !remaining_gates.is_empty() {
163 return Err(QuantRS2Error::RoutingError(format!(
164 "Routing failed: {} gates remaining after {} iterations",
165 remaining_gates.len(),
166 iteration
167 )));
168 }
169
170 let total_swaps = routed_gates.iter().filter(|g| g.name() == "SWAP").count();
171 let circuit_depth = self.calculate_depth(&routed_gates);
172
173 Ok(RoutedCircuit::new(
174 routed_gates,
175 logical_to_physical,
176 RoutingResult {
177 total_swaps,
178 circuit_depth,
179 routing_overhead: if circuit_depth > 0 {
180 total_swaps as f64 / circuit_depth as f64
181 } else {
182 0.0
183 },
184 },
185 ))
186 }
187
188 fn initial_mapping(&self, dag: &CircuitDag) -> HashMap<usize, usize> {
190 let mut mapping = HashMap::new();
191 let logical_qubits = self.extract_logical_qubits(dag);
192
193 for (i, &logical) in logical_qubits.iter().enumerate() {
195 if i < self.coupling_map.num_qubits() {
196 mapping.insert(logical, i);
197 }
198 }
199
200 mapping
201 }
202
203 fn extract_logical_qubits(&self, dag: &CircuitDag) -> Vec<usize> {
205 let mut qubits = HashSet::new();
206
207 for node in dag.nodes() {
208 for qubit in node.gate.qubits() {
209 qubits.insert(qubit.id() as usize);
210 }
211 }
212
213 let mut qubit_vec: Vec<usize> = qubits.into_iter().collect();
214 qubit_vec.sort_unstable();
215 qubit_vec
216 }
217
218 fn find_executable_gates(
220 &self,
221 dag: &CircuitDag,
222 mapping: &HashMap<usize, usize>,
223 ) -> Vec<usize> {
224 let mut executable = Vec::new();
225
226 for node in dag.nodes() {
227 if node.predecessors.is_empty() && self.is_gate_executable(node, mapping) {
228 executable.push(node.id);
229 }
230 }
231
232 executable
233 }
234
235 fn find_executable_gates_from_remaining(
237 &self,
238 dag: &CircuitDag,
239 remaining: &HashSet<usize>,
240 mapping: &HashMap<usize, usize>,
241 ) -> Vec<usize> {
242 let mut executable = Vec::new();
243
244 for &gate_id in remaining {
245 let node = &dag.nodes()[gate_id];
246
247 let ready = node
249 .predecessors
250 .iter()
251 .all(|&pred| !remaining.contains(&pred));
252
253 if ready && self.is_gate_executable(node, mapping) {
254 executable.push(gate_id);
255 }
256 }
257
258 executable
259 }
260
261 fn is_gate_executable(&self, node: &DagNode, mapping: &HashMap<usize, usize>) -> bool {
263 let qubits = node.gate.qubits();
264
265 if qubits.len() <= 1 {
266 return true; }
268
269 if qubits.len() == 2 {
270 let q1 = qubits[0].id() as usize;
271 let q2 = qubits[1].id() as usize;
272
273 if let (Some(&p1), Some(&p2)) = (mapping.get(&q1), mapping.get(&q2)) {
274 return self.coupling_map.are_connected(p1, p2);
275 }
276 }
277
278 false
279 }
280
281 fn map_gate_to_physical(
283 &self,
284 node: &DagNode,
285 mapping: &HashMap<usize, usize>,
286 ) -> QuantRS2Result<Box<dyn GateOp>> {
287 let qubits = node.gate.qubits();
288 let mut physical_qubits = Vec::new();
289
290 for qubit in qubits {
291 let logical = qubit.id() as usize;
292 if let Some(&physical) = mapping.get(&logical) {
293 physical_qubits.push(QubitId::new(physical as u32));
294 } else {
295 return Err(QuantRS2Error::RoutingError(format!(
296 "Logical qubit {logical} not mapped to physical qubit"
297 )));
298 }
299 }
300
301 self.clone_gate_with_qubits(node.gate.as_ref(), &physical_qubits)
304 }
305
306 fn clone_gate_with_qubits(
308 &self,
309 gate: &dyn GateOp,
310 new_qubits: &[QubitId],
311 ) -> QuantRS2Result<Box<dyn GateOp>> {
312 use quantrs2_core::gate::{multi, single};
313
314 match (gate.name(), new_qubits.len()) {
315 ("H", 1) => Ok(Box::new(single::Hadamard {
316 target: new_qubits[0],
317 })),
318 ("X", 1) => Ok(Box::new(single::PauliX {
319 target: new_qubits[0],
320 })),
321 ("Y", 1) => Ok(Box::new(single::PauliY {
322 target: new_qubits[0],
323 })),
324 ("Z", 1) => Ok(Box::new(single::PauliZ {
325 target: new_qubits[0],
326 })),
327 ("S", 1) => Ok(Box::new(single::Phase {
328 target: new_qubits[0],
329 })),
330 ("T", 1) => Ok(Box::new(single::T {
331 target: new_qubits[0],
332 })),
333 ("CNOT", 2) => Ok(Box::new(multi::CNOT {
334 control: new_qubits[0],
335 target: new_qubits[1],
336 })),
337 ("CZ", 2) => Ok(Box::new(multi::CZ {
338 control: new_qubits[0],
339 target: new_qubits[1],
340 })),
341 ("SWAP", 2) => Ok(Box::new(multi::SWAP {
342 qubit1: new_qubits[0],
343 qubit2: new_qubits[1],
344 })),
345 _ => Err(QuantRS2Error::UnsupportedOperation(format!(
346 "Cannot route gate {} with {} qubits",
347 gate.name(),
348 new_qubits.len()
349 ))),
350 }
351 }
352
353 fn find_best_swaps(
355 &self,
356 dag: &CircuitDag,
357 remaining_gates: &HashSet<usize>,
358 mapping: &HashMap<usize, usize>,
359 ) -> QuantRS2Result<Vec<(usize, usize)>> {
360 let front_layer = self.get_front_layer(dag, remaining_gates);
361 let extended_set = self.get_extended_set(dag, &front_layer);
362
363 let mut swap_scores = HashMap::new();
364
365 for &p1 in &self.get_mapped_physical_qubits(mapping) {
367 for &p2 in self.coupling_map.neighbors(p1) {
368 if p1 < p2 {
369 let score = self.calculate_swap_score(
371 dag,
372 (p1, p2),
373 &front_layer,
374 &extended_set,
375 mapping,
376 );
377 swap_scores.insert((p1, p2), score);
378 }
379 }
380 }
381
382 if swap_scores.is_empty() {
383 return Ok(Vec::new());
384 }
385
386 let mut sorted_swaps: Vec<_> = swap_scores.into_iter().collect();
388
389 if self.config.stochastic {
390 let mut rng = thread_rng();
392 sorted_swaps.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
393 let top_candidates = sorted_swaps.len().min(5);
394
395 if top_candidates > 0 {
396 let idx = rng.random_range(0..top_candidates);
397 Ok(vec![sorted_swaps[idx].0])
398 } else {
399 Ok(Vec::new())
400 }
401 } else {
402 sorted_swaps.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
404 if sorted_swaps.is_empty() {
405 Ok(Vec::new())
406 } else {
407 Ok(vec![sorted_swaps[0].0])
408 }
409 }
410 }
411
412 fn get_front_layer(&self, dag: &CircuitDag, remaining: &HashSet<usize>) -> HashSet<usize> {
414 let mut front_layer = HashSet::new();
415
416 for &gate_id in remaining {
417 let node = &dag.nodes()[gate_id];
418
419 let ready = node
421 .predecessors
422 .iter()
423 .all(|&pred| !remaining.contains(&pred));
424
425 if ready {
426 front_layer.insert(gate_id);
427 }
428 }
429
430 front_layer
431 }
432
433 fn get_extended_set(&self, dag: &CircuitDag, front_layer: &HashSet<usize>) -> HashSet<usize> {
435 let mut extended_set = front_layer.clone();
436 let mut to_visit = VecDeque::new();
437
438 for &gate_id in front_layer {
439 to_visit.push_back((gate_id, 0));
440 }
441
442 while let Some((gate_id, depth)) = to_visit.pop_front() {
443 if depth >= self.config.lookahead_depth {
444 continue;
445 }
446
447 let node = &dag.nodes()[gate_id];
448 for &succ in &node.successors {
449 if extended_set.insert(succ) {
450 to_visit.push_back((succ, depth + 1));
451 }
452 }
453 }
454
455 extended_set
456 }
457
458 fn get_mapped_physical_qubits(&self, mapping: &HashMap<usize, usize>) -> Vec<usize> {
460 mapping.values().copied().collect()
461 }
462
463 fn calculate_swap_score(
465 &self,
466 dag: &CircuitDag,
467 swap: (usize, usize),
468 front_layer: &HashSet<usize>,
469 extended_set: &HashSet<usize>,
470 mapping: &HashMap<usize, usize>,
471 ) -> f64 {
472 let mut temp_mapping = mapping.clone();
474 let (p1, p2) = swap;
475
476 let mut l1_opt = None;
478 let mut l2_opt = None;
479
480 for (&logical, &physical) in mapping {
481 if physical == p1 {
482 l1_opt = Some(logical);
483 } else if physical == p2 {
484 l2_opt = Some(logical);
485 }
486 }
487
488 if let (Some(l1), Some(l2)) = (l1_opt, l2_opt) {
489 temp_mapping.insert(l1, p2);
490 temp_mapping.insert(l2, p1);
491 } else {
492 return -1.0; }
494
495 let front_newly_executable = front_layer
499 .iter()
500 .filter(|&&gate_id| {
501 let node = &dag.nodes()[gate_id];
502 self.is_gate_executable(node, &temp_mapping)
503 })
504 .count() as f64;
505
506 let extended_newly_executable = extended_set
507 .iter()
508 .filter(|&&gate_id| {
509 if front_layer.contains(&gate_id) {
510 return false; }
512 let node = &dag.nodes()[gate_id];
513 self.is_gate_executable(node, &temp_mapping)
514 })
515 .count() as f64;
516
517 let front_size = front_layer.len().max(1) as f64;
520 let raw_score = (front_newly_executable / front_size)
521 + self.config.extended_set_weight * (extended_newly_executable / front_size);
522
523 let decay = 1.0 - self.config.decay_factor;
526 raw_score * decay
527 }
528
529 fn calculate_depth(&self, gates: &[Box<dyn GateOp>]) -> usize {
531 gates.len()
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540 use quantrs2_core::gate::{multi::CNOT, single::Hadamard};
541
542 #[test]
543 fn test_sabre_basic() {
544 let coupling_map = CouplingMap::linear(3);
545 let config = SabreConfig::basic();
546 let router = SabreRouter::new(coupling_map, config);
547
548 let mut circuit = Circuit::<3>::new();
549 circuit
550 .add_gate(Hadamard { target: QubitId(0) })
551 .expect("add H gate to circuit");
552 circuit
553 .add_gate(CNOT {
554 control: QubitId(0),
555 target: QubitId(2),
556 })
557 .expect("add CNOT gate to circuit");
558
559 let result = router.route(&circuit);
560 assert!(result.is_ok());
561 }
562
563 #[test]
564 fn test_initial_mapping() {
565 let coupling_map = CouplingMap::linear(5);
566 let config = SabreConfig::default();
567 let router = SabreRouter::new(coupling_map, config);
568
569 let mut circuit = Circuit::<3>::new();
570 circuit
571 .add_gate(CNOT {
572 control: QubitId(0),
573 target: QubitId(1),
574 })
575 .expect("add CNOT gate to circuit");
576
577 let dag = circuit_to_dag(&circuit);
578 let mapping = router.initial_mapping(&dag);
579
580 assert!(mapping.contains_key(&0));
581 assert!(mapping.contains_key(&1));
582 }
583}