1use super::{BatchConfig, BatchExecutionResult, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 gate::GateOp,
7 gpu::{GpuBackendFactory, GpuStateVector},
8 qubit::QubitId,
9};
10use ndarray::{Array1, Array2};
11use num_complex::Complex64;
12use rayon::prelude::*;
13use std::sync::Arc;
14use std::time::Instant;
15
16extern crate scirs2_core;
18use scirs2_core::parallel::{SchedulerConfig, WorkStealingScheduler};
19
20pub struct BatchCircuit {
22 pub n_qubits: usize,
24 pub gates: Vec<Box<dyn GateOp>>,
26}
27
28impl BatchCircuit {
29 pub fn new(n_qubits: usize) -> Self {
31 Self {
32 n_qubits,
33 gates: Vec::new(),
34 }
35 }
36
37 pub fn add_gate(&mut self, gate: Box<dyn GateOp>) -> QuantRS2Result<()> {
39 for qubit in gate.qubits() {
41 if qubit.0 as usize >= self.n_qubits {
42 return Err(QuantRS2Error::InvalidQubitId(qubit.0));
43 }
44 }
45 self.gates.push(gate);
46 Ok(())
47 }
48
49 pub fn gate_sequence(&self) -> impl Iterator<Item = &Box<dyn GateOp>> {
51 self.gates.iter()
52 }
53
54 pub fn num_gates(&self) -> usize {
56 self.gates.len()
57 }
58}
59
60pub struct BatchCircuitExecutor {
62 pub config: BatchConfig,
64 pub gpu_backend: Option<Arc<dyn crate::gpu::GpuBackend>>,
66 pub scheduler: Option<WorkStealingScheduler>,
68}
69
70impl BatchCircuitExecutor {
71 pub fn new(config: BatchConfig) -> QuantRS2Result<Self> {
73 let gpu_backend = if config.use_gpu {
75 GpuBackendFactory::create_best_available().ok()
76 } else {
77 None
78 };
79
80 let scheduler = None; Ok(Self {
84 config,
85 gpu_backend,
86 scheduler,
87 })
88 }
89
90 pub fn execute_batch(
92 &self,
93 circuit: &BatchCircuit,
94 batch: &mut BatchStateVector,
95 ) -> QuantRS2Result<BatchExecutionResult> {
96 if batch.n_qubits != circuit.n_qubits {
97 return Err(QuantRS2Error::InvalidInput(format!(
98 "Circuit has {} qubits but batch has {}",
99 circuit.n_qubits, batch.n_qubits
100 )));
101 }
102
103 let start_time = Instant::now();
104 let gates_applied = circuit.num_gates();
105
106 let used_gpu = if self.gpu_backend.is_some() && batch.batch_size() >= 64 {
108 self.execute_with_gpu(circuit, batch)?;
109 true
110 } else if batch.batch_size() > self.config.max_batch_size {
111 self.execute_chunked(circuit, batch)?;
112 false
113 } else {
114 self.execute_parallel(circuit, batch)?;
115 false
116 };
117
118 let execution_time_ms = start_time.elapsed().as_secs_f64() * 1000.0;
119
120 Ok(BatchExecutionResult {
121 final_states: batch.states.clone(),
122 execution_time_ms,
123 gates_applied,
124 used_gpu,
125 })
126 }
127
128 fn execute_with_gpu(
130 &self,
131 circuit: &BatchCircuit,
132 batch: &mut BatchStateVector,
133 ) -> QuantRS2Result<()> {
134 self.execute_parallel(circuit, batch)
138 }
139
140 fn execute_chunked(
142 &self,
143 circuit: &BatchCircuit,
144 batch: &mut BatchStateVector,
145 ) -> QuantRS2Result<()> {
146 let chunk_size = self.config.max_batch_size;
147 let chunks = super::split_batch(batch, chunk_size);
148
149 let processed_chunks: Vec<_> = chunks
151 .into_par_iter()
152 .map(|mut chunk| {
153 self.execute_parallel(circuit, &mut chunk)?;
154 Ok(chunk)
155 })
156 .collect::<QuantRS2Result<Vec<_>>>()?;
157
158 let merged = super::merge_batches(processed_chunks, batch.config.clone())?;
160 batch.states = merged.states;
161
162 Ok(())
163 }
164
165 fn execute_parallel(
167 &self,
168 circuit: &BatchCircuit,
169 batch: &mut BatchStateVector,
170 ) -> QuantRS2Result<()> {
171 let batch_size = batch.batch_size();
172 let gate_sequence: Vec<_> = circuit.gate_sequence().collect();
173
174 if let Some(scheduler) = &self.scheduler {
175 self.execute_with_scheduler(batch, &gate_sequence, scheduler)?;
177 } else {
178 batch
180 .states
181 .axis_iter_mut(ndarray::Axis(0))
182 .into_par_iter()
183 .try_for_each(|mut state_row| -> QuantRS2Result<()> {
184 let mut state = state_row.to_owned();
185 apply_gates_to_state(&mut state, &gate_sequence, batch.n_qubits)?;
186 state_row.assign(&state);
187 Ok(())
188 })?;
189 }
190
191 Ok(())
192 }
193
194 fn execute_with_scheduler(
196 &self,
197 batch: &mut BatchStateVector,
198 gates: &[&Box<dyn GateOp>],
199 scheduler: &WorkStealingScheduler,
200 ) -> QuantRS2Result<()> {
201 let batch_size = batch.batch_size();
203 let n_qubits = batch.n_qubits;
204
205 let results: Vec<Array1<Complex64>> = (0..batch_size)
207 .into_par_iter()
208 .map(|i| {
209 let mut state = batch.states.row(i).to_owned();
210 apply_gates_to_state(&mut state, gates, n_qubits).map(|_| state)
211 })
212 .collect::<QuantRS2Result<Vec<_>>>()?;
213
214 for (i, state) in results.into_iter().enumerate() {
216 batch.states.row_mut(i).assign(&state);
217 }
218
219 Ok(())
220 }
221
222 pub fn execute_multiple_circuits(
224 &self,
225 circuits: &[BatchCircuit],
226 initial_batch: &BatchStateVector,
227 ) -> QuantRS2Result<Vec<BatchExecutionResult>> {
228 if circuits.is_empty() {
229 return Ok(Vec::new());
230 }
231
232 let results: Vec<_> = circuits
234 .par_iter()
235 .map(|circuit| {
236 let mut batch_copy = BatchStateVector::from_states(
237 initial_batch.states.clone(),
238 initial_batch.config.clone(),
239 )?;
240
241 self.execute_batch(circuit, &mut batch_copy)
242 })
243 .collect::<QuantRS2Result<Vec<_>>>()?;
244
245 Ok(results)
246 }
247
248 pub fn execute_parameterized_batch(
250 &self,
251 circuit_fn: impl Fn(&[f64]) -> QuantRS2Result<BatchCircuit> + Sync,
252 parameter_sets: &[Vec<f64>],
253 initial_state: &Array1<Complex64>,
254 ) -> QuantRS2Result<Vec<Array1<Complex64>>> {
255 let batch_size = parameter_sets.len();
257 let mut states = Array2::zeros((batch_size, initial_state.len()));
258 for i in 0..batch_size {
259 states.row_mut(i).assign(initial_state);
260 }
261
262 let mut batch = BatchStateVector::from_states(states, self.config.clone())?;
263
264 let results: Vec<_> = parameter_sets
266 .par_iter()
267 .enumerate()
268 .map(|(i, params)| {
269 let circuit = circuit_fn(params)?;
270 let mut state = batch.get_state(i)?;
271 apply_gates_to_state(
272 &mut state,
273 &circuit.gate_sequence().collect::<Vec<_>>(),
274 circuit.n_qubits,
275 )?;
276 Ok(state)
277 })
278 .collect::<QuantRS2Result<Vec<_>>>()?;
279
280 Ok(results)
281 }
282}
283
284fn apply_gates_to_state(
286 state: &mut Array1<Complex64>,
287 gates: &[&Box<dyn GateOp>],
288 n_qubits: usize,
289) -> QuantRS2Result<()> {
290 for gate in gates {
291 let qubits = gate.qubits();
292 let matrix = gate.matrix()?;
293
294 match qubits.len() {
295 1 => {
296 apply_single_qubit_gate(state, &matrix, qubits[0], n_qubits)?;
297 }
298 2 => {
299 apply_two_qubit_gate(state, &matrix, qubits[0], qubits[1], n_qubits)?;
300 }
301 _ => {
302 return Err(QuantRS2Error::InvalidInput(format!(
303 "Gates with {} qubits not yet supported",
304 qubits.len()
305 )));
306 }
307 }
308 }
309
310 Ok(())
311}
312
313fn apply_single_qubit_gate(
315 state: &mut Array1<Complex64>,
316 matrix: &[Complex64],
317 target: QubitId,
318 n_qubits: usize,
319) -> QuantRS2Result<()> {
320 let target_idx = target.0 as usize;
321 let state_size = 1 << n_qubits;
322 let target_mask = 1 << target_idx;
323
324 for i in 0..state_size {
325 if i & target_mask == 0 {
326 let j = i | target_mask;
327
328 let a = state[i];
329 let b = state[j];
330
331 state[i] = matrix[0] * a + matrix[1] * b;
332 state[j] = matrix[2] * a + matrix[3] * b;
333 }
334 }
335
336 Ok(())
337}
338
339fn apply_two_qubit_gate(
341 state: &mut Array1<Complex64>,
342 matrix: &[Complex64],
343 control: QubitId,
344 target: QubitId,
345 n_qubits: usize,
346) -> QuantRS2Result<()> {
347 let control_idx = control.0 as usize;
348 let target_idx = target.0 as usize;
349 let state_size = 1 << n_qubits;
350 let control_mask = 1 << control_idx;
351 let target_mask = 1 << target_idx;
352
353 for i in 0..state_size {
354 if (i & control_mask == 0) && (i & target_mask == 0) {
355 let i00 = i;
356 let i01 = i | target_mask;
357 let i10 = i | control_mask;
358 let i11 = i | control_mask | target_mask;
359
360 let a00 = state[i00];
361 let a01 = state[i01];
362 let a10 = state[i10];
363 let a11 = state[i11];
364
365 state[i00] = matrix[0] * a00 + matrix[1] * a01 + matrix[2] * a10 + matrix[3] * a11;
366 state[i01] = matrix[4] * a00 + matrix[5] * a01 + matrix[6] * a10 + matrix[7] * a11;
367 state[i10] = matrix[8] * a00 + matrix[9] * a01 + matrix[10] * a10 + matrix[11] * a11;
368 state[i11] = matrix[12] * a00 + matrix[13] * a01 + matrix[14] * a10 + matrix[15] * a11;
369 }
370 }
371
372 Ok(())
373}
374
375pub fn create_optimized_executor() -> QuantRS2Result<BatchCircuitExecutor> {
377 let config = BatchConfig {
378 num_workers: Some(8), max_batch_size: 1024,
380 use_gpu: true,
381 memory_limit: Some(8 * 1024 * 1024 * 1024), enable_cache: true,
383 };
384
385 BatchCircuitExecutor::new(config)
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use crate::gate::single::Hadamard;
392
393 #[test]
394 fn test_batch_circuit_execution() {
395 let config = BatchConfig {
396 use_gpu: false,
397 ..Default::default()
398 };
399
400 let executor = BatchCircuitExecutor::new(config).unwrap();
401
402 let mut circuit = BatchCircuit::new(2);
404 circuit
405 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
406 .unwrap();
407 circuit
408 .add_gate(Box::new(Hadamard { target: QubitId(1) }))
409 .unwrap();
410
411 let mut batch = BatchStateVector::new(5, 2, Default::default()).unwrap();
413
414 let result = executor.execute_batch(&circuit, &mut batch).unwrap();
416
417 assert_eq!(result.gates_applied, 2);
418 assert!(!result.used_gpu);
419
420 for i in 0..5 {
422 let state = batch.get_state(i).unwrap();
423 assert!((state[0].re - 0.5).abs() < 1e-10);
424 }
425 }
426
427 #[test]
428 fn test_parallel_circuit_execution() {
429 let config = BatchConfig {
430 num_workers: Some(2),
431 use_gpu: false,
432 ..Default::default()
433 };
434
435 let executor = BatchCircuitExecutor::new(config).unwrap();
436
437 let mut circuits = Vec::new();
439 for _ in 0..3 {
440 let mut circuit = BatchCircuit::new(1);
441 circuit
442 .add_gate(Box::new(Hadamard { target: QubitId(0) }))
443 .unwrap();
444 circuits.push(circuit);
445 }
446
447 let batch = BatchStateVector::new(10, 1, Default::default()).unwrap();
449
450 let results = executor
452 .execute_multiple_circuits(&circuits, &batch)
453 .unwrap();
454
455 assert_eq!(results.len(), 3);
456 for result in results {
457 assert_eq!(result.gates_applied, 1);
458 }
459 }
460}