use crate::types::Complex;
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
use std::arch::x86_64::*;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[cfg(feature = "parallel")]
const PARALLEL_THRESHOLD: usize = 65_536;
#[inline]
pub fn apply_single_qubit_gate_scalar(
amplitudes: &mut [Complex],
qubit: u32,
matrix: &[[Complex; 2]; 2],
) {
let step = 1usize << qubit;
let n = amplitudes.len();
let mut block_start = 0;
while block_start < n {
for i in block_start..block_start + step {
let j = i + step;
let a = amplitudes[i];
let b = amplitudes[j];
amplitudes[i] = matrix[0][0] * a + matrix[0][1] * b;
amplitudes[j] = matrix[1][0] * a + matrix[1][1] * b;
}
block_start += step << 1;
}
}
#[inline]
pub fn apply_two_qubit_gate_scalar(
amplitudes: &mut [Complex],
q1: u32,
q2: u32,
matrix: &[[Complex; 4]; 4],
) {
let q1_bit = 1usize << q1;
let q2_bit = 1usize << q2;
let n = amplitudes.len();
for base in 0..n {
if base & q1_bit != 0 || base & q2_bit != 0 {
continue;
}
let idxs = [
base,
base | q2_bit,
base | q1_bit,
base | q1_bit | q2_bit,
];
let vals = [
amplitudes[idxs[0]],
amplitudes[idxs[1]],
amplitudes[idxs[2]],
amplitudes[idxs[3]],
];
for r in 0..4 {
amplitudes[idxs[r]] = matrix[r][0] * vals[0]
+ matrix[r][1] * vals[1]
+ matrix[r][2] * vals[2]
+ matrix[r][3] * vals[3];
}
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
#[target_feature(enable = "avx2")]
pub unsafe fn apply_single_qubit_gate_simd(
amplitudes: &mut [Complex],
qubit: u32,
matrix: &[[Complex; 2]; 2],
) {
let step = 1usize << qubit;
let n = amplitudes.len();
let m00_re = _mm256_set1_pd(matrix[0][0].re);
let m00_im = _mm256_set1_pd(matrix[0][0].im);
let m01_re = _mm256_set1_pd(matrix[0][1].re);
let m01_im = _mm256_set1_pd(matrix[0][1].im);
let m10_re = _mm256_set1_pd(matrix[1][0].re);
let m10_im = _mm256_set1_pd(matrix[1][0].im);
let m11_re = _mm256_set1_pd(matrix[1][1].re);
let m11_im = _mm256_set1_pd(matrix[1][1].im);
let neg_mask = _mm256_set_pd(-1.0, 1.0, -1.0, 1.0);
if step >= 2 {
let mut block_start = 0;
while block_start < n {
let mut i = block_start;
while i + 1 < block_start + step {
let j = i + step;
let a_vec = _mm256_loadu_pd(
&litudes[i] as *const Complex as *const f64,
);
let b_vec = _mm256_loadu_pd(
&litudes[j] as *const Complex as *const f64,
);
let out_i = complex_mul_add_avx2(
a_vec, m00_re, m00_im, b_vec, m01_re, m01_im, neg_mask,
);
let out_j = complex_mul_add_avx2(
a_vec, m10_re, m10_im, b_vec, m11_re, m11_im, neg_mask,
);
_mm256_storeu_pd(
&mut amplitudes[i] as *mut Complex as *mut f64,
out_i,
);
_mm256_storeu_pd(
&mut amplitudes[j] as *mut Complex as *mut f64,
out_j,
);
i += 2;
}
if step & 1 != 0 {
let i = block_start + step - 1;
let j = i + step;
let a = amplitudes[i];
let b = amplitudes[j];
amplitudes[i] = matrix[0][0] * a + matrix[0][1] * b;
amplitudes[j] = matrix[1][0] * a + matrix[1][1] * b;
}
block_start += step << 1;
}
} else {
apply_single_qubit_gate_scalar(amplitudes, qubit, matrix);
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn complex_mul_add_avx2(
a: __m256d,
ma_re: __m256d,
ma_im: __m256d,
b: __m256d,
mb_re: __m256d,
mb_im: __m256d,
neg_mask: __m256d,
) -> __m256d {
let a_swap = _mm256_permute_pd(a, 0b0101);
let prod_a_re = _mm256_mul_pd(ma_re, a);
let prod_a_im = _mm256_mul_pd(ma_im, a_swap);
let prod_a_im_signed = _mm256_mul_pd(prod_a_im, neg_mask);
let result_a = _mm256_add_pd(prod_a_re, prod_a_im_signed);
let b_swap = _mm256_permute_pd(b, 0b0101);
let prod_b_re = _mm256_mul_pd(mb_re, b);
let prod_b_im = _mm256_mul_pd(mb_im, b_swap);
let prod_b_im_signed = _mm256_mul_pd(prod_b_im, neg_mask);
let result_b = _mm256_add_pd(prod_b_re, prod_b_im_signed);
_mm256_add_pd(result_a, result_b)
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
pub fn apply_two_qubit_gate_simd(
amplitudes: &mut [Complex],
q1: u32,
q2: u32,
matrix: &[[Complex; 4]; 4],
) {
apply_two_qubit_gate_scalar(amplitudes, q1, q2, matrix);
}
#[cfg(feature = "parallel")]
pub fn apply_single_qubit_gate_parallel(
amplitudes: &mut [Complex],
qubit: u32,
matrix: &[[Complex; 2]; 2],
) {
let n = amplitudes.len();
if n < PARALLEL_THRESHOLD {
apply_single_qubit_gate_scalar(amplitudes, qubit, matrix);
return;
}
let step = 1usize << qubit;
let block_size = step << 1;
let min_chunk = 4096.max(block_size);
let chunk_size = ((min_chunk + block_size - 1) / block_size) * block_size;
let m = *matrix;
amplitudes.par_chunks_mut(chunk_size).for_each(|chunk| {
let chunk_len = chunk.len();
let mut block_start = 0;
while block_start + block_size <= chunk_len {
for i in block_start..block_start + step {
let j = i + step;
let a = chunk[i];
let b = chunk[j];
chunk[i] = m[0][0] * a + m[0][1] * b;
chunk[j] = m[1][0] * a + m[1][1] * b;
}
block_start += block_size;
}
});
}
#[cfg(feature = "parallel")]
pub fn apply_two_qubit_gate_parallel(
amplitudes: &mut [Complex],
q1: u32,
q2: u32,
matrix: &[[Complex; 4]; 4],
) {
let n = amplitudes.len();
if n < PARALLEL_THRESHOLD {
apply_two_qubit_gate_scalar(amplitudes, q1, q2, matrix);
return;
}
let q1_bit = 1usize << q1;
let q2_bit = 1usize << q2;
let m = *matrix;
let bases: Vec<usize> = (0..n)
.filter(|&base| base & q1_bit == 0 && base & q2_bit == 0)
.collect();
let amp_addr = amplitudes.as_mut_ptr() as usize;
bases.par_iter().for_each(move |&base| {
unsafe {
let ptr = amp_addr as *mut Complex;
let idxs = [
base,
base | q2_bit,
base | q1_bit,
base | q1_bit | q2_bit,
];
let vals = [
*ptr.add(idxs[0]),
*ptr.add(idxs[1]),
*ptr.add(idxs[2]),
*ptr.add(idxs[3]),
];
for r in 0..4 {
*ptr.add(idxs[r]) = m[r][0] * vals[0]
+ m[r][1] * vals[1]
+ m[r][2] * vals[2]
+ m[r][3] * vals[3];
}
}
});
}
pub fn apply_single_qubit_gate_best(
amplitudes: &mut [Complex],
qubit: u32,
matrix: &[[Complex; 2]; 2],
) {
#[cfg(feature = "parallel")]
{
if amplitudes.len() >= PARALLEL_THRESHOLD {
apply_single_qubit_gate_parallel(amplitudes, qubit, matrix);
return;
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
apply_single_qubit_gate_simd(amplitudes, qubit, matrix);
}
return;
}
}
apply_single_qubit_gate_scalar(amplitudes, qubit, matrix);
}
pub fn apply_two_qubit_gate_best(
amplitudes: &mut [Complex],
q1: u32,
q2: u32,
matrix: &[[Complex; 4]; 4],
) {
#[cfg(feature = "parallel")]
{
if amplitudes.len() >= PARALLEL_THRESHOLD {
apply_two_qubit_gate_parallel(amplitudes, q1, q2, matrix);
return;
}
}
apply_two_qubit_gate_scalar(amplitudes, q1, q2, matrix);
}