1use super::{BatchGateOp, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 gate::{single::*, GateOp},
7 qubit::QubitId,
8};
9use ndarray::{s, Array1, Array2, Array3, ArrayView2, Axis};
10use num_complex::Complex64;
11use rayon::prelude::*;
12use std::sync::Arc;
13
14pub fn apply_single_qubit_gate_batch(
21 batch: &mut BatchStateVector,
22 gate_matrix: &[Complex64; 4],
23 target: QubitId,
24) -> QuantRS2Result<()> {
25 let n_qubits = batch.n_qubits;
26 let target_idx = target.0 as usize;
27
28 if target_idx >= n_qubits {
29 return Err(QuantRS2Error::InvalidQubitId(target.0));
30 }
31
32 let batch_size = batch.batch_size();
33 let state_size = 1 << n_qubits;
34
35 if batch_size > 32 {
37 batch
38 .states
39 .axis_iter_mut(Axis(0))
40 .into_par_iter()
41 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
42 apply_single_qubit_to_state(
43 &mut state_row.to_owned(),
44 gate_matrix,
45 target_idx,
46 n_qubits,
47 )?;
48 Ok(())
49 })?;
50 } else {
51 for i in 0..batch_size {
53 let mut state = batch.states.row(i).to_owned();
54 apply_single_qubit_to_state(&mut state, gate_matrix, target_idx, n_qubits)?;
55 batch.states.row_mut(i).assign(&state);
56 }
57 }
58
59 Ok(())
60}
61
62pub fn apply_two_qubit_gate_batch(
64 batch: &mut BatchStateVector,
65 gate_matrix: &[Complex64; 16],
66 control: QubitId,
67 target: QubitId,
68) -> QuantRS2Result<()> {
69 let n_qubits = batch.n_qubits;
70 let control_idx = control.0 as usize;
71 let target_idx = target.0 as usize;
72
73 if control_idx >= n_qubits || target_idx >= n_qubits {
74 return Err(QuantRS2Error::InvalidQubitId(if control_idx >= n_qubits {
75 control.0
76 } else {
77 target.0
78 }));
79 }
80
81 if control_idx == target_idx {
82 return Err(QuantRS2Error::InvalidInput(
83 "Control and target qubits must be different".to_string(),
84 ));
85 }
86
87 let batch_size = batch.batch_size();
88
89 if batch_size > 16 {
91 batch
92 .states
93 .axis_iter_mut(Axis(0))
94 .into_par_iter()
95 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
96 apply_two_qubit_to_state(
97 &mut state_row.to_owned(),
98 gate_matrix,
99 control_idx,
100 target_idx,
101 n_qubits,
102 )?;
103 Ok(())
104 })?;
105 } else {
106 for i in 0..batch_size {
108 let mut state = batch.states.row(i).to_owned();
109 apply_two_qubit_to_state(&mut state, gate_matrix, control_idx, target_idx, n_qubits)?;
110 batch.states.row_mut(i).assign(&state);
111 }
112 }
113
114 Ok(())
115}
116
117fn apply_single_qubit_to_state(
119 state: &mut Array1<Complex64>,
120 gate_matrix: &[Complex64; 4],
121 target_idx: usize,
122 n_qubits: usize,
123) -> QuantRS2Result<()> {
124 let state_size = 1 << n_qubits;
125 let target_mask = 1 << target_idx;
126
127 for i in 0..state_size {
128 if i & target_mask == 0 {
129 let j = i | target_mask;
130
131 let a = state[i];
132 let b = state[j];
133
134 state[i] = gate_matrix[0] * a + gate_matrix[1] * b;
135 state[j] = gate_matrix[2] * a + gate_matrix[3] * b;
136 }
137 }
138
139 Ok(())
140}
141
142fn apply_two_qubit_to_state(
144 state: &mut Array1<Complex64>,
145 gate_matrix: &[Complex64; 16],
146 control_idx: usize,
147 target_idx: usize,
148 n_qubits: usize,
149) -> QuantRS2Result<()> {
150 let state_size = 1 << n_qubits;
151 let control_mask = 1 << control_idx;
152 let target_mask = 1 << target_idx;
153
154 for i in 0..state_size {
155 if (i & control_mask == 0) && (i & target_mask == 0) {
156 let i00 = i;
157 let i01 = i | target_mask;
158 let i10 = i | control_mask;
159 let i11 = i | control_mask | target_mask;
160
161 let a00 = state[i00];
162 let a01 = state[i01];
163 let a10 = state[i10];
164 let a11 = state[i11];
165
166 state[i00] = gate_matrix[0] * a00
167 + gate_matrix[1] * a01
168 + gate_matrix[2] * a10
169 + gate_matrix[3] * a11;
170 state[i01] = gate_matrix[4] * a00
171 + gate_matrix[5] * a01
172 + gate_matrix[6] * a10
173 + gate_matrix[7] * a11;
174 state[i10] = gate_matrix[8] * a00
175 + gate_matrix[9] * a01
176 + gate_matrix[10] * a10
177 + gate_matrix[11] * a11;
178 state[i11] = gate_matrix[12] * a00
179 + gate_matrix[13] * a01
180 + gate_matrix[14] * a10
181 + gate_matrix[15] * a11;
182 }
183 }
184
185 Ok(())
186}
187
188pub struct BatchHadamard;
190
191impl BatchGateOp for Hadamard {
192 fn apply_batch(
193 &self,
194 batch: &mut BatchStateVector,
195 target_qubits: &[QubitId],
196 ) -> QuantRS2Result<()> {
197 if target_qubits.len() != 1 {
198 return Err(QuantRS2Error::InvalidInput(
199 "Hadamard gate requires exactly one target qubit".to_string(),
200 ));
201 }
202
203 let gate_matrix = [
204 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
205 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
206 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
207 Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
208 ];
209
210 apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
211 }
212}
213
214impl BatchGateOp for PauliX {
216 fn apply_batch(
217 &self,
218 batch: &mut BatchStateVector,
219 target_qubits: &[QubitId],
220 ) -> QuantRS2Result<()> {
221 if target_qubits.len() != 1 {
222 return Err(QuantRS2Error::InvalidInput(
223 "Pauli-X gate requires exactly one target qubit".to_string(),
224 ));
225 }
226
227 let gate_matrix = [
228 Complex64::new(0.0, 0.0),
229 Complex64::new(1.0, 0.0),
230 Complex64::new(1.0, 0.0),
231 Complex64::new(0.0, 0.0),
232 ];
233
234 apply_single_qubit_gate_batch(batch, &gate_matrix, target_qubits[0])
235 }
236}
237
238pub fn apply_gate_sequence_batch(
240 batch: &mut BatchStateVector,
241 gates: &[(Box<dyn GateOp>, Vec<QubitId>)],
242) -> QuantRS2Result<()> {
243 for (gate, qubits) in gates {
247 {
250 let matrix = gate.matrix()?;
252
253 match qubits.len() {
254 1 => {
255 let mut gate_array = [Complex64::new(0.0, 0.0); 4];
256 gate_array.copy_from_slice(&matrix[..4]);
257 apply_single_qubit_gate_batch(batch, &gate_array, qubits[0])?;
258 }
259 2 => {
260 let mut gate_array = [Complex64::new(0.0, 0.0); 16];
261 gate_array.copy_from_slice(&matrix[..16]);
262 apply_two_qubit_gate_batch(batch, &gate_array, qubits[0], qubits[1])?;
263 }
264 _ => {
265 return Err(QuantRS2Error::InvalidInput(
266 "Batch operations for gates with more than 2 qubits not yet supported"
267 .to_string(),
268 ));
269 }
270 }
271 }
272 }
273
274 Ok(())
275}
276
277pub fn batch_state_matrix_multiply(
280 batch: &BatchStateVector,
281 matrices: &Array3<Complex64>,
282) -> QuantRS2Result<BatchStateVector> {
283 let batch_size = batch.batch_size();
284 let (num_matrices, rows, cols) = matrices.dim();
285
286 if num_matrices != batch_size {
287 return Err(QuantRS2Error::InvalidInput(format!(
288 "Number of matrices {} doesn't match batch size {}",
289 num_matrices, batch_size
290 )));
291 }
292
293 if cols != batch.states.ncols() {
294 return Err(QuantRS2Error::InvalidInput(format!(
295 "Matrix columns {} don't match state size {}",
296 cols,
297 batch.states.ncols()
298 )));
299 }
300
301 let mut result_states = Array2::zeros((batch_size, rows));
303
304 if batch_size > 16 {
306 use rayon::prelude::*;
307
308 let results: Vec<_> = (0..batch_size)
309 .into_par_iter()
310 .map(|i| {
311 let matrix = matrices.slice(s![i, .., ..]);
312 let state = batch.states.row(i);
313 matrix.dot(&state)
314 })
315 .collect();
316
317 for (i, result) in results.into_iter().enumerate() {
318 result_states.row_mut(i).assign(&result);
319 }
320 } else {
321 for i in 0..batch_size {
323 let matrix = matrices.slice(s![i, .., ..]);
324 let state = batch.states.row(i);
325 let result = matrix.dot(&state);
326 result_states.row_mut(i).assign(&result);
327 }
328 }
329
330 BatchStateVector::from_states(result_states, batch.config.clone())
331}
332
333pub fn compute_expectation_values_batch(
335 batch: &BatchStateVector,
336 observable_matrix: &Array2<Complex64>,
337) -> QuantRS2Result<Vec<f64>> {
338 let batch_size = batch.batch_size();
339
340 if batch_size > 16 {
342 let expectations: Vec<f64> = (0..batch_size)
343 .into_par_iter()
344 .map(|i| {
345 let state = batch.states.row(i);
346 compute_expectation_value(&state.to_owned(), observable_matrix)
347 })
348 .collect();
349
350 Ok(expectations)
351 } else {
352 let mut expectations = Vec::with_capacity(batch_size);
354 for i in 0..batch_size {
355 let state = batch.states.row(i);
356 expectations.push(compute_expectation_value(
357 &state.to_owned(),
358 observable_matrix,
359 ));
360 }
361 Ok(expectations)
362 }
363}
364
365fn compute_expectation_value(state: &Array1<Complex64>, observable: &Array2<Complex64>) -> f64 {
367 let temp = observable.dot(state);
369 let expectation = state
370 .iter()
371 .zip(temp.iter())
372 .map(|(a, b)| a.conj() * b)
373 .sum::<Complex64>();
374
375 expectation.re
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use ndarray::array;
382
383 #[test]
384 fn test_batch_hadamard() {
385 let mut batch = BatchStateVector::new(3, 1, Default::default()).unwrap();
386 let h = Hadamard { target: QubitId(0) };
387
388 h.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
389
390 for i in 0..3 {
392 let state = batch.get_state(i).unwrap();
393 assert!((state[0].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
394 assert!((state[1].re - 1.0 / std::f64::consts::SQRT_2).abs() < 1e-10);
395 }
396 }
397
398 #[test]
399 fn test_batch_pauli_x() {
400 let mut batch = BatchStateVector::new(2, 1, Default::default()).unwrap();
401 let x = PauliX { target: QubitId(0) };
402
403 x.apply_batch(&mut batch, &[QubitId(0)]).unwrap();
404
405 for i in 0..2 {
407 let state = batch.get_state(i).unwrap();
408 assert_eq!(state[0], Complex64::new(0.0, 0.0));
409 assert_eq!(state[1], Complex64::new(1.0, 0.0));
410 }
411 }
412
413 #[test]
414 fn test_expectation_values_batch() {
415 let batch = BatchStateVector::new(5, 1, Default::default()).unwrap();
416
417 let z_observable = array![
419 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
420 [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
421 ];
422
423 let expectations = compute_expectation_values_batch(&batch, &z_observable).unwrap();
424
425 for exp in expectations {
427 assert!((exp - 1.0).abs() < 1e-10);
428 }
429 }
430}