quantrs2_core/batch/
operations.rs

1//! Batch operations for quantum gates using SciRS2 parallel algorithms
2
3use super::{BatchGateOp, BatchStateVector};
4use crate::{
5    error::{QuantRS2Error, QuantRS2Result},
6    gate::{single::*, GateOp},
7    qubit::QubitId,
8};
9use ndarray::{s, Array1, Array2, Array3, Axis};
10use num_complex::Complex64;
11use rayon::prelude::*;
12
13/// Apply a single-qubit gate to all states in a batch
14pub fn apply_single_qubit_gate_batch(
15    batch: &mut BatchStateVector,
16    gate_matrix: &[Complex64; 4],
17    target: QubitId,
18) -> QuantRS2Result<()> {
19    let n_qubits = batch.n_qubits;
20    let target_idx = target.0 as usize;
21
22    if target_idx >= n_qubits {
23        return Err(QuantRS2Error::InvalidQubitId(target.0));
24    }
25
26    let batch_size = batch.batch_size();
27    let _state_size = 1 << n_qubits;
28
29    // Use optimized SIMD batch processing for large batches
30    if batch_size > 32 && cfg!(target_feature = "avx2") {
31        apply_single_qubit_batch_simd(batch, gate_matrix, target_idx, n_qubits)?;
32    } else if batch_size > 16 {
33        // Use parallel processing for medium batches
34        batch
35            .states
36            .axis_iter_mut(Axis(0))
37            .into_par_iter()
38            .try_for_each(|mut state_row| -> QuantRS2Result<()> {
39                let mut state = state_row.to_owned();
40                apply_single_qubit_to_state_optimized(
41                    &mut state,
42                    gate_matrix,
43                    target_idx,
44                    n_qubits,
45                )?;
46                state_row.assign(&state);
47                Ok(())
48            })?;
49    } else {
50        // Sequential for small batches
51        for i in 0..batch_size {
52            let mut state = batch.states.row(i).to_owned();
53            apply_single_qubit_to_state_optimized(&mut state, gate_matrix, target_idx, n_qubits)?;
54            batch.states.row_mut(i).assign(&state);
55        }
56    }
57
58    Ok(())
59}
60
61/// Apply a two-qubit gate to all states in a batch
62pub fn apply_two_qubit_gate_batch(
63    batch: &mut BatchStateVector,
64    gate_matrix: &[Complex64; 16],
65    control: QubitId,
66    target: QubitId,
67) -> QuantRS2Result<()> {
68    let n_qubits = batch.n_qubits;
69    let control_idx = control.0 as usize;
70    let target_idx = target.0 as usize;
71
72    if control_idx >= n_qubits || target_idx >= n_qubits {
73        return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
74            control.0
75        } else {
76            target.0
77        }));
78    }
79
80    if control_idx == target_idx {
81        return Err(QuantRS2Error::InvalidInput(
82            "Control and target qubits must be different".to_string(),
83        ));
84    }
85
86    let batch_size = batch.batch_size();
87
88    // Use parallel processing for large batches
89    if batch_size > 16 {
90        batch
91            .states
92            .axis_iter_mut(Axis(0))
93            .into_par_iter()
94            .try_for_each(|mut state_row| -> QuantRS2Result<()> {
95                let mut state = state_row.to_owned();
96                apply_two_qubit_to_state(
97                    &mut state,
98                    gate_matrix,
99                    control_idx,
100                    target_idx,
101                    n_qubits,
102                )?;
103                state_row.assign(&state);
104                Ok(())
105            })?;
106    } else {
107        // Sequential for small batches
108        for i in 0..batch_size {
109            let mut state = batch.states.row(i).to_owned();
110            apply_two_qubit_to_state(&mut state, gate_matrix, control_idx, target_idx, n_qubits)?;
111            batch.states.row_mut(i).assign(&state);
112        }
113    }
114
115    Ok(())
116}
117
118/// Apply a single-qubit gate to a state vector (optimized version)
119fn apply_single_qubit_to_state_optimized(
120    state: &mut Array1<Complex64>,
121    gate_matrix: &[Complex64; 4],
122    target_idx: usize,
123    n_qubits: usize,
124) -> QuantRS2Result<()> {
125    let state_size = 1 << n_qubits;
126    let target_mask = 1 << target_idx;
127
128    for i in 0..state_size {
129        if i & target_mask == 0 {
130            let j = i | target_mask;
131
132            let a = state[i];
133            let b = state[j];
134
135            state[i] = gate_matrix[0] * a + gate_matrix[1] * b;
136            state[j] = gate_matrix[2] * a + gate_matrix[3] * b;
137        }
138    }
139
140    Ok(())
141}
142
143/// SIMD-optimized batch single-qubit gate application
144#[cfg(target_feature = "avx2")]
145fn apply_single_qubit_batch_simd(
146    batch: &mut BatchStateVector,
147    gate_matrix: &[Complex64; 4],
148    target_idx: usize,
149    n_qubits: usize,
150) -> QuantRS2Result<()> {
151    use std::arch::x86_64::*;
152
153    let batch_size = batch.batch_size();
154    let state_size = 1 << n_qubits;
155    let target_mask = 1 << target_idx;
156
157    // Extract gate matrix components for SIMD
158    let g00_re = gate_matrix[0].re;
159    let g00_im = gate_matrix[0].im;
160    let g01_re = gate_matrix[1].re;
161    let g01_im = gate_matrix[1].im;
162    let g10_re = gate_matrix[2].re;
163    let g10_im = gate_matrix[2].im;
164    let g11_re = gate_matrix[3].re;
165    let g11_im = gate_matrix[3].im;
166
167    unsafe {
168        // Broadcast gate matrix elements to SIMD registers
169        let g00_re_vec = _mm256_set1_pd(g00_re);
170        let g00_im_vec = _mm256_set1_pd(g00_im);
171        let g01_re_vec = _mm256_set1_pd(g01_re);
172        let g01_im_vec = _mm256_set1_pd(g01_im);
173        let g10_re_vec = _mm256_set1_pd(g10_re);
174        let g10_im_vec = _mm256_set1_pd(g10_im);
175        let g11_re_vec = _mm256_set1_pd(g11_re);
176        let g11_im_vec = _mm256_set1_pd(g11_im);
177
178        // Process batches in groups of 4 (AVX2 width for double precision)
179        for batch_start in (0..batch_size).step_by(4) {
180            let batch_end = (batch_start + 4).min(batch_size);
181            let actual_batch_size = batch_end - batch_start;
182
183            for i in 0..state_size {
184                if i & target_mask == 0 {
185                    let j = i | target_mask;
186
187                    // Load state amplitudes for multiple batch elements
188                    let mut a_re = [0.0; 4];
189                    let mut a_im = [0.0; 4];
190                    let mut b_re = [0.0; 4];
191                    let mut b_im = [0.0; 4];
192
193                    for k in 0..actual_batch_size {
194                        a_re[k] = batch.states[[batch_start + k, i]].re;
195                        a_im[k] = batch.states[[batch_start + k, i]].im;
196                        b_re[k] = batch.states[[batch_start + k, j]].re;
197                        b_im[k] = batch.states[[batch_start + k, j]].im;
198                    }
199
200                    // Load into SIMD registers
201                    let a_re_vec = _mm256_loadu_pd(a_re.as_ptr());
202                    let a_im_vec = _mm256_loadu_pd(a_im.as_ptr());
203                    let b_re_vec = _mm256_loadu_pd(b_re.as_ptr());
204                    let b_im_vec = _mm256_loadu_pd(b_im.as_ptr());
205
206                    // Compute new_a = g00 * a + g01 * b
207                    let new_a_re = _mm256_add_pd(
208                        _mm256_sub_pd(
209                            _mm256_mul_pd(g00_re_vec, a_re_vec),
210                            _mm256_mul_pd(g00_im_vec, a_im_vec),
211                        ),
212                        _mm256_sub_pd(
213                            _mm256_mul_pd(g01_re_vec, b_re_vec),
214                            _mm256_mul_pd(g01_im_vec, b_im_vec),
215                        ),
216                    );
217
218                    let new_a_im = _mm256_add_pd(
219                        _mm256_add_pd(
220                            _mm256_mul_pd(g00_re_vec, a_im_vec),
221                            _mm256_mul_pd(g00_im_vec, a_re_vec),
222                        ),
223                        _mm256_add_pd(
224                            _mm256_mul_pd(g01_re_vec, b_im_vec),
225                            _mm256_mul_pd(g01_im_vec, b_re_vec),
226                        ),
227                    );
228
229                    // Compute new_b = g10 * a + g11 * b
230                    let new_b_re = _mm256_add_pd(
231                        _mm256_sub_pd(
232                            _mm256_mul_pd(g10_re_vec, a_re_vec),
233                            _mm256_mul_pd(g10_im_vec, a_im_vec),
234                        ),
235                        _mm256_sub_pd(
236                            _mm256_mul_pd(g11_re_vec, b_re_vec),
237                            _mm256_mul_pd(g11_im_vec, b_im_vec),
238                        ),
239                    );
240
241                    let new_b_im = _mm256_add_pd(
242                        _mm256_add_pd(
243                            _mm256_mul_pd(g10_re_vec, a_im_vec),
244                            _mm256_mul_pd(g10_im_vec, a_re_vec),
245                        ),
246                        _mm256_add_pd(
247                            _mm256_mul_pd(g11_re_vec, b_im_vec),
248                            _mm256_mul_pd(g11_im_vec, b_re_vec),
249                        ),
250                    );
251
252                    // Store results back
253                    let mut result_a_re = [0.0; 4];
254                    let mut result_a_im = [0.0; 4];
255                    let mut result_b_re = [0.0; 4];
256                    let mut result_b_im = [0.0; 4];
257
258                    _mm256_storeu_pd(result_a_re.as_mut_ptr(), new_a_re);
259                    _mm256_storeu_pd(result_a_im.as_mut_ptr(), new_a_im);
260                    _mm256_storeu_pd(result_b_re.as_mut_ptr(), new_b_re);
261                    _mm256_storeu_pd(result_b_im.as_mut_ptr(), new_b_im);
262
263                    for k in 0..actual_batch_size {
264                        batch.states[[batch_start + k, i]] =
265                            Complex64::new(result_a_re[k], result_a_im[k]);
266                        batch.states[[batch_start + k, j]] =
267                            Complex64::new(result_b_re[k], result_b_im[k]);
268                    }
269                }
270            }
271        }
272    }
273
274    Ok(())
275}
276
277/// Fallback for non-AVX2 targets
278#[cfg(not(target_feature = "avx2"))]
279fn apply_single_qubit_batch_simd(
280    batch: &mut BatchStateVector,
281    gate_matrix: &[Complex64; 4],
282    target_idx: usize,
283    n_qubits: usize,
284) -> QuantRS2Result<()> {
285    // Fallback to optimized sequential processing
286    let batch_size = batch.batch_size();
287    for i in 0..batch_size {
288        let mut state = batch.states.row(i).to_owned();
289        apply_single_qubit_to_state_optimized(&mut state, gate_matrix, target_idx, n_qubits)?;
290        batch.states.row_mut(i).assign(&state);
291    }
292    Ok(())
293}
294
295/// Apply a two-qubit gate to a state vector
296fn apply_two_qubit_to_state(
297    state: &mut Array1<Complex64>,
298    gate_matrix: &[Complex64; 16],
299    control_idx: usize,
300    target_idx: usize,
301    n_qubits: usize,
302) -> QuantRS2Result<()> {
303    let state_size = 1 << n_qubits;
304    let control_mask = 1 << control_idx;
305    let target_mask = 1 << target_idx;
306
307    for i in 0..state_size {
308        if (i & control_mask == 0) && (i & target_mask == 0) {
309            let i00 = i;
310            let i01 = i | target_mask;
311            let i10 = i | control_mask;
312            let i11 = i | control_mask | target_mask;
313
314            let a00 = state[i00];
315            let a01 = state[i01];
316            let a10 = state[i10];
317            let a11 = state[i11];
318
319            state[i00] = gate_matrix[0] * a00
320                + gate_matrix[1] * a01
321                + gate_matrix[2] * a10
322                + gate_matrix[3] * a11;
323            state[i01] = gate_matrix[4] * a00
324                + gate_matrix[5] * a01
325                + gate_matrix[6] * a10
326                + gate_matrix[7] * a11;
327            state[i10] = gate_matrix[8] * a00
328                + gate_matrix[9] * a01
329                + gate_matrix[10] * a10
330                + gate_matrix[11] * a11;
331            state[i11] = gate_matrix[12] * a00
332                + gate_matrix[13] * a01
333                + gate_matrix[14] * a10
334                + gate_matrix[15] * a11;
335        }
336    }
337
338    Ok(())
339}
340
341/// Batch-optimized Hadamard gate using SciRS2
342pub struct BatchHadamard;
343
344impl BatchGateOp for Hadamard {
345    fn apply_batch(
346        &self,
347        batch: &mut BatchStateVector,
348        target_qubits: &[QubitId],
349    ) -> QuantRS2Result<()> {
350        if target_qubits.len() != 1 {
351            return Err(QuantRS2Error::InvalidInput(
352                "Hadamard gate requires exactly one target qubit".to_string(),
353            ));
354        }
355
356        let gate_matrix = [
357            Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
358            Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
359            Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
360            Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
361        ];
362
363        apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
364    }
365}
366
367/// Batch-optimized Pauli-X gate
368impl BatchGateOp for PauliX {
369    fn apply_batch(
370        &self,
371        batch: &mut BatchStateVector,
372        target_qubits: &[QubitId],
373    ) -> QuantRS2Result<()> {
374        if target_qubits.len() != 1 {
375            return Err(QuantRS2Error::InvalidInput(
376                "Pauli-X gate requires exactly one target qubit".to_string(),
377            ));
378        }
379
380        let gate_matrix = [
381            Complex64::new(0.0, 0.0),
382            Complex64::new(1.0, 0.0),
383            Complex64::new(1.0, 0.0),
384            Complex64::new(0.0, 0.0),
385        ];
386
387        apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
388    }
389}
390
391/// Apply multiple gates to a batch using SciRS2 batch operations
392pub fn apply_gate_sequence_batch(
393    batch: &mut BatchStateVector,
394    gates: &[(Box<dyn GateOp>, Vec<QubitId>)],
395) -> QuantRS2Result<()> {
396    // For gates that support batch operations, use them
397    // Otherwise fall back to standard application
398
399    for (gate, qubits) in gates {
400        // For now, always use standard application
401        // TODO: Add batch-optimized gate detection
402        {
403            // Fall back to standard application
404            let matrix = gate.matrix()?;
405
406            match qubits.len() {
407                1 => {
408                    let mut gate_array = [Complex64::new(0.0, 0.0); 4];
409                    gate_array.copy_from_slice(&matrix[..4]);
410                    apply_single_qubit_gate_batch(batch, &gate_array, qubits[0])?;
411                }
412                2 => {
413                    let mut gate_array = [Complex64::new(0.0, 0.0); 16];
414                    gate_array.copy_from_slice(&matrix[..16]);
415                    apply_two_qubit_gate_batch(batch, &gate_array, qubits[0], qubits[1])?;
416                }
417                _ => {
418                    return Err(QuantRS2Error::InvalidInput(
419                        "Batch operations for gates with more than 2 qubits not yet supported"
420                            .to_string(),
421                    ));
422                }
423            }
424        }
425    }
426
427    Ok(())
428}
429
430/// Batch matrix multiplication
431/// Note: SciRS2 batch_matmul doesn't support Complex numbers, so we implement our own
432pub fn batch_state_matrix_multiply(
433    batch: &BatchStateVector,
434    matrices: &Array3<Complex64>,
435) -> QuantRS2Result<BatchStateVector> {
436    let batch_size = batch.batch_size();
437    let (num_matrices, rows, cols) = matrices.dim();
438
439    if num_matrices != batch_size {
440        return Err(QuantRS2Error::InvalidInput(format!(
441            "Number of matrices {} doesn't match batch size {}",
442            num_matrices, batch_size
443        )));
444    }
445
446    if cols != batch.states.ncols() {
447        return Err(QuantRS2Error::InvalidInput(format!(
448            "Matrix columns {} don't match state size {}",
449            cols,
450            batch.states.ncols()
451        )));
452    }
453
454    // Perform batch matrix multiplication manually
455    let mut result_states = Array2::zeros((batch_size, rows));
456
457    // Use parallel processing for large batches
458    if batch_size > 16 {
459        use rayon::prelude::*;
460
461        let results: Vec<_> = (0..batch_size)
462            .into_par_iter()
463            .map(|i| {
464                let matrix = matrices.slice(s![i, .., ..]);
465                let state = batch.states.row(i);
466                matrix.dot(&state)
467            })
468            .collect();
469
470        for (i, result) in results.into_iter().enumerate() {
471            result_states.row_mut(i).assign(&result);
472        }
473    } else {
474        // Sequential for small batches
475        for i in 0..batch_size {
476            let matrix = matrices.slice(s![i, .., ..]);
477            let state = batch.states.row(i);
478            let result = matrix.dot(&state);
479            result_states.row_mut(i).assign(&result);
480        }
481    }
482
483    BatchStateVector::from_states(result_states, batch.config.clone())
484}
485
486/// Parallel expectation value computation
487pub fn compute_expectation_values_batch(
488    batch: &BatchStateVector,
489    observable_matrix: &Array2<Complex64>,
490) -> QuantRS2Result<Vec<f64>> {
491    let batch_size = batch.batch_size();
492
493    // Use parallel computation for large batches
494    if batch_size > 16 {
495        let expectations: Vec<f64> = (0..batch_size)
496            .into_par_iter()
497            .map(|i| {
498                let state = batch.states.row(i);
499                compute_expectation_value(&state.to_owned(), observable_matrix)
500            })
501            .collect();
502
503        Ok(expectations)
504    } else {
505        // Sequential for small batches
506        let mut expectations = Vec::with_capacity(batch_size);
507        for i in 0..batch_size {
508            let state = batch.states.row(i);
509            expectations.push(compute_expectation_value(
510                &state.to_owned(),
511                observable_matrix,
512            ));
513        }
514        Ok(expectations)
515    }
516}
517
518/// Compute expectation value for a single state
519fn compute_expectation_value(state: &Array1<Complex64>, observable: &Array2<Complex64>) -> f64 {
520    // <ψ|O|ψ>
521    let temp = observable.dot(state);
522    let expectation = state
523        .iter()
524        .zip(temp.iter())
525        .map(|(a, b)| a.conj() * b)
526        .sum::<Complex64>();
527
528    expectation.re
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use ndarray::array;
535
536    #[test]
537    fn test_batch_hadamard() {
538        let mut batch = BatchStateVector::new(3, 1, Default::default()).unwrap();
539        let h = Hadamard { target: QubitId(0) };
540
541        h.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
542
543        // Check all states are in superposition
544        for i in 0..3 {
545            let state = batch.get_state(i).unwrap();
546            assert!((state[0].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
547            assert!((state[1].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
548        }
549    }
550
551    #[test]
552    fn test_batch_pauli_x() {
553        let mut batch = BatchStateVector::new(2, 1, Default::default()).unwrap();
554        let x = PauliX { target: QubitId(0) };
555
556        x.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
557
558        // Check all states are flipped
559        for i in 0..2 {
560            let state = batch.get_state(i).unwrap();
561            assert_eq!(state[0], Complex64::new(0.0, 0.0));
562            assert_eq!(state[1], Complex64::new(1.0, 0.0));
563        }
564    }
565
566    #[test]
567    fn test_expectation_values_batch() {
568        let batch = BatchStateVector::new(5, 1, Default::default()).unwrap();
569
570        // Pauli Z observable
571        let z_observable = array![
572            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
573            [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
574        ];
575
576        let expectations = compute_expectation_values_batch(&batch, &z_observable).unwrap();
577
578        // All states are |0>, so expectation of Z should be 1
579        for exp in expectations {
580            assert!((exp - 1.0).abs() < 1e-10);
581        }
582    }
583}