1use super::{BatchGateOp, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 gate::{single::*, GateOp},
7 qubit::QubitId,
8};
9use ndarray::{s, Array1, Array2, Array3, Axis};
10use num_complex::Complex64;
11use rayon::prelude::*;
12
13pub fn apply_single_qubit_gate_batch(
15 batch: &mut BatchStateVector,
16 gate_matrix: &[Complex64; 4],
17 target: QubitId,
18) -> QuantRS2Result<()> {
19 let n_qubits = batch.n_qubits;
20 let target_idx = target.0 as usize;
21
22 if target_idx >= n_qubits {
23 return Err(QuantRS2Error::InvalidQubitId(target.0));
24 }
25
26 let batch_size = batch.batch_size();
27 let _state_size = 1 << n_qubits;
28
29 if batch_size > 32 && cfg!(target_feature = "avx2") {
31 apply_single_qubit_batch_simd(batch, gate_matrix, target_idx, n_qubits)?;
32 } else if batch_size > 16 {
33 batch
35 .states
36 .axis_iter_mut(Axis(0))
37 .into_par_iter()
38 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
39 let mut state = state_row.to_owned();
40 apply_single_qubit_to_state_optimized(
41 &mut state,
42 gate_matrix,
43 target_idx,
44 n_qubits,
45 )?;
46 state_row.assign(&state);
47 Ok(())
48 })?;
49 } else {
50 for i in 0..batch_size {
52 let mut state = batch.states.row(i).to_owned();
53 apply_single_qubit_to_state_optimized(&mut state, gate_matrix, target_idx, n_qubits)?;
54 batch.states.row_mut(i).assign(&state);
55 }
56 }
57
58 Ok(())
59}
60
61pub fn apply_two_qubit_gate_batch(
63 batch: &mut BatchStateVector,
64 gate_matrix: &[Complex64; 16],
65 control: QubitId,
66 target: QubitId,
67) -> QuantRS2Result<()> {
68 let n_qubits = batch.n_qubits;
69 let control_idx = control.0 as usize;
70 let target_idx = target.0 as usize;
71
72 if control_idx >= n_qubits || target_idx >= n_qubits {
73 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
74 control.0
75 } else {
76 target.0
77 }));
78 }
79
80 if control_idx == target_idx {
81 return Err(QuantRS2Error::InvalidInput(
82 "Control and target qubits must be different".to_string(),
83 ));
84 }
85
86 let batch_size = batch.batch_size();
87
88 if batch_size > 16 {
90 batch
91 .states
92 .axis_iter_mut(Axis(0))
93 .into_par_iter()
94 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
95 let mut state = state_row.to_owned();
96 apply_two_qubit_to_state(
97 &mut state,
98 gate_matrix,
99 control_idx,
100 target_idx,
101 n_qubits,
102 )?;
103 state_row.assign(&state);
104 Ok(())
105 })?;
106 } else {
107 for i in 0..batch_size {
109 let mut state = batch.states.row(i).to_owned();
110 apply_two_qubit_to_state(&mut state, gate_matrix, control_idx, target_idx, n_qubits)?;
111 batch.states.row_mut(i).assign(&state);
112 }
113 }
114
115 Ok(())
116}
117
118fn apply_single_qubit_to_state_optimized(
120 state: &mut Array1<Complex64>,
121 gate_matrix: &[Complex64; 4],
122 target_idx: usize,
123 n_qubits: usize,
124) -> QuantRS2Result<()> {
125 let state_size = 1 << n_qubits;
126 let target_mask = 1 << target_idx;
127
128 for i in 0..state_size {
129 if i & target_mask == 0 {
130 let j = i | target_mask;
131
132 let a = state[i];
133 let b = state[j];
134
135 state[i] = gate_matrix[0] * a + gate_matrix[1] * b;
136 state[j] = gate_matrix[2] * a + gate_matrix[3] * b;
137 }
138 }
139
140 Ok(())
141}
142
143#[cfg(target_feature = "avx2")]
145fn apply_single_qubit_batch_simd(
146 batch: &mut BatchStateVector,
147 gate_matrix: &[Complex64; 4],
148 target_idx: usize,
149 n_qubits: usize,
150) -> QuantRS2Result<()> {
151 use std::arch::x86_64::*;
152
153 let batch_size = batch.batch_size();
154 let state_size = 1 << n_qubits;
155 let target_mask = 1 << target_idx;
156
157 let g00_re = gate_matrix[0].re;
159 let g00_im = gate_matrix[0].im;
160 let g01_re = gate_matrix[1].re;
161 let g01_im = gate_matrix[1].im;
162 let g10_re = gate_matrix[2].re;
163 let g10_im = gate_matrix[2].im;
164 let g11_re = gate_matrix[3].re;
165 let g11_im = gate_matrix[3].im;
166
167 unsafe {
168 let g00_re_vec = _mm256_set1_pd(g00_re);
170 let g00_im_vec = _mm256_set1_pd(g00_im);
171 let g01_re_vec = _mm256_set1_pd(g01_re);
172 let g01_im_vec = _mm256_set1_pd(g01_im);
173 let g10_re_vec = _mm256_set1_pd(g10_re);
174 let g10_im_vec = _mm256_set1_pd(g10_im);
175 let g11_re_vec = _mm256_set1_pd(g11_re);
176 let g11_im_vec = _mm256_set1_pd(g11_im);
177
178 for batch_start in (0..batch_size).step_by(4) {
180 let batch_end = (batch_start + 4).min(batch_size);
181 let actual_batch_size = batch_end - batch_start;
182
183 for i in 0..state_size {
184 if i & target_mask == 0 {
185 let j = i | target_mask;
186
187 let mut a_re = [0.0; 4];
189 let mut a_im = [0.0; 4];
190 let mut b_re = [0.0; 4];
191 let mut b_im = [0.0; 4];
192
193 for k in 0..actual_batch_size {
194 a_re[k] = batch.states[[batch_start + k, i]].re;
195 a_im[k] = batch.states[[batch_start + k, i]].im;
196 b_re[k] = batch.states[[batch_start + k, j]].re;
197 b_im[k] = batch.states[[batch_start + k, j]].im;
198 }
199
200 let a_re_vec = _mm256_loadu_pd(a_re.as_ptr());
202 let a_im_vec = _mm256_loadu_pd(a_im.as_ptr());
203 let b_re_vec = _mm256_loadu_pd(b_re.as_ptr());
204 let b_im_vec = _mm256_loadu_pd(b_im.as_ptr());
205
206 let new_a_re = _mm256_add_pd(
208 _mm256_sub_pd(
209 _mm256_mul_pd(g00_re_vec, a_re_vec),
210 _mm256_mul_pd(g00_im_vec, a_im_vec),
211 ),
212 _mm256_sub_pd(
213 _mm256_mul_pd(g01_re_vec, b_re_vec),
214 _mm256_mul_pd(g01_im_vec, b_im_vec),
215 ),
216 );
217
218 let new_a_im = _mm256_add_pd(
219 _mm256_add_pd(
220 _mm256_mul_pd(g00_re_vec, a_im_vec),
221 _mm256_mul_pd(g00_im_vec, a_re_vec),
222 ),
223 _mm256_add_pd(
224 _mm256_mul_pd(g01_re_vec, b_im_vec),
225 _mm256_mul_pd(g01_im_vec, b_re_vec),
226 ),
227 );
228
229 let new_b_re = _mm256_add_pd(
231 _mm256_sub_pd(
232 _mm256_mul_pd(g10_re_vec, a_re_vec),
233 _mm256_mul_pd(g10_im_vec, a_im_vec),
234 ),
235 _mm256_sub_pd(
236 _mm256_mul_pd(g11_re_vec, b_re_vec),
237 _mm256_mul_pd(g11_im_vec, b_im_vec),
238 ),
239 );
240
241 let new_b_im = _mm256_add_pd(
242 _mm256_add_pd(
243 _mm256_mul_pd(g10_re_vec, a_im_vec),
244 _mm256_mul_pd(g10_im_vec, a_re_vec),
245 ),
246 _mm256_add_pd(
247 _mm256_mul_pd(g11_re_vec, b_im_vec),
248 _mm256_mul_pd(g11_im_vec, b_re_vec),
249 ),
250 );
251
252 let mut result_a_re = [0.0; 4];
254 let mut result_a_im = [0.0; 4];
255 let mut result_b_re = [0.0; 4];
256 let mut result_b_im = [0.0; 4];
257
258 _mm256_storeu_pd(result_a_re.as_mut_ptr(), new_a_re);
259 _mm256_storeu_pd(result_a_im.as_mut_ptr(), new_a_im);
260 _mm256_storeu_pd(result_b_re.as_mut_ptr(), new_b_re);
261 _mm256_storeu_pd(result_b_im.as_mut_ptr(), new_b_im);
262
263 for k in 0..actual_batch_size {
264 batch.states[[batch_start + k, i]] =
265 Complex64::new(result_a_re[k], result_a_im[k]);
266 batch.states[[batch_start + k, j]] =
267 Complex64::new(result_b_re[k], result_b_im[k]);
268 }
269 }
270 }
271 }
272 }
273
274 Ok(())
275}
276
277#[cfg(not(target_feature = "avx2"))]
279fn apply_single_qubit_batch_simd(
280 batch: &mut BatchStateVector,
281 gate_matrix: &[Complex64; 4],
282 target_idx: usize,
283 n_qubits: usize,
284) -> QuantRS2Result<()> {
285 let batch_size = batch.batch_size();
287 for i in 0..batch_size {
288 let mut state = batch.states.row(i).to_owned();
289 apply_single_qubit_to_state_optimized(&mut state, gate_matrix, target_idx, n_qubits)?;
290 batch.states.row_mut(i).assign(&state);
291 }
292 Ok(())
293}
294
295fn apply_two_qubit_to_state(
297 state: &mut Array1<Complex64>,
298 gate_matrix: &[Complex64; 16],
299 control_idx: usize,
300 target_idx: usize,
301 n_qubits: usize,
302) -> QuantRS2Result<()> {
303 let state_size = 1 << n_qubits;
304 let control_mask = 1 << control_idx;
305 let target_mask = 1 << target_idx;
306
307 for i in 0..state_size {
308 if (i & control_mask == 0) && (i & target_mask == 0) {
309 let i00 = i;
310 let i01 = i | target_mask;
311 let i10 = i | control_mask;
312 let i11 = i | control_mask | target_mask;
313
314 let a00 = state[i00];
315 let a01 = state[i01];
316 let a10 = state[i10];
317 let a11 = state[i11];
318
319 state[i00] = gate_matrix[0] * a00
320 + gate_matrix[1] * a01
321 + gate_matrix[2] * a10
322 + gate_matrix[3] * a11;
323 state[i01] = gate_matrix[4] * a00
324 + gate_matrix[5] * a01
325 + gate_matrix[6] * a10
326 + gate_matrix[7] * a11;
327 state[i10] = gate_matrix[8] * a00
328 + gate_matrix[9] * a01
329 + gate_matrix[10] * a10
330 + gate_matrix[11] * a11;
331 state[i11] = gate_matrix[12] * a00
332 + gate_matrix[13] * a01
333 + gate_matrix[14] * a10
334 + gate_matrix[15] * a11;
335 }
336 }
337
338 Ok(())
339}
340
341pub struct BatchHadamard;
343
344impl BatchGateOp for Hadamard {
345 fn apply_batch(
346 &self,
347 batch: &mut BatchStateVector,
348 target_qubits: &[QubitId],
349 ) -> QuantRS2Result<()> {
350 if target_qubits.len() != 1 {
351 return Err(QuantRS2Error::InvalidInput(
352 "Hadamard gate requires exactly one target qubit".to_string(),
353 ));
354 }
355
356 let gate_matrix = [
357 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
358 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
359 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
360 Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
361 ];
362
363 apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
364 }
365}
366
367impl BatchGateOp for PauliX {
369 fn apply_batch(
370 &self,
371 batch: &mut BatchStateVector,
372 target_qubits: &[QubitId],
373 ) -> QuantRS2Result<()> {
374 if target_qubits.len() != 1 {
375 return Err(QuantRS2Error::InvalidInput(
376 "Pauli-X gate requires exactly one target qubit".to_string(),
377 ));
378 }
379
380 let gate_matrix = [
381 Complex64::new(0.0, 0.0),
382 Complex64::new(1.0, 0.0),
383 Complex64::new(1.0, 0.0),
384 Complex64::new(0.0, 0.0),
385 ];
386
387 apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
388 }
389}
390
391pub fn apply_gate_sequence_batch(
393 batch: &mut BatchStateVector,
394 gates: &[(Box<dyn GateOp>, Vec<QubitId>)],
395) -> QuantRS2Result<()> {
396 for (gate, qubits) in gates {
400 {
403 let matrix = gate.matrix()?;
405
406 match qubits.len() {
407 1 => {
408 let mut gate_array = [Complex64::new(0.0, 0.0); 4];
409 gate_array.copy_from_slice(&matrix[..4]);
410 apply_single_qubit_gate_batch(batch, &gate_array, qubits[0])?;
411 }
412 2 => {
413 let mut gate_array = [Complex64::new(0.0, 0.0); 16];
414 gate_array.copy_from_slice(&matrix[..16]);
415 apply_two_qubit_gate_batch(batch, &gate_array, qubits[0], qubits[1])?;
416 }
417 _ => {
418 return Err(QuantRS2Error::InvalidInput(
419 "Batch operations for gates with more than 2 qubits not yet supported"
420 .to_string(),
421 ));
422 }
423 }
424 }
425 }
426
427 Ok(())
428}
429
430pub fn batch_state_matrix_multiply(
433 batch: &BatchStateVector,
434 matrices: &Array3<Complex64>,
435) -> QuantRS2Result<BatchStateVector> {
436 let batch_size = batch.batch_size();
437 let (num_matrices, rows, cols) = matrices.dim();
438
439 if num_matrices != batch_size {
440 return Err(QuantRS2Error::InvalidInput(format!(
441 "Number of matrices {} doesn't match batch size {}",
442 num_matrices, batch_size
443 )));
444 }
445
446 if cols != batch.states.ncols() {
447 return Err(QuantRS2Error::InvalidInput(format!(
448 "Matrix columns {} don't match state size {}",
449 cols,
450 batch.states.ncols()
451 )));
452 }
453
454 let mut result_states = Array2::zeros((batch_size, rows));
456
457 if batch_size > 16 {
459 use rayon::prelude::*;
460
461 let results: Vec<_> = (0..batch_size)
462 .into_par_iter()
463 .map(|i| {
464 let matrix = matrices.slice(s![i, .., ..]);
465 let state = batch.states.row(i);
466 matrix.dot(&state)
467 })
468 .collect();
469
470 for (i, result) in results.into_iter().enumerate() {
471 result_states.row_mut(i).assign(&result);
472 }
473 } else {
474 for i in 0..batch_size {
476 let matrix = matrices.slice(s![i, .., ..]);
477 let state = batch.states.row(i);
478 let result = matrix.dot(&state);
479 result_states.row_mut(i).assign(&result);
480 }
481 }
482
483 BatchStateVector::from_states(result_states, batch.config.clone())
484}
485
486pub fn compute_expectation_values_batch(
488 batch: &BatchStateVector,
489 observable_matrix: &Array2<Complex64>,
490) -> QuantRS2Result<Vec<f64>> {
491 let batch_size = batch.batch_size();
492
493 if batch_size > 16 {
495 let expectations: Vec<f64> = (0..batch_size)
496 .into_par_iter()
497 .map(|i| {
498 let state = batch.states.row(i);
499 compute_expectation_value(&state.to_owned(), observable_matrix)
500 })
501 .collect();
502
503 Ok(expectations)
504 } else {
505 let mut expectations = Vec::with_capacity(batch_size);
507 for i in 0..batch_size {
508 let state = batch.states.row(i);
509 expectations.push(compute_expectation_value(
510 &state.to_owned(),
511 observable_matrix,
512 ));
513 }
514 Ok(expectations)
515 }
516}
517
518fn compute_expectation_value(state: &Array1<Complex64>, observable: &Array2<Complex64>) -> f64 {
520 let temp = observable.dot(state);
522 let expectation = state
523 .iter()
524 .zip(temp.iter())
525 .map(|(a, b)| a.conj() * b)
526 .sum::<Complex64>();
527
528 expectation.re
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use ndarray::array;
535
536 #[test]
537 fn test_batch_hadamard() {
538 let mut batch = BatchStateVector::new(3, 1, Default::default()).unwrap();
539 let h = Hadamard { target: QubitId(0) };
540
541 h.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
542
543 for i in 0..3 {
545 let state = batch.get_state(i).unwrap();
546 assert!((state[0].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
547 assert!((state[1].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
548 }
549 }
550
551 #[test]
552 fn test_batch_pauli_x() {
553 let mut batch = BatchStateVector::new(2, 1, Default::default()).unwrap();
554 let x = PauliX { target: QubitId(0) };
555
556 x.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
557
558 for i in 0..2 {
560 let state = batch.get_state(i).unwrap();
561 assert_eq!(state[0], Complex64::new(0.0, 0.0));
562 assert_eq!(state[1], Complex64::new(1.0, 0.0));
563 }
564 }
565
566 #[test]
567 fn test_expectation_values_batch() {
568 let batch = BatchStateVector::new(5, 1, Default::default()).unwrap();
569
570 let z_observable = array![
572 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
573 [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
574 ];
575
576 let expectations = compute_expectation_values_batch(&batch, &z_observable).unwrap();
577
578 for exp in expectations {
580 assert!((exp - 1.0).abs() < 1e-10);
581 }
582 }
583}