quantrs2_core/batch/
operations.rs

1//! Batch operations for quantum gates using SciRS2 parallel algorithms
2
3use super::{BatchGateOp, BatchStateVector};
4use crate::{
5    error::{QuantRS2Error, QuantRS2Result},
6    gate::{single::*, GateOp},
7    qubit::QubitId,
8};
9use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
10use scirs2_core::Complex64;
11// use scirs2_core::parallel_ops::*;
12use crate::parallel_ops_stubs::*;
13// use scirs2_core::simd_ops::SimdUnifiedOps;
14use crate::simd_ops_stubs::{SimdComplex64, SimdF64};
15
16/// Apply a single-qubit gate to all states in a batch
17pub fn apply_single_qubit_gate_batch(
18    batch: &mut BatchStateVector,
19    gate_matrix: &[Complex64; 4],
20    target: QubitId,
21) -> QuantRS2Result<()> {
22    let n_qubits = batch.n_qubits;
23    let target_idx = target.0 as usize;
24
25    if target_idx >= n_qubits {
26        return Err(QuantRS2Error::InvalidQubitId(target.0));
27    }
28
29    let batch_size = batch.batch_size();
30    let _state_size = 1 << n_qubits;
31
32    // Use optimized SIMD batch processing for large batches
33    if batch_size > 32 {
34        apply_single_qubit_batch_simd(batch, gate_matrix, target_idx, n_qubits)?;
35    } else if batch_size > 16 {
36        // Use parallel processing for medium batches
37        batch
38            .states
39            .axis_iter_mut(Axis(0))
40            .into_par_iter()
41            .try_for_each(|mut state_row| -> QuantRS2Result<()> {
42                let mut state = state_row.to_owned();
43                apply_single_qubit_to_state_optimized(
44                    &mut state,
45                    gate_matrix,
46                    target_idx,
47                    n_qubits,
48                )?;
49                state_row.assign(&state);
50                Ok(())
51            })?;
52    } else {
53        // Sequential for small batches
54        for i in 0..batch_size {
55            let mut state = batch.states.row(i).to_owned();
56            apply_single_qubit_to_state_optimized(&mut state, gate_matrix, target_idx, n_qubits)?;
57            batch.states.row_mut(i).assign(&state);
58        }
59    }
60
61    Ok(())
62}
63
64/// Apply a two-qubit gate to all states in a batch
65pub fn apply_two_qubit_gate_batch(
66    batch: &mut BatchStateVector,
67    gate_matrix: &[Complex64; 16],
68    control: QubitId,
69    target: QubitId,
70) -> QuantRS2Result<()> {
71    let n_qubits = batch.n_qubits;
72    let control_idx = control.0 as usize;
73    let target_idx = target.0 as usize;
74
75    if control_idx >= n_qubits || target_idx >= n_qubits {
76        return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
77            control.0
78        } else {
79            target.0
80        }));
81    }
82
83    if control_idx == target_idx {
84        return Err(QuantRS2Error::InvalidInput(
85            "Control and target qubits must be different".to_string(),
86        ));
87    }
88
89    let batch_size = batch.batch_size();
90
91    // Use parallel processing for large batches
92    if batch_size > 16 {
93        batch
94            .states
95            .axis_iter_mut(Axis(0))
96            .into_par_iter()
97            .try_for_each(|mut state_row| -> QuantRS2Result<()> {
98                let mut state = state_row.to_owned();
99                apply_two_qubit_to_state(
100                    &mut state,
101                    gate_matrix,
102                    control_idx,
103                    target_idx,
104                    n_qubits,
105                )?;
106                state_row.assign(&state);
107                Ok(())
108            })?;
109    } else {
110        // Sequential for small batches
111        for i in 0..batch_size {
112            let mut state = batch.states.row(i).to_owned();
113            apply_two_qubit_to_state(&mut state, gate_matrix, control_idx, target_idx, n_qubits)?;
114            batch.states.row_mut(i).assign(&state);
115        }
116    }
117
118    Ok(())
119}
120
121/// Apply a single-qubit gate to a state vector (optimized version)
122fn apply_single_qubit_to_state_optimized(
123    state: &mut Array1<Complex64>,
124    gate_matrix: &[Complex64; 4],
125    target_idx: usize,
126    n_qubits: usize,
127) -> QuantRS2Result<()> {
128    let state_size = 1 << n_qubits;
129    let target_mask = 1 << target_idx;
130
131    for i in 0..state_size {
132        if i & target_mask == 0 {
133            let j = i | target_mask;
134
135            let a = state[i];
136            let b = state[j];
137
138            state[i] = gate_matrix[0] * a + gate_matrix[1] * b;
139            state[j] = gate_matrix[2] * a + gate_matrix[3] * b;
140        }
141    }
142
143    Ok(())
144}
145
146/// SIMD-optimized batch single-qubit gate application
147fn apply_single_qubit_batch_simd(
148    batch: &mut BatchStateVector,
149    gate_matrix: &[Complex64; 4],
150    target_idx: usize,
151    n_qubits: usize,
152) -> QuantRS2Result<()> {
153    // use scirs2_core::simd_ops::SimdUnifiedOps;
154    use scirs2_core::ndarray::ArrayView1;
155
156    let batch_size = batch.batch_size();
157    let state_size = 1 << n_qubits;
158    let target_mask = 1 << target_idx;
159
160    // Extract gate matrix components
161    let g00 = gate_matrix[0];
162    let g01 = gate_matrix[1];
163    let g10 = gate_matrix[2];
164    let g11 = gate_matrix[3];
165
166    // Process using scirs2_core SIMD operations
167    // We'll process multiple batch items simultaneously using SIMD
168    // Collect all pairs of amplitudes that need to be transformed
169    let _pairs_per_batch = state_size / 2;
170    let _total_pairs = batch_size * _pairs_per_batch;
171
172    // For simpler implementation, process each batch item individually
173    // but use SIMD within each batch item
174    for batch_idx in 0..batch_size {
175        // Collect indices and values for SIMD processing
176        let mut idx_pairs = Vec::new();
177        let mut a_values = Vec::new();
178        let mut b_values = Vec::new();
179
180        for i in 0..state_size {
181            if i & target_mask == 0 {
182                let j = i | target_mask;
183                idx_pairs.push((i, j));
184                a_values.push(batch.states[[batch_idx, i]]);
185                b_values.push(batch.states[[batch_idx, j]]);
186            }
187        }
188
189        if idx_pairs.is_empty() {
190            continue;
191        }
192
193        // Apply gate transformation using SIMD
194        // new_a = g00 * a + g01 * b
195        // new_b = g10 * a + g11 * b
196
197        // Extract real and imaginary parts
198        let _len = a_values.len();
199        let a_real: Vec<f64> = a_values.iter().map(|c| c.re).collect();
200        let a_imag: Vec<f64> = a_values.iter().map(|c| c.im).collect();
201        let b_real: Vec<f64> = b_values.iter().map(|c| c.re).collect();
202        let b_imag: Vec<f64> = b_values.iter().map(|c| c.im).collect();
203
204        // Compute new_a using SIMD
205        let a_real_view = ArrayView1::from(&a_real);
206        let a_imag_view = ArrayView1::from(&a_imag);
207        let b_real_view = ArrayView1::from(&b_real);
208        let b_imag_view = ArrayView1::from(&b_imag);
209
210        // new_a_real = g00.re * a.re - g00.im * a.im + g01.re * b.re - g01.im * b.im
211        let term1 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g00.re);
212        let term2 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g00.im);
213        let term3 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g01.re);
214        let term4 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g01.im);
215
216        let temp1 = <f64 as SimdF64>::simd_sub_arrays(&term1.view(), &term2.view());
217        let temp2 = <f64 as SimdF64>::simd_sub_arrays(&term3.view(), &term4.view());
218        let new_a_real = <f64 as SimdF64>::simd_add_arrays(&temp1.view(), &temp2.view());
219
220        // new_a_imag = g00.re * a.im + g00.im * a.re + g01.re * b.im + g01.im * b.re
221        let term5 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g00.re);
222        let term6 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g00.im);
223        let term7 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g01.re);
224        let term8 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g01.im);
225
226        let temp3 = <f64 as SimdF64>::simd_add_arrays(&term5.view(), &term6.view());
227        let temp4 = <f64 as SimdF64>::simd_add_arrays(&term7.view(), &term8.view());
228        let new_a_imag = <f64 as SimdF64>::simd_add_arrays(&temp3.view(), &temp4.view());
229
230        // Compute new_b using SIMD (similar process)
231        let term9 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g10.re);
232        let term10 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g10.im);
233        let term11 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g11.re);
234        let term12 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g11.im);
235
236        let temp5 = <f64 as SimdF64>::simd_sub_arrays(&term9.view(), &term10.view());
237        let temp6 = <f64 as SimdF64>::simd_sub_arrays(&term11.view(), &term12.view());
238        let new_b_real = <f64 as SimdF64>::simd_add_arrays(&temp5.view(), &temp6.view());
239
240        let term13 = <f64 as SimdF64>::simd_scalar_mul(&a_imag_view, g10.re);
241        let term14 = <f64 as SimdF64>::simd_scalar_mul(&a_real_view, g10.im);
242        let term15 = <f64 as SimdF64>::simd_scalar_mul(&b_imag_view, g11.re);
243        let term16 = <f64 as SimdF64>::simd_scalar_mul(&b_real_view, g11.im);
244
245        let temp7 = <f64 as SimdF64>::simd_add_arrays(&term13.view(), &term14.view());
246        let temp8 = <f64 as SimdF64>::simd_add_arrays(&term15.view(), &term16.view());
247        let new_b_imag = <f64 as SimdF64>::simd_add_arrays(&temp7.view(), &temp8.view());
248
249        // Write back results
250        for (idx, &(i, j)) in idx_pairs.iter().enumerate() {
251            batch.states[[batch_idx, i]] = Complex64::new(new_a_real[idx], new_a_imag[idx]);
252            batch.states[[batch_idx, j]] = Complex64::new(new_b_real[idx], new_b_imag[idx]);
253        }
254    }
255
256    Ok(())
257}
258
259/// Apply a two-qubit gate to a state vector
260fn apply_two_qubit_to_state(
261    state: &mut Array1<Complex64>,
262    gate_matrix: &[Complex64; 16],
263    control_idx: usize,
264    target_idx: usize,
265    n_qubits: usize,
266) -> QuantRS2Result<()> {
267    let state_size = 1 << n_qubits;
268    let control_mask = 1 << control_idx;
269    let target_mask = 1 << target_idx;
270
271    for i in 0..state_size {
272        if (i & control_mask == 0) && (i & target_mask == 0) {
273            let i00 = i;
274            let i01 = i | target_mask;
275            let i10 = i | control_mask;
276            let i11 = i | control_mask | target_mask;
277
278            let a00 = state[i00];
279            let a01 = state[i01];
280            let a10 = state[i10];
281            let a11 = state[i11];
282
283            state[i00] = gate_matrix[0] * a00
284                + gate_matrix[1] * a01
285                + gate_matrix[2] * a10
286                + gate_matrix[3] * a11;
287            state[i01] = gate_matrix[4] * a00
288                + gate_matrix[5] * a01
289                + gate_matrix[6] * a10
290                + gate_matrix[7] * a11;
291            state[i10] = gate_matrix[8] * a00
292                + gate_matrix[9] * a01
293                + gate_matrix[10] * a10
294                + gate_matrix[11] * a11;
295            state[i11] = gate_matrix[12] * a00
296                + gate_matrix[13] * a01
297                + gate_matrix[14] * a10
298                + gate_matrix[15] * a11;
299        }
300    }
301
302    Ok(())
303}
304
305/// Batch-optimized Hadamard gate using SciRS2
306pub struct BatchHadamard;
307
308impl BatchGateOp for Hadamard {
309    fn apply_batch(
310        &self,
311        batch: &mut BatchStateVector,
312        target_qubits: &[QubitId],
313    ) -> QuantRS2Result<()> {
314        if target_qubits.len() != 1 {
315            return Err(QuantRS2Error::InvalidInput(
316                "Hadamard gate requires exactly one target qubit".to_string(),
317            ));
318        }
319
320        let gate_matrix = [
321            Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
322            Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
323            Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
324            Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
325        ];
326
327        apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
328    }
329}
330
331/// Batch-optimized Pauli-X gate
332impl BatchGateOp for PauliX {
333    fn apply_batch(
334        &self,
335        batch: &mut BatchStateVector,
336        target_qubits: &[QubitId],
337    ) -> QuantRS2Result<()> {
338        if target_qubits.len() != 1 {
339            return Err(QuantRS2Error::InvalidInput(
340                "Pauli-X gate requires exactly one target qubit".to_string(),
341            ));
342        }
343
344        let gate_matrix = [
345            Complex64::new(0.0, 0.0),
346            Complex64::new(1.0, 0.0),
347            Complex64::new(1.0, 0.0),
348            Complex64::new(0.0, 0.0),
349        ];
350
351        apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
352    }
353}
354
355/// Apply multiple gates to a batch using SciRS2 batch operations
356pub fn apply_gate_sequence_batch(
357    batch: &mut BatchStateVector,
358    gates: &[(Box<dyn GateOp>, Vec<QubitId>)],
359) -> QuantRS2Result<()> {
360    // For gates that support batch operations, use them
361    // Otherwise fall back to standard application
362
363    for (gate, qubits) in gates {
364        // For now, always use standard application
365        // TODO: Add batch-optimized gate detection
366        {
367            // Fall back to standard application
368            let matrix = gate.matrix()?;
369
370            match qubits.len() {
371                1 => {
372                    let mut gate_array = [Complex64::new(0.0, 0.0); 4];
373                    gate_array.copy_from_slice(&matrix[..4]);
374                    apply_single_qubit_gate_batch(batch, &gate_array, qubits[0])?;
375                }
376                2 => {
377                    let mut gate_array = [Complex64::new(0.0, 0.0); 16];
378                    gate_array.copy_from_slice(&matrix[..16]);
379                    apply_two_qubit_gate_batch(batch, &gate_array, qubits[0], qubits[1])?;
380                }
381                _ => {
382                    return Err(QuantRS2Error::InvalidInput(
383                        "Batch operations for gates with more than 2 qubits not yet supported"
384                            .to_string(),
385                    ));
386                }
387            }
388        }
389    }
390
391    Ok(())
392}
393
394/// Batch matrix multiplication
395/// Note: SciRS2 batch_matmul doesn't support Complex numbers, so we implement our own
396pub fn batch_state_matrix_multiply(
397    batch: &BatchStateVector,
398    matrices: &Array3<Complex64>,
399) -> QuantRS2Result<BatchStateVector> {
400    let batch_size = batch.batch_size();
401    let (num_matrices, rows, cols) = matrices.dim();
402
403    if num_matrices != batch_size {
404        return Err(QuantRS2Error::InvalidInput(format!(
405            "Number of matrices {} doesn't match batch size {}",
406            num_matrices, batch_size
407        )));
408    }
409
410    if cols != batch.states.ncols() {
411        return Err(QuantRS2Error::InvalidInput(format!(
412            "Matrix columns {} don't match state size {}",
413            cols,
414            batch.states.ncols()
415        )));
416    }
417
418    // Perform batch matrix multiplication manually
419    let mut result_states = Array2::zeros((batch_size, rows));
420
421    // Use parallel processing for large batches
422    if batch_size > 16 {
423        // use scirs2_core::parallel_ops::*;
424        use crate::parallel_ops_stubs::*;
425
426        let results: Vec<_> = (0..batch_size)
427            .into_par_iter()
428            .map(|i| {
429                let matrix = matrices.slice(s![i, .., ..]);
430                let state = batch.states.row(i);
431                matrix.dot(&state)
432            })
433            .collect();
434
435        for (i, result) in results.into_iter().enumerate() {
436            result_states.row_mut(i).assign(&result);
437        }
438    } else {
439        // Sequential for small batches
440        for i in 0..batch_size {
441            let matrix = matrices.slice(s![i, .., ..]);
442            let state = batch.states.row(i);
443            let result = matrix.dot(&state);
444            result_states.row_mut(i).assign(&result);
445        }
446    }
447
448    BatchStateVector::from_states(result_states, batch.config.clone())
449}
450
451/// Parallel expectation value computation
452pub fn compute_expectation_values_batch(
453    batch: &BatchStateVector,
454    observable_matrix: &Array2<Complex64>,
455) -> QuantRS2Result<Vec<f64>> {
456    let batch_size = batch.batch_size();
457
458    // Use parallel computation for large batches
459    if batch_size > 16 {
460        let expectations: Vec<f64> = (0..batch_size)
461            .into_par_iter()
462            .map(|i| {
463                let state = batch.states.row(i);
464                compute_expectation_value(&state.to_owned(), observable_matrix)
465            })
466            .collect();
467
468        Ok(expectations)
469    } else {
470        // Sequential for small batches
471        let mut expectations = Vec::with_capacity(batch_size);
472        for i in 0..batch_size {
473            let state = batch.states.row(i);
474            expectations.push(compute_expectation_value(
475                &state.to_owned(),
476                observable_matrix,
477            ));
478        }
479        Ok(expectations)
480    }
481}
482
483/// Compute expectation value for a single state
484fn compute_expectation_value(state: &Array1<Complex64>, observable: &Array2<Complex64>) -> f64 {
485    // <ψ|O|ψ>
486    let temp = observable.dot(state);
487    let expectation = state
488        .iter()
489        .zip(temp.iter())
490        .map(|(a, b)| a.conj() * b)
491        .sum::<Complex64>();
492
493    expectation.re
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use scirs2_core::ndarray::array;
500
501    #[test]
502    fn test_batch_hadamard() {
503        let mut batch = BatchStateVector::new(3, 1, Default::default()).unwrap();
504        let h = Hadamard { target: QubitId(0) };
505
506        h.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
507
508        // Check all states are in superposition
509        for i in 0..3 {
510            let state = batch.get_state(i).unwrap();
511            assert!((state[0].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
512            assert!((state[1].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
513        }
514    }
515
516    #[test]
517    fn test_batch_pauli_x() {
518        let mut batch = BatchStateVector::new(2, 1, Default::default()).unwrap();
519        let x = PauliX { target: QubitId(0) };
520
521        x.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
522
523        // Check all states are flipped
524        for i in 0..2 {
525            let state = batch.get_state(i).unwrap();
526            assert_eq!(state[0], Complex64::new(0.0, 0.0));
527            assert_eq!(state[1], Complex64::new(1.0, 0.0));
528        }
529    }
530
531    #[test]
532    fn test_expectation_values_batch() {
533        let batch = BatchStateVector::new(5, 1, Default::default()).unwrap();
534
535        // Pauli Z observable
536        let z_observable = array![
537            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
538            [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
539        ];
540
541        let expectations = compute_expectation_values_batch(&batch, &z_observable).unwrap();
542
543        // All states are |0>, so expectation of Z should be 1
544        for exp in expectations {
545            assert!((exp - 1.0).abs() < 1e-10);
546        }
547    }
548}