1use super::{BatchConfig, BatchExecutionResult, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 gate::GateOp,
7 gpu::GpuBackendFactory,
8 qubit::QubitId,
9};
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::Complex64;
12use crate::parallel_ops_stubs::*;
14use std::sync::Arc;
15use std::time::Instant;
16
17pub struct BatchCircuit {
19 pub n_qubits: usize,
21 pub gates: Vec<Box<dyn GateOp>>,
23}
24
25impl BatchCircuit {
26 pub fn new(n_qubits: usize) -> Self {
28 Self {
29 n_qubits,
30 gates: Vec::new(),
31 }
32 }
33
34 pub fn add_gate(&mut self, gate: Box<dyn GateOp>) -> QuantRS2Result<()> {
36 for qubit in gate.qubits() {
38 if qubit.0 as usize >= self.n_qubits {
39 return Err(QuantRS2Error::InvalidQubitId(qubit.0));
40 }
41 }
42 self.gates.push(gate);
43 Ok(())
44 }
45
46 pub fn gate_sequence(&self) -> impl Iterator<Item = &Box<dyn GateOp>> {
48 self.gates.iter()
49 }
50
51 pub fn num_gates(&self) -> usize {
53 self.gates.len()
54 }
55}
56
57pub struct BatchCircuitExecutor {
59 pub config: BatchConfig,
61 pub gpu_backend: Option<Arc<dyn crate::gpu::GpuBackend>>,
63 }
67
68impl BatchCircuitExecutor {
69 pub fn new(config: BatchConfig) -> QuantRS2Result<Self> {
71 let gpu_backend = if config.use_gpu {
73 GpuBackendFactory::create_best_available().ok()
74 } else {
75 None
76 };
77
78 Ok(Self {
96 config,
97 gpu_backend,
98 })
100 }
101
102 pub fn execute_batch(
104 &self,
105 circuit: &BatchCircuit,
106 batch: &mut BatchStateVector,
107 ) -> QuantRS2Result<BatchExecutionResult> {
108 if batch.n_qubits != circuit.n_qubits {
109 return Err(QuantRS2Error::InvalidInput(format!(
110 "Circuit has {} qubits but batch has {}",
111 circuit.n_qubits, batch.n_qubits
112 )));
113 }
114
115 let start_time = Instant::now();
116 let gates_applied = circuit.num_gates();
117
118 let used_gpu = if self.gpu_backend.is_some() && batch.batch_size() >= 64 {
120 self.execute_with_gpu(circuit, batch)?;
121 true
122 } else if batch.batch_size() > self.config.max_batch_size {
123 self.execute_chunked(circuit, batch)?;
124 false
125 } else {
126 self.execute_parallel(circuit, batch)?;
127 false
128 };
129
130 let execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
131
132 Ok(BatchExecutionResult {
133 final_states: batch.states.clone(),
134 execution_time_ms,
135 gates_applied,
136 used_gpu,
137 })
138 }
139
140 fn execute_with_gpu(
142 &self,
143 circuit: &BatchCircuit,
144 batch: &mut BatchStateVector,
145 ) -> QuantRS2Result<()> {
146 if let Some(gpu_backend) = &self.gpu_backend {
147 let mut gpu_states = Vec::new();
149
150 for i in 0..batch.batch_size() {
151 let state_data = batch.get_state(i)?;
152 let mut gpu_buffer = gpu_backend.allocate_state_vector(batch.n_qubits)?;
154 gpu_buffer.upload(state_data.as_slice().unwrap())?;
155
156 gpu_states.push(gpu_buffer);
157 }
158
159 for gate in circuit.gate_sequence() {
161 let gate_qubits = gate.qubits();
162
163 for gpu_state in &mut gpu_states {
165 gpu_backend.apply_gate(
166 gpu_state.as_mut(),
167 gate.as_ref(),
168 &gate_qubits,
169 batch.n_qubits,
170 )?;
171 }
172 }
173
174 for (i, gpu_state) in gpu_states.iter().enumerate() {
176 let state_size = 1 << batch.n_qubits;
177 let mut result_data = vec![Complex64::new(0.0, 0.0); state_size];
178 gpu_state.download(&mut result_data)?;
179
180 let result_array = Array1::from_vec(result_data);
181 batch.set_state(i, &result_array)?;
182 }
183
184 Ok(())
185 } else {
186 self.execute_parallel(circuit, batch)
188 }
189 }
190
191 fn execute_chunked(
193 &self,
194 circuit: &BatchCircuit,
195 batch: &mut BatchStateVector,
196 ) -> QuantRS2Result<()> {
197 let chunk_size = self.config.max_batch_size;
198 let chunks = super::split_batch(batch, chunk_size);
199
200 let processed_chunks: Vec<_> = chunks
202 .into_par_iter()
203 .map(|mut chunk| {
204 self.execute_parallel(circuit, &mut chunk)?;
205 Ok(chunk)
206 })
207 .collect::<QuantRS2Result<Vec<_>>>()?;
208
209 let merged = super::merge_batches(processed_chunks, batch.config.clone())?;
211 batch.states = merged.states;
212
213 Ok(())
214 }
215
216 fn execute_parallel(
218 &self,
219 circuit: &BatchCircuit,
220 batch: &mut BatchStateVector,
221 ) -> QuantRS2Result<()> {
222 let _batch_size = batch.batch_size();
223 let gate_sequence: Vec<_> = circuit.gate_sequence().collect();
224 let gate_refs: Vec<&dyn GateOp> = gate_sequence.iter().map(|g| g.as_ref()).collect();
225
226 self.execute_with_thread_pool(batch, &gate_refs)?;
229
230 Ok(())
231 }
232
233 fn execute_with_thread_pool(
235 &self,
236 batch: &mut BatchStateVector,
237 gates: &[&dyn GateOp],
238 ) -> QuantRS2Result<()> {
239 let batch_size = batch.batch_size();
241 let n_qubits = batch.n_qubits;
242
243 let results: Vec<Array1<Complex64>> = (0..batch_size)
245 .into_par_iter()
246 .map(|i| {
247 let mut state = batch.states.row(i).to_owned();
248 apply_gates_to_state(&mut state, gates, n_qubits).map(|_| state)
249 })
250 .collect::<QuantRS2Result<Vec<_>>>()?;
251
252 for (i, state) in results.into_iter().enumerate() {
254 batch.states.row_mut(i).assign(&state);
255 }
256
257 Ok(())
258 }
259
260 pub fn execute_multiple_circuits(
262 &self,
263 circuits: &[BatchCircuit],
264 initial_batch: &BatchStateVector,
265 ) -> QuantRS2Result<Vec<BatchExecutionResult>> {
266 if circuits.is_empty() {
267 return Ok(Vec::new());
268 }
269
270 let results: Vec<_> = circuits
272 .par_iter()
273 .map(|circuit| {
274 let mut batch_copy = BatchStateVector::from_states(
275 initial_batch.states.clone(),
276 initial_batch.config.clone(),
277 )?;
278
279 self.execute_batch(circuit, &mut batch_copy)
280 })
281 .collect::<QuantRS2Result<Vec<_>>>()?;
282
283 Ok(results)
284 }
285
286 pub fn execute_parameterized_batch(
288 &self,
289 circuit_fn: impl Fn(&[f64]) -> QuantRS2Result<BatchCircuit> + Sync,
290 parameter_sets: &[Vec<f64>],
291 initial_state: &Array1<Complex64>,
292 ) -> QuantRS2Result<Vec<Array1<Complex64>>> {
293 let batch_size = parameter_sets.len();
295 let mut states = Array2::zeros((batch_size, initial_state.len()));
296 for i in 0..batch_size {
297 states.row_mut(i).assign(initial_state);
298 }
299
300 let batch = BatchStateVector::from_states(states, self.config.clone())?;
301
302 let results: Vec<_> = parameter_sets
304 .par_iter()
305 .enumerate()
306 .map(|(i, params)| {
307 let circuit = circuit_fn(params)?;
308 let mut state = batch.get_state(i)?;
309 let gate_sequence: Vec<_> = circuit.gate_sequence().collect();
310 let gate_refs: Vec<&dyn GateOp> =
311 gate_sequence.iter().map(|g| g.as_ref()).collect();
312 apply_gates_to_state(&mut state, &gate_refs, circuit.n_qubits)?;
313 Ok(state)
314 })
315 .collect::<QuantRS2Result<Vec<_>>>()?;
316
317 Ok(results)
318 }
319}
320
321fn apply_gates_to_state(
323 state: &mut Array1<Complex64>,
324 gates: &[&dyn GateOp],
325 n_qubits: usize,
326) -> QuantRS2Result<()> {
327 for gate in gates {
328 let qubits = gate.qubits();
329 let matrix = gate.matrix()?;
330
331 match qubits.len() {
332 1 => {
333 apply_single_qubit_gate(state, &matrix, qubits[0], n_qubits)?;
334 }
335 2 => {
336 apply_two_qubit_gate(state, &matrix, qubits[0], qubits[1], n_qubits)?;
337 }
338 _ => {
339 return Err(QuantRS2Error::InvalidInput(format!(
340 "Gates with {} qubits not yet supported",
341 qubits.len()
342 )));
343 }
344 }
345 }
346
347 Ok(())
348}
349
350fn apply_single_qubit_gate(
352 state: &mut Array1<Complex64>,
353 matrix: &[Complex64],
354 target: QubitId,
355 n_qubits: usize,
356) -> QuantRS2Result<()> {
357 let target_idx = target.0 as usize;
358 let state_size = 1 << n_qubits;
359 let target_mask = 1 << target_idx;
360
361 for i in 0..state_size {
362 if i & target_mask == 0 {
363 let j = i | target_mask;
364
365 let a = state[i];
366 let b = state[j];
367
368 state[i] = matrix[0] * a + matrix[1] * b;
369 state[j] = matrix[2] * a + matrix[3] * b;
370 }
371 }
372
373 Ok(())
374}
375
376fn apply_two_qubit_gate(
378 state: &mut Array1<Complex64>,
379 matrix: &[Complex64],
380 control: QubitId,
381 target: QubitId,
382 n_qubits: usize,
383) -> QuantRS2Result<()> {
384 let control_idx = control.0 as usize;
385 let target_idx = target.0 as usize;
386 let state_size = 1 << n_qubits;
387 let control_mask = 1 << control_idx;
388 let target_mask = 1 << target_idx;
389
390 for i in 0..state_size {
391 if (i & control_mask == 0) && (i & target_mask == 0) {
392 let i00 = i;
393 let i01 = i | target_mask;
394 let i10 = i | control_mask;
395 let i11 = i | control_mask | target_mask;
396
397 let a00 = state[i00];
398 let a01 = state[i01];
399 let a10 = state[i10];
400 let a11 = state[i11];
401
402 state[i00] = matrix[0] * a00 + matrix[1] * a01 + matrix[2] * a10 + matrix[3] * a11;
403 state[i01] = matrix[4] * a00 + matrix[5] * a01 + matrix[6] * a10 + matrix[7] * a11;
404 state[i10] = matrix[8] * a00 + matrix[9] * a01 + matrix[10] * a10 + matrix[11] * a11;
405 state[i11] = matrix[12] * a00 + matrix[13] * a01 + matrix[14] * a10 + matrix[15] * a11;
406 }
407 }
408
409 Ok(())
410}
411
412pub fn create_optimized_executor() -> QuantRS2Result<BatchCircuitExecutor> {
414 let config = BatchConfig {
415 num_workers: Some(8), max_batch_size: 1024,
417 use_gpu: true,
418 memory_limit: Some(8 * 1024 * 1024 * 1024), enable_cache: true,
420 };
421
422 BatchCircuitExecutor::new(config)
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use crate::gate::single::Hadamard;
429
430 #[test]
431 fn test_batch_circuit_execution() {
432 let config = BatchConfig {
433 use_gpu: false,
434 ..Default::default()
435 };
436
437 let executor = BatchCircuitExecutor::new(config).unwrap();
438
439 let mut circuit = BatchCircuit::new(2);
441 circuit
442 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
443 .unwrap();
444 circuit
445 .add_gate(Box::new(Hadamard { target: QubitId(1) }))
446 .unwrap();
447
448 let mut batch = BatchStateVector::new(5, 2, Default::default()).unwrap();
450
451 let result = executor.execute_batch(&circuit, &mut batch).unwrap();
453
454 assert_eq!(result.gates_applied, 2);
455 assert!(!result.used_gpu);
456
457 for i in 0..5 {
459 let state = batch.get_state(i).unwrap();
460 assert!((state[0].re - 0.5).abs() < 1e-10);
461 }
462 }
463
464 #[test]
465 fn test_parallel_circuit_execution() {
466 let config = BatchConfig {
467 num_workers: Some(2),
468 use_gpu: false,
469 ..Default::default()
470 };
471
472 let executor = BatchCircuitExecutor::new(config).unwrap();
473
474 let mut circuits = Vec::new();
476 for _ in 0..3 {
477 let mut circuit = BatchCircuit::new(1);
478 circuit
479 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
480 .unwrap();
481 circuits.push(circuit);
482 }
483
484 let batch = BatchStateVector::new(10, 1, Default::default()).unwrap();
486
487 let results = executor
489 .execute_multiple_circuits(&circuits, &batch)
490 .unwrap();
491
492 assert_eq!(results.len(), 3);
493 for result in results {
494 assert_eq!(result.gates_applied, 1);
495 }
496 }
497}