1use super::{BatchGateOp, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 gate::{single::*, GateOp},
7 qubit::QubitId,
8};
9use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
10use scirs2_core::Complex64;
11use crate::parallel_ops_stubs::*;
13use crate::simd_ops_stubs::{SimdComplex64, SimdF64};
15
16pub fn apply_single_qubit_gate_batch(
18 batch: &mut BatchStateVector,
19 gate_matrix: &[Complex64; 4],
20 target: QubitId,
21) -> QuantRS2Result<()> {
22 let n_qubits = batch.n_qubits;
23 let target_idx = target.0 as usize;
24
25 if target_idx >= n_qubits {
26 return Err(QuantRS2Error::InvalidQubitId(target.0));
27 }
28
29 let batch_size = batch.batch_size();
30 let _state_size = 1 << n_qubits;
31
32 if batch_size > 32 {
34 apply_single_qubit_batch_simd(batch, gate_matrix, target_idx, n_qubits)?;
35 } else if batch_size > 16 {
36 batch
38 .states
39 .axis_iter_mut(Axis(0))
40 .into_par_iter()
41 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
42 let mut state = state_row.to_owned();
43 apply_single_qubit_to_state_optimized(
44 &mut state,
45 gate_matrix,
46 target_idx,
47 n_qubits,
48 )?;
49 state_row.assign(&state);
50 Ok(())
51 })?;
52 } else {
53 for i in 0..batch_size {
55 let mut state = batch.states.row(i).to_owned();
56 apply_single_qubit_to_state_optimized(&mut state, gate_matrix, target_idx, n_qubits)?;
57 batch.states.row_mut(i).assign(&state);
58 }
59 }
60
61 Ok(())
62}
63
64pub fn apply_two_qubit_gate_batch(
66 batch: &mut BatchStateVector,
67 gate_matrix: &[Complex64; 16],
68 control: QubitId,
69 target: QubitId,
70) -> QuantRS2Result<()> {
71 let n_qubits = batch.n_qubits;
72 let control_idx = control.0 as usize;
73 let target_idx = target.0 as usize;
74
75 if control_idx >= n_qubits || target_idx >= n_qubits {
76 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
77 control.0
78 } else {
79 target.0
80 }));
81 }
82
83 if control_idx == target_idx {
84 return Err(QuantRS2Error::InvalidInput(
85 "Control and target qubits must be different".to_string(),
86 ));
87 }
88
89 let batch_size = batch.batch_size();
90
91 if batch_size > 16 {
93 batch
94 .states
95 .axis_iter_mut(Axis(0))
96 .into_par_iter()
97 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
98 let mut state = state_row.to_owned();
99 apply_two_qubit_to_state(
100 &mut state,
101 gate_matrix,
102 control_idx,
103 target_idx,
104 n_qubits,
105 )?;
106 state_row.assign(&state);
107 Ok(())
108 })?;
109 } else {
110 for i in 0..batch_size {
112 let mut state = batch.states.row(i).to_owned();
113 apply_two_qubit_to_state(&mut state, gate_matrix, control_idx, target_idx, n_qubits)?;
114 batch.states.row_mut(i).assign(&state);
115 }
116 }
117
118 Ok(())
119}
120
121fn apply_single_qubit_to_state_optimized(
123 state: &mut Array1<Complex64>,
124 gate_matrix: &[Complex64; 4],
125 target_idx: usize,
126 n_qubits: usize,
127) -> QuantRS2Result<()> {
128 let state_size = 1 << n_qubits;
129 let target_mask = 1 << target_idx;
130
131 for i in 0..state_size {
132 if i & target_mask == 0 {
133 let j = i | target_mask;
134
135 let a = state[i];
136 let b = state[j];
137
138 state[i] = gate_matrix[0] * a + gate_matrix[1] * b;
139 state[j] = gate_matrix[2] * a + gate_matrix[3] * b;
140 }
141 }
142
143 Ok(())
144}
145
146fn apply_single_qubit_batch_simd(
148 batch: &mut BatchStateVector,
149 gate_matrix: &[Complex64; 4],
150 target_idx: usize,
151 n_qubits: usize,
152) -> QuantRS2Result<()> {
153 use scirs2_core::ndarray::ArrayView1;
155
156 let batch_size = batch.batch_size();
157 let state_size = 1 << n_qubits;
158 let target_mask = 1 << target_idx;
159
160 let g00 = gate_matrix[0];
162 let g01 = gate_matrix[1];
163 let g10 = gate_matrix[2];
164 let g11 = gate_matrix[3];
165
166 let _pairs_per_batch = state_size / 2;
170 let _total_pairs = batch_size * _pairs_per_batch;
171
172 for batch_idx in 0..batch_size {
175 let mut idx_pairs = Vec::new();
177 let mut a_values = Vec::new();
178 let mut b_values = Vec::new();
179
180 for i in 0..state_size {
181 if i & target_mask == 0 {
182 let j = i | target_mask;
183 idx_pairs.push((i, j));
184 a_values.push(batch.states[[batch_idx, i]]);
185 b_values.push(batch.states[[batch_idx, j]]);
186 }
187 }
188
189 if idx_pairs.is_empty() {
190 continue;
191 }
192
193 let _len = a_values.len();
199 let a_real: Vec<f64> = a_values.iter().map(|c| c.re).collect();
200 let a_imag: Vec<f64> = a_values.iter().map(|c| c.im).collect();
201 let b_real: Vec<f64> = b_values.iter().map(|c| c.re).collect();
202 let b_imag: Vec<f64> = b_values.iter().map(|c| c.im).collect();
203
204 let a_real_view = ArrayView1::from(&a_real);
206 let a_imag_view = ArrayView1::from(&a_imag);
207 let b_real_view = ArrayView1::from(&b_real);
208 let b_imag_view = ArrayView1::from(&b_imag);
209
210 let term1 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g00.re);
212 let term2 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g00.im);
213 let term3 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g01.re);
214 let term4 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g01.im);
215
216 let temp1 = <f64 as SimdF64>::simd_sub_arrays(&term1.view(), &term2.view());
217 let temp2 = <f64 as SimdF64>::simd_sub_arrays(&term3.view(), &term4.view());
218 let new_a_real = <f64 as SimdF64>::simd_add_arrays(&temp1.view(), &temp2.view());
219
220 let term5 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g00.re);
222 let term6 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g00.im);
223 let term7 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g01.re);
224 let term8 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g01.im);
225
226 let temp3 = <f64 as SimdF64>::simd_add_arrays(&term5.view(), &term6.view());
227 let temp4 = <f64 as SimdF64>::simd_add_arrays(&term7.view(), &term8.view());
228 let new_a_imag = <f64 as SimdF64>::simd_add_arrays(&temp3.view(), &temp4.view());
229
230 let term9 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g10.re);
232 let term10 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g10.im);
233 let term11 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g11.re);
234 let term12 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g11.im);
235
236 let temp5 = <f64 as SimdF64>::simd_sub_arrays(&term9.view(), &term10.view());
237 let temp6 = <f64 as SimdF64>::simd_sub_arrays(&term11.view(), &term12.view());
238 let new_b_real = <f64 as SimdF64>::simd_add_arrays(&temp5.view(), &temp6.view());
239
240 let term13 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g10.re);
241 let term14 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g10.im);
242 let term15 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g11.re);
243 let term16 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g11.im);
244
245 let temp7 = <f64 as SimdF64>::simd_add_arrays(&term13.view(), &term14.view());
246 let temp8 = <f64 as SimdF64>::simd_add_arrays(&term15.view(), &term16.view());
247 let new_b_imag = <f64 as SimdF64>::simd_add_arrays(&temp7.view(), &temp8.view());
248
249 for (idx, &(i, j)) in idx_pairs.iter().enumerate() {
251 batch.states[[batch_idx, i]] = Complex64::new(new_a_real[idx], new_a_imag[idx]);
252 batch.states[[batch_idx, j]] = Complex64::new(new_b_real[idx], new_b_imag[idx]);
253 }
254 }
255
256 Ok(())
257}
258
259fn apply_two_qubit_to_state(
261 state: &mut Array1<Complex64>,
262 gate_matrix: &[Complex64; 16],
263 control_idx: usize,
264 target_idx: usize,
265 n_qubits: usize,
266) -> QuantRS2Result<()> {
267 let state_size = 1 << n_qubits;
268 let control_mask = 1 << control_idx;
269 let target_mask = 1 << target_idx;
270
271 for i in 0..state_size {
272 if (i & control_mask == 0) && (i & target_mask == 0) {
273 let i00 = i;
274 let i01 = i | target_mask;
275 let i10 = i | control_mask;
276 let i11 = i | control_mask | target_mask;
277
278 let a00 = state[i00];
279 let a01 = state[i01];
280 let a10 = state[i10];
281 let a11 = state[i11];
282
283 state[i00] = gate_matrix[0] * a00
284 + gate_matrix[1] * a01
285 + gate_matrix[2] * a10
286 + gate_matrix[3] * a11;
287 state[i01] = gate_matrix[4] * a00
288 + gate_matrix[5] * a01
289 + gate_matrix[6] * a10
290 + gate_matrix[7] * a11;
291 state[i10] = gate_matrix[8] * a00
292 + gate_matrix[9] * a01
293 + gate_matrix[10] * a10
294 + gate_matrix[11] * a11;
295 state[i11] = gate_matrix[12] * a00
296 + gate_matrix[13] * a01
297 + gate_matrix[14] * a10
298 + gate_matrix[15] * a11;
299 }
300 }
301
302 Ok(())
303}
304
305pub struct BatchHadamard;
307
308impl BatchGateOp for Hadamard {
309 fn apply_batch(
310 &self,
311 batch: &mut BatchStateVector,
312 target_qubits: &[QubitId],
313 ) -> QuantRS2Result<()> {
314 if target_qubits.len() != 1 {
315 return Err(QuantRS2Error::InvalidInput(
316 "Hadamard gate requires exactly one target qubit".to_string(),
317 ));
318 }
319
320 let gate_matrix = [
321 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
322 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
323 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
324 Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
325 ];
326
327 apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
328 }
329}
330
331impl BatchGateOp for PauliX {
333 fn apply_batch(
334 &self,
335 batch: &mut BatchStateVector,
336 target_qubits: &[QubitId],
337 ) -> QuantRS2Result<()> {
338 if target_qubits.len() != 1 {
339 return Err(QuantRS2Error::InvalidInput(
340 "Pauli-X gate requires exactly one target qubit".to_string(),
341 ));
342 }
343
344 let gate_matrix = [
345 Complex64::new(0.0, 0.0),
346 Complex64::new(1.0, 0.0),
347 Complex64::new(1.0, 0.0),
348 Complex64::new(0.0, 0.0),
349 ];
350
351 apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
352 }
353}
354
355pub fn apply_gate_sequence_batch(
357 batch: &mut BatchStateVector,
358 gates: &[(Box<dyn GateOp>, Vec<QubitId>)],
359) -> QuantRS2Result<()> {
360 for (gate, qubits) in gates {
364 {
367 let matrix = gate.matrix()?;
369
370 match qubits.len() {
371 1 => {
372 let mut gate_array = [Complex64::new(0.0, 0.0); 4];
373 gate_array.copy_from_slice(&matrix[..4]);
374 apply_single_qubit_gate_batch(batch, &gate_array, qubits[0])?;
375 }
376 2 => {
377 let mut gate_array = [Complex64::new(0.0, 0.0); 16];
378 gate_array.copy_from_slice(&matrix[..16]);
379 apply_two_qubit_gate_batch(batch, &gate_array, qubits[0], qubits[1])?;
380 }
381 _ => {
382 return Err(QuantRS2Error::InvalidInput(
383 "Batch operations for gates with more than 2 qubits not yet supported"
384 .to_string(),
385 ));
386 }
387 }
388 }
389 }
390
391 Ok(())
392}
393
394pub fn batch_state_matrix_multiply(
397 batch: &BatchStateVector,
398 matrices: &Array3<Complex64>,
399) -> QuantRS2Result<BatchStateVector> {
400 let batch_size = batch.batch_size();
401 let (num_matrices, rows, cols) = matrices.dim();
402
403 if num_matrices != batch_size {
404 return Err(QuantRS2Error::InvalidInput(format!(
405 "Number of matrices {} doesn't match batch size {}",
406 num_matrices, batch_size
407 )));
408 }
409
410 if cols != batch.states.ncols() {
411 return Err(QuantRS2Error::InvalidInput(format!(
412 "Matrix columns {} don't match state size {}",
413 cols,
414 batch.states.ncols()
415 )));
416 }
417
418 let mut result_states = Array2::zeros((batch_size, rows));
420
421 if batch_size > 16 {
423 use crate::parallel_ops_stubs::*;
425
426 let results: Vec<_> = (0..batch_size)
427 .into_par_iter()
428 .map(|i| {
429 let matrix = matrices.slice(s![i, .., ..]);
430 let state = batch.states.row(i);
431 matrix.dot(&state)
432 })
433 .collect();
434
435 for (i, result) in results.into_iter().enumerate() {
436 result_states.row_mut(i).assign(&result);
437 }
438 } else {
439 for i in 0..batch_size {
441 let matrix = matrices.slice(s![i, .., ..]);
442 let state = batch.states.row(i);
443 let result = matrix.dot(&state);
444 result_states.row_mut(i).assign(&result);
445 }
446 }
447
448 BatchStateVector::from_states(result_states, batch.config.clone())
449}
450
451pub fn compute_expectation_values_batch(
453 batch: &BatchStateVector,
454 observable_matrix: &Array2<Complex64>,
455) -> QuantRS2Result<Vec<f64>> {
456 let batch_size = batch.batch_size();
457
458 if batch_size > 16 {
460 let expectations: Vec<f64> = (0..batch_size)
461 .into_par_iter()
462 .map(|i| {
463 let state = batch.states.row(i);
464 compute_expectation_value(&state.to_owned(), observable_matrix)
465 })
466 .collect();
467
468 Ok(expectations)
469 } else {
470 let mut expectations = Vec::with_capacity(batch_size);
472 for i in 0..batch_size {
473 let state = batch.states.row(i);
474 expectations.push(compute_expectation_value(
475 &state.to_owned(),
476 observable_matrix,
477 ));
478 }
479 Ok(expectations)
480 }
481}
482
483fn compute_expectation_value(state: &Array1<Complex64>, observable: &Array2<Complex64>) -> f64 {
485 let temp = observable.dot(state);
487 let expectation = state
488 .iter()
489 .zip(temp.iter())
490 .map(|(a, b)| a.conj() * b)
491 .sum::<Complex64>();
492
493 expectation.re
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use scirs2_core::ndarray::array;
500
501 #[test]
502 fn test_batch_hadamard() {
503 let mut batch = BatchStateVector::new(3, 1, Default::default()).unwrap();
504 let h = Hadamard { target: QubitId(0) };
505
506 h.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
507
508 for i in 0..3 {
510 let state = batch.get_state(i).unwrap();
511 assert!((state[0].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
512 assert!((state[1].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
513 }
514 }
515
516 #[test]
517 fn test_batch_pauli_x() {
518 let mut batch = BatchStateVector::new(2, 1, Default::default()).unwrap();
519 let x = PauliX { target: QubitId(0) };
520
521 x.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
522
523 for i in 0..2 {
525 let state = batch.get_state(i).unwrap();
526 assert_eq!(state[0], Complex64::new(0.0, 0.0));
527 assert_eq!(state[1], Complex64::new(1.0, 0.0));
528 }
529 }
530
531 #[test]
532 fn test_expectation_values_batch() {
533 let batch = BatchStateVector::new(5, 1, Default::default()).unwrap();
534
535 let z_observable = array![
537 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
538 [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
539 ];
540
541 let expectations = compute_expectation_values_batch(&batch, &z_observable).unwrap();
542
543 for exp in expectations {
545 assert!((exp - 1.0).abs() < 1e-10);
546 }
547 }
548}