quantrs2_circuit/
commutation.rs

1//! Commutation analysis for quantum gate reordering.
2//!
3//! This module provides functionality to analyze which quantum gates commute
4//! with each other, enabling optimizations like gate reordering and parallelization.
5
6use scirs2_core::ndarray::Array2;
7use scirs2_core::Complex64;
8use std::collections::{HashMap, HashSet};
9
10use quantrs2_core::gate::GateOp;
11use quantrs2_core::qubit::QubitId;
12
13/// Type of gate for commutation analysis
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum GateType {
16    /// Single-qubit X rotation
17    Rx(String), // parameter as string for hashing
18    /// Single-qubit Y rotation
19    Ry(String),
20    /// Single-qubit Z rotation
21    Rz(String),
22    /// Hadamard gate
23    H,
24    /// Pauli-X gate
25    X,
26    /// Pauli-Y gate
27    Y,
28    /// Pauli-Z gate
29    Z,
30    /// Phase gate
31    S,
32    /// T gate
33    T,
34    /// CNOT gate
35    CNOT,
36    /// CZ gate
37    CZ,
38    /// SWAP gate
39    SWAP,
40    /// Toffoli gate
41    Toffoli,
42    /// Measurement
43    Measure,
44    /// Custom gate
45    Custom(String),
46}
47
48/// Result of commutation check
49#[derive(Debug, Clone, PartialEq)]
50pub enum CommutationResult {
51    /// Gates commute exactly
52    Commute,
53    /// Gates anti-commute (commute up to a phase)
54    AntiCommute(Complex64),
55    /// Gates don't commute
56    NonCommute,
57    /// Gates commute under certain conditions
58    ConditionalCommute(String),
59}
60
61/// Commutation rules database
62pub struct CommutationRules {
63    /// Cached commutation results
64    cache: HashMap<(GateType, GateType), CommutationResult>,
65    /// Custom commutation rules
66    custom_rules: HashMap<(String, String), CommutationResult>,
67}
68
69impl CommutationRules {
70    /// Create a new commutation rules database with standard rules
71    #[must_use]
72    pub fn new() -> Self {
73        let mut rules = Self {
74            cache: HashMap::new(),
75            custom_rules: HashMap::new(),
76        };
77        rules.initialize_standard_rules();
78        rules
79    }
80
81    /// Initialize standard commutation rules
82    fn initialize_standard_rules(&mut self) {
83        use CommutationResult::{Commute, ConditionalCommute, NonCommute};
84        use GateType::{Measure, Rz, CNOT, CZ, H, S, T, X, Y, Z};
85
86        // Pauli commutation rules
87        self.add_rule(X, X, Commute);
88        self.add_rule(Y, Y, Commute);
89        self.add_rule(Z, Z, Commute);
90        self.add_rule(X, Y, NonCommute);
91        self.add_rule(X, Z, NonCommute);
92        self.add_rule(Y, Z, NonCommute);
93
94        // Hadamard commutation
95        self.add_rule(H, H, Commute);
96        self.add_rule(H, X, NonCommute);
97        self.add_rule(H, Y, NonCommute);
98        self.add_rule(H, Z, NonCommute);
99
100        // Phase gates
101        self.add_rule(S, S, Commute);
102        self.add_rule(T, T, Commute);
103        self.add_rule(S, T, Commute);
104        self.add_rule(S, Z, Commute);
105        self.add_rule(T, Z, Commute);
106
107        // Z-basis rotations commute
108        self.add_rule(Z, Rz("any".to_string()), Commute);
109        self.add_rule(S, Rz("any".to_string()), Commute);
110        self.add_rule(T, Rz("any".to_string()), Commute);
111        self.add_rule(Rz("any1".to_string()), Rz("any2".to_string()), Commute);
112
113        // CNOT commutation rules
114        self.add_rule(
115            CNOT,
116            CNOT,
117            ConditionalCommute("Same control and target".to_string()),
118        );
119        self.add_rule(CZ, CZ, ConditionalCommute("Same qubits".to_string()));
120
121        // Measurements don't commute with most gates
122        self.add_rule(Measure, X, NonCommute);
123        self.add_rule(Measure, Y, NonCommute);
124        self.add_rule(Measure, H, NonCommute);
125        self.add_rule(Measure, Z, Commute); // Z-basis measurement commutes with Z
126    }
127
128    /// Add a commutation rule
129    pub fn add_rule(&mut self, gate1: GateType, gate2: GateType, result: CommutationResult) {
130        self.cache
131            .insert((gate1.clone(), gate2.clone()), result.clone());
132        // Commutation is symmetric for most cases
133        if matches!(
134            result,
135            CommutationResult::Commute | CommutationResult::NonCommute
136        ) {
137            self.cache.insert((gate2, gate1), result);
138        }
139    }
140
141    /// Add a custom commutation rule
142    pub fn add_custom_rule(&mut self, gate1: String, gate2: String, result: CommutationResult) {
143        self.custom_rules
144            .insert((gate1.clone(), gate2.clone()), result.clone());
145        if matches!(
146            result,
147            CommutationResult::Commute | CommutationResult::NonCommute
148        ) {
149            self.custom_rules.insert((gate2, gate1), result);
150        }
151    }
152
153    /// Check if two gate types commute
154    #[must_use]
155    pub fn check_commutation(&self, gate1: &GateType, gate2: &GateType) -> CommutationResult {
156        // Check cache first
157        if let Some(result) = self.cache.get(&(gate1.clone(), gate2.clone())) {
158            return result.clone();
159        }
160
161        // Check custom rules
162        if let (GateType::Custom(name1), GateType::Custom(name2)) = (gate1, gate2) {
163            if let Some(result) = self.custom_rules.get(&(name1.clone(), name2.clone())) {
164                return result.clone();
165            }
166        }
167
168        // Default: assume non-commuting
169        CommutationResult::NonCommute
170    }
171}
172
173impl Default for CommutationRules {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179/// Analyzer for gate commutation in circuits
180pub struct CommutationAnalyzer {
181    rules: CommutationRules,
182}
183
184impl CommutationAnalyzer {
185    /// Create a new commutation analyzer
186    #[must_use]
187    pub fn new() -> Self {
188        Self {
189            rules: CommutationRules::new(),
190        }
191    }
192
193    /// Create with custom rules
194    #[must_use]
195    pub const fn with_rules(rules: CommutationRules) -> Self {
196        Self { rules }
197    }
198
199    /// Convert a gate operation to a gate type
200    pub fn gate_to_type(gate: &dyn GateOp) -> GateType {
201        match gate.name() {
202            "H" => GateType::H,
203            "X" => GateType::X,
204            "Y" => GateType::Y,
205            "Z" => GateType::Z,
206            "S" => GateType::S,
207            "T" => GateType::T,
208            "RX" => GateType::Rx("generic".to_string()),
209            "RY" => GateType::Ry("generic".to_string()),
210            "RZ" => GateType::Rz("generic".to_string()),
211            "CNOT" => GateType::CNOT,
212            "CZ" => GateType::CZ,
213            "SWAP" => GateType::SWAP,
214            "Toffoli" => GateType::Toffoli,
215            "Measure" => GateType::Measure,
216            name => GateType::Custom(name.to_string()),
217        }
218    }
219
220    /// Check if two gates commute considering their qubit assignments
221    pub fn gates_commute(&self, gate1: &dyn GateOp, gate2: &dyn GateOp) -> bool {
222        let qubits1: HashSet<_> = gate1
223            .qubits()
224            .iter()
225            .map(quantrs2_core::QubitId::id)
226            .collect();
227        let qubits2: HashSet<_> = gate2
228            .qubits()
229            .iter()
230            .map(quantrs2_core::QubitId::id)
231            .collect();
232
233        // Gates on disjoint qubits always commute
234        if qubits1.is_disjoint(&qubits2) {
235            return true;
236        }
237
238        // Check gate types
239        let type1 = Self::gate_to_type(gate1);
240        let type2 = Self::gate_to_type(gate2);
241
242        match self.rules.check_commutation(&type1, &type2) {
243            CommutationResult::Commute | CommutationResult::AntiCommute(_) => true, // Commute (with or without phase)
244            CommutationResult::NonCommute => false,
245            CommutationResult::ConditionalCommute(condition) => {
246                // Check specific conditions
247                self.check_conditional_commutation(gate1, gate2, &condition)
248            }
249        }
250    }
251
252    /// Check conditional commutation
253    fn check_conditional_commutation(
254        &self,
255        gate1: &dyn GateOp,
256        gate2: &dyn GateOp,
257        condition: &str,
258    ) -> bool {
259        match condition {
260            "Same control and target" => {
261                // For CNOT gates
262                if gate1.name() == "CNOT" && gate2.name() == "CNOT" {
263                    let qubits1 = gate1.qubits();
264                    let qubits2 = gate2.qubits();
265                    return qubits1[0] == qubits2[0] && qubits1[1] == qubits2[1];
266                }
267                false
268            }
269            "Same qubits" => {
270                // Check if gates operate on exactly the same qubits
271                let qubits1: HashSet<_> = gate1
272                    .qubits()
273                    .iter()
274                    .map(quantrs2_core::QubitId::id)
275                    .collect();
276                let qubits2: HashSet<_> = gate2
277                    .qubits()
278                    .iter()
279                    .map(quantrs2_core::QubitId::id)
280                    .collect();
281                qubits1 == qubits2
282            }
283            _ => false,
284        }
285    }
286
287    /// Find all gates that commute with a given gate in a list
288    pub fn find_commuting_gates(
289        &self,
290        target_gate: &dyn GateOp,
291        gates: &[Box<dyn GateOp>],
292    ) -> Vec<usize> {
293        gates
294            .iter()
295            .enumerate()
296            .filter(|(_, gate)| self.gates_commute(target_gate, gate.as_ref()))
297            .map(|(idx, _)| idx)
298            .collect()
299    }
300
301    /// Build a commutation matrix for a list of gates
302    #[must_use]
303    pub fn build_commutation_matrix(&self, gates: &[Box<dyn GateOp>]) -> Array2<bool> {
304        let n = gates.len();
305        let mut matrix = Array2::from_elem((n, n), false);
306
307        for i in 0..n {
308            for j in 0..n {
309                if i == j {
310                    matrix[[i, j]] = true; // Gate commutes with itself
311                } else {
312                    matrix[[i, j]] = self.gates_commute(gates[i].as_ref(), gates[j].as_ref());
313                }
314            }
315        }
316
317        matrix
318    }
319
320    /// Find independent gate sets that can be executed in parallel
321    #[must_use]
322    pub fn find_parallel_sets(&self, gates: &[Box<dyn GateOp>]) -> Vec<Vec<usize>> {
323        let n = gates.len();
324        let mut remaining: HashSet<usize> = (0..n).collect();
325        let mut parallel_sets = Vec::new();
326
327        while !remaining.is_empty() {
328            let mut current_set = Vec::new();
329            let mut current_qubits = HashSet::new();
330
331            let mut indices_to_check: Vec<usize> = remaining.iter().copied().collect();
332            indices_to_check.sort_unstable(); // Process in order for deterministic results
333
334            for idx in indices_to_check {
335                let gate_qubits: HashSet<_> = gates[idx]
336                    .qubits()
337                    .iter()
338                    .map(quantrs2_core::QubitId::id)
339                    .collect();
340
341                // Check if this gate can be added to current set
342                let can_add = if current_set.is_empty() {
343                    true
344                } else if !current_qubits.is_disjoint(&gate_qubits) {
345                    false
346                } else {
347                    // Check commutation with all gates in current set
348                    current_set.iter().all(|&other_idx| {
349                        let gate1: &Box<dyn GateOp> = &gates[idx];
350                        let gate2: &Box<dyn GateOp> = &gates[other_idx];
351                        self.gates_commute(gate1.as_ref(), gate2.as_ref())
352                    })
353                };
354
355                if can_add {
356                    current_set.push(idx);
357                    current_qubits.extend(gate_qubits);
358                    remaining.remove(&idx);
359                }
360            }
361
362            if !current_set.is_empty() {
363                parallel_sets.push(current_set);
364            }
365        }
366
367        parallel_sets
368    }
369}
370
371impl Default for CommutationAnalyzer {
372    fn default() -> Self {
373        Self::new()
374    }
375}
376
377/// Extension methods for circuit optimization using commutation
378pub trait CommutationOptimization {
379    /// Reorder gates to maximize parallelism
380    fn optimize_gate_order(&mut self, analyzer: &CommutationAnalyzer);
381
382    /// Group commuting gates together
383    fn group_commuting_gates(&mut self, analyzer: &CommutationAnalyzer);
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use quantrs2_core::gate::multi::CNOT;
390    use quantrs2_core::gate::single::{Hadamard, PauliX, PauliZ};
391
392    #[test]
393    fn test_basic_commutation() {
394        let analyzer = CommutationAnalyzer::new();
395
396        // Test Pauli commutation
397        let x1 = PauliX { target: QubitId(0) };
398        let x2 = PauliX { target: QubitId(0) };
399        let z = PauliZ { target: QubitId(0) };
400
401        assert!(analyzer.gates_commute(&x1, &x2)); // X commutes with X
402        assert!(!analyzer.gates_commute(&x1, &z)); // X doesn't commute with Z
403    }
404
405    #[test]
406    fn test_disjoint_qubits() {
407        let analyzer = CommutationAnalyzer::new();
408
409        // Gates on different qubits always commute
410        let h0 = Hadamard { target: QubitId(0) };
411        let h1 = Hadamard { target: QubitId(1) };
412
413        assert!(analyzer.gates_commute(&h0, &h1));
414    }
415
416    #[test]
417    fn test_cnot_commutation() {
418        let analyzer = CommutationAnalyzer::new();
419
420        // Same CNOT gates commute
421        let cnot1 = CNOT {
422            control: QubitId(0),
423            target: QubitId(1),
424        };
425        let cnot2 = CNOT {
426            control: QubitId(0),
427            target: QubitId(1),
428        };
429        assert!(analyzer.gates_commute(&cnot1, &cnot2));
430
431        // Different CNOT gates may not commute
432        let cnot3 = CNOT {
433            control: QubitId(1),
434            target: QubitId(0),
435        };
436        assert!(!analyzer.gates_commute(&cnot1, &cnot3));
437    }
438
439    #[test]
440    fn test_commutation_matrix() {
441        let analyzer = CommutationAnalyzer::new();
442
443        let gates: Vec<Box<dyn GateOp>> = vec![
444            Box::new(Hadamard { target: QubitId(0) }),
445            Box::new(Hadamard { target: QubitId(1) }),
446            Box::new(PauliX { target: QubitId(0) }),
447        ];
448
449        let matrix = analyzer.build_commutation_matrix(&gates);
450
451        // Check expected commutations
452        assert!(matrix[[0, 0]]); // H0 with itself
453        assert!(matrix[[0, 1]]); // H0 with H1 (different qubits)
454        assert!(!matrix[[0, 2]]); // H0 with X0 (don't commute)
455    }
456
457    #[test]
458    fn test_parallel_sets() {
459        let analyzer = CommutationAnalyzer::new();
460
461        let gates: Vec<Box<dyn GateOp>> = vec![
462            Box::new(Hadamard { target: QubitId(0) }),
463            Box::new(Hadamard { target: QubitId(1) }),
464            Box::new(Hadamard { target: QubitId(2) }),
465            Box::new(CNOT {
466                control: QubitId(0),
467                target: QubitId(1),
468            }),
469        ];
470
471        let parallel_sets = analyzer.find_parallel_sets(&gates);
472
473        // First three H gates can be parallel
474        assert_eq!(parallel_sets.len(), 2);
475        assert_eq!(parallel_sets[0].len(), 3); // All H gates
476        assert_eq!(parallel_sets[1].len(), 1); // CNOT alone
477    }
478}