1use scirs2_core::Complex64;
33
34use crate::optimized_simd::{
35 apply_h_gate_simd, apply_rx_gate_simd, apply_ry_gate_simd, apply_rz_gate_simd,
36 apply_s_gate_simd, apply_single_qubit_gate_optimized, apply_t_gate_simd, apply_x_gate_simd,
37 apply_y_gate_simd, apply_z_gate_simd,
38};
39
40fn gather_pairs(
47 state: &[Complex64],
48 target: usize,
49 n_qubits: usize,
50 out0: &mut Vec<Complex64>,
51 out1: &mut Vec<Complex64>,
52) -> usize {
53 let n_states = 1usize << n_qubits;
54 let stride = 1usize << target;
55 let total_pairs = n_states / 2;
56
57 out0.clear();
58 out1.clear();
59 out0.reserve(total_pairs);
60 out1.reserve(total_pairs);
61
62 let mut i = 0usize;
63 while i < n_states {
64 for j in i..(i + stride) {
65 out0.push(state[j]);
66 out1.push(state[j + stride]);
67 }
68 i += 2 * stride;
69 }
70
71 total_pairs
72}
73
74fn scatter_pairs(
76 state: &mut [Complex64],
77 target: usize,
78 n_qubits: usize,
79 src0: &[Complex64],
80 src1: &[Complex64],
81) {
82 let n_states = 1usize << n_qubits;
83 let stride = 1usize << target;
84
85 let mut pair_idx = 0usize;
86 let mut i = 0usize;
87 while i < n_states {
88 for j in i..(i + stride) {
89 state[j] = src0[pair_idx];
90 state[j + stride] = src1[pair_idx];
91 pair_idx += 1;
92 }
93 i += 2 * stride;
94 }
95}
96
97const SIMD_THRESHOLD: usize = 256; pub fn apply_gate_2x2_scalar(
114 state: &mut [Complex64],
115 matrix: [[Complex64; 2]; 2],
116 target: usize,
117 n_qubits: usize,
118) {
119 let stride = 1usize << target;
120 let n_states = 1usize << n_qubits;
121 let [[a, b], [c, d]] = matrix;
122
123 let mut i = 0usize;
124 while i < n_states {
125 for j in i..(i + stride) {
126 let zero = state[j];
127 let one = state[j + stride];
128 state[j] = a * zero + b * one;
129 state[j + stride] = c * zero + d * one;
130 }
131 i += 2 * stride;
132 }
133}
134
135pub fn apply_gate_2x2_simd(
147 state: &mut Vec<Complex64>,
148 matrix: [[Complex64; 2]; 2],
149 target: usize,
150 n_qubits: usize,
151) {
152 if state.len() < SIMD_THRESHOLD {
153 apply_gate_2x2_scalar(state, matrix, target, n_qubits);
154 return;
155 }
156
157 let flat = [matrix[0][0], matrix[0][1], matrix[1][0], matrix[1][1]];
158
159 let mut amps0 = Vec::with_capacity(state.len() / 2);
160 let mut amps1 = Vec::with_capacity(state.len() / 2);
161 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
162
163 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
164 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
165
166 apply_single_qubit_gate_optimized(&flat, &s0, &s1, &mut out0, &mut out1);
167 scatter_pairs(state, target, n_qubits, &out0, &out1);
168}
169
170pub fn apply_h_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
178 if state.len() < SIMD_THRESHOLD {
179 use std::f64::consts::FRAC_1_SQRT_2;
180 let h = [
181 [
182 Complex64::new(FRAC_1_SQRT_2, 0.0),
183 Complex64::new(FRAC_1_SQRT_2, 0.0),
184 ],
185 [
186 Complex64::new(FRAC_1_SQRT_2, 0.0),
187 Complex64::new(-FRAC_1_SQRT_2, 0.0),
188 ],
189 ];
190 apply_gate_2x2_scalar(state, h, target, n_qubits);
191 return;
192 }
193
194 let mut amps0 = Vec::with_capacity(state.len() / 2);
195 let mut amps1 = Vec::with_capacity(state.len() / 2);
196 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
197
198 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
199 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
200
201 apply_h_gate_simd(&s0, &s1, &mut out0, &mut out1);
202 scatter_pairs(state, target, n_qubits, &out0, &out1);
203}
204
205pub fn apply_x_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
209 if state.len() < SIMD_THRESHOLD {
210 apply_gate_2x2_scalar(
211 state,
212 [
213 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
214 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
215 ],
216 target,
217 n_qubits,
218 );
219 return;
220 }
221
222 let mut amps0 = Vec::with_capacity(state.len() / 2);
223 let mut amps1 = Vec::with_capacity(state.len() / 2);
224 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
225
226 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
227 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
228
229 apply_x_gate_simd(&s0, &s1, &mut out0, &mut out1);
230 scatter_pairs(state, target, n_qubits, &out0, &out1);
231}
232
233pub fn apply_y_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
237 if state.len() < SIMD_THRESHOLD {
238 apply_gate_2x2_scalar(
239 state,
240 [
241 [Complex64::new(0.0, 0.0), Complex64::new(0.0, -1.0)],
242 [Complex64::new(0.0, 1.0), Complex64::new(0.0, 0.0)],
243 ],
244 target,
245 n_qubits,
246 );
247 return;
248 }
249
250 let mut amps0 = Vec::with_capacity(state.len() / 2);
251 let mut amps1 = Vec::with_capacity(state.len() / 2);
252 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
253
254 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
255 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
256
257 apply_y_gate_simd(&s0, &s1, &mut out0, &mut out1);
258 scatter_pairs(state, target, n_qubits, &out0, &out1);
259}
260
261pub fn apply_z_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
265 if state.len() < SIMD_THRESHOLD {
266 apply_gate_2x2_scalar(
267 state,
268 [
269 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
270 [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)],
271 ],
272 target,
273 n_qubits,
274 );
275 return;
276 }
277
278 let mut amps0 = Vec::with_capacity(state.len() / 2);
279 let mut amps1 = Vec::with_capacity(state.len() / 2);
280 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
281
282 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
283 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
284
285 apply_z_gate_simd(&s0, &s1, &mut out0, &mut out1);
286 scatter_pairs(state, target, n_qubits, &out0, &out1);
287}
288
289pub fn apply_s_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
293 if state.len() < SIMD_THRESHOLD {
294 apply_gate_2x2_scalar(
295 state,
296 [
297 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
298 [Complex64::new(0.0, 0.0), Complex64::new(0.0, 1.0)],
299 ],
300 target,
301 n_qubits,
302 );
303 return;
304 }
305
306 let mut amps0 = Vec::with_capacity(state.len() / 2);
307 let mut amps1 = Vec::with_capacity(state.len() / 2);
308 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
309
310 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
311 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
312
313 apply_s_gate_simd(&s0, &s1, &mut out0, &mut out1);
314 scatter_pairs(state, target, n_qubits, &out0, &out1);
315}
316
317pub fn apply_t_simd(state: &mut Vec<Complex64>, target: usize, n_qubits: usize) {
321 if state.len() < SIMD_THRESHOLD {
322 use std::f64::consts::FRAC_PI_4;
323 apply_gate_2x2_scalar(
324 state,
325 [
326 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
327 [
328 Complex64::new(0.0, 0.0),
329 Complex64::new(FRAC_PI_4.cos(), FRAC_PI_4.sin()),
330 ],
331 ],
332 target,
333 n_qubits,
334 );
335 return;
336 }
337
338 let mut amps0 = Vec::with_capacity(state.len() / 2);
339 let mut amps1 = Vec::with_capacity(state.len() / 2);
340 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
341
342 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
343 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
344
345 apply_t_gate_simd(&s0, &s1, &mut out0, &mut out1);
346 scatter_pairs(state, target, n_qubits, &out0, &out1);
347}
348
349pub fn apply_rx_simd(state: &mut Vec<Complex64>, theta: f64, target: usize, n_qubits: usize) {
353 if state.len() < SIMD_THRESHOLD {
354 let h = theta / 2.0;
355 apply_gate_2x2_scalar(
356 state,
357 [
358 [Complex64::new(h.cos(), 0.0), Complex64::new(0.0, -h.sin())],
359 [Complex64::new(0.0, -h.sin()), Complex64::new(h.cos(), 0.0)],
360 ],
361 target,
362 n_qubits,
363 );
364 return;
365 }
366
367 let mut amps0 = Vec::with_capacity(state.len() / 2);
368 let mut amps1 = Vec::with_capacity(state.len() / 2);
369 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
370
371 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
372 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
373
374 apply_rx_gate_simd(theta, &s0, &s1, &mut out0, &mut out1);
375 scatter_pairs(state, target, n_qubits, &out0, &out1);
376}
377
378pub fn apply_ry_simd(state: &mut Vec<Complex64>, theta: f64, target: usize, n_qubits: usize) {
382 if state.len() < SIMD_THRESHOLD {
383 let h = theta / 2.0;
384 apply_gate_2x2_scalar(
385 state,
386 [
387 [Complex64::new(h.cos(), 0.0), Complex64::new(-h.sin(), 0.0)],
388 [Complex64::new(h.sin(), 0.0), Complex64::new(h.cos(), 0.0)],
389 ],
390 target,
391 n_qubits,
392 );
393 return;
394 }
395
396 let mut amps0 = Vec::with_capacity(state.len() / 2);
397 let mut amps1 = Vec::with_capacity(state.len() / 2);
398 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
399
400 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
401 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
402
403 apply_ry_gate_simd(theta, &s0, &s1, &mut out0, &mut out1);
404 scatter_pairs(state, target, n_qubits, &out0, &out1);
405}
406
407pub fn apply_rz_simd(state: &mut Vec<Complex64>, theta: f64, target: usize, n_qubits: usize) {
411 if state.len() < SIMD_THRESHOLD {
412 let h = theta / 2.0;
413 apply_gate_2x2_scalar(
414 state,
415 [
416 [Complex64::new(h.cos(), -h.sin()), Complex64::new(0.0, 0.0)],
417 [Complex64::new(0.0, 0.0), Complex64::new(h.cos(), h.sin())],
418 ],
419 target,
420 n_qubits,
421 );
422 return;
423 }
424
425 let mut amps0 = Vec::with_capacity(state.len() / 2);
426 let mut amps1 = Vec::with_capacity(state.len() / 2);
427 let n_pairs = gather_pairs(state, target, n_qubits, &mut amps0, &mut amps1);
428
429 let mut out0 = vec![Complex64::new(0.0, 0.0); n_pairs];
430 let mut out1 = vec![Complex64::new(0.0, 0.0); n_pairs];
431
432 apply_rz_gate_simd(theta, &s0, &s1, &mut out0, &mut out1);
433 scatter_pairs(state, target, n_qubits, &out0, &out1);
434}
435
436pub fn simd_available() -> bool {
445 #[cfg(target_arch = "x86_64")]
446 {
447 std::arch::is_x86_feature_detected!("avx2")
448 }
449 #[cfg(target_arch = "aarch64")]
450 {
451 true
452 }
453 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
454 {
455 false
456 }
457}
458
459#[cfg(test)]
464mod tests {
465 use super::*;
466 use std::f64::consts::{FRAC_1_SQRT_2, PI};
467
468 fn zero_state(n: usize) -> Vec<Complex64> {
470 let mut s = vec![Complex64::new(0.0, 0.0); 1 << n];
471 s[0] = Complex64::new(1.0, 0.0);
472 s
473 }
474
475 fn max_diff(a: &[Complex64], b: &[Complex64]) -> f64 {
477 a.iter()
478 .zip(b.iter())
479 .map(|(x, y)| (x - y).norm())
480 .fold(0.0_f64, f64::max)
481 }
482
483 #[test]
488 fn test_h_gate_zero_state() {
489 let mut state = zero_state(1);
491 apply_h_simd(&mut state, 0, 1);
492
493 assert!(
494 (state[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-12,
495 "H|0> amplitude of |0> wrong: {:?}",
496 state[0]
497 );
498 assert!(
499 (state[1] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-12,
500 "H|0> amplitude of |1> wrong: {:?}",
501 state[1]
502 );
503 }
504
505 #[test]
506 fn test_x_gate() {
507 let mut state = zero_state(2);
509 apply_x_simd(&mut state, 0, 2);
510
511 assert!(
513 (state[0] - Complex64::new(0.0, 0.0)).norm() < 1e-12,
514 "X|0>: state[0] should be 0"
515 );
516 assert!(
517 (state[1] - Complex64::new(1.0, 0.0)).norm() < 1e-12,
518 "X|0>: state[1] should be 1"
519 );
520 }
521
522 #[test]
523 fn test_z_gate_on_plus_state() {
524 let mut state = zero_state(1);
527 apply_h_simd(&mut state, 0, 1);
528 apply_z_simd(&mut state, 0, 1);
529
530 assert!(
532 (state[0] - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-12,
533 "Z|+>: state[0] wrong"
534 );
535 assert!(
536 (state[1] - Complex64::new(-FRAC_1_SQRT_2, 0.0)).norm() < 1e-12,
537 "Z|+>: state[1] wrong"
538 );
539 }
540
541 #[test]
542 fn test_rx_half_pi() {
543 let theta = PI / 2.0;
545 let mut state = zero_state(1);
546 apply_rx_simd(&mut state, theta, 0, 1);
547
548 let expected0 = Complex64::new((theta / 2.0).cos(), 0.0);
549 let expected1 = Complex64::new(0.0, -(theta / 2.0).sin());
550
551 assert!(
552 (state[0] - expected0).norm() < 1e-12,
553 "RX(π/2)|0>: state[0] wrong: {:?}",
554 state[0]
555 );
556 assert!(
557 (state[1] - expected1).norm() < 1e-12,
558 "RX(π/2)|0>: state[1] wrong: {:?}",
559 state[1]
560 );
561 }
562
563 #[test]
564 fn test_ry_pi() {
565 let mut state = zero_state(1);
567 apply_ry_simd(&mut state, PI, 0, 1);
568
569 assert!(
570 state[0].norm() < 1e-12,
571 "RY(π)|0>: state[0] should be ~0, got {:?}",
572 state[0]
573 );
574 assert!(
575 (state[1] - Complex64::new(1.0, 0.0)).norm() < 1e-12,
576 "RY(π)|0>: state[1] should be ~1, got {:?}",
577 state[1]
578 );
579 }
580
581 #[test]
582 fn test_s_gate() {
583 let mut state = zero_state(1);
585 apply_x_simd(&mut state, 0, 1); apply_s_simd(&mut state, 0, 1);
587
588 assert!(state[0].norm() < 1e-12, "S|1>: state[0] should be 0");
589 assert!(
590 (state[1] - Complex64::new(0.0, 1.0)).norm() < 1e-12,
591 "S|1>: state[1] should be i"
592 );
593 }
594
595 #[test]
596 fn test_t_gate() {
597 use std::f64::consts::FRAC_PI_4;
599 let mut state = zero_state(1);
600 apply_x_simd(&mut state, 0, 1); apply_t_simd(&mut state, 0, 1);
602
603 let expected = Complex64::new(FRAC_PI_4.cos(), FRAC_PI_4.sin());
604 assert!(state[0].norm() < 1e-12, "T|1>: state[0] should be 0");
605 assert!((state[1] - expected).norm() < 1e-12, "T|1>: state[1] wrong");
606 }
607
608 fn lcg_random_state(n_qubits: usize, seed: u64) -> Vec<Complex64> {
614 let mut rng = seed;
615 let mut state: Vec<Complex64> = (0..(1usize << n_qubits))
616 .map(|_| {
617 rng = rng
618 .wrapping_mul(6_364_136_223_846_793_005)
619 .wrapping_add(1_442_695_040_888_963_407);
620 let re = (rng as f64) / (u64::MAX as f64) * 2.0 - 1.0;
621 rng = rng
622 .wrapping_mul(6_364_136_223_846_793_005)
623 .wrapping_add(1_442_695_040_888_963_407);
624 let im = (rng as f64) / (u64::MAX as f64) * 2.0 - 1.0;
625 Complex64::new(re, im)
626 })
627 .collect();
628
629 let norm: f64 = state.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt();
631 state.iter_mut().for_each(|c| *c /= norm);
632 state
633 }
634
635 #[test]
636 fn test_simd_vs_scalar_h() {
637 let n = 6usize;
638 let base = lcg_random_state(n, 42);
639
640 for target in 0..n {
641 let mut simd_state = base.clone();
642 let mut scalar_state = base.clone();
643
644 apply_h_simd(&mut simd_state, target, n);
645
646 let h = [
647 [
648 Complex64::new(FRAC_1_SQRT_2, 0.0),
649 Complex64::new(FRAC_1_SQRT_2, 0.0),
650 ],
651 [
652 Complex64::new(FRAC_1_SQRT_2, 0.0),
653 Complex64::new(-FRAC_1_SQRT_2, 0.0),
654 ],
655 ];
656 apply_gate_2x2_scalar(&mut scalar_state, h, target, n);
657
658 let diff = max_diff(&simd_state, &scalar_state);
659 assert!(
660 diff < 1e-12,
661 "SIMD vs scalar H mismatch at target={}: max_diff={}",
662 target,
663 diff
664 );
665 }
666 }
667
668 #[test]
669 fn test_simd_vs_scalar_x() {
670 let n = 6usize;
671 let base = lcg_random_state(n, 123);
672 let x_mat = [
673 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
674 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
675 ];
676
677 for target in 0..n {
678 let mut simd_state = base.clone();
679 let mut scalar_state = base.clone();
680
681 apply_x_simd(&mut simd_state, target, n);
682 apply_gate_2x2_scalar(&mut scalar_state, x_mat, target, n);
683
684 let diff = max_diff(&simd_state, &scalar_state);
685 assert!(
686 diff < 1e-12,
687 "SIMD vs scalar X mismatch at target={}: max_diff={}",
688 target,
689 diff
690 );
691 }
692 }
693
694 #[test]
695 fn test_simd_vs_scalar_rz() {
696 let n = 6usize;
697 let base = lcg_random_state(n, 999);
698 let theta = 1.23456_f64;
699 let h = theta / 2.0;
700 let rz_mat = [
701 [Complex64::new(h.cos(), -h.sin()), Complex64::new(0.0, 0.0)],
702 [Complex64::new(0.0, 0.0), Complex64::new(h.cos(), h.sin())],
703 ];
704
705 for target in 0..n {
706 let mut simd_state = base.clone();
707 let mut scalar_state = base.clone();
708
709 apply_rz_simd(&mut simd_state, theta, target, n);
710 apply_gate_2x2_scalar(&mut scalar_state, rz_mat, target, n);
711
712 let diff = max_diff(&simd_state, &scalar_state);
713 assert!(
714 diff < 1e-12,
715 "SIMD vs scalar RZ mismatch at target={}: max_diff={}",
716 target,
717 diff
718 );
719 }
720 }
721
722 #[test]
723 fn test_gate_2x2_simd_identity() {
724 let n = 4usize;
726 let mut state = lcg_random_state(n, 7);
727 let original = state.clone();
728 let id = [
729 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
730 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
731 ];
732 apply_gate_2x2_scalar(&mut state, id, 0, n);
733 let diff = max_diff(&state, &original);
734 assert!(
735 diff < 1e-15,
736 "Identity gate altered state: max_diff={}",
737 diff
738 );
739 }
740
741 #[test]
742 fn test_y_gate_eigenvalue() {
743 let mut state = vec![
745 Complex64::new(FRAC_1_SQRT_2, 0.0),
746 Complex64::new(0.0, FRAC_1_SQRT_2),
747 ];
748 let original = state.clone();
749 apply_y_simd(&mut state, 0, 1);
750
751 let diff = max_diff(&state, &original);
753 assert!(
754 diff < 1e-12,
755 "Y eigenstate property failed: max_diff={}",
756 diff
757 );
758 }
759}