1use super::{BatchGateOp, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 gate::{single::*, GateOp},
7 qubit::QubitId,
8};
9use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
10use scirs2_core::Complex64;
11use crate::parallel_ops_stubs::*;
13use crate::simd_ops_stubs::{SimdComplex64, SimdF64};
15
16pub fn apply_single_qubit_gate_batch(
18 batch: &mut BatchStateVector,
19 gate_matrix: &[Complex64; 4],
20 target: QubitId,
21) -> QuantRS2Result<()> {
22 let n_qubits = batch.n_qubits;
23 let target_idx = target.0 as usize;
24
25 if target_idx >= n_qubits {
26 return Err(QuantRS2Error::InvalidQubitId(target.0));
27 }
28
29 let batch_size = batch.batch_size();
30 if batch_size > 32 {
34 apply_single_qubit_batch_simd(batch, gate_matrix, target_idx, n_qubits)?;
35 } else if batch_size > 16 {
36 batch
38 .states
39 .axis_iter_mut(Axis(0))
40 .into_par_iter()
41 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
42 let mut state = state_row.to_owned();
43 apply_single_qubit_to_state_optimized(
44 &mut state,
45 gate_matrix,
46 target_idx,
47 n_qubits,
48 )?;
49 state_row.assign(&state);
50 Ok(())
51 })?;
52 } else {
53 for i in 0..batch_size {
55 let mut state = batch.states.row(i).to_owned();
56 apply_single_qubit_to_state_optimized(&mut state, gate_matrix, target_idx, n_qubits)?;
57 batch.states.row_mut(i).assign(&state);
58 }
59 }
60
61 Ok(())
62}
63
64pub fn apply_two_qubit_gate_batch(
66 batch: &mut BatchStateVector,
67 gate_matrix: &[Complex64; 16],
68 control: QubitId,
69 target: QubitId,
70) -> QuantRS2Result<()> {
71 let n_qubits = batch.n_qubits;
72 let control_idx = control.0 as usize;
73 let target_idx = target.0 as usize;
74
75 if control_idx >= n_qubits || target_idx >= n_qubits {
76 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
77 control.0
78 } else {
79 target.0
80 }));
81 }
82
83 if control_idx == target_idx {
84 return Err(QuantRS2Error::InvalidInput(
85 "Control and target qubits must be different".to_string(),
86 ));
87 }
88
89 let batch_size = batch.batch_size();
90
91 if batch_size > 16 {
93 batch
94 .states
95 .axis_iter_mut(Axis(0))
96 .into_par_iter()
97 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
98 let mut state = state_row.to_owned();
99 apply_two_qubit_to_state(
100 &mut state,
101 gate_matrix,
102 control_idx,
103 target_idx,
104 n_qubits,
105 )?;
106 state_row.assign(&state);
107 Ok(())
108 })?;
109 } else {
110 for i in 0..batch_size {
112 let mut state = batch.states.row(i).to_owned();
113 apply_two_qubit_to_state(&mut state, gate_matrix, control_idx, target_idx, n_qubits)?;
114 batch.states.row_mut(i).assign(&state);
115 }
116 }
117
118 Ok(())
119}
120
121fn apply_single_qubit_to_state_optimized(
123 state: &mut Array1<Complex64>,
124 gate_matrix: &[Complex64; 4],
125 target_idx: usize,
126 n_qubits: usize,
127) -> QuantRS2Result<()> {
128 let state_size = 1 << n_qubits;
129 let target_mask = 1 << target_idx;
130
131 for i in 0..state_size {
132 if i & target_mask == 0 {
133 let j = i | target_mask;
134
135 let a = state[i];
136 let b = state[j];
137
138 state[i] = gate_matrix[0] * a + gate_matrix[1] * b;
139 state[j] = gate_matrix[2] * a + gate_matrix[3] * b;
140 }
141 }
142
143 Ok(())
144}
145
146fn apply_single_qubit_batch_simd(
148 batch: &mut BatchStateVector,
149 gate_matrix: &[Complex64; 4],
150 target_idx: usize,
151 n_qubits: usize,
152) -> QuantRS2Result<()> {
153 use scirs2_core::ndarray::ArrayView1;
155
156 let batch_size = batch.batch_size();
157 let state_size = 1 << n_qubits;
158 let target_mask = 1 << target_idx;
159
160 let g00 = gate_matrix[0];
162 let g01 = gate_matrix[1];
163 let g10 = gate_matrix[2];
164 let g11 = gate_matrix[3];
165
166 let _pairs_per_batch = state_size / 2;
170 let total_pairs = batch_size * _pairs_per_batch;
171
172 for batch_idx in 0..batch_size {
175 let mut idx_pairs = Vec::new();
177 let mut a_values = Vec::new();
178 let mut b_values = Vec::new();
179
180 for i in 0..state_size {
181 if i & target_mask == 0 {
182 let j = i | target_mask;
183 idx_pairs.push((i, j));
184 a_values.push(batch.states[[batch_idx, i]]);
185 b_values.push(batch.states[[batch_idx, j]]);
186 }
187 }
188
189 if idx_pairs.is_empty() {
190 continue;
191 }
192
193 let _len = a_values.len();
199 let a_real: Vec<f64> = a_values.iter().map(|c| c.re).collect();
200 let a_imag: Vec<f64> = a_values.iter().map(|c| c.im).collect();
201 let b_real: Vec<f64> = b_values.iter().map(|c| c.re).collect();
202 let b_imag: Vec<f64> = b_values.iter().map(|c| c.im).collect();
203
204 let a_real_view = ArrayView1::from(&a_real);
206 let a_imag_view = ArrayView1::from(&a_imag);
207 let b_real_view = ArrayView1::from(&b_real);
208 let b_imag_view = ArrayView1::from(&b_imag);
209
210 let term1 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g00.re);
212 let term2 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g00.im);
213 let term3 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g01.re);
214 let term4 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g01.im);
215
216 let temp1 = <f64 as SimdF64>::simd_sub_arrays(&term1.view(), &term2.view());
217 let temp2 = <f64 as SimdF64>::simd_sub_arrays(&term3.view(), &term4.view());
218 let new_a_real = <f64 as SimdF64>::simd_add_arrays(&temp1.view(), &temp2.view());
219
220 let term5 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g00.re);
222 let term6 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g00.im);
223 let term7 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g01.re);
224 let term8 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g01.im);
225
226 let temp3 = <f64 as SimdF64>::simd_add_arrays(&term5.view(), &term6.view());
227 let temp4 = <f64 as SimdF64>::simd_add_arrays(&term7.view(), &term8.view());
228 let new_a_imag = <f64 as SimdF64>::simd_add_arrays(&temp3.view(), &temp4.view());
229
230 let term9 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g10.re);
232 let term10 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g10.im);
233 let term11 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g11.re);
234 let term12 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g11.im);
235
236 let temp5 = <f64 as SimdF64>::simd_sub_arrays(&term9.view(), &term10.view());
237 let temp6 = <f64 as SimdF64>::simd_sub_arrays(&term11.view(), &term12.view());
238 let new_b_real = <f64 as SimdF64>::simd_add_arrays(&temp5.view(), &temp6.view());
239
240 let term13 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g10.re);
241 let term14 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g10.im);
242 let term15 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g11.re);
243 let term16 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g11.im);
244
245 let temp7 = <f64 as SimdF64>::simd_add_arrays(&term13.view(), &term14.view());
246 let temp8 = <f64 as SimdF64>::simd_add_arrays(&term15.view(), &term16.view());
247 let new_b_imag = <f64 as SimdF64>::simd_add_arrays(&temp7.view(), &temp8.view());
248
249 for (idx, &(i, j)) in idx_pairs.iter().enumerate() {
251 batch.states[[batch_idx, i]] = Complex64::new(new_a_real[idx], new_a_imag[idx]);
252 batch.states[[batch_idx, j]] = Complex64::new(new_b_real[idx], new_b_imag[idx]);
253 }
254 }
255
256 Ok(())
257}
258
259fn apply_two_qubit_to_state(
261 state: &mut Array1<Complex64>,
262 gate_matrix: &[Complex64; 16],
263 control_idx: usize,
264 target_idx: usize,
265 n_qubits: usize,
266) -> QuantRS2Result<()> {
267 let state_size = 1 << n_qubits;
268 let control_mask = 1 << control_idx;
269 let target_mask = 1 << target_idx;
270
271 for i in 0..state_size {
272 if (i & control_mask == 0) && (i & target_mask == 0) {
273 let i00 = i;
274 let i01 = i | target_mask;
275 let i10 = i | control_mask;
276 let i11 = i | control_mask | target_mask;
277
278 let a00 = state[i00];
279 let a01 = state[i01];
280 let a10 = state[i10];
281 let a11 = state[i11];
282
283 state[i00] = gate_matrix[0] * a00
284 + gate_matrix[1] * a01
285 + gate_matrix[2] * a10
286 + gate_matrix[3] * a11;
287 state[i01] = gate_matrix[4] * a00
288 + gate_matrix[5] * a01
289 + gate_matrix[6] * a10
290 + gate_matrix[7] * a11;
291 state[i10] = gate_matrix[8] * a00
292 + gate_matrix[9] * a01
293 + gate_matrix[10] * a10
294 + gate_matrix[11] * a11;
295 state[i11] = gate_matrix[12] * a00
296 + gate_matrix[13] * a01
297 + gate_matrix[14] * a10
298 + gate_matrix[15] * a11;
299 }
300 }
301
302 Ok(())
303}
304
305pub struct BatchHadamard;
307
308impl BatchGateOp for Hadamard {
309 fn apply_batch(
310 &self,
311 batch: &mut BatchStateVector,
312 target_qubits: &[QubitId],
313 ) -> QuantRS2Result<()> {
314 if target_qubits.len() != 1 {
315 return Err(QuantRS2Error::InvalidInput(
316 "Hadamard gate requires exactly one target qubit".to_string(),
317 ));
318 }
319
320 let gate_matrix = [
321 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
322 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
323 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
324 Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
325 ];
326
327 apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
328 }
329}
330
331impl BatchGateOp for PauliX {
333 fn apply_batch(
334 &self,
335 batch: &mut BatchStateVector,
336 target_qubits: &[QubitId],
337 ) -> QuantRS2Result<()> {
338 if target_qubits.len() != 1 {
339 return Err(QuantRS2Error::InvalidInput(
340 "Pauli-X gate requires exactly one target qubit".to_string(),
341 ));
342 }
343
344 let gate_matrix = [
345 Complex64::new(0.0, 0.0),
346 Complex64::new(1.0, 0.0),
347 Complex64::new(1.0, 0.0),
348 Complex64::new(0.0, 0.0),
349 ];
350
351 apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
352 }
353}
354
355pub fn apply_gate_sequence_batch(
357 batch: &mut BatchStateVector,
358 gates: &[(Box<dyn GateOp>, Vec<QubitId>)],
359) -> QuantRS2Result<()> {
360 for (gate, qubits) in gates {
364 {
367 let matrix = gate.matrix()?;
369
370 match qubits.len() {
371 1 => {
372 let mut gate_array = [Complex64::new(0.0, 0.0); 4];
373 gate_array.copy_from_slice(&matrix[..4]);
374 apply_single_qubit_gate_batch(batch, &gate_array, qubits[0])?;
375 }
376 2 => {
377 let mut gate_array = [Complex64::new(0.0, 0.0); 16];
378 gate_array.copy_from_slice(&matrix[..16]);
379 apply_two_qubit_gate_batch(batch, &gate_array, qubits[0], qubits[1])?;
380 }
381 _ => {
382 return Err(QuantRS2Error::InvalidInput(
383 "Batch operations for gates with more than 2 qubits not yet supported"
384 .to_string(),
385 ));
386 }
387 }
388 }
389 }
390
391 Ok(())
392}
393
394pub fn batch_state_matrix_multiply(
397 batch: &BatchStateVector,
398 matrices: &Array3<Complex64>,
399) -> QuantRS2Result<BatchStateVector> {
400 let batch_size = batch.batch_size();
401 let (num_matrices, rows, cols) = matrices.dim();
402
403 if num_matrices != batch_size {
404 return Err(QuantRS2Error::InvalidInput(format!(
405 "Number of matrices {num_matrices} doesn't match batch size {batch_size}"
406 )));
407 }
408
409 if cols != batch.states.ncols() {
410 return Err(QuantRS2Error::InvalidInput(format!(
411 "Matrix columns {} don't match state size {}",
412 cols,
413 batch.states.ncols()
414 )));
415 }
416
417 let mut result_states = Array2::zeros((batch_size, rows));
419
420 if batch_size > 16 {
422 use crate::parallel_ops_stubs::*;
424
425 let results: Vec<_> = (0..batch_size)
426 .into_par_iter()
427 .map(|i| {
428 let matrix = matrices.slice(s![i, .., ..]);
429 let state = batch.states.row(i);
430 matrix.dot(&state)
431 })
432 .collect();
433
434 for (i, result) in results.into_iter().enumerate() {
435 result_states.row_mut(i).assign(&result);
436 }
437 } else {
438 for i in 0..batch_size {
440 let matrix = matrices.slice(s![i, .., ..]);
441 let state = batch.states.row(i);
442 let result = matrix.dot(&state);
443 result_states.row_mut(i).assign(&result);
444 }
445 }
446
447 BatchStateVector::from_states(result_states, batch.config.clone())
448}
449
450pub fn compute_expectation_values_batch(
452 batch: &BatchStateVector,
453 observable_matrix: &Array2<Complex64>,
454) -> QuantRS2Result<Vec<f64>> {
455 let batch_size = batch.batch_size();
456
457 if batch_size > 16 {
459 let expectations: Vec<f64> = (0..batch_size)
460 .into_par_iter()
461 .map(|i| {
462 let state = batch.states.row(i);
463 compute_expectation_value(&state.to_owned(), observable_matrix)
464 })
465 .collect();
466
467 Ok(expectations)
468 } else {
469 let mut expectations = Vec::with_capacity(batch_size);
471 for i in 0..batch_size {
472 let state = batch.states.row(i);
473 expectations.push(compute_expectation_value(
474 &state.to_owned(),
475 observable_matrix,
476 ));
477 }
478 Ok(expectations)
479 }
480}
481
482fn compute_expectation_value(state: &Array1<Complex64>, observable: &Array2<Complex64>) -> f64 {
484 let temp = observable.dot(state);
486 let expectation = state
487 .iter()
488 .zip(temp.iter())
489 .map(|(a, b)| a.conj() * b)
490 .sum::<Complex64>();
491
492 expectation.re
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use scirs2_core::ndarray::array;
499
500 #[test]
501 fn test_batch_hadamard() {
502 let mut batch = BatchStateVector::new(3, 1, Default::default())
503 .expect("Failed to create batch state vector for Hadamard test");
504 let h = Hadamard { target: QubitId(0) };
505
506 h.apply_batch(&mut batch, &[QubitId(0)])
507 .expect("Failed to apply Hadamard gate to batch");
508
509 for i in 0..3 {
511 let state = batch.get_state(i).expect("Failed to get state from batch");
512 assert!((state[0].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
513 assert!((state[1].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
514 }
515 }
516
517 #[test]
518 fn test_batch_pauli_x() {
519 let mut batch = BatchStateVector::new(2, 1, Default::default())
520 .expect("Failed to create batch state vector for Pauli X test");
521 let x = PauliX { target: QubitId(0) };
522
523 x.apply_batch(&mut batch, &[QubitId(0)])
524 .expect("Failed to apply Pauli X gate to batch");
525
526 for i in 0..2 {
528 let state = batch.get_state(i).expect("Failed to get state from batch");
529 assert_eq!(state[0], Complex64::new(0.0, 0.0));
530 assert_eq!(state[1], Complex64::new(1.0, 0.0));
531 }
532 }
533
534 #[test]
535 fn test_expectation_values_batch() {
536 let batch = BatchStateVector::new(5, 1, Default::default())
537 .expect("Failed to create batch state vector for expectation test");
538
539 let z_observable = array![
541 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
542 [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
543 ];
544
545 let expectations = compute_expectation_values_batch(&batch, &z_observable)
546 .expect("Failed to compute expectation values");
547
548 for exp in expectations {
550 assert!((exp - 1.0).abs() < 1e-10);
551 }
552 }
553}