1use super::{BatchConfig, BatchExecutionResult, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 gate::GateOp,
7 gpu::GpuBackendFactory,
8 qubit::QubitId,
9};
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::Complex64;
12use crate::parallel_ops_stubs::*;
14use std::sync::Arc;
15use std::time::Instant;
16
17pub struct BatchCircuit {
19 pub n_qubits: usize,
21 pub gates: Vec<Box<dyn GateOp>>,
23}
24
25impl BatchCircuit {
26 pub fn new(n_qubits: usize) -> Self {
28 Self {
29 n_qubits,
30 gates: Vec::new(),
31 }
32 }
33
34 pub fn add_gate(&mut self, gate: Box<dyn GateOp>) -> QuantRS2Result<()> {
36 for qubit in gate.qubits() {
38 if qubit.0 as usize >= self.n_qubits {
39 return Err(QuantRS2Error::InvalidQubitId(qubit.0));
40 }
41 }
42 self.gates.push(gate);
43 Ok(())
44 }
45
46 pub fn gate_sequence(&self) -> impl Iterator<Item = &Box<dyn GateOp>> {
48 self.gates.iter()
49 }
50
51 pub fn num_gates(&self) -> usize {
53 self.gates.len()
54 }
55}
56
57pub struct BatchCircuitExecutor {
59 pub config: BatchConfig,
61 pub gpu_backend: Option<Arc<dyn crate::gpu::GpuBackend>>,
63 }
67
68impl BatchCircuitExecutor {
69 pub fn new(config: BatchConfig) -> QuantRS2Result<Self> {
71 let gpu_backend = if config.use_gpu {
73 GpuBackendFactory::create_best_available().ok()
74 } else {
75 None
76 };
77
78 Ok(Self {
96 config,
97 gpu_backend,
98 })
100 }
101
102 pub fn execute_batch(
104 &self,
105 circuit: &BatchCircuit,
106 batch: &mut BatchStateVector,
107 ) -> QuantRS2Result<BatchExecutionResult> {
108 if batch.n_qubits != circuit.n_qubits {
109 return Err(QuantRS2Error::InvalidInput(format!(
110 "Circuit has {} qubits but batch has {}",
111 circuit.n_qubits, batch.n_qubits
112 )));
113 }
114
115 let start_time = Instant::now();
116 let gates_applied = circuit.num_gates();
117
118 let used_gpu = if self.gpu_backend.is_some() && batch.batch_size() >= 64 {
120 self.execute_with_gpu(circuit, batch)?;
121 true
122 } else if batch.batch_size() > self.config.max_batch_size {
123 self.execute_chunked(circuit, batch)?;
124 false
125 } else {
126 self.execute_parallel(circuit, batch)?;
127 false
128 };
129
130 let execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
131
132 Ok(BatchExecutionResult {
133 final_states: batch.states.clone(),
134 execution_time_ms,
135 gates_applied,
136 used_gpu,
137 })
138 }
139
140 fn execute_with_gpu(
142 &self,
143 circuit: &BatchCircuit,
144 batch: &mut BatchStateVector,
145 ) -> QuantRS2Result<()> {
146 if let Some(gpu_backend) = &self.gpu_backend {
147 let mut gpu_states = Vec::new();
149
150 for i in 0..batch.batch_size() {
151 let state_data = batch.get_state(i)?;
152 let mut gpu_buffer = gpu_backend.allocate_state_vector(batch.n_qubits)?;
154 let state_slice = state_data.as_slice().ok_or_else(|| {
155 QuantRS2Error::RuntimeError("Failed to get state data as slice".to_string())
156 })?;
157 gpu_buffer.upload(state_slice)?;
158
159 gpu_states.push(gpu_buffer);
160 }
161
162 for gate in circuit.gate_sequence() {
164 let gate_qubits = gate.qubits();
165
166 for gpu_state in &mut gpu_states {
168 gpu_backend.apply_gate(
169 gpu_state.as_mut(),
170 gate.as_ref(),
171 &gate_qubits,
172 batch.n_qubits,
173 )?;
174 }
175 }
176
177 for (i, gpu_state) in gpu_states.iter().enumerate() {
179 let state_size = 1 << batch.n_qubits;
180 let mut result_data = vec![Complex64::new(0.0, 0.0); state_size];
181 gpu_state.download(&mut result_data)?;
182
183 let result_array = Array1::from_vec(result_data);
184 batch.set_state(i, &result_array)?;
185 }
186
187 Ok(())
188 } else {
189 self.execute_parallel(circuit, batch)
191 }
192 }
193
194 fn execute_chunked(
196 &self,
197 circuit: &BatchCircuit,
198 batch: &mut BatchStateVector,
199 ) -> QuantRS2Result<()> {
200 let chunk_size = self.config.max_batch_size;
201 let chunks = super::split_batch(batch, chunk_size);
202
203 let processed_chunks: Vec<_> = chunks
205 .into_par_iter()
206 .map(|mut chunk| {
207 self.execute_parallel(circuit, &mut chunk)?;
208 Ok(chunk)
209 })
210 .collect::<QuantRS2Result<Vec<_>>>()?;
211
212 let merged = super::merge_batches(processed_chunks, batch.config.clone())?;
214 batch.states = merged.states;
215
216 Ok(())
217 }
218
219 fn execute_parallel(
221 &self,
222 circuit: &BatchCircuit,
223 batch: &mut BatchStateVector,
224 ) -> QuantRS2Result<()> {
225 let _batch_size = batch.batch_size();
226 let gate_sequence: Vec<_> = circuit.gate_sequence().collect();
227 let gate_refs: Vec<&dyn GateOp> = gate_sequence.iter().map(|g| g.as_ref()).collect();
228
229 self.execute_with_thread_pool(batch, &gate_refs)?;
232
233 Ok(())
234 }
235
236 fn execute_with_thread_pool(
238 &self,
239 batch: &mut BatchStateVector,
240 gates: &[&dyn GateOp],
241 ) -> QuantRS2Result<()> {
242 let batch_size = batch.batch_size();
244 let n_qubits = batch.n_qubits;
245
246 let results: Vec<Array1<Complex64>> = (0..batch_size)
248 .into_par_iter()
249 .map(|i| {
250 let mut state = batch.states.row(i).to_owned();
251 apply_gates_to_state(&mut state, gates, n_qubits).map(|()| state)
252 })
253 .collect::<QuantRS2Result<Vec<_>>>()?;
254
255 for (i, state) in results.into_iter().enumerate() {
257 batch.states.row_mut(i).assign(&state);
258 }
259
260 Ok(())
261 }
262
263 pub fn execute_multiple_circuits(
265 &self,
266 circuits: &[BatchCircuit],
267 initial_batch: &BatchStateVector,
268 ) -> QuantRS2Result<Vec<BatchExecutionResult>> {
269 if circuits.is_empty() {
270 return Ok(Vec::new());
271 }
272
273 let results: Vec<_> = circuits
275 .par_iter()
276 .map(|circuit| {
277 let mut batch_copy = BatchStateVector::from_states(
278 initial_batch.states.clone(),
279 initial_batch.config.clone(),
280 )?;
281
282 self.execute_batch(circuit, &mut batch_copy)
283 })
284 .collect::<QuantRS2Result<Vec<_>>>()?;
285
286 Ok(results)
287 }
288
289 pub fn execute_parameterized_batch(
291 &self,
292 circuit_fn: impl Fn(&[f64]) -> QuantRS2Result<BatchCircuit> + Sync,
293 parameter_sets: &[Vec<f64>],
294 initial_state: &Array1<Complex64>,
295 ) -> QuantRS2Result<Vec<Array1<Complex64>>> {
296 let batch_size = parameter_sets.len();
298 let mut states = Array2::zeros((batch_size, initial_state.len()));
299 for i in 0..batch_size {
300 states.row_mut(i).assign(initial_state);
301 }
302
303 let batch = BatchStateVector::from_states(states, self.config.clone())?;
304
305 let results: Vec<_> = parameter_sets
307 .par_iter()
308 .enumerate()
309 .map(|(i, params)| {
310 let circuit = circuit_fn(params)?;
311 let mut state = batch.get_state(i)?;
312 let gate_sequence: Vec<_> = circuit.gate_sequence().collect();
313 let gate_refs: Vec<&dyn GateOp> =
314 gate_sequence.iter().map(|g| g.as_ref()).collect();
315 apply_gates_to_state(&mut state, &gate_refs, circuit.n_qubits)?;
316 Ok(state)
317 })
318 .collect::<QuantRS2Result<Vec<_>>>()?;
319
320 Ok(results)
321 }
322}
323
324fn apply_gates_to_state(
326 state: &mut Array1<Complex64>,
327 gates: &[&dyn GateOp],
328 n_qubits: usize,
329) -> QuantRS2Result<()> {
330 for gate in gates {
331 let qubits = gate.qubits();
332 let matrix = gate.matrix()?;
333
334 match qubits.len() {
335 1 => {
336 apply_single_qubit_gate(state, &matrix, qubits[0], n_qubits)?;
337 }
338 2 => {
339 apply_two_qubit_gate(state, &matrix, qubits[0], qubits[1], n_qubits)?;
340 }
341 _ => {
342 return Err(QuantRS2Error::InvalidInput(format!(
343 "Gates with {} qubits not yet supported",
344 qubits.len()
345 )));
346 }
347 }
348 }
349
350 Ok(())
351}
352
353fn apply_single_qubit_gate(
355 state: &mut Array1<Complex64>,
356 matrix: &[Complex64],
357 target: QubitId,
358 n_qubits: usize,
359) -> QuantRS2Result<()> {
360 let target_idx = target.0 as usize;
361 let state_size = 1 << n_qubits;
362 let target_mask = 1 << target_idx;
363
364 for i in 0..state_size {
365 if i & target_mask == 0 {
366 let j = i | target_mask;
367
368 let a = state[i];
369 let b = state[j];
370
371 state[i] = matrix[0] * a + matrix[1] * b;
372 state[j] = matrix[2] * a + matrix[3] * b;
373 }
374 }
375
376 Ok(())
377}
378
379fn apply_two_qubit_gate(
381 state: &mut Array1<Complex64>,
382 matrix: &[Complex64],
383 control: QubitId,
384 target: QubitId,
385 n_qubits: usize,
386) -> QuantRS2Result<()> {
387 let control_idx = control.0 as usize;
388 let target_idx = target.0 as usize;
389 let state_size = 1 << n_qubits;
390 let control_mask = 1 << control_idx;
391 let target_mask = 1 << target_idx;
392
393 for i in 0..state_size {
394 if (i & control_mask == 0) && (i & target_mask == 0) {
395 let i00 = i;
396 let i01 = i | target_mask;
397 let i10 = i | control_mask;
398 let i11 = i | control_mask | target_mask;
399
400 let a00 = state[i00];
401 let a01 = state[i01];
402 let a10 = state[i10];
403 let a11 = state[i11];
404
405 state[i00] = matrix[0] * a00 + matrix[1] * a01 + matrix[2] * a10 + matrix[3] * a11;
406 state[i01] = matrix[4] * a00 + matrix[5] * a01 + matrix[6] * a10 + matrix[7] * a11;
407 state[i10] = matrix[8] * a00 + matrix[9] * a01 + matrix[10] * a10 + matrix[11] * a11;
408 state[i11] = matrix[12] * a00 + matrix[13] * a01 + matrix[14] * a10 + matrix[15] * a11;
409 }
410 }
411
412 Ok(())
413}
414
415pub fn create_optimized_executor() -> QuantRS2Result<BatchCircuitExecutor> {
417 let config = BatchConfig {
418 num_workers: Some(8), max_batch_size: 1024,
420 use_gpu: true,
421 memory_limit: Some(8 * 1024 * 1024 * 1024), enable_cache: true,
423 };
424
425 BatchCircuitExecutor::new(config)
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use crate::gate::single::Hadamard;
432
433 #[test]
434 fn test_batch_circuit_execution() {
435 let config = BatchConfig {
436 use_gpu: false,
437 ..Default::default()
438 };
439
440 let executor =
441 BatchCircuitExecutor::new(config).expect("Failed to create batch circuit executor");
442
443 let mut circuit = BatchCircuit::new(2);
445 circuit
446 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
447 .expect("Failed to add Hadamard gate to qubit 0");
448 circuit
449 .add_gate(Box::new(Hadamard { target: QubitId(1) }))
450 .expect("Failed to add Hadamard gate to qubit 1");
451
452 let mut batch = BatchStateVector::new(5, 2, Default::default())
454 .expect("Failed to create batch state vector");
455
456 let result = executor
458 .execute_batch(&circuit, &mut batch)
459 .expect("Failed to execute batch circuit");
460
461 assert_eq!(result.gates_applied, 2);
462 assert!(!result.used_gpu);
463
464 for i in 0..5 {
466 let state = batch.get_state(i).expect("Failed to get batch state");
467 assert!((state[0].re - 0.5).abs() < 1e-10);
468 }
469 }
470
471 #[test]
472 fn test_parallel_circuit_execution() {
473 let config = BatchConfig {
474 num_workers: Some(2),
475 use_gpu: false,
476 ..Default::default()
477 };
478
479 let executor =
480 BatchCircuitExecutor::new(config).expect("Failed to create batch circuit executor");
481
482 let mut circuits = Vec::new();
484 for _ in 0..3 {
485 let mut circuit = BatchCircuit::new(1);
486 circuit
487 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
488 .expect("Failed to add Hadamard gate");
489 circuits.push(circuit);
490 }
491
492 let batch = BatchStateVector::new(10, 1, Default::default())
494 .expect("Failed to create batch state vector");
495
496 let results = executor
498 .execute_multiple_circuits(&circuits, &batch)
499 .expect("Failed to execute multiple circuits");
500
501 assert_eq!(results.len(), 3);
502 for result in results {
503 assert_eq!(result.gates_applied, 1);
504 }
505 }
506}