use crate::error::{QuantRS2Error, QuantRS2Result};
use crate::platform::PlatformCapabilities;
use crate::simd_ops_stubs::{SimdComplex64, SimdF64};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1};
use scirs2_core::Complex64;
pub struct SimdGateEngine {
capabilities: PlatformCapabilities,
simd_width: usize,
cache_line_size: usize,
}
impl Default for SimdGateEngine {
fn default() -> Self {
Self::new()
}
}
impl SimdGateEngine {
pub fn new() -> Self {
let capabilities = PlatformCapabilities::detect();
let simd_width = capabilities.optimal_simd_width_f64();
let cache_line_size = capabilities.cpu.cache.line_size.unwrap_or(64);
Self {
capabilities,
simd_width,
cache_line_size,
}
}
pub const fn capabilities(&self) -> &PlatformCapabilities {
&self.capabilities
}
pub const fn simd_width(&self) -> usize {
self.simd_width
}
pub fn apply_rotation_gate(
&self,
amplitudes: &mut [Complex64],
qubit: usize,
axis: RotationAxis,
angle: f64,
) -> QuantRS2Result<()> {
let num_qubits = (amplitudes.len() as f64).log2() as usize;
if qubit >= num_qubits {
return Err(QuantRS2Error::InvalidInput(
"Qubit index out of range".to_string(),
));
}
match axis {
RotationAxis::X => self.apply_rx(amplitudes, qubit, angle),
RotationAxis::Y => self.apply_ry(amplitudes, qubit, angle),
RotationAxis::Z => self.apply_rz(amplitudes, qubit, angle),
}
}
fn apply_rx(
&self,
amplitudes: &mut [Complex64],
qubit: usize,
angle: f64,
) -> QuantRS2Result<()> {
let half_angle = angle / 2.0;
let cos_half = half_angle.cos();
let sin_half = half_angle.sin();
let qubit_mask = 1 << qubit;
let mut idx0_list = Vec::new();
let mut idx1_list = Vec::new();
for i in 0..(amplitudes.len() / 2) {
let idx0 = (i & !(qubit_mask >> 1)) | ((i & (qubit_mask >> 1)) << 1);
let idx1 = idx0 | qubit_mask;
if idx1 < amplitudes.len() {
idx0_list.push(idx0);
idx1_list.push(idx1);
}
}
let pair_count = idx0_list.len();
if pair_count == 0 {
return Ok(());
}
let mut a0_real = Vec::with_capacity(pair_count);
let mut a0_imag = Vec::with_capacity(pair_count);
let mut a1_real = Vec::with_capacity(pair_count);
let mut a1_imag = Vec::with_capacity(pair_count);
for i in 0..pair_count {
let a0 = amplitudes[idx0_list[i]];
let a1 = amplitudes[idx1_list[i]];
a0_real.push(a0.re);
a0_imag.push(a0.im);
a1_real.push(a1.re);
a1_imag.push(a1.im);
}
let a0_real_view = ArrayView1::from(&a0_real);
let a0_imag_view = ArrayView1::from(&a0_imag);
let a1_real_view = ArrayView1::from(&a1_real);
let a1_imag_view = ArrayView1::from(&a1_imag);
let cos_a0_r = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, cos_half);
let cos_a0_i = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, cos_half);
let sin_a1_i = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, sin_half);
let sin_a1_r = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, sin_half);
let new_a0_r = <f64 as SimdF64>::simd_add_arrays(&cos_a0_r.view(), &sin_a1_i.view());
let new_a0_i = <f64 as SimdF64>::simd_sub_arrays(&cos_a0_i.view(), &sin_a1_r.view());
let sin_a0_i = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, sin_half);
let sin_a0_r = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, sin_half);
let cos_a1_r = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, cos_half);
let cos_a1_i = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, cos_half);
let new_a1_r = <f64 as SimdF64>::simd_add_arrays(&sin_a0_i.view(), &cos_a1_r.view());
let new_a1_i = <f64 as SimdF64>::simd_sub_arrays(&cos_a1_i.view(), &sin_a0_r.view());
for i in 0..pair_count {
amplitudes[idx0_list[i]] = Complex64::new(new_a0_r[i], new_a0_i[i]);
amplitudes[idx1_list[i]] = Complex64::new(new_a1_r[i], new_a1_i[i]);
}
Ok(())
}
fn apply_ry(
&self,
amplitudes: &mut [Complex64],
qubit: usize,
angle: f64,
) -> QuantRS2Result<()> {
let half_angle = angle / 2.0;
let cos_half = half_angle.cos();
let sin_half = half_angle.sin();
let qubit_mask = 1 << qubit;
let mut idx0_list = Vec::new();
let mut idx1_list = Vec::new();
for i in 0..(amplitudes.len() / 2) {
let idx0 = (i & !(qubit_mask >> 1)) | ((i & (qubit_mask >> 1)) << 1);
let idx1 = idx0 | qubit_mask;
if idx1 < amplitudes.len() {
idx0_list.push(idx0);
idx1_list.push(idx1);
}
}
let pair_count = idx0_list.len();
if pair_count == 0 {
return Ok(());
}
let mut a0_real = Vec::with_capacity(pair_count);
let mut a0_imag = Vec::with_capacity(pair_count);
let mut a1_real = Vec::with_capacity(pair_count);
let mut a1_imag = Vec::with_capacity(pair_count);
for i in 0..pair_count {
let a0 = amplitudes[idx0_list[i]];
let a1 = amplitudes[idx1_list[i]];
a0_real.push(a0.re);
a0_imag.push(a0.im);
a1_real.push(a1.re);
a1_imag.push(a1.im);
}
let a0_real_view = ArrayView1::from(&a0_real);
let a0_imag_view = ArrayView1::from(&a0_imag);
let a1_real_view = ArrayView1::from(&a1_real);
let a1_imag_view = ArrayView1::from(&a1_imag);
let cos_a0_r = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, cos_half);
let cos_a0_i = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, cos_half);
let sin_a1_r = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, sin_half);
let sin_a1_i = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, sin_half);
let new_a0_r = <f64 as SimdF64>::simd_sub_arrays(&cos_a0_r.view(), &sin_a1_r.view());
let new_a0_i = <f64 as SimdF64>::simd_sub_arrays(&cos_a0_i.view(), &sin_a1_i.view());
let sin_a0_r = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, sin_half);
let sin_a0_i = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, sin_half);
let cos_a1_r = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, cos_half);
let cos_a1_i = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, cos_half);
let new_a1_r = <f64 as SimdF64>::simd_add_arrays(&sin_a0_r.view(), &cos_a1_r.view());
let new_a1_i = <f64 as SimdF64>::simd_add_arrays(&sin_a0_i.view(), &cos_a1_i.view());
for i in 0..pair_count {
amplitudes[idx0_list[i]] = Complex64::new(new_a0_r[i], new_a0_i[i]);
amplitudes[idx1_list[i]] = Complex64::new(new_a1_r[i], new_a1_i[i]);
}
Ok(())
}
fn apply_rz(
&self,
amplitudes: &mut [Complex64],
qubit: usize,
angle: f64,
) -> QuantRS2Result<()> {
let half_angle = angle / 2.0;
let cos_half = half_angle.cos();
let sin_half = half_angle.sin();
let qubit_mask = 1 << qubit;
let mut idx0_list = Vec::new(); let mut idx1_list = Vec::new();
for i in 0..amplitudes.len() {
if (i & qubit_mask) == 0 {
idx0_list.push(i);
} else {
idx1_list.push(i);
}
}
if !idx0_list.is_empty() {
let mut real_parts = Vec::with_capacity(idx0_list.len());
let mut imag_parts = Vec::with_capacity(idx0_list.len());
for &idx in &idx0_list {
real_parts.push(amplitudes[idx].re);
imag_parts.push(amplitudes[idx].im);
}
let real_view = ArrayView1::from(&real_parts);
let imag_view = ArrayView1::from(&imag_parts);
let real_cos = <f64 as SimdF64>::simd_scalar_mul(&real_view, cos_half);
let imag_sin = <f64 as SimdF64>::simd_scalar_mul(&imag_view, sin_half);
let new_real = <f64 as SimdF64>::simd_add_arrays(&real_cos.view(), &imag_sin.view());
let real_sin = <f64 as SimdF64>::simd_scalar_mul(&real_view, -sin_half);
let imag_cos = <f64 as SimdF64>::simd_scalar_mul(&imag_view, cos_half);
let new_imag = <f64 as SimdF64>::simd_add_arrays(&real_sin.view(), &imag_cos.view());
for (i, &idx) in idx0_list.iter().enumerate() {
amplitudes[idx] = Complex64::new(new_real[i], new_imag[i]);
}
}
if !idx1_list.is_empty() {
let mut real_parts = Vec::with_capacity(idx1_list.len());
let mut imag_parts = Vec::with_capacity(idx1_list.len());
for &idx in &idx1_list {
real_parts.push(amplitudes[idx].re);
imag_parts.push(amplitudes[idx].im);
}
let real_view = ArrayView1::from(&real_parts);
let imag_view = ArrayView1::from(&imag_parts);
let real_cos = <f64 as SimdF64>::simd_scalar_mul(&real_view, cos_half);
let imag_sin = <f64 as SimdF64>::simd_scalar_mul(&imag_view, sin_half);
let new_real = <f64 as SimdF64>::simd_sub_arrays(&real_cos.view(), &imag_sin.view());
let real_sin = <f64 as SimdF64>::simd_scalar_mul(&real_view, sin_half);
let imag_cos = <f64 as SimdF64>::simd_scalar_mul(&imag_view, cos_half);
let new_imag = <f64 as SimdF64>::simd_add_arrays(&real_sin.view(), &imag_cos.view());
for (i, &idx) in idx1_list.iter().enumerate() {
amplitudes[idx] = Complex64::new(new_real[i], new_imag[i]);
}
}
Ok(())
}
pub fn apply_cnot(
&self,
amplitudes: &mut [Complex64],
control: usize,
target: usize,
) -> QuantRS2Result<()> {
let num_qubits = (amplitudes.len() as f64).log2() as usize;
if control >= num_qubits || target >= num_qubits {
return Err(QuantRS2Error::InvalidInput(
"Qubit index out of range".to_string(),
));
}
if control == target {
return Err(QuantRS2Error::InvalidInput(
"Control and target must be different qubits".to_string(),
));
}
let control_mask = 1 << control;
let target_mask = 1 << target;
let mut idx0_list = Vec::new();
let mut idx1_list = Vec::new();
for i in 0..amplitudes.len() {
if (i & control_mask) != 0 {
if (i & target_mask) == 0 {
let idx0 = i; let idx1 = i ^ target_mask; idx0_list.push(idx0);
idx1_list.push(idx1);
}
}
}
let pair_count = idx0_list.len();
if pair_count == 0 {
return Ok(());
}
if pair_count < 4 {
for i in 0..pair_count {
amplitudes.swap(idx0_list[i], idx1_list[i]);
}
return Ok(());
}
let mut a0_real = Vec::with_capacity(pair_count);
let mut a0_imag = Vec::with_capacity(pair_count);
let mut a1_real = Vec::with_capacity(pair_count);
let mut a1_imag = Vec::with_capacity(pair_count);
for i in 0..pair_count {
let a0 = amplitudes[idx0_list[i]];
let a1 = amplitudes[idx1_list[i]];
a0_real.push(a0.re);
a0_imag.push(a0.im);
a1_real.push(a1.re);
a1_imag.push(a1.im);
}
for i in 0..pair_count {
amplitudes[idx0_list[i]] = Complex64::new(a1_real[i], a1_imag[i]);
amplitudes[idx1_list[i]] = Complex64::new(a0_real[i], a0_imag[i]);
}
Ok(())
}
pub fn batch_apply_single_qubit(
&self,
amplitudes: &mut [Complex64],
gates: &[(usize, RotationAxis, f64)],
) -> QuantRS2Result<()> {
let mut sorted_gates = gates.to_vec();
sorted_gates.sort_by_key(|(qubit, _, _)| *qubit);
for (qubit, axis, angle) in sorted_gates {
self.apply_rotation_gate(amplitudes, qubit, axis, angle)?;
}
Ok(())
}
pub fn fidelity(&self, state1: &[Complex64], state2: &[Complex64]) -> QuantRS2Result<f64> {
if state1.len() != state2.len() {
return Err(QuantRS2Error::InvalidInput(
"States must have the same length".to_string(),
));
}
let len = state1.len();
let mut state1_real = Vec::with_capacity(len);
let mut state1_imag = Vec::with_capacity(len);
let mut state2_real = Vec::with_capacity(len);
let mut state2_imag = Vec::with_capacity(len);
for (a, b) in state1.iter().zip(state2.iter()) {
state1_real.push(a.re);
state1_imag.push(a.im);
state2_real.push(b.re);
state2_imag.push(b.im);
}
let state1_real_view = ArrayView1::from(&state1_real);
let state1_imag_view = ArrayView1::from(&state1_imag);
let state2_real_view = ArrayView1::from(&state2_real);
let state2_imag_view = ArrayView1::from(&state2_imag);
let rr = <f64 as SimdF64>::simd_mul_arrays(&state1_real_view, &state2_real_view);
let ii = <f64 as SimdF64>::simd_mul_arrays(&state1_imag_view, &state2_imag_view);
let real_sum = <f64 as SimdF64>::simd_add_arrays(&rr.view(), &ii.view());
let real_part = <f64 as SimdF64>::simd_sum_array(&real_sum.view());
let ri = <f64 as SimdF64>::simd_mul_arrays(&state1_real_view, &state2_imag_view);
let ir = <f64 as SimdF64>::simd_mul_arrays(&state1_imag_view, &state2_real_view);
let imag_diff = <f64 as SimdF64>::simd_sub_arrays(&ri.view(), &ir.view());
let imag_part = <f64 as SimdF64>::simd_sum_array(&imag_diff.view());
let fidelity = real_part.mul_add(real_part, imag_part * imag_part);
Ok(fidelity)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RotationAxis {
X,
Y,
Z,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_engine_creation() {
let engine = SimdGateEngine::new();
assert!(engine.simd_width() >= 1);
assert!(engine.simd_width() <= 8);
}
#[test]
fn test_rx_gate() {
let engine = SimdGateEngine::new();
let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
engine
.apply_rotation_gate(&mut state, 0, RotationAxis::X, std::f64::consts::PI)
.expect("Failed to apply RX gate");
assert!(state[0].norm() < 0.1);
assert!((state[1].norm() - 1.0).abs() < 1e-10);
}
#[test]
fn test_ry_gate() {
let engine = SimdGateEngine::new();
let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
engine
.apply_rotation_gate(&mut state, 0, RotationAxis::Y, std::f64::consts::PI / 2.0)
.expect("Failed to apply RY gate");
let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
assert!((state[0].norm() - sqrt2_inv).abs() < 1e-10);
assert!((state[1].norm() - sqrt2_inv).abs() < 1e-10);
}
#[test]
fn test_rz_gate() {
let engine = SimdGateEngine::new();
let mut state = vec![
Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
];
engine
.apply_rotation_gate(&mut state, 0, RotationAxis::Z, std::f64::consts::PI / 4.0)
.expect("Failed to apply RZ gate");
let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
assert!((state[0].norm() - sqrt2_inv).abs() < 1e-10);
assert!((state[1].norm() - sqrt2_inv).abs() < 1e-10);
}
#[test]
fn test_cnot_gate() {
let engine = SimdGateEngine::new();
let mut state = vec![
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
];
engine
.apply_cnot(&mut state, 1, 0)
.expect("Failed to apply CNOT gate");
assert!(state[0].norm() < 1e-10);
assert!(state[1].norm() < 1e-10);
assert!(state[2].norm() < 1e-10);
assert!((state[3].norm() - 1.0).abs() < 1e-10);
}
#[test]
fn test_fidelity() {
let engine = SimdGateEngine::new();
let state1 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
let state2 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
let fid = engine
.fidelity(&state1, &state2)
.expect("Failed to compute fidelity");
assert!((fid - 1.0).abs() < 1e-10);
let state3 = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
let fid2 = engine
.fidelity(&state1, &state3)
.expect("Failed to compute fidelity for orthogonal states");
assert!(fid2.abs() < 1e-10);
}
#[test]
fn test_batch_gates() {
let engine = SimdGateEngine::new();
let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
let gates = vec![(0, RotationAxis::X, std::f64::consts::PI / 2.0)];
engine
.batch_apply_single_qubit(&mut state, &gates)
.expect("Failed to apply batch gates");
let norm_sqr: f64 = state.iter().map(|c| c.norm_sqr()).sum();
assert!((norm_sqr - 1.0).abs() < 1e-8, "Norm squared: {}", norm_sqr);
}
}