quantrs2_core/batch/
operations.rs

1//! Batch operations for quantum gates using SciRS2 parallel algorithms
2
3use super::{BatchGateOp, BatchStateVector};
4use crate::{
5    error::{QuantRS2Error, QuantRS2Result},
6    gate::{single::*, GateOp},
7    qubit::QubitId,
8};
9use ndarray::{s, Array1, Array2, Array3, ArrayView2, Axis};
10use num_complex::Complex64;
11use rayon::prelude::*;
12use std::sync::Arc;
13
14// Import SciRS2 batch operations
15// Note: SciRS2 batch operations don't support Complex numbers yet
16// extern crate scirs2_linalg;
17// use scirs2_linalg::batch::{batch_matmul, batch_matvec};
18
19/// Apply a single-qubit gate to all states in a batch
20pub fn apply_single_qubit_gate_batch(
21    batch: &mut BatchStateVector,
22    gate_matrix: &[Complex64; 4],
23    target: QubitId,
24) -> QuantRS2Result<()> {
25    let n_qubits = batch.n_qubits;
26    let target_idx = target.0 as usize;
27
28    if target_idx >= n_qubits {
29        return Err(QuantRS2Error::InvalidQubitId(target.0));
30    }
31
32    let batch_size = batch.batch_size();
33    let state_size = 1 << n_qubits;
34
35    // Use parallel processing for large batches
36    if batch_size > 32 {
37        batch
38            .states
39            .axis_iter_mut(Axis(0))
40            .into_par_iter()
41            .try_for_each(|mut state_row| -> QuantRS2Result<()> {
42                apply_single_qubit_to_state(
43                    &mut state_row.to_owned(),
44                    gate_matrix,
45                    target_idx,
46                    n_qubits,
47                )?;
48                Ok(())
49            })?;
50    } else {
51        // Sequential for small batches
52        for i in 0..batch_size {
53            let mut state = batch.states.row(i).to_owned();
54            apply_single_qubit_to_state(&mut state, gate_matrix, target_idx, n_qubits)?;
55            batch.states.row_mut(i).assign(&state);
56        }
57    }
58
59    Ok(())
60}
61
62/// Apply a two-qubit gate to all states in a batch
63pub fn apply_two_qubit_gate_batch(
64    batch: &mut BatchStateVector,
65    gate_matrix: &[Complex64; 16],
66    control: QubitId,
67    target: QubitId,
68) -> QuantRS2Result<()> {
69    let n_qubits = batch.n_qubits;
70    let control_idx = control.0 as usize;
71    let target_idx = target.0 as usize;
72
73    if control_idx >= n_qubits || target_idx >= n_qubits {
74        return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
75            control.0
76        } else {
77            target.0
78        }));
79    }
80
81    if control_idx == target_idx {
82        return Err(QuantRS2Error::InvalidInput(
83            "Control and target qubits must be different".to_string(),
84        ));
85    }
86
87    let batch_size = batch.batch_size();
88
89    // Use parallel processing for large batches
90    if batch_size > 16 {
91        batch
92            .states
93            .axis_iter_mut(Axis(0))
94            .into_par_iter()
95            .try_for_each(|mut state_row| -> QuantRS2Result<()> {
96                apply_two_qubit_to_state(
97                    &mut state_row.to_owned(),
98                    gate_matrix,
99                    control_idx,
100                    target_idx,
101                    n_qubits,
102                )?;
103                Ok(())
104            })?;
105    } else {
106        // Sequential for small batches
107        for i in 0..batch_size {
108            let mut state = batch.states.row(i).to_owned();
109            apply_two_qubit_to_state(&mut state, gate_matrix, control_idx, target_idx, n_qubits)?;
110            batch.states.row_mut(i).assign(&state);
111        }
112    }
113
114    Ok(())
115}
116
117/// Apply a single-qubit gate to a state vector
118fn apply_single_qubit_to_state(
119    state: &mut Array1<Complex64>,
120    gate_matrix: &[Complex64; 4],
121    target_idx: usize,
122    n_qubits: usize,
123) -> QuantRS2Result<()> {
124    let state_size = 1 << n_qubits;
125    let target_mask = 1 << target_idx;
126
127    for i in 0..state_size {
128        if i & target_mask == 0 {
129            let j = i | target_mask;
130
131            let a = state[i];
132            let b = state[j];
133
134            state[i] = gate_matrix[0] * a + gate_matrix[1] * b;
135            state[j] = gate_matrix[2] * a + gate_matrix[3] * b;
136        }
137    }
138
139    Ok(())
140}
141
142/// Apply a two-qubit gate to a state vector
143fn apply_two_qubit_to_state(
144    state: &mut Array1<Complex64>,
145    gate_matrix: &[Complex64; 16],
146    control_idx: usize,
147    target_idx: usize,
148    n_qubits: usize,
149) -> QuantRS2Result<()> {
150    let state_size = 1 << n_qubits;
151    let control_mask = 1 << control_idx;
152    let target_mask = 1 << target_idx;
153
154    for i in 0..state_size {
155        if (i & control_mask == 0) && (i & target_mask == 0) {
156            let i00 = i;
157            let i01 = i | target_mask;
158            let i10 = i | control_mask;
159            let i11 = i | control_mask | target_mask;
160
161            let a00 = state[i00];
162            let a01 = state[i01];
163            let a10 = state[i10];
164            let a11 = state[i11];
165
166            state[i00] = gate_matrix[0] * a00
167                + gate_matrix[1] * a01
168                + gate_matrix[2] * a10
169                + gate_matrix[3] * a11;
170            state[i01] = gate_matrix[4] * a00
171                + gate_matrix[5] * a01
172                + gate_matrix[6] * a10
173                + gate_matrix[7] * a11;
174            state[i10] = gate_matrix[8] * a00
175                + gate_matrix[9] * a01
176                + gate_matrix[10] * a10
177                + gate_matrix[11] * a11;
178            state[i11] = gate_matrix[12] * a00
179                + gate_matrix[13] * a01
180                + gate_matrix[14] * a10
181                + gate_matrix[15] * a11;
182        }
183    }
184
185    Ok(())
186}
187
188/// Batch-optimized Hadamard gate using SciRS2
189pub struct BatchHadamard;
190
191impl BatchGateOp for Hadamard {
192    fn apply_batch(
193        &self,
194        batch: &mut BatchStateVector,
195        target_qubits: &[QubitId],
196    ) -> QuantRS2Result<()> {
197        if target_qubits.len() != 1 {
198            return Err(QuantRS2Error::InvalidInput(
199                "Hadamard gate requires exactly one target qubit".to_string(),
200            ));
201        }
202
203        let gate_matrix = [
204            Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
205            Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
206            Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
207            Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
208        ];
209
210        apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
211    }
212}
213
214/// Batch-optimized Pauli-X gate
215impl BatchGateOp for PauliX {
216    fn apply_batch(
217        &self,
218        batch: &mut BatchStateVector,
219        target_qubits: &[QubitId],
220    ) -> QuantRS2Result<()> {
221        if target_qubits.len() != 1 {
222            return Err(QuantRS2Error::InvalidInput(
223                "Pauli-X gate requires exactly one target qubit".to_string(),
224            ));
225        }
226
227        let gate_matrix = [
228            Complex64::new(0.0, 0.0),
229            Complex64::new(1.0, 0.0),
230            Complex64::new(1.0, 0.0),
231            Complex64::new(0.0, 0.0),
232        ];
233
234        apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
235    }
236}
237
238/// Apply multiple gates to a batch using SciRS2 batch operations
239pub fn apply_gate_sequence_batch(
240    batch: &mut BatchStateVector,
241    gates: &[(Box<dyn GateOp>, Vec<QubitId>)],
242) -> QuantRS2Result<()> {
243    // For gates that support batch operations, use them
244    // Otherwise fall back to standard application
245
246    for (gate, qubits) in gates {
247        // For now, always use standard application
248        // TODO: Add batch-optimized gate detection
249        {
250            // Fall back to standard application
251            let matrix = gate.matrix()?;
252
253            match qubits.len() {
254                1 => {
255                    let mut gate_array = [Complex64::new(0.0, 0.0); 4];
256                    gate_array.copy_from_slice(&matrix[..4]);
257                    apply_single_qubit_gate_batch(batch, &gate_array, qubits[0])?;
258                }
259                2 => {
260                    let mut gate_array = [Complex64::new(0.0, 0.0); 16];
261                    gate_array.copy_from_slice(&matrix[..16]);
262                    apply_two_qubit_gate_batch(batch, &gate_array, qubits[0], qubits[1])?;
263                }
264                _ => {
265                    return Err(QuantRS2Error::InvalidInput(
266                        "Batch operations for gates with more than 2 qubits not yet supported"
267                            .to_string(),
268                    ));
269                }
270            }
271        }
272    }
273
274    Ok(())
275}
276
277/// Batch matrix multiplication
278/// Note: SciRS2 batch_matmul doesn't support Complex numbers, so we implement our own
279pub fn batch_state_matrix_multiply(
280    batch: &BatchStateVector,
281    matrices: &Array3<Complex64>,
282) -> QuantRS2Result<BatchStateVector> {
283    let batch_size = batch.batch_size();
284    let (num_matrices, rows, cols) = matrices.dim();
285
286    if num_matrices != batch_size {
287        return Err(QuantRS2Error::InvalidInput(format!(
288            "Number of matrices {} doesn't match batch size {}",
289            num_matrices, batch_size
290        )));
291    }
292
293    if cols != batch.states.ncols() {
294        return Err(QuantRS2Error::InvalidInput(format!(
295            "Matrix columns {} don't match state size {}",
296            cols,
297            batch.states.ncols()
298        )));
299    }
300
301    // Perform batch matrix multiplication manually
302    let mut result_states = Array2::zeros((batch_size, rows));
303
304    // Use parallel processing for large batches
305    if batch_size > 16 {
306        use rayon::prelude::*;
307
308        let results: Vec<_> = (0..batch_size)
309            .into_par_iter()
310            .map(|i| {
311                let matrix = matrices.slice(s![i, .., ..]);
312                let state = batch.states.row(i);
313                matrix.dot(&state)
314            })
315            .collect();
316
317        for (i, result) in results.into_iter().enumerate() {
318            result_states.row_mut(i).assign(&result);
319        }
320    } else {
321        // Sequential for small batches
322        for i in 0..batch_size {
323            let matrix = matrices.slice(s![i, .., ..]);
324            let state = batch.states.row(i);
325            let result = matrix.dot(&state);
326            result_states.row_mut(i).assign(&result);
327        }
328    }
329
330    BatchStateVector::from_states(result_states, batch.config.clone())
331}
332
333/// Parallel expectation value computation
334pub fn compute_expectation_values_batch(
335    batch: &BatchStateVector,
336    observable_matrix: &Array2<Complex64>,
337) -> QuantRS2Result<Vec<f64>> {
338    let batch_size = batch.batch_size();
339
340    // Use parallel computation for large batches
341    if batch_size > 16 {
342        let expectations: Vec<f64> = (0..batch_size)
343            .into_par_iter()
344            .map(|i| {
345                let state = batch.states.row(i);
346                compute_expectation_value(&state.to_owned(), observable_matrix)
347            })
348            .collect();
349
350        Ok(expectations)
351    } else {
352        // Sequential for small batches
353        let mut expectations = Vec::with_capacity(batch_size);
354        for i in 0..batch_size {
355            let state = batch.states.row(i);
356            expectations.push(compute_expectation_value(
357                &state.to_owned(),
358                observable_matrix,
359            ));
360        }
361        Ok(expectations)
362    }
363}
364
365/// Compute expectation value for a single state
366fn compute_expectation_value(state: &Array1<Complex64>, observable: &Array2<Complex64>) -> f64 {
367    // <ψ|O|ψ>
368    let temp = observable.dot(state);
369    let expectation = state
370        .iter()
371        .zip(temp.iter())
372        .map(|(a, b)| a.conj() * b)
373        .sum::<Complex64>();
374
375    expectation.re
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use ndarray::array;
382
383    #[test]
384    fn test_batch_hadamard() {
385        let mut batch = BatchStateVector::new(3, 1, Default::default()).unwrap();
386        let h = Hadamard { target: QubitId(0) };
387
388        h.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
389
390        // Check all states are in superposition
391        for i in 0..3 {
392            let state = batch.get_state(i).unwrap();
393            assert!((state[0].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
394            assert!((state[1].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
395        }
396    }
397
398    #[test]
399    fn test_batch_pauli_x() {
400        let mut batch = BatchStateVector::new(2, 1, Default::default()).unwrap();
401        let x = PauliX { target: QubitId(0) };
402
403        x.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
404
405        // Check all states are flipped
406        for i in 0..2 {
407            let state = batch.get_state(i).unwrap();
408            assert_eq!(state[0], Complex64::new(0.0, 0.0));
409            assert_eq!(state[1], Complex64::new(1.0, 0.0));
410        }
411    }
412
413    #[test]
414    fn test_expectation_values_batch() {
415        let batch = BatchStateVector::new(5, 1, Default::default()).unwrap();
416
417        // Pauli Z observable
418        let z_observable = array![
419            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
420            [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
421        ];
422
423        let expectations = compute_expectation_values_batch(&batch, &z_observable).unwrap();
424
425        // All states are |0>, so expectation of Z should be 1
426        for exp in expectations {
427            assert!((exp - 1.0).abs() < 1e-10);
428        }
429    }
430}