Skip to main content

quantrs2_circuit/routing/
sabre.rs

1//! SABRE (SWAP-based `BidiREctional`) routing algorithm
2//!
3//! Based on the paper "Tackling the Qubit Mapping Problem for NISQ-Era Quantum Devices"
4//! by Gushu Li et al. This implementation provides efficient routing for quantum circuits
5//! on limited connectivity devices.
6
7use crate::builder::Circuit;
8use crate::dag::{circuit_to_dag, CircuitDag, DagNode};
9use crate::routing::{CouplingMap, RoutedCircuit, RoutingResult};
10use quantrs2_core::{
11    error::{QuantRS2Error, QuantRS2Result},
12    gate::{
13        multi::SWAP,
14        single::{RotationX, RotationY, RotationZ},
15        GateOp,
16    },
17    qubit::QubitId,
18};
19use scirs2_core::random::{seq::SliceRandom, thread_rng, Rng};
20use std::collections::{HashMap, HashSet, VecDeque};
21
22/// Configuration for the SABRE router
23#[derive(Debug, Clone)]
24pub struct SabreConfig {
25    /// Maximum number of iterations for the routing process
26    pub max_iterations: usize,
27    /// Number of lookahead layers to consider
28    pub lookahead_depth: usize,
29    /// Decay factor for distance calculation
30    pub decay_factor: f64,
31    /// Weight for extended set calculation
32    pub extended_set_weight: f64,
33    /// Maximum number of SWAP insertions per iteration
34    pub max_swaps_per_iteration: usize,
35    /// Enable stochastic tie-breaking
36    pub stochastic: bool,
37}
38
39impl Default for SabreConfig {
40    fn default() -> Self {
41        Self {
42            max_iterations: 1000,
43            lookahead_depth: 20,
44            decay_factor: 0.001,
45            extended_set_weight: 0.5,
46            max_swaps_per_iteration: 10,
47            stochastic: false,
48        }
49    }
50}
51
52impl SabreConfig {
53    /// Create a basic configuration with minimal overhead
54    #[must_use]
55    pub const fn basic() -> Self {
56        Self {
57            max_iterations: 100,
58            lookahead_depth: 5,
59            decay_factor: 0.01,
60            extended_set_weight: 0.3,
61            max_swaps_per_iteration: 5,
62            stochastic: false,
63        }
64    }
65
66    /// Create a stochastic configuration for multiple trials
67    #[must_use]
68    pub fn stochastic() -> Self {
69        Self {
70            stochastic: true,
71            ..Default::default()
72        }
73    }
74}
75
76/// SABRE routing algorithm implementation
77pub struct SabreRouter {
78    coupling_map: CouplingMap,
79    config: SabreConfig,
80}
81
82impl SabreRouter {
83    /// Create a new SABRE router
84    #[must_use]
85    pub const fn new(coupling_map: CouplingMap, config: SabreConfig) -> Self {
86        Self {
87            coupling_map,
88            config,
89        }
90    }
91
92    /// Route a circuit using the SABRE algorithm
93    pub fn route<const N: usize>(&self, circuit: &Circuit<N>) -> QuantRS2Result<RoutedCircuit<N>> {
94        let dag = circuit_to_dag(circuit);
95        let mut logical_to_physical = self.initial_mapping(&dag);
96        let mut physical_to_logical: HashMap<usize, usize> = logical_to_physical
97            .iter()
98            .map(|(&logical, &physical)| (physical, logical))
99            .collect();
100
101        let mut routed_gates = Vec::new();
102        let mut executable = self.find_executable_gates(&dag, &logical_to_physical);
103        let mut remaining_gates: HashSet<usize> = (0..dag.nodes().len()).collect();
104        let mut iteration = 0;
105
106        while !remaining_gates.is_empty() && iteration < self.config.max_iterations {
107            iteration += 1;
108
109            // Execute all possible gates
110            while let Some(gate_id) = executable.pop() {
111                if remaining_gates.contains(&gate_id) {
112                    let node = &dag.nodes()[gate_id];
113                    let routed_gate = self.map_gate_to_physical(node, &logical_to_physical)?;
114                    routed_gates.push(routed_gate);
115                    remaining_gates.remove(&gate_id);
116
117                    // Update executable gates
118                    for &succ in &node.successors {
119                        if remaining_gates.contains(&succ)
120                            && self.is_gate_executable(&dag.nodes()[succ], &logical_to_physical)
121                        {
122                            executable.push(succ);
123                        }
124                    }
125                }
126            }
127
128            // If no more gates can be executed, insert SWAPs
129            if !remaining_gates.is_empty() {
130                let swaps = self.find_best_swaps(&dag, &remaining_gates, &logical_to_physical)?;
131
132                if swaps.is_empty() {
133                    return Err(QuantRS2Error::RoutingError(
134                        "Cannot find valid SWAP operations".to_string(),
135                    ));
136                }
137
138                // Apply SWAPs
139                for (p1, p2) in swaps {
140                    // Add SWAP gate to routed circuit
141                    let swap_gate = Box::new(SWAP {
142                        qubit1: QubitId::new(p1 as u32),
143                        qubit2: QubitId::new(p2 as u32),
144                    }) as Box<dyn GateOp>;
145                    routed_gates.push(swap_gate);
146
147                    // Update mappings
148                    let l1 = physical_to_logical[&p1];
149                    let l2 = physical_to_logical[&p2];
150
151                    logical_to_physical.insert(l1, p2);
152                    logical_to_physical.insert(l2, p1);
153                    physical_to_logical.insert(p1, l2);
154                    physical_to_logical.insert(p2, l1);
155                }
156
157                // Update executable gates after SWAP
158                executable = self.find_executable_gates_from_remaining(
159                    &dag,
160                    &remaining_gates,
161                    &logical_to_physical,
162                );
163            }
164        }
165
166        if !remaining_gates.is_empty() {
167            return Err(QuantRS2Error::RoutingError(format!(
168                "Routing failed: {} gates remaining after {} iterations",
169                remaining_gates.len(),
170                iteration
171            )));
172        }
173
174        let total_swaps = routed_gates.iter().filter(|g| g.name() == "SWAP").count();
175        let circuit_depth = self.calculate_depth(&routed_gates);
176
177        Ok(RoutedCircuit::new(
178            routed_gates,
179            logical_to_physical,
180            RoutingResult {
181                total_swaps,
182                circuit_depth,
183                routing_overhead: if circuit_depth > 0 {
184                    total_swaps as f64 / circuit_depth as f64
185                } else {
186                    0.0
187                },
188            },
189        ))
190    }
191
192    /// Create initial mapping using a simple heuristic
193    fn initial_mapping(&self, dag: &CircuitDag) -> HashMap<usize, usize> {
194        let mut mapping = HashMap::new();
195        let logical_qubits = self.extract_logical_qubits(dag);
196
197        // Simple strategy: map to the first available physical qubits
198        for (i, &logical) in logical_qubits.iter().enumerate() {
199            if i < self.coupling_map.num_qubits() {
200                mapping.insert(logical, i);
201            }
202        }
203
204        mapping
205    }
206
207    /// Extract logical qubits from the DAG
208    fn extract_logical_qubits(&self, dag: &CircuitDag) -> Vec<usize> {
209        let mut qubits = HashSet::new();
210
211        for node in dag.nodes() {
212            for qubit in node.gate.qubits() {
213                qubits.insert(qubit.id() as usize);
214            }
215        }
216
217        let mut qubit_vec: Vec<usize> = qubits.into_iter().collect();
218        qubit_vec.sort_unstable();
219        qubit_vec
220    }
221
222    /// Find gates that can be executed with current mapping
223    fn find_executable_gates(
224        &self,
225        dag: &CircuitDag,
226        mapping: &HashMap<usize, usize>,
227    ) -> Vec<usize> {
228        let mut executable = Vec::new();
229
230        for node in dag.nodes() {
231            if node.predecessors.is_empty() && self.is_gate_executable(node, mapping) {
232                executable.push(node.id);
233            }
234        }
235
236        executable
237    }
238
239    /// Find executable gates from remaining set
240    fn find_executable_gates_from_remaining(
241        &self,
242        dag: &CircuitDag,
243        remaining: &HashSet<usize>,
244        mapping: &HashMap<usize, usize>,
245    ) -> Vec<usize> {
246        let mut executable = Vec::new();
247
248        for &gate_id in remaining {
249            let node = &dag.nodes()[gate_id];
250
251            // Check if all predecessors are executed
252            let ready = node
253                .predecessors
254                .iter()
255                .all(|&pred| !remaining.contains(&pred));
256
257            if ready && self.is_gate_executable(node, mapping) {
258                executable.push(gate_id);
259            }
260        }
261
262        executable
263    }
264
265    /// Check if a gate can be executed with current mapping
266    fn is_gate_executable(&self, node: &DagNode, mapping: &HashMap<usize, usize>) -> bool {
267        let qubits = node.gate.qubits();
268
269        if qubits.len() <= 1 {
270            return true; // Single-qubit gates are always executable
271        }
272
273        if qubits.len() == 2 {
274            let q1 = qubits[0].id() as usize;
275            let q2 = qubits[1].id() as usize;
276
277            if let (Some(&p1), Some(&p2)) = (mapping.get(&q1), mapping.get(&q2)) {
278                return self.coupling_map.are_connected(p1, p2);
279            }
280        }
281
282        false
283    }
284
285    /// Map a logical gate to physical qubits
286    fn map_gate_to_physical(
287        &self,
288        node: &DagNode,
289        mapping: &HashMap<usize, usize>,
290    ) -> QuantRS2Result<Box<dyn GateOp>> {
291        let qubits = node.gate.qubits();
292        let mut physical_qubits = Vec::new();
293
294        for qubit in qubits {
295            let logical = qubit.id() as usize;
296            if let Some(&physical) = mapping.get(&logical) {
297                physical_qubits.push(QubitId::new(physical as u32));
298            } else {
299                return Err(QuantRS2Error::RoutingError(format!(
300                    "Logical qubit {logical} not mapped to physical qubit"
301                )));
302            }
303        }
304
305        // Clone the gate with new physical qubits
306        // This is a simplified implementation - in practice, we'd need to handle each gate type
307        self.clone_gate_with_qubits(node.gate.as_ref(), &physical_qubits)
308    }
309
310    /// Clone a gate with new qubits (simplified implementation)
311    fn clone_gate_with_qubits(
312        &self,
313        gate: &dyn GateOp,
314        new_qubits: &[QubitId],
315    ) -> QuantRS2Result<Box<dyn GateOp>> {
316        use quantrs2_core::gate::{multi, single};
317
318        match (gate.name(), new_qubits.len()) {
319            ("H", 1) => Ok(Box::new(single::Hadamard {
320                target: new_qubits[0],
321            })),
322            ("X", 1) => Ok(Box::new(single::PauliX {
323                target: new_qubits[0],
324            })),
325            ("Y", 1) => Ok(Box::new(single::PauliY {
326                target: new_qubits[0],
327            })),
328            ("Z", 1) => Ok(Box::new(single::PauliZ {
329                target: new_qubits[0],
330            })),
331            ("S", 1) => Ok(Box::new(single::Phase {
332                target: new_qubits[0],
333            })),
334            ("T", 1) => Ok(Box::new(single::T {
335                target: new_qubits[0],
336            })),
337            ("CNOT", 2) => Ok(Box::new(multi::CNOT {
338                control: new_qubits[0],
339                target: new_qubits[1],
340            })),
341            ("CZ", 2) => Ok(Box::new(multi::CZ {
342                control: new_qubits[0],
343                target: new_qubits[1],
344            })),
345            ("SWAP", 2) => Ok(Box::new(multi::SWAP {
346                qubit1: new_qubits[0],
347                qubit2: new_qubits[1],
348            })),
349            ("RZ", 1) => {
350                // Parameterized single-qubit gate: extract angle via downcast
351                let theta = gate
352                    .as_any()
353                    .downcast_ref::<RotationZ>()
354                    .map(|g| g.theta)
355                    .unwrap_or(0.0);
356                Ok(Box::new(RotationZ {
357                    target: new_qubits[0],
358                    theta,
359                }))
360            }
361            ("RY", 1) => {
362                let theta = gate
363                    .as_any()
364                    .downcast_ref::<RotationY>()
365                    .map(|g| g.theta)
366                    .unwrap_or(0.0);
367                Ok(Box::new(RotationY {
368                    target: new_qubits[0],
369                    theta,
370                }))
371            }
372            ("RX", 1) => {
373                let theta = gate
374                    .as_any()
375                    .downcast_ref::<RotationX>()
376                    .map(|g| g.theta)
377                    .unwrap_or(0.0);
378                Ok(Box::new(RotationX {
379                    target: new_qubits[0],
380                    theta,
381                }))
382            }
383            _ => Err(QuantRS2Error::UnsupportedOperation(format!(
384                "Cannot route gate {} with {} qubits",
385                gate.name(),
386                new_qubits.len()
387            ))),
388        }
389    }
390
391    /// Find the best SWAP operations to enable more gates
392    fn find_best_swaps(
393        &self,
394        dag: &CircuitDag,
395        remaining_gates: &HashSet<usize>,
396        mapping: &HashMap<usize, usize>,
397    ) -> QuantRS2Result<Vec<(usize, usize)>> {
398        let front_layer = self.get_front_layer(dag, remaining_gates);
399        let extended_set = self.get_extended_set(dag, &front_layer);
400
401        let mut swap_scores = HashMap::new();
402
403        // Score all possible SWAPs
404        for &p1 in &self.get_mapped_physical_qubits(mapping) {
405            for &p2 in self.coupling_map.neighbors(p1) {
406                if p1 < p2 {
407                    // Avoid duplicate pairs
408                    let score = self.calculate_swap_score(
409                        dag,
410                        (p1, p2),
411                        &front_layer,
412                        &extended_set,
413                        mapping,
414                    );
415                    swap_scores.insert((p1, p2), score);
416                }
417            }
418        }
419
420        if swap_scores.is_empty() {
421            return Ok(Vec::new());
422        }
423
424        // Select best SWAP(s)
425        let mut sorted_swaps: Vec<_> = swap_scores.into_iter().collect();
426
427        if self.config.stochastic {
428            // Stochastic selection from top candidates
429            let mut rng = thread_rng();
430            sorted_swaps.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
431            let top_candidates = sorted_swaps.len().min(5);
432
433            if top_candidates > 0 {
434                let idx = rng.random_range(0..top_candidates);
435                Ok(vec![sorted_swaps[idx].0])
436            } else {
437                Ok(Vec::new())
438            }
439        } else {
440            // Deterministic selection of best SWAP
441            sorted_swaps.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
442            if sorted_swaps.is_empty() {
443                Ok(Vec::new())
444            } else {
445                Ok(vec![sorted_swaps[0].0])
446            }
447        }
448    }
449
450    /// Get the front layer of executable gates
451    fn get_front_layer(&self, dag: &CircuitDag, remaining: &HashSet<usize>) -> HashSet<usize> {
452        let mut front_layer = HashSet::new();
453
454        for &gate_id in remaining {
455            let node = &dag.nodes()[gate_id];
456
457            // Check if all predecessors are executed
458            let ready = node
459                .predecessors
460                .iter()
461                .all(|&pred| !remaining.contains(&pred));
462
463            if ready {
464                front_layer.insert(gate_id);
465            }
466        }
467
468        front_layer
469    }
470
471    /// Get extended set for lookahead
472    fn get_extended_set(&self, dag: &CircuitDag, front_layer: &HashSet<usize>) -> HashSet<usize> {
473        let mut extended_set = front_layer.clone();
474        let mut to_visit = VecDeque::new();
475
476        for &gate_id in front_layer {
477            to_visit.push_back((gate_id, 0));
478        }
479
480        while let Some((gate_id, depth)) = to_visit.pop_front() {
481            if depth >= self.config.lookahead_depth {
482                continue;
483            }
484
485            let node = &dag.nodes()[gate_id];
486            for &succ in &node.successors {
487                if extended_set.insert(succ) {
488                    to_visit.push_back((succ, depth + 1));
489                }
490            }
491        }
492
493        extended_set
494    }
495
496    /// Get currently mapped physical qubits
497    fn get_mapped_physical_qubits(&self, mapping: &HashMap<usize, usize>) -> Vec<usize> {
498        mapping.values().copied().collect()
499    }
500
501    /// Calculate score for a SWAP operation
502    fn calculate_swap_score(
503        &self,
504        dag: &CircuitDag,
505        swap: (usize, usize),
506        front_layer: &HashSet<usize>,
507        extended_set: &HashSet<usize>,
508        mapping: &HashMap<usize, usize>,
509    ) -> f64 {
510        // Create temporary mapping with the SWAP applied
511        let mut temp_mapping = mapping.clone();
512        let (p1, p2) = swap;
513
514        // Find logical qubits mapped to these physical qubits
515        let mut l1_opt = None;
516        let mut l2_opt = None;
517
518        for (&logical, &physical) in mapping {
519            if physical == p1 {
520                l1_opt = Some(logical);
521            } else if physical == p2 {
522                l2_opt = Some(logical);
523            }
524        }
525
526        if let (Some(l1), Some(l2)) = (l1_opt, l2_opt) {
527            temp_mapping.insert(l1, p2);
528            temp_mapping.insert(l2, p1);
529        } else {
530            return -1.0; // Invalid SWAP
531        }
532
533        // Count newly executable gates in front layer with the updated mapping.
534        // We also consider extended-set gates with a reduced weight to encourage
535        // SWAPs that unlock future gates, not just immediate ones.
536        let front_newly_executable = front_layer
537            .iter()
538            .filter(|&&gate_id| {
539                let node = &dag.nodes()[gate_id];
540                self.is_gate_executable(node, &temp_mapping)
541            })
542            .count() as f64;
543
544        let extended_newly_executable = extended_set
545            .iter()
546            .filter(|&&gate_id| {
547                if front_layer.contains(&gate_id) {
548                    return false; // Already counted
549                }
550                let node = &dag.nodes()[gate_id];
551                self.is_gate_executable(node, &temp_mapping)
552            })
553            .count() as f64;
554
555        // Score = front-layer executable gates + decay-weighted extended-set gains.
556        // Subtract 1.0 per inserted SWAP (normalised by |front_layer|) to penalise overhead.
557        let front_size = front_layer.len().max(1) as f64;
558        let raw_score = (front_newly_executable / front_size)
559            + self.config.extended_set_weight * (extended_newly_executable / front_size);
560
561        // Apply decay penalty based on how far p1 and p2 are from the front-layer qubits
562        // (discourages SWAPs on idle qubits).
563        let decay = 1.0 - self.config.decay_factor;
564        raw_score * decay
565    }
566
567    /// Calculate circuit depth
568    fn calculate_depth(&self, gates: &[Box<dyn GateOp>]) -> usize {
569        // Simplified depth calculation
570        // In practice, would need to track dependencies properly
571        gates.len()
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use quantrs2_core::gate::{multi::CNOT, single::Hadamard};
579
580    #[test]
581    fn test_sabre_basic() {
582        let coupling_map = CouplingMap::linear(3);
583        let config = SabreConfig::basic();
584        let router = SabreRouter::new(coupling_map, config);
585
586        let mut circuit = Circuit::<3>::new();
587        circuit
588            .add_gate(Hadamard { target: QubitId(0) })
589            .expect("add H gate to circuit");
590        circuit
591            .add_gate(CNOT {
592                control: QubitId(0),
593                target: QubitId(2),
594            })
595            .expect("add CNOT gate to circuit");
596
597        let result = router.route(&circuit);
598        assert!(result.is_ok());
599    }
600
601    #[test]
602    fn test_initial_mapping() {
603        let coupling_map = CouplingMap::linear(5);
604        let config = SabreConfig::default();
605        let router = SabreRouter::new(coupling_map, config);
606
607        let mut circuit = Circuit::<3>::new();
608        circuit
609            .add_gate(CNOT {
610                control: QubitId(0),
611                target: QubitId(1),
612            })
613            .expect("add CNOT gate to circuit");
614
615        let dag = circuit_to_dag(&circuit);
616        let mapping = router.initial_mapping(&dag);
617
618        assert!(mapping.contains_key(&0));
619        assert!(mapping.contains_key(&1));
620    }
621}