1use scirs2_core::ndarray::Array2;
7use scirs2_core::Complex64;
8use std::collections::{HashMap, HashSet};
9
10use quantrs2_core::gate::GateOp;
11use quantrs2_core::qubit::QubitId;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum GateType {
16 Rx(String), Ry(String),
20 Rz(String),
22 H,
24 X,
26 Y,
28 Z,
30 S,
32 T,
34 CNOT,
36 CZ,
38 SWAP,
40 Toffoli,
42 Measure,
44 Custom(String),
46}
47
48#[derive(Debug, Clone, PartialEq)]
50pub enum CommutationResult {
51 Commute,
53 AntiCommute(Complex64),
55 NonCommute,
57 ConditionalCommute(String),
59}
60
61pub struct CommutationRules {
63 cache: HashMap<(GateType, GateType), CommutationResult>,
65 custom_rules: HashMap<(String, String), CommutationResult>,
67}
68
69impl CommutationRules {
70 #[must_use]
72 pub fn new() -> Self {
73 let mut rules = Self {
74 cache: HashMap::new(),
75 custom_rules: HashMap::new(),
76 };
77 rules.initialize_standard_rules();
78 rules
79 }
80
81 fn initialize_standard_rules(&mut self) {
83 use CommutationResult::{Commute, ConditionalCommute, NonCommute};
84 use GateType::{Measure, Rz, CNOT, CZ, H, S, T, X, Y, Z};
85
86 self.add_rule(X, X, Commute);
88 self.add_rule(Y, Y, Commute);
89 self.add_rule(Z, Z, Commute);
90 self.add_rule(X, Y, NonCommute);
91 self.add_rule(X, Z, NonCommute);
92 self.add_rule(Y, Z, NonCommute);
93
94 self.add_rule(H, H, Commute);
96 self.add_rule(H, X, NonCommute);
97 self.add_rule(H, Y, NonCommute);
98 self.add_rule(H, Z, NonCommute);
99
100 self.add_rule(S, S, Commute);
102 self.add_rule(T, T, Commute);
103 self.add_rule(S, T, Commute);
104 self.add_rule(S, Z, Commute);
105 self.add_rule(T, Z, Commute);
106
107 self.add_rule(Z, Rz("any".to_string()), Commute);
109 self.add_rule(S, Rz("any".to_string()), Commute);
110 self.add_rule(T, Rz("any".to_string()), Commute);
111 self.add_rule(Rz("any1".to_string()), Rz("any2".to_string()), Commute);
112
113 self.add_rule(
115 CNOT,
116 CNOT,
117 ConditionalCommute("Same control and target".to_string()),
118 );
119 self.add_rule(CZ, CZ, ConditionalCommute("Same qubits".to_string()));
120
121 self.add_rule(Measure, X, NonCommute);
123 self.add_rule(Measure, Y, NonCommute);
124 self.add_rule(Measure, H, NonCommute);
125 self.add_rule(Measure, Z, Commute); }
127
128 pub fn add_rule(&mut self, gate1: GateType, gate2: GateType, result: CommutationResult) {
130 self.cache
131 .insert((gate1.clone(), gate2.clone()), result.clone());
132 if matches!(
134 result,
135 CommutationResult::Commute | CommutationResult::NonCommute
136 ) {
137 self.cache.insert((gate2, gate1), result);
138 }
139 }
140
141 pub fn add_custom_rule(&mut self, gate1: String, gate2: String, result: CommutationResult) {
143 self.custom_rules
144 .insert((gate1.clone(), gate2.clone()), result.clone());
145 if matches!(
146 result,
147 CommutationResult::Commute | CommutationResult::NonCommute
148 ) {
149 self.custom_rules.insert((gate2, gate1), result);
150 }
151 }
152
153 #[must_use]
155 pub fn check_commutation(&self, gate1: &GateType, gate2: &GateType) -> CommutationResult {
156 if let Some(result) = self.cache.get(&(gate1.clone(), gate2.clone())) {
158 return result.clone();
159 }
160
161 if let (GateType::Custom(name1), GateType::Custom(name2)) = (gate1, gate2) {
163 if let Some(result) = self.custom_rules.get(&(name1.clone(), name2.clone())) {
164 return result.clone();
165 }
166 }
167
168 CommutationResult::NonCommute
170 }
171}
172
173impl Default for CommutationRules {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179pub struct CommutationAnalyzer {
181 rules: CommutationRules,
182}
183
184impl CommutationAnalyzer {
185 #[must_use]
187 pub fn new() -> Self {
188 Self {
189 rules: CommutationRules::new(),
190 }
191 }
192
193 #[must_use]
195 pub const fn with_rules(rules: CommutationRules) -> Self {
196 Self { rules }
197 }
198
199 pub fn gate_to_type(gate: &dyn GateOp) -> GateType {
201 match gate.name() {
202 "H" => GateType::H,
203 "X" => GateType::X,
204 "Y" => GateType::Y,
205 "Z" => GateType::Z,
206 "S" => GateType::S,
207 "T" => GateType::T,
208 "RX" => GateType::Rx("generic".to_string()),
209 "RY" => GateType::Ry("generic".to_string()),
210 "RZ" => GateType::Rz("generic".to_string()),
211 "CNOT" => GateType::CNOT,
212 "CZ" => GateType::CZ,
213 "SWAP" => GateType::SWAP,
214 "Toffoli" => GateType::Toffoli,
215 "Measure" => GateType::Measure,
216 name => GateType::Custom(name.to_string()),
217 }
218 }
219
220 pub fn gates_commute(&self, gate1: &dyn GateOp, gate2: &dyn GateOp) -> bool {
222 let qubits1: HashSet<_> = gate1
223 .qubits()
224 .iter()
225 .map(quantrs2_core::QubitId::id)
226 .collect();
227 let qubits2: HashSet<_> = gate2
228 .qubits()
229 .iter()
230 .map(quantrs2_core::QubitId::id)
231 .collect();
232
233 if qubits1.is_disjoint(&qubits2) {
235 return true;
236 }
237
238 let type1 = Self::gate_to_type(gate1);
240 let type2 = Self::gate_to_type(gate2);
241
242 match self.rules.check_commutation(&type1, &type2) {
243 CommutationResult::Commute | CommutationResult::AntiCommute(_) => true, CommutationResult::NonCommute => false,
245 CommutationResult::ConditionalCommute(condition) => {
246 self.check_conditional_commutation(gate1, gate2, &condition)
248 }
249 }
250 }
251
252 fn check_conditional_commutation(
254 &self,
255 gate1: &dyn GateOp,
256 gate2: &dyn GateOp,
257 condition: &str,
258 ) -> bool {
259 match condition {
260 "Same control and target" => {
261 if gate1.name() == "CNOT" && gate2.name() == "CNOT" {
263 let qubits1 = gate1.qubits();
264 let qubits2 = gate2.qubits();
265 return qubits1[0] == qubits2[0] && qubits1[1] == qubits2[1];
266 }
267 false
268 }
269 "Same qubits" => {
270 let qubits1: HashSet<_> = gate1
272 .qubits()
273 .iter()
274 .map(quantrs2_core::QubitId::id)
275 .collect();
276 let qubits2: HashSet<_> = gate2
277 .qubits()
278 .iter()
279 .map(quantrs2_core::QubitId::id)
280 .collect();
281 qubits1 == qubits2
282 }
283 _ => false,
284 }
285 }
286
287 pub fn find_commuting_gates(
289 &self,
290 target_gate: &dyn GateOp,
291 gates: &[Box<dyn GateOp>],
292 ) -> Vec<usize> {
293 gates
294 .iter()
295 .enumerate()
296 .filter(|(_, gate)| self.gates_commute(target_gate, gate.as_ref()))
297 .map(|(idx, _)| idx)
298 .collect()
299 }
300
301 #[must_use]
303 pub fn build_commutation_matrix(&self, gates: &[Box<dyn GateOp>]) -> Array2<bool> {
304 let n = gates.len();
305 let mut matrix = Array2::from_elem((n, n), false);
306
307 for i in 0..n {
308 for j in 0..n {
309 if i == j {
310 matrix[[i, j]] = true; } else {
312 matrix[[i, j]] = self.gates_commute(gates[i].as_ref(), gates[j].as_ref());
313 }
314 }
315 }
316
317 matrix
318 }
319
320 #[must_use]
322 pub fn find_parallel_sets(&self, gates: &[Box<dyn GateOp>]) -> Vec<Vec<usize>> {
323 let n = gates.len();
324 let mut remaining: HashSet<usize> = (0..n).collect();
325 let mut parallel_sets = Vec::new();
326
327 while !remaining.is_empty() {
328 let mut current_set = Vec::new();
329 let mut current_qubits = HashSet::new();
330
331 let mut indices_to_check: Vec<usize> = remaining.iter().copied().collect();
332 indices_to_check.sort_unstable(); for idx in indices_to_check {
335 let gate_qubits: HashSet<_> = gates[idx]
336 .qubits()
337 .iter()
338 .map(quantrs2_core::QubitId::id)
339 .collect();
340
341 let can_add = if current_set.is_empty() {
343 true
344 } else if !current_qubits.is_disjoint(&gate_qubits) {
345 false
346 } else {
347 current_set.iter().all(|&other_idx| {
349 let gate1: &Box<dyn GateOp> = &gates[idx];
350 let gate2: &Box<dyn GateOp> = &gates[other_idx];
351 self.gates_commute(gate1.as_ref(), gate2.as_ref())
352 })
353 };
354
355 if can_add {
356 current_set.push(idx);
357 current_qubits.extend(gate_qubits);
358 remaining.remove(&idx);
359 }
360 }
361
362 if !current_set.is_empty() {
363 parallel_sets.push(current_set);
364 }
365 }
366
367 parallel_sets
368 }
369}
370
371impl Default for CommutationAnalyzer {
372 fn default() -> Self {
373 Self::new()
374 }
375}
376
377pub trait CommutationOptimization {
379 fn optimize_gate_order(&mut self, analyzer: &CommutationAnalyzer);
381
382 fn group_commuting_gates(&mut self, analyzer: &CommutationAnalyzer);
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use quantrs2_core::gate::multi::CNOT;
390 use quantrs2_core::gate::single::{Hadamard, PauliX, PauliZ};
391
392 #[test]
393 fn test_basic_commutation() {
394 let analyzer = CommutationAnalyzer::new();
395
396 let x1 = PauliX { target: QubitId(0) };
398 let x2 = PauliX { target: QubitId(0) };
399 let z = PauliZ { target: QubitId(0) };
400
401 assert!(analyzer.gates_commute(&x1, &x2)); assert!(!analyzer.gates_commute(&x1, &z)); }
404
405 #[test]
406 fn test_disjoint_qubits() {
407 let analyzer = CommutationAnalyzer::new();
408
409 let h0 = Hadamard { target: QubitId(0) };
411 let h1 = Hadamard { target: QubitId(1) };
412
413 assert!(analyzer.gates_commute(&h0, &h1));
414 }
415
416 #[test]
417 fn test_cnot_commutation() {
418 let analyzer = CommutationAnalyzer::new();
419
420 let cnot1 = CNOT {
422 control: QubitId(0),
423 target: QubitId(1),
424 };
425 let cnot2 = CNOT {
426 control: QubitId(0),
427 target: QubitId(1),
428 };
429 assert!(analyzer.gates_commute(&cnot1, &cnot2));
430
431 let cnot3 = CNOT {
433 control: QubitId(1),
434 target: QubitId(0),
435 };
436 assert!(!analyzer.gates_commute(&cnot1, &cnot3));
437 }
438
439 #[test]
440 fn test_commutation_matrix() {
441 let analyzer = CommutationAnalyzer::new();
442
443 let gates: Vec<Box<dyn GateOp>> = vec![
444 Box::new(Hadamard { target: QubitId(0) }),
445 Box::new(Hadamard { target: QubitId(1) }),
446 Box::new(PauliX { target: QubitId(0) }),
447 ];
448
449 let matrix = analyzer.build_commutation_matrix(&gates);
450
451 assert!(matrix[[0, 0]]); assert!(matrix[[0, 1]]); assert!(!matrix[[0, 2]]); }
456
457 #[test]
458 fn test_parallel_sets() {
459 let analyzer = CommutationAnalyzer::new();
460
461 let gates: Vec<Box<dyn GateOp>> = vec![
462 Box::new(Hadamard { target: QubitId(0) }),
463 Box::new(Hadamard { target: QubitId(1) }),
464 Box::new(Hadamard { target: QubitId(2) }),
465 Box::new(CNOT {
466 control: QubitId(0),
467 target: QubitId(1),
468 }),
469 ];
470
471 let parallel_sets = analyzer.find_parallel_sets(&gates);
472
473 assert_eq!(parallel_sets.len(), 2);
475 assert_eq!(parallel_sets[0].len(), 3); assert_eq!(parallel_sets[1].len(), 1); }
478}