1use super::{BatchConfig, BatchExecutionResult, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 gate::GateOp,
7 gpu::GpuBackendFactory,
8 qubit::QubitId,
9};
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::Complex64;
12use crate::parallel_ops_stubs::*;
14use std::sync::Arc;
15use std::time::Instant;
16
17pub struct BatchCircuit {
19 pub n_qubits: usize,
21 pub gates: Vec<Box<dyn GateOp>>,
23}
24
25impl BatchCircuit {
26 pub fn new(n_qubits: usize) -> Self {
28 Self {
29 n_qubits,
30 gates: Vec::new(),
31 }
32 }
33
34 pub fn add_gate(&mut self, gate: Box<dyn GateOp>) -> QuantRS2Result<()> {
36 for qubit in gate.qubits() {
38 if qubit.0 as usize >= self.n_qubits {
39 return Err(QuantRS2Error::InvalidQubitId(qubit.0));
40 }
41 }
42 self.gates.push(gate);
43 Ok(())
44 }
45
46 pub fn gate_sequence(&self) -> impl Iterator<Item = &Box<dyn GateOp>> {
48 self.gates.iter()
49 }
50
51 pub fn num_gates(&self) -> usize {
53 self.gates.len()
54 }
55}
56
57pub struct BatchCircuitExecutor {
63 pub config: BatchConfig,
65 pub gpu_backend: Option<Arc<dyn crate::gpu::GpuBackend>>,
67}
68
69impl BatchCircuitExecutor {
70 pub fn new(config: BatchConfig) -> QuantRS2Result<Self> {
72 let gpu_backend = if config.use_gpu {
74 GpuBackendFactory::create_best_available().ok()
75 } else {
76 None
77 };
78
79 Ok(Self {
84 config,
85 gpu_backend,
86 })
87 }
88
89 pub fn execute_batch(
91 &self,
92 circuit: &BatchCircuit,
93 batch: &mut BatchStateVector,
94 ) -> QuantRS2Result<BatchExecutionResult> {
95 if batch.n_qubits != circuit.n_qubits {
96 return Err(QuantRS2Error::InvalidInput(format!(
97 "Circuit has {} qubits but batch has {}",
98 circuit.n_qubits, batch.n_qubits
99 )));
100 }
101
102 let start_time = Instant::now();
103 let gates_applied = circuit.num_gates();
104
105 let used_gpu = if self.gpu_backend.is_some() && batch.batch_size() >= 64 {
107 self.execute_with_gpu(circuit, batch)?;
108 true
109 } else if batch.batch_size() > self.config.max_batch_size {
110 self.execute_chunked(circuit, batch)?;
111 false
112 } else {
113 self.execute_parallel(circuit, batch)?;
114 false
115 };
116
117 let execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
118
119 Ok(BatchExecutionResult {
120 final_states: batch.states.clone(),
121 execution_time_ms,
122 gates_applied,
123 used_gpu,
124 })
125 }
126
127 fn execute_with_gpu(
129 &self,
130 circuit: &BatchCircuit,
131 batch: &mut BatchStateVector,
132 ) -> QuantRS2Result<()> {
133 if let Some(gpu_backend) = &self.gpu_backend {
134 let mut gpu_states = Vec::new();
136
137 for i in 0..batch.batch_size() {
138 let state_data = batch.get_state(i)?;
139 let mut gpu_buffer = gpu_backend.allocate_state_vector(batch.n_qubits)?;
141 let state_slice = state_data.as_slice().ok_or_else(|| {
142 QuantRS2Error::RuntimeError("Failed to get state data as slice".to_string())
143 })?;
144 gpu_buffer.upload(state_slice)?;
145
146 gpu_states.push(gpu_buffer);
147 }
148
149 for gate in circuit.gate_sequence() {
151 let gate_qubits = gate.qubits();
152
153 for gpu_state in &mut gpu_states {
155 gpu_backend.apply_gate(
156 gpu_state.as_mut(),
157 gate.as_ref(),
158 &gate_qubits,
159 batch.n_qubits,
160 )?;
161 }
162 }
163
164 for (i, gpu_state) in gpu_states.iter().enumerate() {
166 let state_size = 1 << batch.n_qubits;
167 let mut result_data = vec![Complex64::new(0.0, 0.0); state_size];
168 gpu_state.download(&mut result_data)?;
169
170 let result_array = Array1::from_vec(result_data);
171 batch.set_state(i, &result_array)?;
172 }
173
174 Ok(())
175 } else {
176 self.execute_parallel(circuit, batch)
178 }
179 }
180
181 fn execute_chunked(
183 &self,
184 circuit: &BatchCircuit,
185 batch: &mut BatchStateVector,
186 ) -> QuantRS2Result<()> {
187 let chunk_size = self.config.max_batch_size;
188 let chunks = super::split_batch(batch, chunk_size);
189
190 let processed_chunks: Vec<_> = chunks
192 .into_par_iter()
193 .map(|mut chunk| {
194 self.execute_parallel(circuit, &mut chunk)?;
195 Ok(chunk)
196 })
197 .collect::<QuantRS2Result<Vec<_>>>()?;
198
199 let merged = super::merge_batches(processed_chunks, batch.config.clone())?;
201 batch.states = merged.states;
202
203 Ok(())
204 }
205
206 fn execute_parallel(
208 &self,
209 circuit: &BatchCircuit,
210 batch: &mut BatchStateVector,
211 ) -> QuantRS2Result<()> {
212 let _batch_size = batch.batch_size();
213 let gate_sequence: Vec<_> = circuit.gate_sequence().collect();
214 let gate_refs: Vec<&dyn GateOp> = gate_sequence.iter().map(|g| g.as_ref()).collect();
215
216 self.execute_with_thread_pool(batch, &gate_refs)?;
219
220 Ok(())
221 }
222
223 fn execute_with_thread_pool(
225 &self,
226 batch: &mut BatchStateVector,
227 gates: &[&dyn GateOp],
228 ) -> QuantRS2Result<()> {
229 let batch_size = batch.batch_size();
231 let n_qubits = batch.n_qubits;
232
233 let results: Vec<Array1<Complex64>> = (0..batch_size)
235 .into_par_iter()
236 .map(|i| {
237 let mut state = batch.states.row(i).to_owned();
238 apply_gates_to_state(&mut state, gates, n_qubits).map(|()| state)
239 })
240 .collect::<QuantRS2Result<Vec<_>>>()?;
241
242 for (i, state) in results.into_iter().enumerate() {
244 batch.states.row_mut(i).assign(&state);
245 }
246
247 Ok(())
248 }
249
250 pub fn execute_multiple_circuits(
252 &self,
253 circuits: &[BatchCircuit],
254 initial_batch: &BatchStateVector,
255 ) -> QuantRS2Result<Vec<BatchExecutionResult>> {
256 if circuits.is_empty() {
257 return Ok(Vec::new());
258 }
259
260 let results: Vec<_> = circuits
262 .par_iter()
263 .map(|circuit| {
264 let mut batch_copy = BatchStateVector::from_states(
265 initial_batch.states.clone(),
266 initial_batch.config.clone(),
267 )?;
268
269 self.execute_batch(circuit, &mut batch_copy)
270 })
271 .collect::<QuantRS2Result<Vec<_>>>()?;
272
273 Ok(results)
274 }
275
276 pub fn execute_parameterized_batch(
278 &self,
279 circuit_fn: impl Fn(&[f64]) -> QuantRS2Result<BatchCircuit> + Sync,
280 parameter_sets: &[Vec<f64>],
281 initial_state: &Array1<Complex64>,
282 ) -> QuantRS2Result<Vec<Array1<Complex64>>> {
283 let batch_size = parameter_sets.len();
285 let mut states = Array2::zeros((batch_size, initial_state.len()));
286 for i in 0..batch_size {
287 states.row_mut(i).assign(initial_state);
288 }
289
290 let batch = BatchStateVector::from_states(states, self.config.clone())?;
291
292 let results: Vec<_> = parameter_sets
294 .par_iter()
295 .enumerate()
296 .map(|(i, params)| {
297 let circuit = circuit_fn(params)?;
298 let mut state = batch.get_state(i)?;
299 let gate_sequence: Vec<_> = circuit.gate_sequence().collect();
300 let gate_refs: Vec<&dyn GateOp> =
301 gate_sequence.iter().map(|g| g.as_ref()).collect();
302 apply_gates_to_state(&mut state, &gate_refs, circuit.n_qubits)?;
303 Ok(state)
304 })
305 .collect::<QuantRS2Result<Vec<_>>>()?;
306
307 Ok(results)
308 }
309}
310
311fn apply_gates_to_state(
313 state: &mut Array1<Complex64>,
314 gates: &[&dyn GateOp],
315 n_qubits: usize,
316) -> QuantRS2Result<()> {
317 for gate in gates {
318 let qubits = gate.qubits();
319 let matrix = gate.matrix()?;
320
321 match qubits.len() {
322 1 => {
323 apply_single_qubit_gate(state, &matrix, qubits[0], n_qubits)?;
324 }
325 2 => {
326 apply_two_qubit_gate(state, &matrix, qubits[0], qubits[1], n_qubits)?;
327 }
328 _ => {
329 return Err(QuantRS2Error::InvalidInput(format!(
330 "Gates with {} qubits not yet supported",
331 qubits.len()
332 )));
333 }
334 }
335 }
336
337 Ok(())
338}
339
340fn apply_single_qubit_gate(
342 state: &mut Array1<Complex64>,
343 matrix: &[Complex64],
344 target: QubitId,
345 n_qubits: usize,
346) -> QuantRS2Result<()> {
347 let target_idx = target.0 as usize;
348 let state_size = 1 << n_qubits;
349 let target_mask = 1 << target_idx;
350
351 for i in 0..state_size {
352 if i & target_mask == 0 {
353 let j = i | target_mask;
354
355 let a = state[i];
356 let b = state[j];
357
358 state[i] = matrix[0] * a + matrix[1] * b;
359 state[j] = matrix[2] * a + matrix[3] * b;
360 }
361 }
362
363 Ok(())
364}
365
366fn apply_two_qubit_gate(
368 state: &mut Array1<Complex64>,
369 matrix: &[Complex64],
370 control: QubitId,
371 target: QubitId,
372 n_qubits: usize,
373) -> QuantRS2Result<()> {
374 let control_idx = control.0 as usize;
375 let target_idx = target.0 as usize;
376 let state_size = 1 << n_qubits;
377 let control_mask = 1 << control_idx;
378 let target_mask = 1 << target_idx;
379
380 for i in 0..state_size {
381 if (i & control_mask == 0) && (i & target_mask == 0) {
382 let i00 = i;
383 let i01 = i | target_mask;
384 let i10 = i | control_mask;
385 let i11 = i | control_mask | target_mask;
386
387 let a00 = state[i00];
388 let a01 = state[i01];
389 let a10 = state[i10];
390 let a11 = state[i11];
391
392 state[i00] = matrix[0] * a00 + matrix[1] * a01 + matrix[2] * a10 + matrix[3] * a11;
393 state[i01] = matrix[4] * a00 + matrix[5] * a01 + matrix[6] * a10 + matrix[7] * a11;
394 state[i10] = matrix[8] * a00 + matrix[9] * a01 + matrix[10] * a10 + matrix[11] * a11;
395 state[i11] = matrix[12] * a00 + matrix[13] * a01 + matrix[14] * a10 + matrix[15] * a11;
396 }
397 }
398
399 Ok(())
400}
401
402pub fn create_optimized_executor() -> QuantRS2Result<BatchCircuitExecutor> {
404 let config = BatchConfig {
405 num_workers: Some(8), max_batch_size: 1024,
407 use_gpu: true,
408 memory_limit: Some(8 * 1024 * 1024 * 1024), enable_cache: true,
410 };
411
412 BatchCircuitExecutor::new(config)
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::gate::single::Hadamard;
419
420 #[test]
421 fn test_batch_circuit_execution() {
422 let config = BatchConfig {
423 use_gpu: false,
424 ..Default::default()
425 };
426
427 let executor =
428 BatchCircuitExecutor::new(config).expect("Failed to create batch circuit executor");
429
430 let mut circuit = BatchCircuit::new(2);
432 circuit
433 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
434 .expect("Failed to add Hadamard gate to qubit 0");
435 circuit
436 .add_gate(Box::new(Hadamard { target: QubitId(1) }))
437 .expect("Failed to add Hadamard gate to qubit 1");
438
439 let mut batch = BatchStateVector::new(5, 2, Default::default())
441 .expect("Failed to create batch state vector");
442
443 let result = executor
445 .execute_batch(&circuit, &mut batch)
446 .expect("Failed to execute batch circuit");
447
448 assert_eq!(result.gates_applied, 2);
449 assert!(!result.used_gpu);
450
451 for i in 0..5 {
453 let state = batch.get_state(i).expect("Failed to get batch state");
454 assert!((state[0].re - 0.5).abs() < 1e-10);
455 }
456 }
457
458 #[test]
459 fn test_parallel_circuit_execution() {
460 let config = BatchConfig {
461 num_workers: Some(2),
462 use_gpu: false,
463 ..Default::default()
464 };
465
466 let executor =
467 BatchCircuitExecutor::new(config).expect("Failed to create batch circuit executor");
468
469 let mut circuits = Vec::new();
471 for _ in 0..3 {
472 let mut circuit = BatchCircuit::new(1);
473 circuit
474 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
475 .expect("Failed to add Hadamard gate");
476 circuits.push(circuit);
477 }
478
479 let batch = BatchStateVector::new(10, 1, Default::default())
481 .expect("Failed to create batch state vector");
482
483 let results = executor
485 .execute_multiple_circuits(&circuits, &batch)
486 .expect("Failed to execute multiple circuits");
487
488 assert_eq!(results.len(), 3);
489 for result in results {
490 assert_eq!(result.gates_applied, 1);
491 }
492 }
493}