1use super::{BatchConfig, BatchExecutionResult, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 gate::GateOp,
7 gpu::GpuBackendFactory,
8 qubit::QubitId,
9};
10use ndarray::{Array1, Array2};
11use num_complex::Complex64;
12use rayon::prelude::*;
13use std::sync::Arc;
14use std::time::Instant;
15
16use rayon::ThreadPool;
18
19pub struct BatchCircuit {
21 pub n_qubits: usize,
23 pub gates: Vec<Box<dyn GateOp>>,
25}
26
27impl BatchCircuit {
28 pub fn new(n_qubits: usize) -> Self {
30 Self {
31 n_qubits,
32 gates: Vec::new(),
33 }
34 }
35
36 pub fn add_gate(&mut self, gate: Box<dyn GateOp>) -> QuantRS2Result<()> {
38 for qubit in gate.qubits() {
40 if qubit.0 as usize >= self.n_qubits {
41 return Err(QuantRS2Error::InvalidQubitId(qubit.0));
42 }
43 }
44 self.gates.push(gate);
45 Ok(())
46 }
47
48 pub fn gate_sequence(&self) -> impl Iterator<Item = &Box<dyn GateOp>> {
50 self.gates.iter()
51 }
52
53 pub fn num_gates(&self) -> usize {
55 self.gates.len()
56 }
57}
58
59pub struct BatchCircuitExecutor {
61 pub config: BatchConfig,
63 pub gpu_backend: Option<Arc<dyn crate::gpu::GpuBackend>>,
65 pub thread_pool: Option<ThreadPool>,
67}
68
69impl BatchCircuitExecutor {
70 pub fn new(config: BatchConfig) -> QuantRS2Result<Self> {
72 let gpu_backend = if config.use_gpu {
74 GpuBackendFactory::create_best_available().ok()
75 } else {
76 None
77 };
78
79 let thread_pool = if let Some(num_workers) = config.num_workers {
81 Some(
82 rayon::ThreadPoolBuilder::new()
83 .num_threads(num_workers)
84 .build()
85 .map_err(|e| {
86 QuantRS2Error::ExecutionError(format!(
87 "Failed to create thread pool: {}",
88 e
89 ))
90 })?,
91 )
92 } else {
93 None
94 };
95
96 Ok(Self {
97 config,
98 gpu_backend,
99 thread_pool,
100 })
101 }
102
103 pub fn execute_batch(
105 &self,
106 circuit: &BatchCircuit,
107 batch: &mut BatchStateVector,
108 ) -> QuantRS2Result<BatchExecutionResult> {
109 if batch.n_qubits != circuit.n_qubits {
110 return Err(QuantRS2Error::InvalidInput(format!(
111 "Circuit has {} qubits but batch has {}",
112 circuit.n_qubits, batch.n_qubits
113 )));
114 }
115
116 let start_time = Instant::now();
117 let gates_applied = circuit.num_gates();
118
119 let used_gpu = if self.gpu_backend.is_some() && batch.batch_size() >= 64 {
121 self.execute_with_gpu(circuit, batch)?;
122 true
123 } else if batch.batch_size() > self.config.max_batch_size {
124 self.execute_chunked(circuit, batch)?;
125 false
126 } else {
127 self.execute_parallel(circuit, batch)?;
128 false
129 };
130
131 let execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
132
133 Ok(BatchExecutionResult {
134 final_states: batch.states.clone(),
135 execution_time_ms,
136 gates_applied,
137 used_gpu,
138 })
139 }
140
141 fn execute_with_gpu(
143 &self,
144 circuit: &BatchCircuit,
145 batch: &mut BatchStateVector,
146 ) -> QuantRS2Result<()> {
147 if let Some(gpu_backend) = &self.gpu_backend {
148 let mut gpu_states = Vec::new();
150
151 for i in 0..batch.batch_size() {
152 let state_data = batch.get_state(i)?;
153 let mut gpu_buffer = gpu_backend.allocate_state_vector(batch.n_qubits)?;
155 gpu_buffer.upload(state_data.as_slice().unwrap())?;
156
157 gpu_states.push(gpu_buffer);
158 }
159
160 for gate in circuit.gate_sequence() {
162 let gate_qubits = gate.qubits();
163
164 for gpu_state in &mut gpu_states {
166 gpu_backend.apply_gate(
167 gpu_state.as_mut(),
168 gate.as_ref(),
169 &gate_qubits,
170 batch.n_qubits,
171 )?;
172 }
173 }
174
175 for (i, gpu_state) in gpu_states.iter().enumerate() {
177 let state_size = 1 << batch.n_qubits;
178 let mut result_data = vec![Complex64::new(0.0, 0.0); state_size];
179 gpu_state.download(&mut result_data)?;
180
181 let result_array = Array1::from_vec(result_data);
182 batch.set_state(i, &result_array)?;
183 }
184
185 Ok(())
186 } else {
187 self.execute_parallel(circuit, batch)
189 }
190 }
191
192 fn execute_chunked(
194 &self,
195 circuit: &BatchCircuit,
196 batch: &mut BatchStateVector,
197 ) -> QuantRS2Result<()> {
198 let chunk_size = self.config.max_batch_size;
199 let chunks = super::split_batch(batch, chunk_size);
200
201 let processed_chunks: Vec<_> = chunks
203 .into_par_iter()
204 .map(|mut chunk| {
205 self.execute_parallel(circuit, &mut chunk)?;
206 Ok(chunk)
207 })
208 .collect::<QuantRS2Result<Vec<_>>>()?;
209
210 let merged = super::merge_batches(processed_chunks, batch.config.clone())?;
212 batch.states = merged.states;
213
214 Ok(())
215 }
216
217 fn execute_parallel(
219 &self,
220 circuit: &BatchCircuit,
221 batch: &mut BatchStateVector,
222 ) -> QuantRS2Result<()> {
223 let _batch_size = batch.batch_size();
224 let gate_sequence: Vec<_> = circuit.gate_sequence().collect();
225 let gate_refs: Vec<&dyn GateOp> = gate_sequence.iter().map(|g| g.as_ref()).collect();
226
227 if let Some(thread_pool) = &self.thread_pool {
228 self.execute_with_thread_pool(batch, &gate_refs, thread_pool)?;
230 } else {
231 batch
233 .states
234 .axis_iter_mut(ndarray::Axis(0))
235 .into_par_iter()
236 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
237 let mut state = state_row.to_owned();
238 apply_gates_to_state(&mut state, &gate_refs, batch.n_qubits)?;
239 state_row.assign(&state);
240 Ok(())
241 })?;
242 }
243
244 Ok(())
245 }
246
247 fn execute_with_thread_pool(
249 &self,
250 batch: &mut BatchStateVector,
251 gates: &[&dyn GateOp],
252 _thread_pool: &ThreadPool,
253 ) -> QuantRS2Result<()> {
254 let batch_size = batch.batch_size();
256 let n_qubits = batch.n_qubits;
257
258 let results: Vec<Array1<Complex64>> = (0..batch_size)
260 .into_par_iter()
261 .map(|i| {
262 let mut state = batch.states.row(i).to_owned();
263 apply_gates_to_state(&mut state, gates, n_qubits).map(|_| state)
264 })
265 .collect::<QuantRS2Result<Vec<_>>>()?;
266
267 for (i, state) in results.into_iter().enumerate() {
269 batch.states.row_mut(i).assign(&state);
270 }
271
272 Ok(())
273 }
274
275 pub fn execute_multiple_circuits(
277 &self,
278 circuits: &[BatchCircuit],
279 initial_batch: &BatchStateVector,
280 ) -> QuantRS2Result<Vec<BatchExecutionResult>> {
281 if circuits.is_empty() {
282 return Ok(Vec::new());
283 }
284
285 let results: Vec<_> = circuits
287 .par_iter()
288 .map(|circuit| {
289 let mut batch_copy = BatchStateVector::from_states(
290 initial_batch.states.clone(),
291 initial_batch.config.clone(),
292 )?;
293
294 self.execute_batch(circuit, &mut batch_copy)
295 })
296 .collect::<QuantRS2Result<Vec<_>>>()?;
297
298 Ok(results)
299 }
300
301 pub fn execute_parameterized_batch(
303 &self,
304 circuit_fn: impl Fn(&[f64]) -> QuantRS2Result<BatchCircuit> + Sync,
305 parameter_sets: &[Vec<f64>],
306 initial_state: &Array1<Complex64>,
307 ) -> QuantRS2Result<Vec<Array1<Complex64>>> {
308 let batch_size = parameter_sets.len();
310 let mut states = Array2::zeros((batch_size, initial_state.len()));
311 for i in 0..batch_size {
312 states.row_mut(i).assign(initial_state);
313 }
314
315 let batch = BatchStateVector::from_states(states, self.config.clone())?;
316
317 let results: Vec<_> = parameter_sets
319 .par_iter()
320 .enumerate()
321 .map(|(i, params)| {
322 let circuit = circuit_fn(params)?;
323 let mut state = batch.get_state(i)?;
324 let gate_sequence: Vec<_> = circuit.gate_sequence().collect();
325 let gate_refs: Vec<&dyn GateOp> =
326 gate_sequence.iter().map(|g| g.as_ref()).collect();
327 apply_gates_to_state(&mut state, &gate_refs, circuit.n_qubits)?;
328 Ok(state)
329 })
330 .collect::<QuantRS2Result<Vec<_>>>()?;
331
332 Ok(results)
333 }
334}
335
336fn apply_gates_to_state(
338 state: &mut Array1<Complex64>,
339 gates: &[&dyn GateOp],
340 n_qubits: usize,
341) -> QuantRS2Result<()> {
342 for gate in gates {
343 let qubits = gate.qubits();
344 let matrix = gate.matrix()?;
345
346 match qubits.len() {
347 1 => {
348 apply_single_qubit_gate(state, &matrix, qubits[0], n_qubits)?;
349 }
350 2 => {
351 apply_two_qubit_gate(state, &matrix, qubits[0], qubits[1], n_qubits)?;
352 }
353 _ => {
354 return Err(QuantRS2Error::InvalidInput(format!(
355 "Gates with {} qubits not yet supported",
356 qubits.len()
357 )));
358 }
359 }
360 }
361
362 Ok(())
363}
364
365fn apply_single_qubit_gate(
367 state: &mut Array1<Complex64>,
368 matrix: &[Complex64],
369 target: QubitId,
370 n_qubits: usize,
371) -> QuantRS2Result<()> {
372 let target_idx = target.0 as usize;
373 let state_size = 1 << n_qubits;
374 let target_mask = 1 << target_idx;
375
376 for i in 0..state_size {
377 if i & target_mask == 0 {
378 let j = i | target_mask;
379
380 let a = state[i];
381 let b = state[j];
382
383 state[i] = matrix[0] * a + matrix[1] * b;
384 state[j] = matrix[2] * a + matrix[3] * b;
385 }
386 }
387
388 Ok(())
389}
390
391fn apply_two_qubit_gate(
393 state: &mut Array1<Complex64>,
394 matrix: &[Complex64],
395 control: QubitId,
396 target: QubitId,
397 n_qubits: usize,
398) -> QuantRS2Result<()> {
399 let control_idx = control.0 as usize;
400 let target_idx = target.0 as usize;
401 let state_size = 1 << n_qubits;
402 let control_mask = 1 << control_idx;
403 let target_mask = 1 << target_idx;
404
405 for i in 0..state_size {
406 if (i & control_mask == 0) && (i & target_mask == 0) {
407 let i00 = i;
408 let i01 = i | target_mask;
409 let i10 = i | control_mask;
410 let i11 = i | control_mask | target_mask;
411
412 let a00 = state[i00];
413 let a01 = state[i01];
414 let a10 = state[i10];
415 let a11 = state[i11];
416
417 state[i00] = matrix[0] * a00 + matrix[1] * a01 + matrix[2] * a10 + matrix[3] * a11;
418 state[i01] = matrix[4] * a00 + matrix[5] * a01 + matrix[6] * a10 + matrix[7] * a11;
419 state[i10] = matrix[8] * a00 + matrix[9] * a01 + matrix[10] * a10 + matrix[11] * a11;
420 state[i11] = matrix[12] * a00 + matrix[13] * a01 + matrix[14] * a10 + matrix[15] * a11;
421 }
422 }
423
424 Ok(())
425}
426
427pub fn create_optimized_executor() -> QuantRS2Result<BatchCircuitExecutor> {
429 let config = BatchConfig {
430 num_workers: Some(8), max_batch_size: 1024,
432 use_gpu: true,
433 memory_limit: Some(8 * 1024 * 1024 * 1024), enable_cache: true,
435 };
436
437 BatchCircuitExecutor::new(config)
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use crate::gate::single::Hadamard;
444
445 #[test]
446 fn test_batch_circuit_execution() {
447 let config = BatchConfig {
448 use_gpu: false,
449 ..Default::default()
450 };
451
452 let executor = BatchCircuitExecutor::new(config).unwrap();
453
454 let mut circuit = BatchCircuit::new(2);
456 circuit
457 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
458 .unwrap();
459 circuit
460 .add_gate(Box::new(Hadamard { target: QubitId(1) }))
461 .unwrap();
462
463 let mut batch = BatchStateVector::new(5, 2, Default::default()).unwrap();
465
466 let result = executor.execute_batch(&circuit, &mut batch).unwrap();
468
469 assert_eq!(result.gates_applied, 2);
470 assert!(!result.used_gpu);
471
472 for i in 0..5 {
474 let state = batch.get_state(i).unwrap();
475 assert!((state[0].re - 0.5).abs() < 1e-10);
476 }
477 }
478
479 #[test]
480 fn test_parallel_circuit_execution() {
481 let config = BatchConfig {
482 num_workers: Some(2),
483 use_gpu: false,
484 ..Default::default()
485 };
486
487 let executor = BatchCircuitExecutor::new(config).unwrap();
488
489 let mut circuits = Vec::new();
491 for _ in 0..3 {
492 let mut circuit = BatchCircuit::new(1);
493 circuit
494 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
495 .unwrap();
496 circuits.push(circuit);
497 }
498
499 let batch = BatchStateVector::new(10, 1, Default::default()).unwrap();
501
502 let results = executor
504 .execute_multiple_circuits(&circuits, &batch)
505 .unwrap();
506
507 assert_eq!(results.len(), 3);
508 for result in results {
509 assert_eq!(result.gates_applied, 1);
510 }
511 }
512}