use crate::complex::Complex64;
use crate::gate::{Gate1, Gate2};
#[inline]
#[must_use]
pub(crate) fn index_pair(k: usize, i: usize) -> (usize, usize) {
let low_mask = (1usize << k) - 1;
let i0 = ((i >> k) << (k + 1)) | (i & low_mask);
let i1 = i0 | (1usize << k);
(i0, i1)
}
#[inline]
#[must_use]
pub(crate) fn index_quad(a: usize, b: usize, i: usize) -> [usize; 4] {
let (lo, hi) = if a < b { (a, b) } else { (b, a) };
let low_mask = (1usize << lo) - 1;
let t = ((i >> lo) << (lo + 1)) | (i & low_mask);
let mid_mask = (1usize << hi) - 1;
let base = ((t >> hi) << (hi + 1)) | (t & mid_mask);
let bit_a = 1usize << a;
let bit_b = 1usize << b;
[base, base | bit_b, base | bit_a, base | bit_a | bit_b]
}
pub fn apply_1q(amps: &mut [Complex64], k: usize, gate: &Gate1) {
debug_assert!(amps.len().is_power_of_two());
debug_assert!(
(1usize << k) < amps.len(),
"qubit index k must be < log2(n)"
);
#[cfg(all(target_arch = "x86_64", not(miri)))]
{
if k >= 1 && is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { apply_1q_avx2(amps, k, gate) };
return;
}
}
apply_1q_scalar(amps, k, gate);
}
fn apply_1q_scalar(amps: &mut [Complex64], k: usize, gate: &Gate1) {
let n = amps.len();
let m = &gate.m;
let pairs = n / 2;
for i in 0..pairs {
let (i0, i1) = index_pair(k, i);
debug_assert!(i0 < n && i1 < n);
unsafe {
let a0 = *amps.get_unchecked(i0);
let a1 = *amps.get_unchecked(i1);
*amps.get_unchecked_mut(i0) = m[0] * a0 + m[1] * a1;
*amps.get_unchecked_mut(i1) = m[2] * a0 + m[3] * a1;
}
}
}
#[cfg(all(target_arch = "x86_64", not(miri)))]
#[target_feature(enable = "avx2,fma")]
#[allow(clippy::similar_names)]
unsafe fn apply_1q_avx2(amps: &mut [Complex64], k: usize, gate: &Gate1) {
use std::arch::x86_64::{
_mm256_add_pd, _mm256_fmaddsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_permute_pd,
_mm256_set1_pd, _mm256_storeu_pd,
};
let n = amps.len();
let stride = 1usize << k; let m = &gate.m;
unsafe {
let m00r = _mm256_set1_pd(m[0].re);
let m00i = _mm256_set1_pd(m[0].im);
let m01r = _mm256_set1_pd(m[1].re);
let m01i = _mm256_set1_pd(m[1].im);
let m10r = _mm256_set1_pd(m[2].re);
let m10i = _mm256_set1_pd(m[2].im);
let m11r = _mm256_set1_pd(m[3].re);
let m11i = _mm256_set1_pd(m[3].im);
let ptr = amps.as_mut_ptr().cast::<f64>();
let mut base = 0;
while base < n {
let mut off = 0;
while off < stride {
let i0 = base + off;
let i1 = i0 + stride;
let v0 = _mm256_loadu_pd(ptr.add(2 * i0));
let v1 = _mm256_loadu_pd(ptr.add(2 * i1));
let v0s = _mm256_permute_pd(v0, 0b0101);
let v1s = _mm256_permute_pd(v1, 0b0101);
let m00v0 = _mm256_fmaddsub_pd(m00r, v0, _mm256_mul_pd(m00i, v0s));
let m01v1 = _mm256_fmaddsub_pd(m01r, v1, _mm256_mul_pd(m01i, v1s));
let m10v0 = _mm256_fmaddsub_pd(m10r, v0, _mm256_mul_pd(m10i, v0s));
let m11v1 = _mm256_fmaddsub_pd(m11r, v1, _mm256_mul_pd(m11i, v1s));
let new0 = _mm256_add_pd(m00v0, m01v1);
let new1 = _mm256_add_pd(m10v0, m11v1);
_mm256_storeu_pd(ptr.add(2 * i0), new0);
_mm256_storeu_pd(ptr.add(2 * i1), new1);
off += 2;
}
base += 2 * stride;
}
}
}
pub fn apply_2q(amps: &mut [Complex64], a: usize, b: usize, gate: &Gate2) {
let n = amps.len();
debug_assert!(a != b);
debug_assert!(n.is_power_of_two());
debug_assert!((1usize << a) < n && (1usize << b) < n);
if *gate == Gate2::cnot() {
apply_2q_cnot(amps, a, b);
return;
}
if *gate == Gate2::cz() {
apply_2q_cz(amps, a, b);
return;
}
if *gate == Gate2::swap() {
apply_2q_swap(amps, a, b);
return;
}
let mat = &gate.m;
let groups = n / 4;
for i in 0..groups {
let idx = index_quad(a, b, i);
debug_assert!(idx.iter().all(|&j| j < n));
let amp = unsafe {
[
*amps.get_unchecked(idx[0]),
*amps.get_unchecked(idx[1]),
*amps.get_unchecked(idx[2]),
*amps.get_unchecked(idx[3]),
]
};
for (row, &out) in idx.iter().enumerate() {
let base = row * 4;
let new = mat[base] * amp[0]
+ mat[base + 1] * amp[1]
+ mat[base + 2] * amp[2]
+ mat[base + 3] * amp[3];
unsafe {
*amps.get_unchecked_mut(out) = new;
}
}
}
}
fn apply_2q_cnot(amps: &mut [Complex64], a: usize, b: usize) {
let groups = amps.len() / 4;
for i in 0..groups {
let idx = index_quad(a, b, i);
debug_assert!(idx.iter().all(|&j| j < amps.len()));
unsafe {
let tmp = *amps.get_unchecked(idx[2]);
*amps.get_unchecked_mut(idx[2]) = *amps.get_unchecked(idx[3]);
*amps.get_unchecked_mut(idx[3]) = tmp;
}
}
}
fn apply_2q_cz(amps: &mut [Complex64], a: usize, b: usize) {
let groups = amps.len() / 4;
for i in 0..groups {
let idx = index_quad(a, b, i);
debug_assert!(idx[3] < amps.len());
unsafe {
let v = amps.get_unchecked_mut(idx[3]);
*v = -*v;
}
}
}
fn apply_2q_swap(amps: &mut [Complex64], a: usize, b: usize) {
let groups = amps.len() / 4;
for i in 0..groups {
let idx = index_quad(a, b, i);
debug_assert!(idx.iter().all(|&j| j < amps.len()));
unsafe {
let tmp = *amps.get_unchecked(idx[1]);
*amps.get_unchecked_mut(idx[1]) = *amps.get_unchecked(idx[2]);
*amps.get_unchecked_mut(idx[2]) = tmp;
}
}
}
pub fn apply_controlled_1q(
amps: &mut [Complex64],
controls: &[usize],
target: usize,
gate: &Gate1,
) {
debug_assert!(!controls.contains(&target));
debug_assert!(amps.len().is_power_of_two());
let mut control_mask = 0usize;
for &c in controls {
control_mask |= 1usize << c;
}
#[cfg(all(target_arch = "x86_64", not(miri)))]
{
if target >= 1
&& (control_mask & 1) == 0
&& is_x86_feature_detected!("avx2")
&& is_x86_feature_detected!("fma")
{
unsafe { apply_controlled_1q_avx2(amps, control_mask, target, gate) };
return;
}
}
apply_controlled_1q_scalar(amps, control_mask, target, gate);
}
fn apply_controlled_1q_scalar(
amps: &mut [Complex64],
control_mask: usize,
target: usize,
gate: &Gate1,
) {
let n = amps.len();
let m = &gate.m;
let pairs = n / 2;
for i in 0..pairs {
let (i0, i1) = index_pair(target, i);
if i0 & control_mask == control_mask {
debug_assert!(i0 < n && i1 < n);
unsafe {
let a0 = *amps.get_unchecked(i0);
let a1 = *amps.get_unchecked(i1);
*amps.get_unchecked_mut(i0) = m[0] * a0 + m[1] * a1;
*amps.get_unchecked_mut(i1) = m[2] * a0 + m[3] * a1;
}
}
}
}
#[cfg(all(target_arch = "x86_64", not(miri)))]
#[target_feature(enable = "avx2,fma")]
#[allow(clippy::similar_names)]
unsafe fn apply_controlled_1q_avx2(
amps: &mut [Complex64],
control_mask: usize,
target: usize,
gate: &Gate1,
) {
use std::arch::x86_64::{
_mm256_add_pd, _mm256_fmaddsub_pd, _mm256_loadu_pd, _mm256_mul_pd, _mm256_permute_pd,
_mm256_set1_pd, _mm256_storeu_pd,
};
let n = amps.len();
let stride = 1usize << target; let m = &gate.m;
unsafe {
let m00r = _mm256_set1_pd(m[0].re);
let m00i = _mm256_set1_pd(m[0].im);
let m01r = _mm256_set1_pd(m[1].re);
let m01i = _mm256_set1_pd(m[1].im);
let m10r = _mm256_set1_pd(m[2].re);
let m10i = _mm256_set1_pd(m[2].im);
let m11r = _mm256_set1_pd(m[3].re);
let m11i = _mm256_set1_pd(m[3].im);
let ptr = amps.as_mut_ptr().cast::<f64>();
let mut base = 0;
while base < n {
let mut off = 0;
while off < stride {
let i0a = base + off;
if i0a & control_mask == control_mask {
let i1a = i0a + stride;
let v0 = _mm256_loadu_pd(ptr.add(2 * i0a));
let v1 = _mm256_loadu_pd(ptr.add(2 * i1a));
let v0s = _mm256_permute_pd(v0, 0b0101);
let v1s = _mm256_permute_pd(v1, 0b0101);
let m00v0 = _mm256_fmaddsub_pd(m00r, v0, _mm256_mul_pd(m00i, v0s));
let m01v1 = _mm256_fmaddsub_pd(m01r, v1, _mm256_mul_pd(m01i, v1s));
let m10v0 = _mm256_fmaddsub_pd(m10r, v0, _mm256_mul_pd(m10i, v0s));
let m11v1 = _mm256_fmaddsub_pd(m11r, v1, _mm256_mul_pd(m11i, v1s));
let new0 = _mm256_add_pd(m00v0, m01v1);
let new1 = _mm256_add_pd(m10v0, m11v1);
_mm256_storeu_pd(ptr.add(2 * i0a), new0);
_mm256_storeu_pd(ptr.add(2 * i1a), new1);
}
off += 2;
}
base += 2 * stride;
}
}
}
#[cfg(kani)]
mod proofs {
use super::{index_pair, index_quad};
#[kani::proof]
fn index_pair_in_bounds() {
let q: usize = kani::any();
kani::assume((1..=12).contains(&q));
let n: usize = 1 << q;
let k: usize = kani::any();
kani::assume(k < q);
let i: usize = kani::any();
kani::assume(i < n / 2);
let (i0, i1) = index_pair(k, i);
assert!(i0 < n);
assert!(i1 < n);
assert!(i0 != i1);
assert!(i1 == i0 | (1 << k));
assert!(i0 & (1 << k) == 0);
}
#[kani::proof]
fn index_pair_injective() {
let q: usize = kani::any();
kani::assume((1..=12).contains(&q));
let k: usize = kani::any();
kani::assume(k < q);
let half: usize = (1usize << q) / 2;
let i: usize = kani::any();
let j: usize = kani::any();
kani::assume(i < half && j < half && i != j);
let (i0, _) = index_pair(k, i);
let (j0, _) = index_pair(k, j);
assert!(i0 != j0);
}
#[kani::proof]
fn index_quad_in_bounds() {
let q: usize = kani::any();
kani::assume((2..=10).contains(&q));
let n: usize = 1 << q;
let a: usize = kani::any();
let b: usize = kani::any();
kani::assume(a < q && b < q && a != b);
let i: usize = kani::any();
kani::assume(i < n / 4);
let idx = index_quad(a, b, i);
let mut x = 0;
while x < 4 {
assert!(idx[x] < n);
let mut y = x + 1;
while y < 4 {
assert!(idx[x] != idx[y]);
y += 1;
}
x += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn index_pair_inserts_zero_bit() {
for i in 0..8 {
let (i0, i1) = index_pair(1, i);
assert_eq!(i0 & (1 << 1), 0);
assert_eq!(i1, i0 | (1 << 1));
}
}
#[test]
fn index_pair_covers_all_indices_once() {
let mut seen = [false; 8];
for i in 0..4 {
let (i0, i1) = index_pair(1, i);
assert!(!std::mem::replace(&mut seen[i0], true));
assert!(!std::mem::replace(&mut seen[i1], true));
}
assert!(seen.iter().all(|&b| b));
}
#[test]
fn x_gate_flips_single_qubit() {
let mut amps = vec![Complex64::ONE, Complex64::ZERO];
apply_1q(&mut amps, 0, &Gate1::x());
assert_eq!(amps, vec![Complex64::ZERO, Complex64::ONE]);
}
#[test]
fn hadamard_then_hadamard_is_identity() {
let mut amps = vec![Complex64::ONE, Complex64::ZERO];
apply_1q(&mut amps, 0, &Gate1::h());
apply_1q(&mut amps, 0, &Gate1::h());
assert!((amps[0] - Complex64::ONE).norm() < 1e-12);
assert!(amps[1].norm() < 1e-12);
}
#[test]
fn index_quad_partitions_indices() {
let mut seen = [false; 8];
for i in 0..2 {
for &idx in &index_quad(0, 2, i) {
assert!(!std::mem::replace(&mut seen[idx], true));
}
}
assert!(seen.iter().all(|&b| b));
}
#[test]
fn cnot_flips_target_when_control_set() {
let mut amps = vec![Complex64::ZERO; 4];
amps[2] = Complex64::ONE;
apply_2q(&mut amps, 1, 0, &Gate2::cnot());
assert!((amps[3] - Complex64::ONE).norm() < 1e-12);
}
#[test]
fn controlled_x_matches_cnot() {
let mut a = vec![
Complex64::new(0.5, 0.1),
Complex64::new(0.2, -0.3),
Complex64::new(-0.4, 0.2),
Complex64::new(0.1, 0.6),
];
let mut b = a.clone();
apply_2q(&mut a, 1, 0, &Gate2::cnot());
apply_controlled_1q(&mut b, &[1], 0, &Gate1::x());
for (x, y) in a.iter().zip(&b) {
assert!((*x - *y).norm() < 1e-12);
}
}
#[test]
fn miri_apply_1q_no_ub_small() {
for n_qubits in 1..=4 {
let dim = 1usize << n_qubits;
for k in 0..n_qubits {
let mut amps = vec![Complex64::ZERO; dim];
amps[0] = Complex64::ONE;
apply_1q(&mut amps, k, &Gate1::h());
let norm: f64 = amps.iter().map(|a| a.norm_sqr()).sum();
assert!((norm - 1.0).abs() < 1e-12);
}
}
}
#[test]
fn miri_apply_2q_and_controlled_no_ub_small() {
for n_qubits in 2..=4 {
let dim = 1usize << n_qubits;
let mut amps = vec![Complex64::ZERO; dim];
amps[0] = Complex64::ONE;
apply_1q(&mut amps, 0, &Gate1::h());
apply_2q(&mut amps, 0, 1, &Gate2::cnot());
apply_controlled_1q(&mut amps, &[0], 1, &Gate1::z());
let norm: f64 = amps.iter().map(|a| a.norm_sqr()).sum();
assert!((norm - 1.0).abs() < 1e-12);
}
}
fn sample_state(n: usize) -> Vec<Complex64> {
(0..(1usize << n))
.map(|j| Complex64::new(0.1 + j as f64 * 0.03, -0.2 + j as f64 * 0.017))
.collect()
}
#[test]
fn specialized_2q_matches_dense() {
for n in 2..=10 {
for a in 0..n {
for b in 0..n {
if a == b {
continue;
}
let base = sample_state(n);
for gate in [Gate2::cnot(), Gate2::cz(), Gate2::swap()] {
let mut via_dispatch = base.clone();
apply_2q(&mut via_dispatch, a, b, &gate);
let mut via_dense = base.clone();
apply_2q_dense_ref(&mut via_dense, a, b, &gate);
for (x, y) in via_dispatch.iter().zip(&via_dense) {
assert!(
(*x - *y).norm() < 1e-12,
"specialized/dense mismatch n={n} a={a} b={b}"
);
}
}
}
}
}
}
fn apply_2q_dense_ref(amps: &mut [Complex64], a: usize, b: usize, gate: &Gate2) {
let n = amps.len();
let mat = &gate.m;
let groups = n / 4;
for i in 0..groups {
let idx = index_quad(a, b, i);
let amp = [amps[idx[0]], amps[idx[1]], amps[idx[2]], amps[idx[3]]];
for (row, &out) in idx.iter().enumerate() {
let base = row * 4;
amps[out] = mat[base] * amp[0]
+ mat[base + 1] * amp[1]
+ mat[base + 2] * amp[2]
+ mat[base + 3] * amp[3];
}
}
}
#[test]
fn simd_controlled_matches_scalar() {
let gates = [
Gate1::x(),
Gate1::y(),
Gate1::z(),
Gate1::h(),
Gate1::s(),
Gate1::rx(1.1),
Gate1::ry(-0.7),
Gate1::rz(2.3),
];
for n in 2..=10 {
for target in 0..n {
for ctrl in 0..n {
if ctrl == target {
continue;
}
let control_mask = 1usize << ctrl;
let base = sample_state(n);
for g in &gates {
let mut via_dispatch = base.clone();
apply_controlled_1q(&mut via_dispatch, &[ctrl], target, g);
let mut via_scalar = base.clone();
apply_controlled_1q_scalar(&mut via_scalar, control_mask, target, g);
for (x, y) in via_dispatch.iter().zip(&via_scalar) {
assert!(
(*x - *y).norm() < 1e-12,
"controlled SIMD/scalar mismatch n={n} target={target} ctrl={ctrl}"
);
}
}
}
}
}
}
#[test]
fn simd_path_matches_scalar() {
let gates = [
Gate1::h(),
Gate1::x(),
Gate1::y(),
Gate1::s(),
Gate1::t(),
Gate1::rx(0.7),
Gate1::ry(-1.3),
Gate1::rz(2.1),
Gate1::phase(0.9),
];
for n in 1..=10 {
for k in 0..n {
for g in &gates {
let mut via_dispatch = sample_state(n);
let mut via_scalar = via_dispatch.clone();
apply_1q(&mut via_dispatch, k, g);
apply_1q_scalar(&mut via_scalar, k, g);
for (a, b) in via_dispatch.iter().zip(&via_scalar) {
assert!(
(*a - *b).norm() < 1e-12,
"mismatch at n={n} k={k}: {a} vs {b}"
);
}
}
}
}
}
}