1use crate::error::{QuantRS2Error, QuantRS2Result};
7use crate::platform::PlatformCapabilities;
8use crate::simd_ops_stubs::{SimdComplex64, SimdF64};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1};
10use scirs2_core::Complex64;
11
12pub struct SimdGateEngine {
14 capabilities: PlatformCapabilities,
16 simd_width: usize,
18 cache_line_size: usize,
20}
21
22impl Default for SimdGateEngine {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl SimdGateEngine {
29 pub fn new() -> Self {
31 let capabilities = PlatformCapabilities::detect();
32 let simd_width = capabilities.optimal_simd_width_f64();
33 let cache_line_size = capabilities.cpu.cache.line_size.unwrap_or(64);
34
35 Self {
36 capabilities,
37 simd_width,
38 cache_line_size,
39 }
40 }
41
42 pub fn capabilities(&self) -> &PlatformCapabilities {
44 &self.capabilities
45 }
46
47 pub fn simd_width(&self) -> usize {
49 self.simd_width
50 }
51
52 pub fn apply_rotation_gate(
56 &self,
57 amplitudes: &mut [Complex64],
58 qubit: usize,
59 axis: RotationAxis,
60 angle: f64,
61 ) -> QuantRS2Result<()> {
62 let num_qubits = (amplitudes.len() as f64).log2() as usize;
63 if qubit >= num_qubits {
64 return Err(QuantRS2Error::InvalidInput(
65 "Qubit index out of range".to_string(),
66 ));
67 }
68
69 match axis {
70 RotationAxis::X => self.apply_rx(amplitudes, qubit, angle),
71 RotationAxis::Y => self.apply_ry(amplitudes, qubit, angle),
72 RotationAxis::Z => self.apply_rz(amplitudes, qubit, angle),
73 }
74 }
75
76 fn apply_rx(
78 &self,
79 amplitudes: &mut [Complex64],
80 qubit: usize,
81 angle: f64,
82 ) -> QuantRS2Result<()> {
83 let half_angle = angle / 2.0;
84 let cos_half = half_angle.cos();
85 let sin_half = half_angle.sin();
86
87 let qubit_mask = 1 << qubit;
88 let mut idx0_list = Vec::new();
89 let mut idx1_list = Vec::new();
90
91 for i in 0..(amplitudes.len() / 2) {
93 let idx0 = (i & !(qubit_mask >> 1)) | ((i & (qubit_mask >> 1)) << 1);
94 let idx1 = idx0 | qubit_mask;
95
96 if idx1 < amplitudes.len() {
97 idx0_list.push(idx0);
98 idx1_list.push(idx1);
99 }
100 }
101
102 let pair_count = idx0_list.len();
103 if pair_count == 0 {
104 return Ok(());
105 }
106
107 let mut a0_real = Vec::with_capacity(pair_count);
109 let mut a0_imag = Vec::with_capacity(pair_count);
110 let mut a1_real = Vec::with_capacity(pair_count);
111 let mut a1_imag = Vec::with_capacity(pair_count);
112
113 for i in 0..pair_count {
114 let a0 = amplitudes[idx0_list[i]];
115 let a1 = amplitudes[idx1_list[i]];
116 a0_real.push(a0.re);
117 a0_imag.push(a0.im);
118 a1_real.push(a1.re);
119 a1_imag.push(a1.im);
120 }
121
122 let a0_real_view = ArrayView1::from(&a0_real);
124 let a0_imag_view = ArrayView1::from(&a0_imag);
125 let a1_real_view = ArrayView1::from(&a1_real);
126 let a1_imag_view = ArrayView1::from(&a1_imag);
127
128 let cos_a0_r = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, cos_half);
133 let cos_a0_i = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, cos_half);
134 let sin_a1_i = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, sin_half);
135 let sin_a1_r = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, sin_half);
136
137 let new_a0_r = <f64 as SimdF64>::simd_add_arrays(&cos_a0_r.view(), &sin_a1_i.view());
138 let new_a0_i = <f64 as SimdF64>::simd_sub_arrays(&cos_a0_i.view(), &sin_a1_r.view());
139
140 let sin_a0_i = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, sin_half);
141 let sin_a0_r = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, sin_half);
142 let cos_a1_r = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, cos_half);
143 let cos_a1_i = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, cos_half);
144
145 let new_a1_r = <f64 as SimdF64>::simd_add_arrays(&sin_a0_i.view(), &cos_a1_r.view());
146 let new_a1_i = <f64 as SimdF64>::simd_sub_arrays(&cos_a1_i.view(), &sin_a0_r.view());
147
148 for i in 0..pair_count {
150 amplitudes[idx0_list[i]] = Complex64::new(new_a0_r[i], new_a0_i[i]);
151 amplitudes[idx1_list[i]] = Complex64::new(new_a1_r[i], new_a1_i[i]);
152 }
153
154 Ok(())
155 }
156
157 fn apply_ry(
159 &self,
160 amplitudes: &mut [Complex64],
161 qubit: usize,
162 angle: f64,
163 ) -> QuantRS2Result<()> {
164 let half_angle = angle / 2.0;
165 let cos_half = half_angle.cos();
166 let sin_half = half_angle.sin();
167
168 let qubit_mask = 1 << qubit;
169 let mut idx0_list = Vec::new();
170 let mut idx1_list = Vec::new();
171
172 for i in 0..(amplitudes.len() / 2) {
174 let idx0 = (i & !(qubit_mask >> 1)) | ((i & (qubit_mask >> 1)) << 1);
175 let idx1 = idx0 | qubit_mask;
176
177 if idx1 < amplitudes.len() {
178 idx0_list.push(idx0);
179 idx1_list.push(idx1);
180 }
181 }
182
183 let pair_count = idx0_list.len();
184 if pair_count == 0 {
185 return Ok(());
186 }
187
188 let mut a0_real = Vec::with_capacity(pair_count);
190 let mut a0_imag = Vec::with_capacity(pair_count);
191 let mut a1_real = Vec::with_capacity(pair_count);
192 let mut a1_imag = Vec::with_capacity(pair_count);
193
194 for i in 0..pair_count {
195 let a0 = amplitudes[idx0_list[i]];
196 let a1 = amplitudes[idx1_list[i]];
197 a0_real.push(a0.re);
198 a0_imag.push(a0.im);
199 a1_real.push(a1.re);
200 a1_imag.push(a1.im);
201 }
202
203 let a0_real_view = ArrayView1::from(&a0_real);
205 let a0_imag_view = ArrayView1::from(&a0_imag);
206 let a1_real_view = ArrayView1::from(&a1_real);
207 let a1_imag_view = ArrayView1::from(&a1_imag);
208
209 let cos_a0_r = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, cos_half);
211 let cos_a0_i = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, cos_half);
212 let sin_a1_r = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, sin_half);
213 let sin_a1_i = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, sin_half);
214
215 let new_a0_r = <f64 as SimdF64>::simd_sub_arrays(&cos_a0_r.view(), &sin_a1_r.view());
216 let new_a0_i = <f64 as SimdF64>::simd_sub_arrays(&cos_a0_i.view(), &sin_a1_i.view());
217
218 let sin_a0_r = <f64 as SimdF64>::simd_scalar_mul(&a0_real_view, sin_half);
219 let sin_a0_i = <f64 as SimdF64>::simd_scalar_mul(&a0_imag_view, sin_half);
220 let cos_a1_r = <f64 as SimdF64>::simd_scalar_mul(&a1_real_view, cos_half);
221 let cos_a1_i = <f64 as SimdF64>::simd_scalar_mul(&a1_imag_view, cos_half);
222
223 let new_a1_r = <f64 as SimdF64>::simd_add_arrays(&sin_a0_r.view(), &cos_a1_r.view());
224 let new_a1_i = <f64 as SimdF64>::simd_add_arrays(&sin_a0_i.view(), &cos_a1_i.view());
225
226 for i in 0..pair_count {
228 amplitudes[idx0_list[i]] = Complex64::new(new_a0_r[i], new_a0_i[i]);
229 amplitudes[idx1_list[i]] = Complex64::new(new_a1_r[i], new_a1_i[i]);
230 }
231
232 Ok(())
233 }
234
235 fn apply_rz(
237 &self,
238 amplitudes: &mut [Complex64],
239 qubit: usize,
240 angle: f64,
241 ) -> QuantRS2Result<()> {
242 let half_angle = angle / 2.0;
243 let cos_half = half_angle.cos();
244 let sin_half = half_angle.sin();
245
246 let qubit_mask = 1 << qubit;
247
248 let mut idx0_list = Vec::new(); let mut idx1_list = Vec::new(); for i in 0..amplitudes.len() {
253 if (i & qubit_mask) == 0 {
254 idx0_list.push(i);
255 } else {
256 idx1_list.push(i);
257 }
258 }
259
260 if !idx0_list.is_empty() {
262 let mut real_parts = Vec::with_capacity(idx0_list.len());
263 let mut imag_parts = Vec::with_capacity(idx0_list.len());
264
265 for &idx in &idx0_list {
266 real_parts.push(amplitudes[idx].re);
267 imag_parts.push(amplitudes[idx].im);
268 }
269
270 let real_view = ArrayView1::from(&real_parts);
271 let imag_view = ArrayView1::from(&imag_parts);
272
273 let real_cos = <f64 as SimdF64>::simd_scalar_mul(&real_view, cos_half);
275 let imag_sin = <f64 as SimdF64>::simd_scalar_mul(&imag_view, sin_half);
276 let new_real = <f64 as SimdF64>::simd_add_arrays(&real_cos.view(), &imag_sin.view());
277
278 let real_sin = <f64 as SimdF64>::simd_scalar_mul(&real_view, -sin_half);
279 let imag_cos = <f64 as SimdF64>::simd_scalar_mul(&imag_view, cos_half);
280 let new_imag = <f64 as SimdF64>::simd_add_arrays(&real_sin.view(), &imag_cos.view());
281
282 for (i, &idx) in idx0_list.iter().enumerate() {
283 amplitudes[idx] = Complex64::new(new_real[i], new_imag[i]);
284 }
285 }
286
287 if !idx1_list.is_empty() {
289 let mut real_parts = Vec::with_capacity(idx1_list.len());
290 let mut imag_parts = Vec::with_capacity(idx1_list.len());
291
292 for &idx in &idx1_list {
293 real_parts.push(amplitudes[idx].re);
294 imag_parts.push(amplitudes[idx].im);
295 }
296
297 let real_view = ArrayView1::from(&real_parts);
298 let imag_view = ArrayView1::from(&imag_parts);
299
300 let real_cos = <f64 as SimdF64>::simd_scalar_mul(&real_view, cos_half);
302 let imag_sin = <f64 as SimdF64>::simd_scalar_mul(&imag_view, sin_half);
303 let new_real = <f64 as SimdF64>::simd_sub_arrays(&real_cos.view(), &imag_sin.view());
304
305 let real_sin = <f64 as SimdF64>::simd_scalar_mul(&real_view, sin_half);
306 let imag_cos = <f64 as SimdF64>::simd_scalar_mul(&imag_view, cos_half);
307 let new_imag = <f64 as SimdF64>::simd_add_arrays(&real_sin.view(), &imag_cos.view());
308
309 for (i, &idx) in idx1_list.iter().enumerate() {
310 amplitudes[idx] = Complex64::new(new_real[i], new_imag[i]);
311 }
312 }
313
314 Ok(())
315 }
316
317 pub fn apply_cnot(
319 &self,
320 amplitudes: &mut [Complex64],
321 control: usize,
322 target: usize,
323 ) -> QuantRS2Result<()> {
324 let num_qubits = (amplitudes.len() as f64).log2() as usize;
325 if control >= num_qubits || target >= num_qubits {
326 return Err(QuantRS2Error::InvalidInput(
327 "Qubit index out of range".to_string(),
328 ));
329 }
330
331 if control == target {
332 return Err(QuantRS2Error::InvalidInput(
333 "Control and target must be different qubits".to_string(),
334 ));
335 }
336
337 let control_mask = 1 << control;
338 let target_mask = 1 << target;
339
340 let mut idx0_list = Vec::new();
342 let mut idx1_list = Vec::new();
343
344 for i in 0..amplitudes.len() {
347 if (i & control_mask) != 0 {
349 if (i & target_mask) == 0 {
351 let idx0 = i; let idx1 = i ^ target_mask; idx0_list.push(idx0);
354 idx1_list.push(idx1);
355 }
356 }
357 }
358
359 let pair_count = idx0_list.len();
361 if pair_count == 0 {
362 return Ok(());
363 }
364
365 if pair_count < 4 {
367 for i in 0..pair_count {
368 amplitudes.swap(idx0_list[i], idx1_list[i]);
369 }
370 return Ok(());
371 }
372
373 let mut a0_real = Vec::with_capacity(pair_count);
375 let mut a0_imag = Vec::with_capacity(pair_count);
376 let mut a1_real = Vec::with_capacity(pair_count);
377 let mut a1_imag = Vec::with_capacity(pair_count);
378
379 for i in 0..pair_count {
380 let a0 = amplitudes[idx0_list[i]];
381 let a1 = amplitudes[idx1_list[i]];
382 a0_real.push(a0.re);
383 a0_imag.push(a0.im);
384 a1_real.push(a1.re);
385 a1_imag.push(a1.im);
386 }
387
388 for i in 0..pair_count {
390 amplitudes[idx0_list[i]] = Complex64::new(a1_real[i], a1_imag[i]);
391 amplitudes[idx1_list[i]] = Complex64::new(a0_real[i], a0_imag[i]);
392 }
393
394 Ok(())
395 }
396
397 pub fn batch_apply_single_qubit(
402 &self,
403 amplitudes: &mut [Complex64],
404 gates: &[(usize, RotationAxis, f64)],
405 ) -> QuantRS2Result<()> {
406 let mut sorted_gates = gates.to_vec();
408 sorted_gates.sort_by_key(|(qubit, _, _)| *qubit);
409
410 for (qubit, axis, angle) in sorted_gates {
411 self.apply_rotation_gate(amplitudes, qubit, axis, angle)?;
412 }
413
414 Ok(())
415 }
416
417 pub fn fidelity(&self, state1: &[Complex64], state2: &[Complex64]) -> QuantRS2Result<f64> {
419 if state1.len() != state2.len() {
420 return Err(QuantRS2Error::InvalidInput(
421 "States must have the same length".to_string(),
422 ));
423 }
424
425 let len = state1.len();
427
428 let mut state1_real = Vec::with_capacity(len);
429 let mut state1_imag = Vec::with_capacity(len);
430 let mut state2_real = Vec::with_capacity(len);
431 let mut state2_imag = Vec::with_capacity(len);
432
433 for (a, b) in state1.iter().zip(state2.iter()) {
434 state1_real.push(a.re);
435 state1_imag.push(a.im);
436 state2_real.push(b.re);
437 state2_imag.push(b.im);
438 }
439
440 let state1_real_view = ArrayView1::from(&state1_real);
441 let state1_imag_view = ArrayView1::from(&state1_imag);
442 let state2_real_view = ArrayView1::from(&state2_real);
443 let state2_imag_view = ArrayView1::from(&state2_imag);
444
445 let rr = <f64 as SimdF64>::simd_mul_arrays(&state1_real_view, &state2_real_view);
447 let ii = <f64 as SimdF64>::simd_mul_arrays(&state1_imag_view, &state2_imag_view);
448 let real_sum = <f64 as SimdF64>::simd_add_arrays(&rr.view(), &ii.view());
449 let real_part = <f64 as SimdF64>::simd_sum_array(&real_sum.view());
450
451 let ri = <f64 as SimdF64>::simd_mul_arrays(&state1_real_view, &state2_imag_view);
452 let ir = <f64 as SimdF64>::simd_mul_arrays(&state1_imag_view, &state2_real_view);
453 let imag_diff = <f64 as SimdF64>::simd_sub_arrays(&ri.view(), &ir.view());
454 let imag_part = <f64 as SimdF64>::simd_sum_array(&imag_diff.view());
455
456 let fidelity = real_part * real_part + imag_part * imag_part;
458 Ok(fidelity)
459 }
460}
461
462#[derive(Debug, Clone, Copy, PartialEq, Eq)]
464pub enum RotationAxis {
465 X,
466 Y,
467 Z,
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_simd_engine_creation() {
476 let engine = SimdGateEngine::new();
477 assert!(engine.simd_width() >= 1);
478 assert!(engine.simd_width() <= 8);
479 }
480
481 #[test]
482 fn test_rx_gate() {
483 let engine = SimdGateEngine::new();
484 let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
485
486 engine
488 .apply_rotation_gate(&mut state, 0, RotationAxis::X, std::f64::consts::PI)
489 .unwrap();
490
491 assert!(state[0].norm() < 0.1);
493 assert!((state[1].norm() - 1.0).abs() < 1e-10);
494 }
495
496 #[test]
497 fn test_ry_gate() {
498 let engine = SimdGateEngine::new();
499 let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
500
501 engine
503 .apply_rotation_gate(&mut state, 0, RotationAxis::Y, std::f64::consts::PI / 2.0)
504 .unwrap();
505
506 let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
507 assert!((state[0].norm() - sqrt2_inv).abs() < 1e-10);
508 assert!((state[1].norm() - sqrt2_inv).abs() < 1e-10);
509 }
510
511 #[test]
512 fn test_rz_gate() {
513 let engine = SimdGateEngine::new();
514 let mut state = vec![
515 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
516 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
517 ];
518
519 engine
521 .apply_rotation_gate(&mut state, 0, RotationAxis::Z, std::f64::consts::PI / 4.0)
522 .unwrap();
523
524 let sqrt2_inv = 1.0 / std::f64::consts::SQRT_2;
526 assert!((state[0].norm() - sqrt2_inv).abs() < 1e-10);
527 assert!((state[1].norm() - sqrt2_inv).abs() < 1e-10);
528 }
529
530 #[test]
531 fn test_cnot_gate() {
532 let engine = SimdGateEngine::new();
533
534 let mut state = vec![
536 Complex64::new(0.0, 0.0),
537 Complex64::new(0.0, 0.0),
538 Complex64::new(1.0, 0.0),
539 Complex64::new(0.0, 0.0),
540 ];
541
542 engine.apply_cnot(&mut state, 1, 0).unwrap();
543
544 assert!(state[0].norm() < 1e-10);
548 assert!(state[1].norm() < 1e-10);
549 assert!(state[2].norm() < 1e-10);
550 assert!((state[3].norm() - 1.0).abs() < 1e-10);
551 }
552
553 #[test]
554 fn test_fidelity() {
555 let engine = SimdGateEngine::new();
556
557 let state1 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
558 let state2 = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
559
560 let fid = engine.fidelity(&state1, &state2).unwrap();
561 assert!((fid - 1.0).abs() < 1e-10);
562
563 let state3 = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
564 let fid2 = engine.fidelity(&state1, &state3).unwrap();
565 assert!(fid2.abs() < 1e-10);
566 }
567
568 #[test]
569 fn test_batch_gates() {
570 let engine = SimdGateEngine::new();
571 let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
573
574 let gates = vec![(0, RotationAxis::X, std::f64::consts::PI / 2.0)];
575
576 engine.batch_apply_single_qubit(&mut state, &gates).unwrap();
577
578 let norm_sqr: f64 = state.iter().map(|c| c.norm_sqr()).sum();
580 assert!((norm_sqr - 1.0).abs() < 1e-8, "Norm squared: {}", norm_sqr);
581 }
582}