1use num_complex::Complex64;
9use scirs2_core::parallel_ops::*;
10use std::f64::consts::{FRAC_PI_2, FRAC_PI_4, PI};
11
12use quantrs2_core::{
13 error::{QuantRS2Error, QuantRS2Result},
14 gate::GateOp,
15 qubit::QubitId,
16};
17
18pub trait SpecializedGate: GateOp {
20 fn apply_specialized(
22 &self,
23 state: &mut [Complex64],
24 n_qubits: usize,
25 parallel: bool,
26 ) -> QuantRS2Result<()>;
27
28 fn can_fuse_with(&self, other: &dyn SpecializedGate) -> bool {
30 false
31 }
32
33 fn fuse_with(&self, other: &dyn SpecializedGate) -> Option<Box<dyn SpecializedGate>> {
35 None
36 }
37}
38
39#[derive(Debug, Clone, Copy)]
43pub struct HadamardSpecialized {
44 pub target: QubitId,
45}
46
47impl SpecializedGate for HadamardSpecialized {
48 fn apply_specialized(
49 &self,
50 state: &mut [Complex64],
51 n_qubits: usize,
52 parallel: bool,
53 ) -> QuantRS2Result<()> {
54 let target_idx = self.target.id() as usize;
55 if target_idx >= n_qubits {
56 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
57 }
58
59 let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
60
61 if parallel {
62 let state_copy = state.to_vec();
63 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
64 let bit_val = (idx >> target_idx) & 1;
65 let paired_idx = idx ^ (1 << target_idx);
66
67 let val0 = if bit_val == 0 {
68 state_copy[idx]
69 } else {
70 state_copy[paired_idx]
71 };
72 let val1 = if bit_val == 0 {
73 state_copy[paired_idx]
74 } else {
75 state_copy[idx]
76 };
77
78 *amp = sqrt2_inv
79 * if bit_val == 0 {
80 val0 + val1
81 } else {
82 val0 - val1
83 };
84 });
85 } else {
86 for i in 0..(1 << n_qubits) {
87 if (i >> target_idx) & 1 == 0 {
88 let j = i | (1 << target_idx);
89 let temp0 = state[i];
90 let temp1 = state[j];
91 state[i] = sqrt2_inv * (temp0 + temp1);
92 state[j] = sqrt2_inv * (temp0 - temp1);
93 }
94 }
95 }
96
97 Ok(())
98 }
99}
100
101#[derive(Debug, Clone, Copy)]
103pub struct PauliXSpecialized {
104 pub target: QubitId,
105}
106
107impl SpecializedGate for PauliXSpecialized {
108 fn apply_specialized(
109 &self,
110 state: &mut [Complex64],
111 n_qubits: usize,
112 parallel: bool,
113 ) -> QuantRS2Result<()> {
114 let target_idx = self.target.id() as usize;
115 if target_idx >= n_qubits {
116 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
117 }
118
119 if parallel {
120 let state_copy = state.to_vec();
121 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
122 let flipped_idx = idx ^ (1 << target_idx);
123 *amp = state_copy[flipped_idx];
124 });
125 } else {
126 for i in 0..(1 << n_qubits) {
127 if (i >> target_idx) & 1 == 0 {
128 let j = i | (1 << target_idx);
129 state.swap(i, j);
130 }
131 }
132 }
133
134 Ok(())
135 }
136}
137
138#[derive(Debug, Clone, Copy)]
140pub struct PauliYSpecialized {
141 pub target: QubitId,
142}
143
144impl SpecializedGate for PauliYSpecialized {
145 fn apply_specialized(
146 &self,
147 state: &mut [Complex64],
148 n_qubits: usize,
149 parallel: bool,
150 ) -> QuantRS2Result<()> {
151 let target_idx = self.target.id() as usize;
152 if target_idx >= n_qubits {
153 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
154 }
155
156 let i_unit = Complex64::new(0.0, 1.0);
157
158 if parallel {
159 let state_copy = state.to_vec();
160 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
161 let bit_val = (idx >> target_idx) & 1;
162 let flipped_idx = idx ^ (1 << target_idx);
163 *amp = if bit_val == 0 {
164 i_unit * state_copy[flipped_idx]
165 } else {
166 -i_unit * state_copy[flipped_idx]
167 };
168 });
169 } else {
170 for i in 0..(1 << n_qubits) {
171 if (i >> target_idx) & 1 == 0 {
172 let j = i | (1 << target_idx);
173 let temp0 = state[i];
174 let temp1 = state[j];
175 state[i] = i_unit * temp1;
176 state[j] = -i_unit * temp0;
177 }
178 }
179 }
180
181 Ok(())
182 }
183}
184
185#[derive(Debug, Clone, Copy)]
187pub struct PauliZSpecialized {
188 pub target: QubitId,
189}
190
191impl SpecializedGate for PauliZSpecialized {
192 fn apply_specialized(
193 &self,
194 state: &mut [Complex64],
195 n_qubits: usize,
196 parallel: bool,
197 ) -> QuantRS2Result<()> {
198 let target_idx = self.target.id() as usize;
199 if target_idx >= n_qubits {
200 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
201 }
202
203 if parallel {
204 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
205 if (idx >> target_idx) & 1 == 1 {
206 *amp = -*amp;
207 }
208 });
209 } else {
210 for i in 0..(1 << n_qubits) {
211 if (i >> target_idx) & 1 == 1 {
212 state[i] = -state[i];
213 }
214 }
215 }
216
217 Ok(())
218 }
219}
220
221#[derive(Debug, Clone, Copy)]
223pub struct PhaseSpecialized {
224 pub target: QubitId,
225 pub phase: f64,
226}
227
228impl SpecializedGate for PhaseSpecialized {
229 fn apply_specialized(
230 &self,
231 state: &mut [Complex64],
232 n_qubits: usize,
233 parallel: bool,
234 ) -> QuantRS2Result<()> {
235 let target_idx = self.target.id() as usize;
236 if target_idx >= n_qubits {
237 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
238 }
239
240 let phase_factor = Complex64::from_polar(1.0, self.phase);
241
242 if parallel {
243 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
244 if (idx >> target_idx) & 1 == 1 {
245 *amp *= phase_factor;
246 }
247 });
248 } else {
249 for i in 0..(1 << n_qubits) {
250 if (i >> target_idx) & 1 == 1 {
251 state[i] *= phase_factor;
252 }
253 }
254 }
255
256 Ok(())
257 }
258}
259
260#[derive(Debug, Clone, Copy)]
262pub struct SGateSpecialized {
263 pub target: QubitId,
264}
265
266impl SpecializedGate for SGateSpecialized {
267 fn apply_specialized(
268 &self,
269 state: &mut [Complex64],
270 n_qubits: usize,
271 parallel: bool,
272 ) -> QuantRS2Result<()> {
273 let phase_gate = PhaseSpecialized {
274 target: self.target,
275 phase: FRAC_PI_2,
276 };
277 phase_gate.apply_specialized(state, n_qubits, parallel)
278 }
279}
280
281#[derive(Debug, Clone, Copy)]
283pub struct TGateSpecialized {
284 pub target: QubitId,
285}
286
287impl SpecializedGate for TGateSpecialized {
288 fn apply_specialized(
289 &self,
290 state: &mut [Complex64],
291 n_qubits: usize,
292 parallel: bool,
293 ) -> QuantRS2Result<()> {
294 let phase_gate = PhaseSpecialized {
295 target: self.target,
296 phase: FRAC_PI_4,
297 };
298 phase_gate.apply_specialized(state, n_qubits, parallel)
299 }
300}
301
302#[derive(Debug, Clone, Copy)]
304pub struct RXSpecialized {
305 pub target: QubitId,
306 pub theta: f64,
307}
308
309impl SpecializedGate for RXSpecialized {
310 fn apply_specialized(
311 &self,
312 state: &mut [Complex64],
313 n_qubits: usize,
314 parallel: bool,
315 ) -> QuantRS2Result<()> {
316 let target_idx = self.target.id() as usize;
317 if target_idx >= n_qubits {
318 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
319 }
320
321 let cos_half = (self.theta / 2.0).cos();
322 let sin_half = (self.theta / 2.0).sin();
323 let i_sin = Complex64::new(0.0, -sin_half);
324
325 if parallel {
326 let state_copy = state.to_vec();
327 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
328 let bit_val = (idx >> target_idx) & 1;
329 let paired_idx = idx ^ (1 << target_idx);
330
331 let val0 = if bit_val == 0 {
332 state_copy[idx]
333 } else {
334 state_copy[paired_idx]
335 };
336 let val1 = if bit_val == 0 {
337 state_copy[paired_idx]
338 } else {
339 state_copy[idx]
340 };
341
342 *amp = if bit_val == 0 {
343 cos_half * val0 + i_sin * val1
344 } else {
345 i_sin * val0 + cos_half * val1
346 };
347 });
348 } else {
349 for i in 0..(1 << n_qubits) {
350 if (i >> target_idx) & 1 == 0 {
351 let j = i | (1 << target_idx);
352 let temp0 = state[i];
353 let temp1 = state[j];
354 state[i] = cos_half * temp0 + i_sin * temp1;
355 state[j] = i_sin * temp0 + cos_half * temp1;
356 }
357 }
358 }
359
360 Ok(())
361 }
362}
363
364#[derive(Debug, Clone, Copy)]
366pub struct RYSpecialized {
367 pub target: QubitId,
368 pub theta: f64,
369}
370
371impl SpecializedGate for RYSpecialized {
372 fn apply_specialized(
373 &self,
374 state: &mut [Complex64],
375 n_qubits: usize,
376 parallel: bool,
377 ) -> QuantRS2Result<()> {
378 let target_idx = self.target.id() as usize;
379 if target_idx >= n_qubits {
380 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
381 }
382
383 let cos_half = (self.theta / 2.0).cos();
384 let sin_half = (self.theta / 2.0).sin();
385
386 if parallel {
387 let state_copy = state.to_vec();
388 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
389 let bit_val = (idx >> target_idx) & 1;
390 let paired_idx = idx ^ (1 << target_idx);
391
392 let val0 = if bit_val == 0 {
393 state_copy[idx]
394 } else {
395 state_copy[paired_idx]
396 };
397 let val1 = if bit_val == 0 {
398 state_copy[paired_idx]
399 } else {
400 state_copy[idx]
401 };
402
403 *amp = if bit_val == 0 {
404 cos_half * val0 - sin_half * val1
405 } else {
406 sin_half * val0 + cos_half * val1
407 };
408 });
409 } else {
410 for i in 0..(1 << n_qubits) {
411 if (i >> target_idx) & 1 == 0 {
412 let j = i | (1 << target_idx);
413 let temp0 = state[i];
414 let temp1 = state[j];
415 state[i] = cos_half * temp0 - sin_half * temp1;
416 state[j] = sin_half * temp0 + cos_half * temp1;
417 }
418 }
419 }
420
421 Ok(())
422 }
423}
424
425#[derive(Debug, Clone, Copy)]
427pub struct RZSpecialized {
428 pub target: QubitId,
429 pub theta: f64,
430}
431
432impl SpecializedGate for RZSpecialized {
433 fn apply_specialized(
434 &self,
435 state: &mut [Complex64],
436 n_qubits: usize,
437 parallel: bool,
438 ) -> QuantRS2Result<()> {
439 let target_idx = self.target.id() as usize;
440 if target_idx >= n_qubits {
441 return Err(QuantRS2Error::InvalidQubitId(self.target.id()));
442 }
443
444 let phase_0 = Complex64::from_polar(1.0, -self.theta / 2.0);
445 let phase_1 = Complex64::from_polar(1.0, self.theta / 2.0);
446
447 if parallel {
448 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
449 if (idx >> target_idx) & 1 == 0 {
450 *amp *= phase_0;
451 } else {
452 *amp *= phase_1;
453 }
454 });
455 } else {
456 for i in 0..(1 << n_qubits) {
457 if (i >> target_idx) & 1 == 0 {
458 state[i] *= phase_0;
459 } else {
460 state[i] *= phase_1;
461 }
462 }
463 }
464
465 Ok(())
466 }
467}
468
469#[derive(Debug, Clone, Copy)]
473pub struct CNOTSpecialized {
474 pub control: QubitId,
475 pub target: QubitId,
476}
477
478impl SpecializedGate for CNOTSpecialized {
479 fn apply_specialized(
480 &self,
481 state: &mut [Complex64],
482 n_qubits: usize,
483 parallel: bool,
484 ) -> QuantRS2Result<()> {
485 let control_idx = self.control.id() as usize;
486 let target_idx = self.target.id() as usize;
487
488 if control_idx >= n_qubits || target_idx >= n_qubits {
489 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
490 self.control.id()
491 } else {
492 self.target.id()
493 }));
494 }
495
496 if control_idx == target_idx {
497 return Err(QuantRS2Error::CircuitValidationFailed(
498 "Control and target qubits must be different".into(),
499 ));
500 }
501
502 if parallel {
503 let state_copy = state.to_vec();
504 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
505 if (idx >> control_idx) & 1 == 1 {
506 let flipped_idx = idx ^ (1 << target_idx);
507 *amp = state_copy[flipped_idx];
508 }
509 });
510 } else {
511 for i in 0..(1 << n_qubits) {
512 if (i >> control_idx) & 1 == 1 && (i >> target_idx) & 1 == 0 {
513 let j = i | (1 << target_idx);
514 state.swap(i, j);
515 }
516 }
517 }
518
519 Ok(())
520 }
521
522 fn can_fuse_with(&self, other: &dyn SpecializedGate) -> bool {
523 if let Some(other_cnot) = other.as_any().downcast_ref::<CNOTSpecialized>() {
525 self.control == other_cnot.control && self.target == other_cnot.target
526 } else {
527 false
528 }
529 }
530}
531
532#[derive(Debug, Clone, Copy)]
534pub struct CZSpecialized {
535 pub control: QubitId,
536 pub target: QubitId,
537}
538
539impl SpecializedGate for CZSpecialized {
540 fn apply_specialized(
541 &self,
542 state: &mut [Complex64],
543 n_qubits: usize,
544 parallel: bool,
545 ) -> QuantRS2Result<()> {
546 let control_idx = self.control.id() as usize;
547 let target_idx = self.target.id() as usize;
548
549 if control_idx >= n_qubits || target_idx >= n_qubits {
550 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
551 self.control.id()
552 } else {
553 self.target.id()
554 }));
555 }
556
557 if control_idx == target_idx {
558 return Err(QuantRS2Error::CircuitValidationFailed(
559 "Control and target qubits must be different".into(),
560 ));
561 }
562
563 if parallel {
564 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
565 if (idx >> control_idx) & 1 == 1 && (idx >> target_idx) & 1 == 1 {
566 *amp = -*amp;
567 }
568 });
569 } else {
570 for i in 0..(1 << n_qubits) {
571 if (i >> control_idx) & 1 == 1 && (i >> target_idx) & 1 == 1 {
572 state[i] = -state[i];
573 }
574 }
575 }
576
577 Ok(())
578 }
579}
580
581#[derive(Debug, Clone, Copy)]
583pub struct SWAPSpecialized {
584 pub qubit1: QubitId,
585 pub qubit2: QubitId,
586}
587
588impl SpecializedGate for SWAPSpecialized {
589 fn apply_specialized(
590 &self,
591 state: &mut [Complex64],
592 n_qubits: usize,
593 parallel: bool,
594 ) -> QuantRS2Result<()> {
595 let idx1 = self.qubit1.id() as usize;
596 let idx2 = self.qubit2.id() as usize;
597
598 if idx1 >= n_qubits || idx2 >= n_qubits {
599 return Err(QuantRS2Error::InvalidQubitId(if idx1 >= n_qubits {
600 self.qubit1.id()
601 } else {
602 self.qubit2.id()
603 }));
604 }
605
606 if idx1 == idx2 {
607 return Ok(()); }
609
610 if parallel {
611 let state_copy = state.to_vec();
612 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
613 let bit1 = (idx >> idx1) & 1;
614 let bit2 = (idx >> idx2) & 1;
615
616 if bit1 != bit2 {
617 let swapped_idx =
618 (idx & !(1 << idx1) & !(1 << idx2)) | (bit2 << idx1) | (bit1 << idx2);
619 *amp = state_copy[swapped_idx];
620 }
621 });
622 } else {
623 for i in 0..(1 << n_qubits) {
624 let bit1 = (i >> idx1) & 1;
625 let bit2 = (i >> idx2) & 1;
626
627 if bit1 == 0 && bit2 == 1 {
628 let j = (i | (1 << idx1)) & !(1 << idx2);
629 state.swap(i, j);
630 }
631 }
632 }
633
634 Ok(())
635 }
636}
637
638#[derive(Debug, Clone, Copy)]
640pub struct CPhaseSpecialized {
641 pub control: QubitId,
642 pub target: QubitId,
643 pub phase: f64,
644}
645
646impl SpecializedGate for CPhaseSpecialized {
647 fn apply_specialized(
648 &self,
649 state: &mut [Complex64],
650 n_qubits: usize,
651 parallel: bool,
652 ) -> QuantRS2Result<()> {
653 let control_idx = self.control.id() as usize;
654 let target_idx = self.target.id() as usize;
655
656 if control_idx >= n_qubits || target_idx >= n_qubits {
657 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
658 self.control.id()
659 } else {
660 self.target.id()
661 }));
662 }
663
664 if control_idx == target_idx {
665 return Err(QuantRS2Error::CircuitValidationFailed(
666 "Control and target qubits must be different".into(),
667 ));
668 }
669
670 let phase_factor = Complex64::from_polar(1.0, self.phase);
671
672 if parallel {
673 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
674 if (idx >> control_idx) & 1 == 1 && (idx >> target_idx) & 1 == 1 {
675 *amp *= phase_factor;
676 }
677 });
678 } else {
679 for i in 0..(1 << n_qubits) {
680 if (i >> control_idx) & 1 == 1 && (i >> target_idx) & 1 == 1 {
681 state[i] *= phase_factor;
682 }
683 }
684 }
685
686 Ok(())
687 }
688}
689
690#[derive(Debug, Clone, Copy)]
694pub struct ToffoliSpecialized {
695 pub control1: QubitId,
696 pub control2: QubitId,
697 pub target: QubitId,
698}
699
700impl SpecializedGate for ToffoliSpecialized {
701 fn apply_specialized(
702 &self,
703 state: &mut [Complex64],
704 n_qubits: usize,
705 parallel: bool,
706 ) -> QuantRS2Result<()> {
707 let ctrl1_idx = self.control1.id() as usize;
708 let ctrl2_idx = self.control2.id() as usize;
709 let target_idx = self.target.id() as usize;
710
711 if ctrl1_idx >= n_qubits || ctrl2_idx >= n_qubits || target_idx >= n_qubits {
712 return Err(QuantRS2Error::InvalidQubitId(if ctrl1_idx >= n_qubits {
713 self.control1.id()
714 } else if ctrl2_idx >= n_qubits {
715 self.control2.id()
716 } else {
717 self.target.id()
718 }));
719 }
720
721 if ctrl1_idx == ctrl2_idx || ctrl1_idx == target_idx || ctrl2_idx == target_idx {
722 return Err(QuantRS2Error::CircuitValidationFailed(
723 "All qubits must be different".into(),
724 ));
725 }
726
727 if parallel {
728 let state_copy = state.to_vec();
729 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
730 if (idx >> ctrl1_idx) & 1 == 1 && (idx >> ctrl2_idx) & 1 == 1 {
731 let flipped_idx = idx ^ (1 << target_idx);
732 *amp = state_copy[flipped_idx];
733 }
734 });
735 } else {
736 for i in 0..(1 << n_qubits) {
737 if (i >> ctrl1_idx) & 1 == 1
738 && (i >> ctrl2_idx) & 1 == 1
739 && (i >> target_idx) & 1 == 0
740 {
741 let j = i | (1 << target_idx);
742 state.swap(i, j);
743 }
744 }
745 }
746
747 Ok(())
748 }
749}
750
751#[derive(Debug, Clone, Copy)]
753pub struct FredkinSpecialized {
754 pub control: QubitId,
755 pub target1: QubitId,
756 pub target2: QubitId,
757}
758
759impl SpecializedGate for FredkinSpecialized {
760 fn apply_specialized(
761 &self,
762 state: &mut [Complex64],
763 n_qubits: usize,
764 parallel: bool,
765 ) -> QuantRS2Result<()> {
766 let ctrl_idx = self.control.id() as usize;
767 let tgt1_idx = self.target1.id() as usize;
768 let tgt2_idx = self.target2.id() as usize;
769
770 if ctrl_idx >= n_qubits || tgt1_idx >= n_qubits || tgt2_idx >= n_qubits {
771 return Err(QuantRS2Error::InvalidQubitId(if ctrl_idx >= n_qubits {
772 self.control.id()
773 } else if tgt1_idx >= n_qubits {
774 self.target1.id()
775 } else {
776 self.target2.id()
777 }));
778 }
779
780 if ctrl_idx == tgt1_idx || ctrl_idx == tgt2_idx || tgt1_idx == tgt2_idx {
781 return Err(QuantRS2Error::CircuitValidationFailed(
782 "All qubits must be different".into(),
783 ));
784 }
785
786 if parallel {
787 let state_copy = state.to_vec();
788 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
789 if (idx >> ctrl_idx) & 1 == 1 {
790 let bit1 = (idx >> tgt1_idx) & 1;
791 let bit2 = (idx >> tgt2_idx) & 1;
792
793 if bit1 != bit2 {
794 let swapped_idx = (idx & !(1 << tgt1_idx) & !(1 << tgt2_idx))
795 | (bit2 << tgt1_idx)
796 | (bit1 << tgt2_idx);
797 *amp = state_copy[swapped_idx];
798 }
799 }
800 });
801 } else {
802 for i in 0..(1 << n_qubits) {
803 if (i >> ctrl_idx) & 1 == 1 {
804 let bit1 = (i >> tgt1_idx) & 1;
805 let bit2 = (i >> tgt2_idx) & 1;
806
807 if bit1 == 0 && bit2 == 1 {
808 let j = (i | (1 << tgt1_idx)) & !(1 << tgt2_idx);
809 state.swap(i, j);
810 }
811 }
812 }
813 }
814
815 Ok(())
816 }
817}
818
819pub fn specialize_gate(gate: &dyn GateOp) -> Option<Box<dyn SpecializedGate>> {
823 use quantrs2_core::gate::{multi::*, single::*};
824 use std::any::Any;
825
826 if let Some(h) = gate.as_any().downcast_ref::<Hadamard>() {
828 return Some(Box::new(HadamardSpecialized { target: h.target }));
829 }
830 if let Some(x) = gate.as_any().downcast_ref::<PauliX>() {
831 return Some(Box::new(PauliXSpecialized { target: x.target }));
832 }
833 if let Some(y) = gate.as_any().downcast_ref::<PauliY>() {
834 return Some(Box::new(PauliYSpecialized { target: y.target }));
835 }
836 if let Some(z) = gate.as_any().downcast_ref::<PauliZ>() {
837 return Some(Box::new(PauliZSpecialized { target: z.target }));
838 }
839 if let Some(rx) = gate.as_any().downcast_ref::<RotationX>() {
840 return Some(Box::new(RXSpecialized {
841 target: rx.target,
842 theta: rx.theta,
843 }));
844 }
845 if let Some(ry) = gate.as_any().downcast_ref::<RotationY>() {
846 return Some(Box::new(RYSpecialized {
847 target: ry.target,
848 theta: ry.theta,
849 }));
850 }
851 if let Some(rz) = gate.as_any().downcast_ref::<RotationZ>() {
852 return Some(Box::new(RZSpecialized {
853 target: rz.target,
854 theta: rz.theta,
855 }));
856 }
857 if let Some(s) = gate.as_any().downcast_ref::<Phase>() {
858 return Some(Box::new(SGateSpecialized { target: s.target }));
859 }
860 if let Some(t) = gate.as_any().downcast_ref::<T>() {
861 return Some(Box::new(TGateSpecialized { target: t.target }));
862 }
863
864 if let Some(cnot) = gate.as_any().downcast_ref::<CNOT>() {
866 return Some(Box::new(CNOTSpecialized {
867 control: cnot.control,
868 target: cnot.target,
869 }));
870 }
871 if let Some(cz) = gate.as_any().downcast_ref::<CZ>() {
872 return Some(Box::new(CZSpecialized {
873 control: cz.control,
874 target: cz.target,
875 }));
876 }
877 if let Some(swap) = gate.as_any().downcast_ref::<SWAP>() {
878 return Some(Box::new(SWAPSpecialized {
879 qubit1: swap.qubit1,
880 qubit2: swap.qubit2,
881 }));
882 }
883
884 None
885}
886
887macro_rules! impl_gate_op_for_specialized {
890 ($gate_type:ty, $name:expr, $qubits:expr, $matrix:expr) => {
891 impl GateOp for $gate_type {
892 fn name(&self) -> &'static str {
893 $name
894 }
895
896 fn qubits(&self) -> Vec<QubitId> {
897 $qubits(self)
898 }
899
900 fn matrix(&self) -> QuantRS2Result<Vec<Complex64>> {
901 $matrix(self)
902 }
903
904 fn as_any(&self) -> &dyn Any {
905 self
906 }
907
908 fn clone_gate(&self) -> Box<dyn GateOp> {
909 Box::new(self.clone())
910 }
911 }
912 };
913}
914
915impl_gate_op_for_specialized!(
917 HadamardSpecialized,
918 "H",
919 |g: &HadamardSpecialized| vec![g.target],
920 |_: &HadamardSpecialized| {
921 let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
922 Ok(vec![
923 Complex64::new(sqrt2_inv, 0.0),
924 Complex64::new(sqrt2_inv, 0.0),
925 Complex64::new(sqrt2_inv, 0.0),
926 Complex64::new(-sqrt2_inv, 0.0),
927 ])
928 }
929);
930
931impl_gate_op_for_specialized!(
932 PauliXSpecialized,
933 "X",
934 |g: &PauliXSpecialized| vec![g.target],
935 |_: &PauliXSpecialized| Ok(vec![
936 Complex64::new(0.0, 0.0),
937 Complex64::new(1.0, 0.0),
938 Complex64::new(1.0, 0.0),
939 Complex64::new(0.0, 0.0),
940 ])
941);
942
943impl_gate_op_for_specialized!(
944 PauliYSpecialized,
945 "Y",
946 |g: &PauliYSpecialized| vec![g.target],
947 |_: &PauliYSpecialized| Ok(vec![
948 Complex64::new(0.0, 0.0),
949 Complex64::new(0.0, -1.0),
950 Complex64::new(0.0, 1.0),
951 Complex64::new(0.0, 0.0),
952 ])
953);
954
955impl_gate_op_for_specialized!(
956 PauliZSpecialized,
957 "Z",
958 |g: &PauliZSpecialized| vec![g.target],
959 |_: &PauliZSpecialized| Ok(vec![
960 Complex64::new(1.0, 0.0),
961 Complex64::new(0.0, 0.0),
962 Complex64::new(0.0, 0.0),
963 Complex64::new(-1.0, 0.0),
964 ])
965);
966
967impl_gate_op_for_specialized!(
969 CNOTSpecialized,
970 "CNOT",
971 |g: &CNOTSpecialized| vec![g.control, g.target],
972 |_: &CNOTSpecialized| Ok(vec![
973 Complex64::new(1.0, 0.0),
974 Complex64::new(0.0, 0.0),
975 Complex64::new(0.0, 0.0),
976 Complex64::new(0.0, 0.0),
977 Complex64::new(0.0, 0.0),
978 Complex64::new(1.0, 0.0),
979 Complex64::new(0.0, 0.0),
980 Complex64::new(0.0, 0.0),
981 Complex64::new(0.0, 0.0),
982 Complex64::new(0.0, 0.0),
983 Complex64::new(0.0, 0.0),
984 Complex64::new(1.0, 0.0),
985 Complex64::new(0.0, 0.0),
986 Complex64::new(0.0, 0.0),
987 Complex64::new(1.0, 0.0),
988 Complex64::new(0.0, 0.0),
989 ])
990);
991
992impl_gate_op_for_specialized!(
994 PhaseSpecialized,
995 "Phase",
996 |g: &PhaseSpecialized| vec![g.target],
997 |g: &PhaseSpecialized| Ok(vec![
998 Complex64::new(1.0, 0.0),
999 Complex64::new(0.0, 0.0),
1000 Complex64::new(0.0, 0.0),
1001 Complex64::from_polar(1.0, g.phase),
1002 ])
1003);
1004
1005impl_gate_op_for_specialized!(
1006 SGateSpecialized,
1007 "S",
1008 |g: &SGateSpecialized| vec![g.target],
1009 |_: &SGateSpecialized| Ok(vec![
1010 Complex64::new(1.0, 0.0),
1011 Complex64::new(0.0, 0.0),
1012 Complex64::new(0.0, 0.0),
1013 Complex64::new(0.0, 1.0),
1014 ])
1015);
1016
1017impl_gate_op_for_specialized!(
1018 TGateSpecialized,
1019 "T",
1020 |g: &TGateSpecialized| vec![g.target],
1021 |_: &TGateSpecialized| {
1022 let phase = Complex64::from_polar(1.0, PI / 4.0);
1023 Ok(vec![
1024 Complex64::new(1.0, 0.0),
1025 Complex64::new(0.0, 0.0),
1026 Complex64::new(0.0, 0.0),
1027 phase,
1028 ])
1029 }
1030);
1031
1032impl_gate_op_for_specialized!(
1033 RXSpecialized,
1034 "RX",
1035 |g: &RXSpecialized| vec![g.target],
1036 |g: &RXSpecialized| {
1037 let cos = (g.theta / 2.0).cos();
1038 let sin = (g.theta / 2.0).sin();
1039 Ok(vec![
1040 Complex64::new(cos, 0.0),
1041 Complex64::new(0.0, -sin),
1042 Complex64::new(0.0, -sin),
1043 Complex64::new(cos, 0.0),
1044 ])
1045 }
1046);
1047
1048impl_gate_op_for_specialized!(
1049 RYSpecialized,
1050 "RY",
1051 |g: &RYSpecialized| vec![g.target],
1052 |g: &RYSpecialized| {
1053 let cos = (g.theta / 2.0).cos();
1054 let sin = (g.theta / 2.0).sin();
1055 Ok(vec![
1056 Complex64::new(cos, 0.0),
1057 Complex64::new(-sin, 0.0),
1058 Complex64::new(sin, 0.0),
1059 Complex64::new(cos, 0.0),
1060 ])
1061 }
1062);
1063
1064impl_gate_op_for_specialized!(
1065 RZSpecialized,
1066 "RZ",
1067 |g: &RZSpecialized| vec![g.target],
1068 |g: &RZSpecialized| {
1069 let phase_pos = Complex64::from_polar(1.0, g.theta / 2.0);
1070 let phase_neg = Complex64::from_polar(1.0, -g.theta / 2.0);
1071 Ok(vec![
1072 phase_neg,
1073 Complex64::new(0.0, 0.0),
1074 Complex64::new(0.0, 0.0),
1075 phase_pos,
1076 ])
1077 }
1078);
1079
1080impl_gate_op_for_specialized!(
1081 CZSpecialized,
1082 "CZ",
1083 |g: &CZSpecialized| vec![g.control, g.target],
1084 |_: &CZSpecialized| Ok(vec![
1085 Complex64::new(1.0, 0.0),
1086 Complex64::new(0.0, 0.0),
1087 Complex64::new(0.0, 0.0),
1088 Complex64::new(0.0, 0.0),
1089 Complex64::new(0.0, 0.0),
1090 Complex64::new(1.0, 0.0),
1091 Complex64::new(0.0, 0.0),
1092 Complex64::new(0.0, 0.0),
1093 Complex64::new(0.0, 0.0),
1094 Complex64::new(0.0, 0.0),
1095 Complex64::new(1.0, 0.0),
1096 Complex64::new(0.0, 0.0),
1097 Complex64::new(0.0, 0.0),
1098 Complex64::new(0.0, 0.0),
1099 Complex64::new(0.0, 0.0),
1100 Complex64::new(-1.0, 0.0),
1101 ])
1102);
1103
1104impl_gate_op_for_specialized!(
1105 SWAPSpecialized,
1106 "SWAP",
1107 |g: &SWAPSpecialized| vec![g.qubit1, g.qubit2],
1108 |_: &SWAPSpecialized| Ok(vec![
1109 Complex64::new(1.0, 0.0),
1110 Complex64::new(0.0, 0.0),
1111 Complex64::new(0.0, 0.0),
1112 Complex64::new(0.0, 0.0),
1113 Complex64::new(0.0, 0.0),
1114 Complex64::new(0.0, 0.0),
1115 Complex64::new(1.0, 0.0),
1116 Complex64::new(0.0, 0.0),
1117 Complex64::new(0.0, 0.0),
1118 Complex64::new(1.0, 0.0),
1119 Complex64::new(0.0, 0.0),
1120 Complex64::new(0.0, 0.0),
1121 Complex64::new(0.0, 0.0),
1122 Complex64::new(0.0, 0.0),
1123 Complex64::new(0.0, 0.0),
1124 Complex64::new(1.0, 0.0),
1125 ])
1126);
1127
1128impl_gate_op_for_specialized!(
1130 CPhaseSpecialized,
1131 "CPhase",
1132 |g: &CPhaseSpecialized| vec![g.control, g.target],
1133 |g: &CPhaseSpecialized| {
1134 let phase = Complex64::from_polar(1.0, g.phase);
1135 Ok(vec![
1136 Complex64::new(1.0, 0.0),
1137 Complex64::new(0.0, 0.0),
1138 Complex64::new(0.0, 0.0),
1139 Complex64::new(0.0, 0.0),
1140 Complex64::new(0.0, 0.0),
1141 Complex64::new(1.0, 0.0),
1142 Complex64::new(0.0, 0.0),
1143 Complex64::new(0.0, 0.0),
1144 Complex64::new(0.0, 0.0),
1145 Complex64::new(0.0, 0.0),
1146 Complex64::new(1.0, 0.0),
1147 Complex64::new(0.0, 0.0),
1148 Complex64::new(0.0, 0.0),
1149 Complex64::new(0.0, 0.0),
1150 Complex64::new(0.0, 0.0),
1151 phase,
1152 ])
1153 }
1154);
1155
1156impl_gate_op_for_specialized!(
1157 ToffoliSpecialized,
1158 "Toffoli",
1159 |g: &ToffoliSpecialized| vec![g.control1, g.control2, g.target],
1160 |_: &ToffoliSpecialized| Ok(vec![
1161 Complex64::new(1.0, 0.0),
1163 Complex64::new(0.0, 0.0),
1164 Complex64::new(0.0, 0.0),
1165 Complex64::new(0.0, 0.0),
1166 Complex64::new(0.0, 0.0),
1167 Complex64::new(0.0, 0.0),
1168 Complex64::new(0.0, 0.0),
1169 Complex64::new(0.0, 0.0),
1170 Complex64::new(0.0, 0.0),
1171 Complex64::new(1.0, 0.0),
1172 Complex64::new(0.0, 0.0),
1173 Complex64::new(0.0, 0.0),
1174 Complex64::new(0.0, 0.0),
1175 Complex64::new(0.0, 0.0),
1176 Complex64::new(0.0, 0.0),
1177 Complex64::new(0.0, 0.0),
1178 Complex64::new(0.0, 0.0),
1179 Complex64::new(0.0, 0.0),
1180 Complex64::new(1.0, 0.0),
1181 Complex64::new(0.0, 0.0),
1182 Complex64::new(0.0, 0.0),
1183 Complex64::new(0.0, 0.0),
1184 Complex64::new(0.0, 0.0),
1185 Complex64::new(0.0, 0.0),
1186 Complex64::new(0.0, 0.0),
1187 Complex64::new(0.0, 0.0),
1188 Complex64::new(0.0, 0.0),
1189 Complex64::new(1.0, 0.0),
1190 Complex64::new(0.0, 0.0),
1191 Complex64::new(0.0, 0.0),
1192 Complex64::new(0.0, 0.0),
1193 Complex64::new(0.0, 0.0),
1194 Complex64::new(0.0, 0.0),
1195 Complex64::new(0.0, 0.0),
1196 Complex64::new(0.0, 0.0),
1197 Complex64::new(0.0, 0.0),
1198 Complex64::new(1.0, 0.0),
1199 Complex64::new(0.0, 0.0),
1200 Complex64::new(0.0, 0.0),
1201 Complex64::new(0.0, 0.0),
1202 Complex64::new(0.0, 0.0),
1203 Complex64::new(0.0, 0.0),
1204 Complex64::new(0.0, 0.0),
1205 Complex64::new(0.0, 0.0),
1206 Complex64::new(0.0, 0.0),
1207 Complex64::new(1.0, 0.0),
1208 Complex64::new(0.0, 0.0),
1209 Complex64::new(0.0, 0.0),
1210 Complex64::new(0.0, 0.0),
1211 Complex64::new(0.0, 0.0),
1212 Complex64::new(0.0, 0.0),
1213 Complex64::new(0.0, 0.0),
1214 Complex64::new(0.0, 0.0),
1215 Complex64::new(0.0, 0.0),
1216 Complex64::new(0.0, 0.0),
1217 Complex64::new(1.0, 0.0),
1218 Complex64::new(0.0, 0.0),
1219 Complex64::new(0.0, 0.0),
1220 Complex64::new(0.0, 0.0),
1221 Complex64::new(0.0, 0.0),
1222 Complex64::new(0.0, 0.0),
1223 Complex64::new(0.0, 0.0),
1224 Complex64::new(1.0, 0.0),
1225 Complex64::new(0.0, 0.0),
1226 ])
1227);
1228
1229impl_gate_op_for_specialized!(
1230 FredkinSpecialized,
1231 "Fredkin",
1232 |g: &FredkinSpecialized| vec![g.control, g.target1, g.target2],
1233 |_: &FredkinSpecialized| Ok(vec![
1234 Complex64::new(1.0, 0.0),
1236 Complex64::new(0.0, 0.0),
1237 Complex64::new(0.0, 0.0),
1238 Complex64::new(0.0, 0.0),
1239 Complex64::new(0.0, 0.0),
1240 Complex64::new(0.0, 0.0),
1241 Complex64::new(0.0, 0.0),
1242 Complex64::new(0.0, 0.0),
1243 Complex64::new(0.0, 0.0),
1244 Complex64::new(1.0, 0.0),
1245 Complex64::new(0.0, 0.0),
1246 Complex64::new(0.0, 0.0),
1247 Complex64::new(0.0, 0.0),
1248 Complex64::new(0.0, 0.0),
1249 Complex64::new(0.0, 0.0),
1250 Complex64::new(0.0, 0.0),
1251 Complex64::new(0.0, 0.0),
1252 Complex64::new(0.0, 0.0),
1253 Complex64::new(1.0, 0.0),
1254 Complex64::new(0.0, 0.0),
1255 Complex64::new(0.0, 0.0),
1256 Complex64::new(0.0, 0.0),
1257 Complex64::new(0.0, 0.0),
1258 Complex64::new(0.0, 0.0),
1259 Complex64::new(0.0, 0.0),
1260 Complex64::new(0.0, 0.0),
1261 Complex64::new(0.0, 0.0),
1262 Complex64::new(1.0, 0.0),
1263 Complex64::new(0.0, 0.0),
1264 Complex64::new(0.0, 0.0),
1265 Complex64::new(0.0, 0.0),
1266 Complex64::new(0.0, 0.0),
1267 Complex64::new(0.0, 0.0),
1268 Complex64::new(0.0, 0.0),
1269 Complex64::new(0.0, 0.0),
1270 Complex64::new(0.0, 0.0),
1271 Complex64::new(1.0, 0.0),
1272 Complex64::new(0.0, 0.0),
1273 Complex64::new(0.0, 0.0),
1274 Complex64::new(0.0, 0.0),
1275 Complex64::new(0.0, 0.0),
1276 Complex64::new(0.0, 0.0),
1277 Complex64::new(0.0, 0.0),
1278 Complex64::new(0.0, 0.0),
1279 Complex64::new(0.0, 0.0),
1280 Complex64::new(0.0, 0.0),
1281 Complex64::new(1.0, 0.0),
1282 Complex64::new(0.0, 0.0),
1283 Complex64::new(0.0, 0.0),
1284 Complex64::new(0.0, 0.0),
1285 Complex64::new(0.0, 0.0),
1286 Complex64::new(0.0, 0.0),
1287 Complex64::new(0.0, 0.0),
1288 Complex64::new(1.0, 0.0),
1289 Complex64::new(0.0, 0.0),
1290 Complex64::new(0.0, 0.0),
1291 Complex64::new(0.0, 0.0),
1292 Complex64::new(0.0, 0.0),
1293 Complex64::new(0.0, 0.0),
1294 Complex64::new(0.0, 0.0),
1295 Complex64::new(0.0, 0.0),
1296 Complex64::new(0.0, 0.0),
1297 Complex64::new(0.0, 0.0),
1298 Complex64::new(1.0, 0.0),
1299 ])
1300);
1301
1302use std::any::Any;
1303
1304#[cfg(test)]
1305mod tests {
1306 use super::*;
1307 use num_complex::Complex64;
1308
1309 #[test]
1310 fn test_hadamard_specialized() {
1311 let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
1312 let gate = HadamardSpecialized { target: QubitId(0) };
1313
1314 gate.apply_specialized(&mut state, 1, false).unwrap();
1315
1316 let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
1317 assert!((state[0] - Complex64::new(sqrt2_inv, 0.0)).norm() < 1e-10);
1318 assert!((state[1] - Complex64::new(sqrt2_inv, 0.0)).norm() < 1e-10);
1319 }
1320
1321 #[test]
1322 fn test_cnot_specialized() {
1323 let mut state = vec![
1324 Complex64::new(0.0, 0.0),
1325 Complex64::new(1.0, 0.0),
1326 Complex64::new(0.0, 0.0),
1327 Complex64::new(0.0, 0.0),
1328 ];
1329 let gate = CNOTSpecialized {
1330 control: QubitId(0),
1331 target: QubitId(1),
1332 };
1333
1334 gate.apply_specialized(&mut state, 2, false).unwrap();
1335
1336 assert!((state[0] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
1337 assert!((state[1] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
1338 assert!((state[2] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
1339 assert!((state[3] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
1340 }
1341}