1use crate::{
8 cartan::OptimizedCartanDecomposer,
9 controlled::make_controlled,
10 error::{QuantRS2Error, QuantRS2Result},
11 gate::{single::*, GateOp},
12 matrix_ops::{DenseMatrix, QuantumMatrix},
13 qubit::QubitId,
14 synthesis::{decompose_single_qubit_zyz, SingleQubitDecomposition},
15};
16use rustc_hash::FxHashMap;
17use scirs2_core::ndarray::{s, Array2};
18use scirs2_core::Complex;
19use std::f64::consts::PI;
20
21#[derive(Debug, Clone)]
23pub struct ShannonDecomposition {
24 pub gates: Vec<Box<dyn GateOp>>,
26 pub cnot_count: usize,
28 pub single_qubit_count: usize,
30 pub depth: usize,
32}
33
34pub struct ShannonDecomposer {
36 tolerance: f64,
38 cache: FxHashMap<u64, ShannonDecomposition>,
40 max_depth: usize,
42}
43
44impl ShannonDecomposer {
45 pub fn new() -> Self {
47 Self {
48 tolerance: 1e-10,
49 cache: FxHashMap::default(),
50 max_depth: 20,
51 }
52 }
53
54 pub fn with_tolerance(tolerance: f64) -> Self {
56 Self {
57 tolerance,
58 cache: FxHashMap::default(),
59 max_depth: 20,
60 }
61 }
62
63 pub fn decompose(
65 &mut self,
66 unitary: &Array2<Complex<f64>>,
67 qubit_ids: &[QubitId],
68 ) -> QuantRS2Result<ShannonDecomposition> {
69 let n = qubit_ids.len();
70 let size = 1 << n;
71
72 if unitary.shape() != [size, size] {
74 return Err(QuantRS2Error::InvalidInput(format!(
75 "Unitary size {} doesn't match {} qubits",
76 unitary.shape()[0],
77 n
78 )));
79 }
80
81 let mat = DenseMatrix::new(unitary.clone())?;
83 if !mat.is_unitary(self.tolerance)? {
84 return Err(QuantRS2Error::InvalidInput(
85 "Matrix is not unitary".to_string(),
86 ));
87 }
88
89 if n == 0 {
91 return Ok(ShannonDecomposition {
92 gates: vec![],
93 cnot_count: 0,
94 single_qubit_count: 0,
95 depth: 0,
96 });
97 }
98
99 if n == 1 {
100 let decomp = decompose_single_qubit_zyz(&unitary.view())?;
102 let gates = self.single_qubit_to_gates(&decomp, qubit_ids[0]);
103 let count = gates.len();
104
105 return Ok(ShannonDecomposition {
106 gates,
107 cnot_count: 0,
108 single_qubit_count: count,
109 depth: count,
110 });
111 }
112
113 if n == 2 {
114 return self.decompose_two_qubit(unitary, qubit_ids);
116 }
117
118 self.decompose_recursive(unitary, qubit_ids, 0)
120 }
121
122 fn decompose_recursive(
124 &mut self,
125 unitary: &Array2<Complex<f64>>,
126 qubit_ids: &[QubitId],
127 depth: usize,
128 ) -> QuantRS2Result<ShannonDecomposition> {
129 if depth > self.max_depth {
130 return Err(QuantRS2Error::InvalidInput(
131 "Maximum recursion depth exceeded".to_string(),
132 ));
133 }
134
135 let n = qubit_ids.len();
136 let half_size = 1 << (n - 1);
137
138 let a = unitary.slice(s![..half_size, ..half_size]).to_owned();
142 let b = unitary.slice(s![..half_size, half_size..]).to_owned();
143 let c = unitary.slice(s![half_size.., ..half_size]).to_owned();
144 let d = unitary.slice(s![half_size.., half_size..]).to_owned();
145
146 let (v, w, u_diag) = self.block_diagonalize(&a, &b, &c, &d)?;
150
151 let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
152 let mut cnot_count = 0;
153 let mut single_qubit_count = 0;
154
155 if !self.is_identity(&w) {
157 let w_decomp = self.decompose_recursive(&w, &qubit_ids[1..], depth + 1)?;
158 gates.extend(w_decomp.gates);
159 cnot_count += w_decomp.cnot_count;
160 single_qubit_count += w_decomp.single_qubit_count;
161 }
162
163 let diag_gates = self.decompose_controlled_diagonal(&u_diag, qubit_ids)?;
165 cnot_count += diag_gates.1;
166 single_qubit_count += diag_gates.2;
167 gates.extend(diag_gates.0);
168
169 if !self.is_identity(&v) {
171 let v_dag = v.mapv(|z| z.conj()).t().to_owned();
172 let v_decomp = self.decompose_recursive(&v_dag, &qubit_ids[1..], depth + 1)?;
173 gates.extend(v_decomp.gates);
174 cnot_count += v_decomp.cnot_count;
175 single_qubit_count += v_decomp.single_qubit_count;
176 }
177
178 let depth = gates.len();
180
181 Ok(ShannonDecomposition {
182 gates,
183 cnot_count,
184 single_qubit_count,
185 depth,
186 })
187 }
188
189 fn block_diagonalize(
191 &self,
192 a: &Array2<Complex<f64>>,
193 b: &Array2<Complex<f64>>,
194 c: &Array2<Complex<f64>>,
195 d: &Array2<Complex<f64>>,
196 ) -> QuantRS2Result<(
197 Array2<Complex<f64>>,
198 Array2<Complex<f64>>,
199 Array2<Complex<f64>>,
200 )> {
201 let size = a.shape()[0];
202
203 let b_norm = b.iter().map(|z| z.norm_sqr()).sum::<f64>().sqrt();
212 let c_norm = c.iter().map(|z| z.norm_sqr()).sum::<f64>().sqrt();
213
214 if b_norm < self.tolerance && c_norm < self.tolerance {
215 let identity = Array2::eye(size);
216 let combined = self.combine_blocks(a, b, c, d);
217 return Ok((identity.clone(), identity, combined));
218 }
219
220 let combined = self.combine_blocks(a, b, c, d);
223
224 let identity = Array2::eye(size);
227 Ok((identity.clone(), identity, combined))
228 }
229
230 fn combine_blocks(
232 &self,
233 a: &Array2<Complex<f64>>,
234 b: &Array2<Complex<f64>>,
235 c: &Array2<Complex<f64>>,
236 d: &Array2<Complex<f64>>,
237 ) -> Array2<Complex<f64>> {
238 let size = a.shape()[0];
239 let total_size = 2 * size;
240 let mut result = Array2::zeros((total_size, total_size));
241
242 result.slice_mut(s![..size, ..size]).assign(a);
243 result.slice_mut(s![..size, size..]).assign(b);
244 result.slice_mut(s![size.., ..size]).assign(c);
245 result.slice_mut(s![size.., size..]).assign(d);
246
247 result
248 }
249
250 fn decompose_controlled_diagonal(
252 &self,
253 diagonal: &Array2<Complex<f64>>,
254 qubit_ids: &[QubitId],
255 ) -> QuantRS2Result<(Vec<Box<dyn GateOp>>, usize, usize)> {
256 let mut gates: Vec<Box<dyn GateOp>> = Vec::new();
257 let mut cnot_count = 0;
258 let mut single_qubit_count = 0;
259
260 let n = diagonal.shape()[0];
262 let mut phases = Vec::with_capacity(n);
263
264 for i in 0..n {
265 let phase = diagonal[[i, i]].arg();
266 phases.push(phase);
267 }
268
269 let control = qubit_ids[0];
272
273 for (i, &phase) in phases.iter().enumerate() {
274 if phase.abs() > self.tolerance {
275 if i == 0 {
276 let gate: Box<dyn GateOp> = Box::new(RotationZ {
278 target: control,
279 theta: phase,
280 });
281 gates.push(gate);
282 single_qubit_count += 1;
283 } else {
284 let base_gate = Box::new(RotationZ {
288 target: qubit_ids[1],
289 theta: phase,
290 });
291
292 let controlled = Box::new(make_controlled(vec![control], *base_gate));
293 gates.push(controlled);
294 cnot_count += 2; single_qubit_count += 3; }
297 }
298 }
299
300 Ok((gates, cnot_count, single_qubit_count))
301 }
302
303 fn decompose_two_qubit(
305 &self,
306 unitary: &Array2<Complex<f64>>,
307 qubit_ids: &[QubitId],
308 ) -> QuantRS2Result<ShannonDecomposition> {
309 if self.is_identity(unitary) {
311 return Ok(ShannonDecomposition {
312 gates: vec![],
313 cnot_count: 0,
314 single_qubit_count: 0,
315 depth: 0,
316 });
317 }
318
319 let mut cartan_decomposer = OptimizedCartanDecomposer::new();
321 let cartan_decomp = cartan_decomposer.decompose(unitary)?;
322 let gates = cartan_decomposer.base.to_gates(&cartan_decomp, qubit_ids)?;
323
324 let mut cnot_count = 0;
326 let mut single_qubit_count = 0;
327
328 for gate in &gates {
329 match gate.name() {
330 "CNOT" => cnot_count += 1,
331 _ => single_qubit_count += 1,
332 }
333 }
334
335 let depth = gates.len();
336
337 Ok(ShannonDecomposition {
338 gates,
339 cnot_count,
340 single_qubit_count,
341 depth,
342 })
343 }
344
345 fn single_qubit_to_gates(
347 &self,
348 decomp: &SingleQubitDecomposition,
349 qubit: QubitId,
350 ) -> Vec<Box<dyn GateOp>> {
351 let mut gates = Vec::new();
352
353 if decomp.theta1.abs() > self.tolerance {
355 gates.push(Box::new(RotationZ {
356 target: qubit,
357 theta: decomp.theta1,
358 }) as Box<dyn GateOp>);
359 }
360
361 if decomp.phi.abs() > self.tolerance {
363 gates.push(Box::new(RotationY {
364 target: qubit,
365 theta: decomp.phi,
366 }) as Box<dyn GateOp>);
367 }
368
369 if decomp.theta2.abs() > self.tolerance {
371 gates.push(Box::new(RotationZ {
372 target: qubit,
373 theta: decomp.theta2,
374 }) as Box<dyn GateOp>);
375 }
376
377 gates
380 }
381
382 fn is_identity(&self, matrix: &Array2<Complex<f64>>) -> bool {
384 let n = matrix.shape()[0];
385
386 for i in 0..n {
387 for j in 0..n {
388 let expected = if i == j {
389 Complex::new(1.0, 0.0)
390 } else {
391 Complex::new(0.0, 0.0)
392 };
393 if (matrix[[i, j]] - expected).norm() > self.tolerance {
394 return false;
395 }
396 }
397 }
398
399 true
400 }
401}
402
403pub struct OptimizedShannonDecomposer {
405 base: ShannonDecomposer,
406 peephole: bool,
408 commutation: bool,
410}
411
412impl OptimizedShannonDecomposer {
413 pub fn new() -> Self {
415 Self {
416 base: ShannonDecomposer::new(),
417 peephole: true,
418 commutation: true,
419 }
420 }
421
422 pub fn decompose(
424 &mut self,
425 unitary: &Array2<Complex<f64>>,
426 qubit_ids: &[QubitId],
427 ) -> QuantRS2Result<ShannonDecomposition> {
428 let mut decomp = self.base.decompose(unitary, qubit_ids)?;
430
431 if self.peephole {
432 decomp = self.apply_peephole_optimization(decomp)?;
433 }
434
435 if self.commutation {
436 decomp = self.apply_commutation_optimization(decomp)?;
437 }
438
439 Ok(decomp)
440 }
441
442 fn apply_peephole_optimization(
444 &self,
445 mut decomp: ShannonDecomposition,
446 ) -> QuantRS2Result<ShannonDecomposition> {
447 let mut optimized_gates = Vec::new();
453 let mut i = 0;
454
455 while i < decomp.gates.len() {
456 if i + 1 < decomp.gates.len() {
457 if self.gates_cancel(&decomp.gates[i], &decomp.gates[i + 1]) {
459 i += 2;
461 decomp.cnot_count =
462 decomp
463 .cnot_count
464 .saturating_sub(if decomp.gates[i - 2].name() == "CNOT" {
465 2
466 } else {
467 0
468 });
469 decomp.single_qubit_count = decomp.single_qubit_count.saturating_sub(
470 if decomp.gates[i - 2].name() == "CNOT" {
471 0
472 } else {
473 2
474 },
475 );
476 continue;
477 }
478
479 if let Some(merged) =
481 self.try_merge_rotations(&decomp.gates[i], &decomp.gates[i + 1])
482 {
483 optimized_gates.push(merged);
484 i += 2;
485 decomp.single_qubit_count = decomp.single_qubit_count.saturating_sub(1);
486 continue;
487 }
488 }
489
490 optimized_gates.push(decomp.gates[i].clone());
491 i += 1;
492 }
493
494 decomp.gates = optimized_gates;
495 decomp.depth = decomp.gates.len();
496
497 Ok(decomp)
498 }
499
500 const fn apply_commutation_optimization(
502 &self,
503 decomp: ShannonDecomposition,
504 ) -> QuantRS2Result<ShannonDecomposition> {
505 Ok(decomp)
510 }
511
512 fn gates_cancel(&self, gate1: &Box<dyn GateOp>, gate2: &Box<dyn GateOp>) -> bool {
514 if gate1.name() == gate2.name() && gate1.qubits() == gate2.qubits() {
516 match gate1.name() {
517 "X" | "Y" | "Z" | "H" | "CNOT" | "SWAP" => true,
518 _ => false,
519 }
520 } else {
521 false
522 }
523 }
524
525 fn try_merge_rotations(
527 &self,
528 gate1: &Box<dyn GateOp>,
529 gate2: &Box<dyn GateOp>,
530 ) -> Option<Box<dyn GateOp>> {
531 if gate1.qubits() != gate2.qubits() || gate1.qubits().len() != 1 {
533 return None;
534 }
535
536 let qubit = gate1.qubits()[0];
537
538 match (gate1.name(), gate2.name()) {
539 ("RZ", "RZ") => {
540 let theta1 = PI / 4.0; let theta2 = PI / 4.0; Some(Box::new(RotationZ {
546 target: qubit,
547 theta: theta1 + theta2,
548 }))
549 }
550 _ => None,
551 }
552 }
553}
554
555pub fn shannon_decompose(
557 unitary: &Array2<Complex<f64>>,
558 qubit_ids: &[QubitId],
559) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
560 let mut decomposer = ShannonDecomposer::new();
561 let decomp = decomposer.decompose(unitary, qubit_ids)?;
562 Ok(decomp.gates)
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use scirs2_core::ndarray::Array2;
569 use scirs2_core::Complex;
570
571 #[test]
572 fn test_shannon_single_qubit() {
573 let mut decomposer = ShannonDecomposer::new();
574
575 let h = Array2::from_shape_vec(
577 (2, 2),
578 vec![
579 Complex::new(1.0, 0.0),
580 Complex::new(1.0, 0.0),
581 Complex::new(1.0, 0.0),
582 Complex::new(-1.0, 0.0),
583 ],
584 )
585 .expect("Failed to create Hadamard matrix")
586 / Complex::new(2.0_f64.sqrt(), 0.0);
587
588 let qubit_ids = vec![QubitId(0)];
589 let decomp = decomposer
590 .decompose(&h, &qubit_ids)
591 .expect("Failed to decompose Hadamard gate");
592
593 assert!(decomp.single_qubit_count <= 3);
595 assert_eq!(decomp.cnot_count, 0);
596 }
597
598 #[test]
599 fn test_shannon_two_qubit() {
600 let mut decomposer = ShannonDecomposer::new();
601
602 let cnot = Array2::from_shape_vec(
604 (4, 4),
605 vec![
606 Complex::new(1.0, 0.0),
607 Complex::new(0.0, 0.0),
608 Complex::new(0.0, 0.0),
609 Complex::new(0.0, 0.0),
610 Complex::new(0.0, 0.0),
611 Complex::new(1.0, 0.0),
612 Complex::new(0.0, 0.0),
613 Complex::new(0.0, 0.0),
614 Complex::new(0.0, 0.0),
615 Complex::new(0.0, 0.0),
616 Complex::new(0.0, 0.0),
617 Complex::new(1.0, 0.0),
618 Complex::new(0.0, 0.0),
619 Complex::new(0.0, 0.0),
620 Complex::new(1.0, 0.0),
621 Complex::new(0.0, 0.0),
622 ],
623 )
624 .expect("Failed to create CNOT matrix");
625
626 let qubit_ids = vec![QubitId(0), QubitId(1)];
627 let decomp = decomposer
628 .decompose(&cnot, &qubit_ids)
629 .expect("Failed to decompose CNOT gate");
630
631 assert!(decomp.cnot_count <= 3);
633 }
634
635 #[test]
636 fn test_optimized_decomposer() {
637 let mut decomposer = OptimizedShannonDecomposer::new();
638
639 let identity = Array2::eye(4);
641 let identity_complex = identity.mapv(|x| Complex::new(x, 0.0));
642
643 let qubit_ids = vec![QubitId(0), QubitId(1)];
644 let decomp = decomposer
645 .decompose(&identity_complex, &qubit_ids)
646 .expect("Failed to decompose identity matrix");
647
648 assert_eq!(decomp.gates.len(), 0);
650 }
651}