1use scirs2_core::parallel_ops::{
9 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
10};
11use scirs2_core::Complex64;
12use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
13
14use quantrs2_core::{
15 error::{QuantRS2Error, QuantRS2Result},
16 gate::GateOp,
17 qubit::QubitId,
18};
19
20pub trait SpecializedGate: GateOp {
22 fn apply_specialized(
24 &self,
25 state: &mut [Complex64],
26 n_qubits: usize,
27 parallel: bool,
28 ) -> QuantRS2Result<()>;
29
30 fn can_fuse_with(&self, other: &dyn SpecializedGate) -> bool {
32 false
33 }
34
35 fn fuse_with(&self, other: &dyn SpecializedGate) -> Option<Box<dyn SpecializedGate>> {
37 None
38 }
39}
40
41#[derive(Debug, Clone, Copy)]
45pub struct HadamardSpecialized {
46 pub target: QubitId,
47}
48
49impl SpecializedGate for HadamardSpecialized {
50 fn apply_specialized(
51 &self,
52 state: &mut [Complex64],
53 n_qubits: usize,
54 parallel: bool,
55 ) -> QuantRS2Result<()> {
56 let target_idx = self.target.id() as usize;
57 if target_idx >= n_qubits {
58 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
59 }
60
61 let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
62
63 if parallel {
64 let state_copy = state.to_vec();
65 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
66 let bit_val = (idx >> target_idx) & 1;
67 let paired_idx = idx ^ (1 << target_idx);
68
69 let val0 = if bit_val == 0 {
70 state_copy[idx]
71 } else {
72 state_copy[paired_idx]
73 };
74 let val1 = if bit_val == 0 {
75 state_copy[paired_idx]
76 } else {
77 state_copy[idx]
78 };
79
80 *amp = sqrt2_inv
81 * if bit_val == 0 {
82 val0 + val1
83 } else {
84 val0 - val1
85 };
86 });
87 } else {
88 for i in 0..(1 << n_qubits) {
89 if (i >> target_idx) & 1 == 0 {
90 let j = i | (1 << target_idx);
91 let temp0 = state[i];
92 let temp1 = state[j];
93 state[i] = sqrt2_inv * (temp0 + temp1);
94 state[j] = sqrt2_inv * (temp0 - temp1);
95 }
96 }
97 }
98
99 Ok(())
100 }
101}
102
103#[derive(Debug, Clone, Copy)]
105pub struct PauliXSpecialized {
106 pub target: QubitId,
107}
108
109impl SpecializedGate for PauliXSpecialized {
110 fn apply_specialized(
111 &self,
112 state: &mut [Complex64],
113 n_qubits: usize,
114 parallel: bool,
115 ) -> QuantRS2Result<()> {
116 let target_idx = self.target.id() as usize;
117 if target_idx >= n_qubits {
118 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
119 }
120
121 if parallel {
122 let state_copy = state.to_vec();
123 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
124 let flipped_idx = idx ^ (1 << target_idx);
125 *amp = state_copy[flipped_idx];
126 });
127 } else {
128 for i in 0..(1 << n_qubits) {
129 if (i >> target_idx) & 1 == 0 {
130 let j = i | (1 << target_idx);
131 state.swap(i, j);
132 }
133 }
134 }
135
136 Ok(())
137 }
138}
139
140#[derive(Debug, Clone, Copy)]
142pub struct PauliYSpecialized {
143 pub target: QubitId,
144}
145
146impl SpecializedGate for PauliYSpecialized {
147 fn apply_specialized(
148 &self,
149 state: &mut [Complex64],
150 n_qubits: usize,
151 parallel: bool,
152 ) -> QuantRS2Result<()> {
153 let target_idx = self.target.id() as usize;
154 if target_idx >= n_qubits {
155 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
156 }
157
158 let i_unit = Complex64::new(0.0, 1.0);
159
160 if parallel {
161 let state_copy = state.to_vec();
162 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
163 let bit_val = (idx >> target_idx) & 1;
164 let flipped_idx = idx ^ (1 << target_idx);
165 *amp = if bit_val == 0 {
166 i_unit * state_copy[flipped_idx]
167 } else {
168 -i_unit * state_copy[flipped_idx]
169 };
170 });
171 } else {
172 for i in 0..(1 << n_qubits) {
173 if (i >> target_idx) & 1 == 0 {
174 let j = i | (1 << target_idx);
175 let temp0 = state[i];
176 let temp1 = state[j];
177 state[i] = i_unit * temp1;
178 state[j] = -i_unit * temp0;
179 }
180 }
181 }
182
183 Ok(())
184 }
185}
186
187#[derive(Debug, Clone, Copy)]
189pub struct PauliZSpecialized {
190 pub target: QubitId,
191}
192
193impl SpecializedGate for PauliZSpecialized {
194 fn apply_specialized(
195 &self,
196 state: &mut [Complex64],
197 n_qubits: usize,
198 parallel: bool,
199 ) -> QuantRS2Result<()> {
200 let target_idx = self.target.id() as usize;
201 if target_idx >= n_qubits {
202 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
203 }
204
205 if parallel {
206 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
207 if (idx >> target_idx) & 1 == 1 {
208 *amp = -*amp;
209 }
210 });
211 } else {
212 for i in 0..(1 << n_qubits) {
213 if (i >> target_idx) & 1 == 1 {
214 state[i] = -state[i];
215 }
216 }
217 }
218
219 Ok(())
220 }
221}
222
223#[derive(Debug, Clone, Copy)]
225pub struct PhaseSpecialized {
226 pub target: QubitId,
227 pub phase: f64,
228}
229
230impl SpecializedGate for PhaseSpecialized {
231 fn apply_specialized(
232 &self,
233 state: &mut [Complex64],
234 n_qubits: usize,
235 parallel: bool,
236 ) -> QuantRS2Result<()> {
237 let target_idx = self.target.id() as usize;
238 if target_idx >= n_qubits {
239 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
240 }
241
242 let phase_factor = Complex64::from_polar(1.0, self.phase);
243
244 if parallel {
245 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
246 if (idx >> target_idx) & 1 == 1 {
247 *amp *= phase_factor;
248 }
249 });
250 } else {
251 for i in 0..(1 << n_qubits) {
252 if (i >> target_idx) & 1 == 1 {
253 state[i] *= phase_factor;
254 }
255 }
256 }
257
258 Ok(())
259 }
260}
261
262#[derive(Debug, Clone, Copy)]
264pub struct SGateSpecialized {
265 pub target: QubitId,
266}
267
268impl SpecializedGate for SGateSpecialized {
269 fn apply_specialized(
270 &self,
271 state: &mut [Complex64],
272 n_qubits: usize,
273 parallel: bool,
274 ) -> QuantRS2Result<()> {
275 let phase_gate = PhaseSpecialized {
276 target: self.target,
277 phase: FRAC_PI_2,
278 };
279 phase_gate.apply_specialized(state, n_qubits, parallel)
280 }
281}
282
283#[derive(Debug, Clone, Copy)]
285pub struct TGateSpecialized {
286 pub target: QubitId,
287}
288
289impl SpecializedGate for TGateSpecialized {
290 fn apply_specialized(
291 &self,
292 state: &mut [Complex64],
293 n_qubits: usize,
294 parallel: bool,
295 ) -> QuantRS2Result<()> {
296 let phase_gate = PhaseSpecialized {
297 target: self.target,
298 phase: FRAC_PI_4,
299 };
300 phase_gate.apply_specialized(state, n_qubits, parallel)
301 }
302}
303
304#[derive(Debug, Clone, Copy)]
306pub struct RXSpecialized {
307 pub target: QubitId,
308 pub theta: f64,
309}
310
311impl SpecializedGate for RXSpecialized {
312 fn apply_specialized(
313 &self,
314 state: &mut [Complex64],
315 n_qubits: usize,
316 parallel: bool,
317 ) -> QuantRS2Result<()> {
318 let target_idx = self.target.id() as usize;
319 if target_idx >= n_qubits {
320 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
321 }
322
323 let cos_half = (self.theta / 2.0).cos();
324 let sin_half = (self.theta / 2.0).sin();
325 let i_sin = Complex64::new(0.0, -sin_half);
326
327 if parallel {
328 let state_copy = state.to_vec();
329 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
330 let bit_val = (idx >> target_idx) & 1;
331 let paired_idx = idx ^ (1 << target_idx);
332
333 let val0 = if bit_val == 0 {
334 state_copy[idx]
335 } else {
336 state_copy[paired_idx]
337 };
338 let val1 = if bit_val == 0 {
339 state_copy[paired_idx]
340 } else {
341 state_copy[idx]
342 };
343
344 *amp = if bit_val == 0 {
345 cos_half * val0 + i_sin * val1
346 } else {
347 i_sin * val0 + cos_half * val1
348 };
349 });
350 } else {
351 for i in 0..(1 << n_qubits) {
352 if (i >> target_idx) & 1 == 0 {
353 let j = i | (1 << target_idx);
354 let temp0 = state[i];
355 let temp1 = state[j];
356 state[i] = cos_half * temp0 + i_sin * temp1;
357 state[j] = i_sin * temp0 + cos_half * temp1;
358 }
359 }
360 }
361
362 Ok(())
363 }
364}
365
366#[derive(Debug, Clone, Copy)]
368pub struct RYSpecialized {
369 pub target: QubitId,
370 pub theta: f64,
371}
372
373impl SpecializedGate for RYSpecialized {
374 fn apply_specialized(
375 &self,
376 state: &mut [Complex64],
377 n_qubits: usize,
378 parallel: bool,
379 ) -> QuantRS2Result<()> {
380 let target_idx = self.target.id() as usize;
381 if target_idx >= n_qubits {
382 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
383 }
384
385 let cos_half = (self.theta / 2.0).cos();
386 let sin_half = (self.theta / 2.0).sin();
387
388 if parallel {
389 let state_copy = state.to_vec();
390 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
391 let bit_val = (idx >> target_idx) & 1;
392 let paired_idx = idx ^ (1 << target_idx);
393
394 let val0 = if bit_val == 0 {
395 state_copy[idx]
396 } else {
397 state_copy[paired_idx]
398 };
399 let val1 = if bit_val == 0 {
400 state_copy[paired_idx]
401 } else {
402 state_copy[idx]
403 };
404
405 *amp = if bit_val == 0 {
406 cos_half * val0 - sin_half * val1
407 } else {
408 sin_half * val0 + cos_half * val1
409 };
410 });
411 } else {
412 for i in 0..(1 << n_qubits) {
413 if (i >> target_idx) & 1 == 0 {
414 let j = i | (1 << target_idx);
415 let temp0 = state[i];
416 let temp1 = state[j];
417 state[i] = cos_half * temp0 - sin_half * temp1;
418 state[j] = sin_half * temp0 + cos_half * temp1;
419 }
420 }
421 }
422
423 Ok(())
424 }
425}
426
427#[derive(Debug, Clone, Copy)]
429pub struct RZSpecialized {
430 pub target: QubitId,
431 pub theta: f64,
432}
433
434impl SpecializedGate for RZSpecialized {
435 fn apply_specialized(
436 &self,
437 state: &mut [Complex64],
438 n_qubits: usize,
439 parallel: bool,
440 ) -> QuantRS2Result<()> {
441 let target_idx = self.target.id() as usize;
442 if target_idx >= n_qubits {
443 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
444 }
445
446 let phase_0 = Complex64::from_polar(1.0, -self.theta / 2.0);
447 let phase_1 = Complex64::from_polar(1.0, self.theta / 2.0);
448
449 if parallel {
450 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
451 if (idx >> target_idx) & 1 == 0 {
452 *amp *= phase_0;
453 } else {
454 *amp *= phase_1;
455 }
456 });
457 } else {
458 for i in 0..(1 << n_qubits) {
459 if (i >> target_idx) & 1 == 0 {
460 state[i] *= phase_0;
461 } else {
462 state[i] *= phase_1;
463 }
464 }
465 }
466
467 Ok(())
468 }
469}
470
471#[derive(Debug, Clone, Copy)]
475pub struct CNOTSpecialized {
476 pub control: QubitId,
477 pub target: QubitId,
478}
479
480impl SpecializedGate for CNOTSpecialized {
481 fn apply_specialized(
482 &self,
483 state: &mut [Complex64],
484 n_qubits: usize,
485 parallel: bool,
486 ) -> QuantRS2Result<()> {
487 let control_idx = self.control.id() as usize;
488 let target_idx = self.target.id() as usize;
489
490 if control_idx >= n_qubits || target_idx >= n_qubits {
491 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
492 self.control.id()
493 } else {
494 self.target.id()
495 }));
496 }
497
498 if control_idx == target_idx {
499 return Err(QuantRS2Error::CircuitValidationFailed(
500 "Control and target qubits must be different".into(),
501 ));
502 }
503
504 if parallel {
505 let state_copy = state.to_vec();
506 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
507 if (idx >> control_idx) & 1 == 1 {
508 let flipped_idx = idx ^ (1 << target_idx);
509 *amp = state_copy[flipped_idx];
510 }
511 });
512 } else {
513 for i in 0..(1 << n_qubits) {
514 if (i >> control_idx) & 1 == 1 && (i >> target_idx) & 1 == 0 {
515 let j = i | (1 << target_idx);
516 state.swap(i, j);
517 }
518 }
519 }
520
521 Ok(())
522 }
523
524 fn can_fuse_with(&self, other: &dyn SpecializedGate) -> bool {
525 if let Some(other_cnot) = other.as_any().downcast_ref::<Self>() {
527 self.control == other_cnot.control && self.target == other_cnot.target
528 } else {
529 false
530 }
531 }
532}
533
534#[derive(Debug, Clone, Copy)]
536pub struct CZSpecialized {
537 pub control: QubitId,
538 pub target: QubitId,
539}
540
541impl SpecializedGate for CZSpecialized {
542 fn apply_specialized(
543 &self,
544 state: &mut [Complex64],
545 n_qubits: usize,
546 parallel: bool,
547 ) -> QuantRS2Result<()> {
548 let control_idx = self.control.id() as usize;
549 let target_idx = self.target.id() as usize;
550
551 if control_idx >= n_qubits || target_idx >= n_qubits {
552 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
553 self.control.id()
554 } else {
555 self.target.id()
556 }));
557 }
558
559 if control_idx == target_idx {
560 return Err(QuantRS2Error::CircuitValidationFailed(
561 "Control and target qubits must be different".into(),
562 ));
563 }
564
565 if parallel {
566 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
567 if (idx >> control_idx) & 1 == 1 && (idx >> target_idx) & 1 == 1 {
568 *amp = -*amp;
569 }
570 });
571 } else {
572 for i in 0..(1 << n_qubits) {
573 if (i >> control_idx) & 1 == 1 && (i >> target_idx) & 1 == 1 {
574 state[i] = -state[i];
575 }
576 }
577 }
578
579 Ok(())
580 }
581}
582
583#[derive(Debug, Clone, Copy)]
585pub struct SWAPSpecialized {
586 pub qubit1: QubitId,
587 pub qubit2: QubitId,
588}
589
590impl SpecializedGate for SWAPSpecialized {
591 fn apply_specialized(
592 &self,
593 state: &mut [Complex64],
594 n_qubits: usize,
595 parallel: bool,
596 ) -> QuantRS2Result<()> {
597 let idx1 = self.qubit1.id() as usize;
598 let idx2 = self.qubit2.id() as usize;
599
600 if idx1 >= n_qubits || idx2 >= n_qubits {
601 return Err(QuantRS2Error::InvalidQubitId(if idx1 >= n_qubits {
602 self.qubit1.id()
603 } else {
604 self.qubit2.id()
605 }));
606 }
607
608 if idx1 == idx2 {
609 return Ok(()); }
611
612 if parallel {
613 let state_copy = state.to_vec();
614 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
615 let bit1 = (idx >> idx1) & 1;
616 let bit2 = (idx >> idx2) & 1;
617
618 if bit1 != bit2 {
619 let swapped_idx =
620 (idx & !(1 << idx1) & !(1 << idx2)) | (bit2 << idx1) | (bit1 << idx2);
621 *amp = state_copy[swapped_idx];
622 }
623 });
624 } else {
625 for i in 0..(1 << n_qubits) {
626 let bit1 = (i >> idx1) & 1;
627 let bit2 = (i >> idx2) & 1;
628
629 if bit1 == 0 && bit2 == 1 {
630 let j = (i | (1 << idx1)) & !(1 << idx2);
631 state.swap(i, j);
632 }
633 }
634 }
635
636 Ok(())
637 }
638}
639
640#[derive(Debug, Clone, Copy)]
642pub struct CPhaseSpecialized {
643 pub control: QubitId,
644 pub target: QubitId,
645 pub phase: f64,
646}
647
648impl SpecializedGate for CPhaseSpecialized {
649 fn apply_specialized(
650 &self,
651 state: &mut [Complex64],
652 n_qubits: usize,
653 parallel: bool,
654 ) -> QuantRS2Result<()> {
655 let control_idx = self.control.id() as usize;
656 let target_idx = self.target.id() as usize;
657
658 if control_idx >= n_qubits || target_idx >= n_qubits {
659 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
660 self.control.id()
661 } else {
662 self.target.id()
663 }));
664 }
665
666 if control_idx == target_idx {
667 return Err(QuantRS2Error::CircuitValidationFailed(
668 "Control and target qubits must be different".into(),
669 ));
670 }
671
672 let phase_factor = Complex64::from_polar(1.0, self.phase);
673
674 if parallel {
675 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
676 if (idx >> control_idx) & 1 == 1 && (idx >> target_idx) & 1 == 1 {
677 *amp *= phase_factor;
678 }
679 });
680 } else {
681 for i in 0..(1 << n_qubits) {
682 if (i >> control_idx) & 1 == 1 && (i >> target_idx) & 1 == 1 {
683 state[i] *= phase_factor;
684 }
685 }
686 }
687
688 Ok(())
689 }
690}
691
692#[derive(Debug, Clone, Copy)]
696pub struct ToffoliSpecialized {
697 pub control1: QubitId,
698 pub control2: QubitId,
699 pub target: QubitId,
700}
701
702impl SpecializedGate for ToffoliSpecialized {
703 fn apply_specialized(
704 &self,
705 state: &mut [Complex64],
706 n_qubits: usize,
707 parallel: bool,
708 ) -> QuantRS2Result<()> {
709 let ctrl1_idx = self.control1.id() as usize;
710 let ctrl2_idx = self.control2.id() as usize;
711 let target_idx = self.target.id() as usize;
712
713 if ctrl1_idx >= n_qubits || ctrl2_idx >= n_qubits || target_idx >= n_qubits {
714 return Err(QuantRS2Error::InvalidQubitId(if ctrl1_idx >= n_qubits {
715 self.control1.id()
716 } else if ctrl2_idx >= n_qubits {
717 self.control2.id()
718 } else {
719 self.target.id()
720 }));
721 }
722
723 if ctrl1_idx == ctrl2_idx || ctrl1_idx == target_idx || ctrl2_idx == target_idx {
724 return Err(QuantRS2Error::CircuitValidationFailed(
725 "All qubits must be different".into(),
726 ));
727 }
728
729 if parallel {
730 let state_copy = state.to_vec();
731 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
732 if (idx >> ctrl1_idx) & 1 == 1 && (idx >> ctrl2_idx) & 1 == 1 {
733 let flipped_idx = idx ^ (1 << target_idx);
734 *amp = state_copy[flipped_idx];
735 }
736 });
737 } else {
738 for i in 0..(1 << n_qubits) {
739 if (i >> ctrl1_idx) & 1 == 1
740 && (i >> ctrl2_idx) & 1 == 1
741 && (i >> target_idx) & 1 == 0
742 {
743 let j = i | (1 << target_idx);
744 state.swap(i, j);
745 }
746 }
747 }
748
749 Ok(())
750 }
751}
752
753#[derive(Debug, Clone, Copy)]
755pub struct FredkinSpecialized {
756 pub control: QubitId,
757 pub target1: QubitId,
758 pub target2: QubitId,
759}
760
761impl SpecializedGate for FredkinSpecialized {
762 fn apply_specialized(
763 &self,
764 state: &mut [Complex64],
765 n_qubits: usize,
766 parallel: bool,
767 ) -> QuantRS2Result<()> {
768 let ctrl_idx = self.control.id() as usize;
769 let tgt1_idx = self.target1.id() as usize;
770 let tgt2_idx = self.target2.id() as usize;
771
772 if ctrl_idx >= n_qubits || tgt1_idx >= n_qubits || tgt2_idx >= n_qubits {
773 return Err(QuantRS2Error::InvalidQubitId(if ctrl_idx >= n_qubits {
774 self.control.id()
775 } else if tgt1_idx >= n_qubits {
776 self.target1.id()
777 } else {
778 self.target2.id()
779 }));
780 }
781
782 if ctrl_idx == tgt1_idx || ctrl_idx == tgt2_idx || tgt1_idx == tgt2_idx {
783 return Err(QuantRS2Error::CircuitValidationFailed(
784 "All qubits must be different".into(),
785 ));
786 }
787
788 if parallel {
789 let state_copy = state.to_vec();
790 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
791 if (idx >> ctrl_idx) & 1 == 1 {
792 let bit1 = (idx >> tgt1_idx) & 1;
793 let bit2 = (idx >> tgt2_idx) & 1;
794
795 if bit1 != bit2 {
796 let swapped_idx = (idx & !(1 << tgt1_idx) & !(1 << tgt2_idx))
797 | (bit2 << tgt1_idx)
798 | (bit1 << tgt2_idx);
799 *amp = state_copy[swapped_idx];
800 }
801 }
802 });
803 } else {
804 for i in 0..(1 << n_qubits) {
805 if (i >> ctrl_idx) & 1 == 1 {
806 let bit1 = (i >> tgt1_idx) & 1;
807 let bit2 = (i >> tgt2_idx) & 1;
808
809 if bit1 == 0 && bit2 == 1 {
810 let j = (i | (1 << tgt1_idx)) & !(1 << tgt2_idx);
811 state.swap(i, j);
812 }
813 }
814 }
815 }
816
817 Ok(())
818 }
819}
820
821pub fn specialize_gate(gate: &dyn GateOp) -> Option<Box<dyn SpecializedGate>> {
825 use quantrs2_core::gate::{
826 multi::{CNOT, CZ, SWAP},
827 single::{Hadamard, PauliX, PauliY, PauliZ, Phase, RotationX, RotationY, RotationZ, T},
828 };
829 use std::any::Any;
830
831 if let Some(h) = gate.as_any().downcast_ref::<Hadamard>() {
833 return Some(Box::new(HadamardSpecialized { target: h.target }));
834 }
835 if let Some(x) = gate.as_any().downcast_ref::<PauliX>() {
836 return Some(Box::new(PauliXSpecialized { target: x.target }));
837 }
838 if let Some(y) = gate.as_any().downcast_ref::<PauliY>() {
839 return Some(Box::new(PauliYSpecialized { target: y.target }));
840 }
841 if let Some(z) = gate.as_any().downcast_ref::<PauliZ>() {
842 return Some(Box::new(PauliZSpecialized { target: z.target }));
843 }
844 if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
845 return Some(Box::new(RXSpecialized {
846 target: rx.target,
847 theta: rx.theta,
848 }));
849 }
850 if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
851 return Some(Box::new(RYSpecialized {
852 target: ry.target,
853 theta: ry.theta,
854 }));
855 }
856 if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
857 return Some(Box::new(RZSpecialized {
858 target: rz.target,
859 theta: rz.theta,
860 }));
861 }
862 if let Some(s) = gate.as_any().downcast_ref::<Phase>() {
863 return Some(Box::new(SGateSpecialized { target: s.target }));
864 }
865 if let Some(t) = gate.as_any().downcast_ref::<T>() {
866 return Some(Box::new(TGateSpecialized { target: t.target }));
867 }
868
869 if let Some(cnot) = gate.as_any().downcast_ref::<CNOT>() {
871 return Some(Box::new(CNOTSpecialized {
872 control: cnot.control,
873 target: cnot.target,
874 }));
875 }
876 if let Some(cz) = gate.as_any().downcast_ref::<CZ>() {
877 return Some(Box::new(CZSpecialized {
878 control: cz.control,
879 target: cz.target,
880 }));
881 }
882 if let Some(swap) = gate.as_any().downcast_ref::<SWAP>() {
883 return Some(Box::new(SWAPSpecialized {
884 qubit1: swap.qubit1,
885 qubit2: swap.qubit2,
886 }));
887 }
888
889 None
890}
891
892macro_rules! impl_gate_op_for_specialized {
895 ($gate_type:ty, $name:expr, $qubits:expr, $matrix:expr) => {
896 impl GateOp for $gate_type {
897 fn name(&self) -> &'static str {
898 $name
899 }
900
901 fn qubits(&self) -> Vec<QubitId> {
902 $qubits(self)
903 }
904
905 fn matrix(&self) -> QuantRS2Result<Vec<Complex64>> {
906 $matrix(self)
907 }
908
909 fn as_any(&self) -> &dyn Any {
910 self
911 }
912
913 fn clone_gate(&self) -> Box<dyn GateOp> {
914 Box::new(self.clone())
915 }
916 }
917 };
918}
919
920impl_gate_op_for_specialized!(
922 HadamardSpecialized,
923 "H",
924 |g: &HadamardSpecialized| vec![g.target],
925 |_: &HadamardSpecialized| {
926 let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
927 Ok(vec![
928 Complex64::new(sqrt2_inv, 0.0),
929 Complex64::new(sqrt2_inv, 0.0),
930 Complex64::new(sqrt2_inv, 0.0),
931 Complex64::new(-sqrt2_inv, 0.0),
932 ])
933 }
934);
935
936impl_gate_op_for_specialized!(
937 PauliXSpecialized,
938 "X",
939 |g: &PauliXSpecialized| vec![g.target],
940 |_: &PauliXSpecialized| Ok(vec![
941 Complex64::new(0.0, 0.0),
942 Complex64::new(1.0, 0.0),
943 Complex64::new(1.0, 0.0),
944 Complex64::new(0.0, 0.0),
945 ])
946);
947
948impl_gate_op_for_specialized!(
949 PauliYSpecialized,
950 "Y",
951 |g: &PauliYSpecialized| vec![g.target],
952 |_: &PauliYSpecialized| Ok(vec![
953 Complex64::new(0.0, 0.0),
954 Complex64::new(0.0, -1.0),
955 Complex64::new(0.0, 1.0),
956 Complex64::new(0.0, 0.0),
957 ])
958);
959
960impl_gate_op_for_specialized!(
961 PauliZSpecialized,
962 "Z",
963 |g: &PauliZSpecialized| vec![g.target],
964 |_: &PauliZSpecialized| Ok(vec![
965 Complex64::new(1.0, 0.0),
966 Complex64::new(0.0, 0.0),
967 Complex64::new(0.0, 0.0),
968 Complex64::new(-1.0, 0.0),
969 ])
970);
971
972impl_gate_op_for_specialized!(
974 CNOTSpecialized,
975 "CNOT",
976 |g: &CNOTSpecialized| vec![g.control, g.target],
977 |_: &CNOTSpecialized| Ok(vec![
978 Complex64::new(1.0, 0.0),
979 Complex64::new(0.0, 0.0),
980 Complex64::new(0.0, 0.0),
981 Complex64::new(0.0, 0.0),
982 Complex64::new(0.0, 0.0),
983 Complex64::new(1.0, 0.0),
984 Complex64::new(0.0, 0.0),
985 Complex64::new(0.0, 0.0),
986 Complex64::new(0.0, 0.0),
987 Complex64::new(0.0, 0.0),
988 Complex64::new(0.0, 0.0),
989 Complex64::new(1.0, 0.0),
990 Complex64::new(0.0, 0.0),
991 Complex64::new(0.0, 0.0),
992 Complex64::new(1.0, 0.0),
993 Complex64::new(0.0, 0.0),
994 ])
995);
996
997impl_gate_op_for_specialized!(
999 PhaseSpecialized,
1000 "Phase",
1001 |g: &PhaseSpecialized| vec![g.target],
1002 |g: &PhaseSpecialized| Ok(vec![
1003 Complex64::new(1.0, 0.0),
1004 Complex64::new(0.0, 0.0),
1005 Complex64::new(0.0, 0.0),
1006 Complex64::from_polar(1.0, g.phase),
1007 ])
1008);
1009
1010impl_gate_op_for_specialized!(
1011 SGateSpecialized,
1012 "S",
1013 |g: &SGateSpecialized| vec![g.target],
1014 |_: &SGateSpecialized| Ok(vec![
1015 Complex64::new(1.0, 0.0),
1016 Complex64::new(0.0, 0.0),
1017 Complex64::new(0.0, 0.0),
1018 Complex64::new(0.0, 1.0),
1019 ])
1020);
1021
1022impl_gate_op_for_specialized!(
1023 TGateSpecialized,
1024 "T",
1025 |g: &TGateSpecialized| vec![g.target],
1026 |_: &TGateSpecialized| {
1027 let phase = Complex64::from_polar(1.0, PI / 4.0);
1028 Ok(vec![
1029 Complex64::new(1.0, 0.0),
1030 Complex64::new(0.0, 0.0),
1031 Complex64::new(0.0, 0.0),
1032 phase,
1033 ])
1034 }
1035);
1036
1037impl_gate_op_for_specialized!(
1038 RXSpecialized,
1039 "RX",
1040 |g: &RXSpecialized| vec![g.target],
1041 |g: &RXSpecialized| {
1042 let cos = (g.theta / 2.0).cos();
1043 let sin = (g.theta / 2.0).sin();
1044 Ok(vec![
1045 Complex64::new(cos, 0.0),
1046 Complex64::new(0.0, -sin),
1047 Complex64::new(0.0, -sin),
1048 Complex64::new(cos, 0.0),
1049 ])
1050 }
1051);
1052
1053impl_gate_op_for_specialized!(
1054 RYSpecialized,
1055 "RY",
1056 |g: &RYSpecialized| vec![g.target],
1057 |g: &RYSpecialized| {
1058 let cos = (g.theta / 2.0).cos();
1059 let sin = (g.theta / 2.0).sin();
1060 Ok(vec![
1061 Complex64::new(cos, 0.0),
1062 Complex64::new(-sin, 0.0),
1063 Complex64::new(sin, 0.0),
1064 Complex64::new(cos, 0.0),
1065 ])
1066 }
1067);
1068
1069impl_gate_op_for_specialized!(
1070 RZSpecialized,
1071 "RZ",
1072 |g: &RZSpecialized| vec![g.target],
1073 |g: &RZSpecialized| {
1074 let phase_pos = Complex64::from_polar(1.0, g.theta / 2.0);
1075 let phase_neg = Complex64::from_polar(1.0, -g.theta / 2.0);
1076 Ok(vec![
1077 phase_neg,
1078 Complex64::new(0.0, 0.0),
1079 Complex64::new(0.0, 0.0),
1080 phase_pos,
1081 ])
1082 }
1083);
1084
1085impl_gate_op_for_specialized!(
1086 CZSpecialized,
1087 "CZ",
1088 |g: &CZSpecialized| vec![g.control, g.target],
1089 |_: &CZSpecialized| Ok(vec![
1090 Complex64::new(1.0, 0.0),
1091 Complex64::new(0.0, 0.0),
1092 Complex64::new(0.0, 0.0),
1093 Complex64::new(0.0, 0.0),
1094 Complex64::new(0.0, 0.0),
1095 Complex64::new(1.0, 0.0),
1096 Complex64::new(0.0, 0.0),
1097 Complex64::new(0.0, 0.0),
1098 Complex64::new(0.0, 0.0),
1099 Complex64::new(0.0, 0.0),
1100 Complex64::new(1.0, 0.0),
1101 Complex64::new(0.0, 0.0),
1102 Complex64::new(0.0, 0.0),
1103 Complex64::new(0.0, 0.0),
1104 Complex64::new(0.0, 0.0),
1105 Complex64::new(-1.0, 0.0),
1106 ])
1107);
1108
1109impl_gate_op_for_specialized!(
1110 SWAPSpecialized,
1111 "SWAP",
1112 |g: &SWAPSpecialized| vec![g.qubit1, g.qubit2],
1113 |_: &SWAPSpecialized| Ok(vec![
1114 Complex64::new(1.0, 0.0),
1115 Complex64::new(0.0, 0.0),
1116 Complex64::new(0.0, 0.0),
1117 Complex64::new(0.0, 0.0),
1118 Complex64::new(0.0, 0.0),
1119 Complex64::new(0.0, 0.0),
1120 Complex64::new(1.0, 0.0),
1121 Complex64::new(0.0, 0.0),
1122 Complex64::new(0.0, 0.0),
1123 Complex64::new(1.0, 0.0),
1124 Complex64::new(0.0, 0.0),
1125 Complex64::new(0.0, 0.0),
1126 Complex64::new(0.0, 0.0),
1127 Complex64::new(0.0, 0.0),
1128 Complex64::new(0.0, 0.0),
1129 Complex64::new(1.0, 0.0),
1130 ])
1131);
1132
1133impl_gate_op_for_specialized!(
1135 CPhaseSpecialized,
1136 "CPhase",
1137 |g: &CPhaseSpecialized| vec![g.control, g.target],
1138 |g: &CPhaseSpecialized| {
1139 let phase = Complex64::from_polar(1.0, g.phase);
1140 Ok(vec![
1141 Complex64::new(1.0, 0.0),
1142 Complex64::new(0.0, 0.0),
1143 Complex64::new(0.0, 0.0),
1144 Complex64::new(0.0, 0.0),
1145 Complex64::new(0.0, 0.0),
1146 Complex64::new(1.0, 0.0),
1147 Complex64::new(0.0, 0.0),
1148 Complex64::new(0.0, 0.0),
1149 Complex64::new(0.0, 0.0),
1150 Complex64::new(0.0, 0.0),
1151 Complex64::new(1.0, 0.0),
1152 Complex64::new(0.0, 0.0),
1153 Complex64::new(0.0, 0.0),
1154 Complex64::new(0.0, 0.0),
1155 Complex64::new(0.0, 0.0),
1156 phase,
1157 ])
1158 }
1159);
1160
1161impl_gate_op_for_specialized!(
1162 ToffoliSpecialized,
1163 "Toffoli",
1164 |g: &ToffoliSpecialized| vec![g.control1, g.control2, g.target],
1165 |_: &ToffoliSpecialized| Ok(vec![
1166 Complex64::new(1.0, 0.0),
1168 Complex64::new(0.0, 0.0),
1169 Complex64::new(0.0, 0.0),
1170 Complex64::new(0.0, 0.0),
1171 Complex64::new(0.0, 0.0),
1172 Complex64::new(0.0, 0.0),
1173 Complex64::new(0.0, 0.0),
1174 Complex64::new(0.0, 0.0),
1175 Complex64::new(0.0, 0.0),
1176 Complex64::new(1.0, 0.0),
1177 Complex64::new(0.0, 0.0),
1178 Complex64::new(0.0, 0.0),
1179 Complex64::new(0.0, 0.0),
1180 Complex64::new(0.0, 0.0),
1181 Complex64::new(0.0, 0.0),
1182 Complex64::new(0.0, 0.0),
1183 Complex64::new(0.0, 0.0),
1184 Complex64::new(0.0, 0.0),
1185 Complex64::new(1.0, 0.0),
1186 Complex64::new(0.0, 0.0),
1187 Complex64::new(0.0, 0.0),
1188 Complex64::new(0.0, 0.0),
1189 Complex64::new(0.0, 0.0),
1190 Complex64::new(0.0, 0.0),
1191 Complex64::new(0.0, 0.0),
1192 Complex64::new(0.0, 0.0),
1193 Complex64::new(0.0, 0.0),
1194 Complex64::new(1.0, 0.0),
1195 Complex64::new(0.0, 0.0),
1196 Complex64::new(0.0, 0.0),
1197 Complex64::new(0.0, 0.0),
1198 Complex64::new(0.0, 0.0),
1199 Complex64::new(0.0, 0.0),
1200 Complex64::new(0.0, 0.0),
1201 Complex64::new(0.0, 0.0),
1202 Complex64::new(0.0, 0.0),
1203 Complex64::new(1.0, 0.0),
1204 Complex64::new(0.0, 0.0),
1205 Complex64::new(0.0, 0.0),
1206 Complex64::new(0.0, 0.0),
1207 Complex64::new(0.0, 0.0),
1208 Complex64::new(0.0, 0.0),
1209 Complex64::new(0.0, 0.0),
1210 Complex64::new(0.0, 0.0),
1211 Complex64::new(0.0, 0.0),
1212 Complex64::new(1.0, 0.0),
1213 Complex64::new(0.0, 0.0),
1214 Complex64::new(0.0, 0.0),
1215 Complex64::new(0.0, 0.0),
1216 Complex64::new(0.0, 0.0),
1217 Complex64::new(0.0, 0.0),
1218 Complex64::new(0.0, 0.0),
1219 Complex64::new(0.0, 0.0),
1220 Complex64::new(0.0, 0.0),
1221 Complex64::new(0.0, 0.0),
1222 Complex64::new(1.0, 0.0),
1223 Complex64::new(0.0, 0.0),
1224 Complex64::new(0.0, 0.0),
1225 Complex64::new(0.0, 0.0),
1226 Complex64::new(0.0, 0.0),
1227 Complex64::new(0.0, 0.0),
1228 Complex64::new(0.0, 0.0),
1229 Complex64::new(1.0, 0.0),
1230 Complex64::new(0.0, 0.0),
1231 ])
1232);
1233
1234impl_gate_op_for_specialized!(
1235 FredkinSpecialized,
1236 "Fredkin",
1237 |g: &FredkinSpecialized| vec![g.control, g.target1, g.target2],
1238 |_: &FredkinSpecialized| Ok(vec![
1239 Complex64::new(1.0, 0.0),
1241 Complex64::new(0.0, 0.0),
1242 Complex64::new(0.0, 0.0),
1243 Complex64::new(0.0, 0.0),
1244 Complex64::new(0.0, 0.0),
1245 Complex64::new(0.0, 0.0),
1246 Complex64::new(0.0, 0.0),
1247 Complex64::new(0.0, 0.0),
1248 Complex64::new(0.0, 0.0),
1249 Complex64::new(1.0, 0.0),
1250 Complex64::new(0.0, 0.0),
1251 Complex64::new(0.0, 0.0),
1252 Complex64::new(0.0, 0.0),
1253 Complex64::new(0.0, 0.0),
1254 Complex64::new(0.0, 0.0),
1255 Complex64::new(0.0, 0.0),
1256 Complex64::new(0.0, 0.0),
1257 Complex64::new(0.0, 0.0),
1258 Complex64::new(1.0, 0.0),
1259 Complex64::new(0.0, 0.0),
1260 Complex64::new(0.0, 0.0),
1261 Complex64::new(0.0, 0.0),
1262 Complex64::new(0.0, 0.0),
1263 Complex64::new(0.0, 0.0),
1264 Complex64::new(0.0, 0.0),
1265 Complex64::new(0.0, 0.0),
1266 Complex64::new(0.0, 0.0),
1267 Complex64::new(1.0, 0.0),
1268 Complex64::new(0.0, 0.0),
1269 Complex64::new(0.0, 0.0),
1270 Complex64::new(0.0, 0.0),
1271 Complex64::new(0.0, 0.0),
1272 Complex64::new(0.0, 0.0),
1273 Complex64::new(0.0, 0.0),
1274 Complex64::new(0.0, 0.0),
1275 Complex64::new(0.0, 0.0),
1276 Complex64::new(1.0, 0.0),
1277 Complex64::new(0.0, 0.0),
1278 Complex64::new(0.0, 0.0),
1279 Complex64::new(0.0, 0.0),
1280 Complex64::new(0.0, 0.0),
1281 Complex64::new(0.0, 0.0),
1282 Complex64::new(0.0, 0.0),
1283 Complex64::new(0.0, 0.0),
1284 Complex64::new(0.0, 0.0),
1285 Complex64::new(0.0, 0.0),
1286 Complex64::new(1.0, 0.0),
1287 Complex64::new(0.0, 0.0),
1288 Complex64::new(0.0, 0.0),
1289 Complex64::new(0.0, 0.0),
1290 Complex64::new(0.0, 0.0),
1291 Complex64::new(0.0, 0.0),
1292 Complex64::new(0.0, 0.0),
1293 Complex64::new(1.0, 0.0),
1294 Complex64::new(0.0, 0.0),
1295 Complex64::new(0.0, 0.0),
1296 Complex64::new(0.0, 0.0),
1297 Complex64::new(0.0, 0.0),
1298 Complex64::new(0.0, 0.0),
1299 Complex64::new(0.0, 0.0),
1300 Complex64::new(0.0, 0.0),
1301 Complex64::new(0.0, 0.0),
1302 Complex64::new(0.0, 0.0),
1303 Complex64::new(1.0, 0.0),
1304 ])
1305);
1306
1307use std::any::Any;
1308
1309#[cfg(test)]
1310mod tests {
1311 use super::*;
1312 use scirs2_core::Complex64;
1313
1314 #[test]
1315 fn test_hadamard_specialized() {
1316 let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
1317 let gate = HadamardSpecialized { target: QubitId(0) };
1318
1319 gate.apply_specialized(&mut state, 1, false)
1320 .expect("Hadamard gate application should succeed");
1321
1322 let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
1323 assert!((state[0] - Complex64::new(sqrt2_inv, 0.0)).norm() < 1e-10);
1324 assert!((state[1] - Complex64::new(sqrt2_inv, 0.0)).norm() < 1e-10);
1325 }
1326
1327 #[test]
1328 fn test_cnot_specialized() {
1329 let mut state = vec![
1330 Complex64::new(0.0, 0.0),
1331 Complex64::new(1.0, 0.0),
1332 Complex64::new(0.0, 0.0),
1333 Complex64::new(0.0, 0.0),
1334 ];
1335 let gate = CNOTSpecialized {
1336 control: QubitId(0),
1337 target: QubitId(1),
1338 };
1339
1340 gate.apply_specialized(&mut state, 2, false)
1341 .expect("CNOT gate application should succeed");
1342
1343 assert!((state[0] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
1344 assert!((state[1] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
1345 assert!((state[2] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
1346 assert!((state[3] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
1347 }
1348}