1pub mod execution;
7pub mod measurement;
8pub mod operations;
9pub mod optimization;
10
11use crate::{
12 error::{QuantRS2Error, QuantRS2Result},
13 gate::GateOp,
14 qubit::QubitId,
15};
16use scirs2_core::ndarray::{Array1, Array2};
17use scirs2_core::Complex64;
18
19#[derive(Debug, Clone)]
21pub struct BatchConfig {
22 pub num_workers: Option<usize>,
24 pub max_batch_size: usize,
26 pub use_gpu: bool,
28 pub memory_limit: Option<usize>,
30 pub enable_cache: bool,
32}
33
34impl Default for BatchConfig {
35 fn default() -> Self {
36 Self {
37 num_workers: None, max_batch_size: 1024,
39 use_gpu: true,
40 memory_limit: None,
41 enable_cache: true,
42 }
43 }
44}
45
46#[derive(Clone)]
48pub struct BatchStateVector {
49 pub states: Array2<Complex64>,
51 pub n_qubits: usize,
53 pub config: BatchConfig,
55}
56
57impl BatchStateVector {
58 pub fn new(batch_size: usize, n_qubits: usize, config: BatchConfig) -> QuantRS2Result<Self> {
60 let state_size = 1 << n_qubits;
61
62 if let Some(limit) = config.memory_limit {
64 let required_memory = batch_size * state_size * std::mem::size_of::<Complex64>();
65 if required_memory > limit {
66 return Err(QuantRS2Error::InvalidInput(format!(
67 "Batch requires {} bytes, limit is {}",
68 required_memory, limit
69 )));
70 }
71 }
72
73 let mut states = Array2::zeros((batch_size, state_size));
75 for i in 0..batch_size {
76 states[[i, 0]] = Complex64::new(1.0, 0.0);
77 }
78
79 Ok(Self {
80 states,
81 n_qubits,
82 config,
83 })
84 }
85
86 pub fn from_states(states: Array2<Complex64>, config: BatchConfig) -> QuantRS2Result<Self> {
88 let (_batch_size, state_size) = states.dim();
89
90 let n_qubits = (state_size as f64).log2().round() as usize;
92 if 1 << n_qubits != state_size {
93 return Err(QuantRS2Error::InvalidInput(
94 "State size must be a power of 2".to_string(),
95 ));
96 }
97
98 Ok(Self {
99 states,
100 n_qubits,
101 config,
102 })
103 }
104
105 pub fn batch_size(&self) -> usize {
107 self.states.nrows()
108 }
109
110 pub fn get_state(&self, index: usize) -> QuantRS2Result<Array1<Complex64>> {
112 if index >= self.batch_size() {
113 return Err(QuantRS2Error::InvalidInput(format!(
114 "Index {} out of bounds for batch size {}",
115 index,
116 self.batch_size()
117 )));
118 }
119
120 Ok(self.states.row(index).to_owned())
121 }
122
123 pub fn set_state(&mut self, index: usize, state: &Array1<Complex64>) -> QuantRS2Result<()> {
125 if index >= self.batch_size() {
126 return Err(QuantRS2Error::InvalidInput(format!(
127 "Index {} out of bounds for batch size {}",
128 index,
129 self.batch_size()
130 )));
131 }
132
133 if state.len() != self.states.ncols() {
134 return Err(QuantRS2Error::InvalidInput(format!(
135 "State size {} doesn't match expected {}",
136 state.len(),
137 self.states.ncols()
138 )));
139 }
140
141 self.states.row_mut(index).assign(state);
142 Ok(())
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct BatchExecutionResult {
149 pub final_states: Array2<Complex64>,
151 pub execution_time_ms: f64,
153 pub gates_applied: usize,
155 pub used_gpu: bool,
157}
158
159#[derive(Debug, Clone)]
161pub struct BatchMeasurementResult {
162 pub outcomes: Array2<u8>,
165 pub probabilities: Array2<f64>,
168 pub post_measurement_states: Option<Array2<Complex64>>,
170}
171
172pub trait BatchGateOp: GateOp {
174 fn apply_batch(
176 &self,
177 batch: &mut BatchStateVector,
178 target_qubits: &[QubitId],
179 ) -> QuantRS2Result<()>;
180
181 fn has_batch_optimization(&self) -> bool {
183 true
184 }
185}
186
187pub fn create_batch<I>(states: I, config: BatchConfig) -> QuantRS2Result<BatchStateVector>
189where
190 I: IntoIterator<Item = Array1<Complex64>>,
191{
192 let states_vec: Vec<_> = states.into_iter().collect();
193 if states_vec.is_empty() {
194 return Err(QuantRS2Error::InvalidInput(
195 "Cannot create empty batch".to_string(),
196 ));
197 }
198
199 let state_size = states_vec[0].len();
200 let batch_size = states_vec.len();
201
202 for (i, state) in states_vec.iter().enumerate() {
204 if state.len() != state_size {
205 return Err(QuantRS2Error::InvalidInput(format!(
206 "State {} has size {}, expected {}",
207 i,
208 state.len(),
209 state_size
210 )));
211 }
212 }
213
214 let mut batch_array = Array2::zeros((batch_size, state_size));
216 for (i, state) in states_vec.iter().enumerate() {
217 batch_array.row_mut(i).assign(state);
218 }
219
220 BatchStateVector::from_states(batch_array, config)
221}
222
223pub fn split_batch(batch: &BatchStateVector, chunk_size: usize) -> Vec<BatchStateVector> {
225 let mut chunks = Vec::new();
226 let batch_size = batch.batch_size();
227
228 for start in (0..batch_size).step_by(chunk_size) {
229 let end = (start + chunk_size).min(batch_size);
230 let chunk_states = batch
231 .states
232 .slice(scirs2_core::ndarray::s![start..end, ..])
233 .to_owned();
234
235 if let Ok(chunk) = BatchStateVector::from_states(chunk_states, batch.config.clone()) {
236 chunks.push(chunk);
237 }
238 }
239
240 chunks
241}
242
243pub fn merge_batches(
245 batches: Vec<BatchStateVector>,
246 config: BatchConfig,
247) -> QuantRS2Result<BatchStateVector> {
248 if batches.is_empty() {
249 return Err(QuantRS2Error::InvalidInput(
250 "Cannot merge empty batches".to_string(),
251 ));
252 }
253
254 let n_qubits = batches[0].n_qubits;
256 for (i, batch) in batches.iter().enumerate() {
257 if batch.n_qubits != n_qubits {
258 return Err(QuantRS2Error::InvalidInput(format!(
259 "Batch {} has {} qubits, expected {}",
260 i, batch.n_qubits, n_qubits
261 )));
262 }
263 }
264
265 let mut all_states = Vec::new();
267 for batch in batches {
268 for i in 0..batch.batch_size() {
269 all_states.push(batch.states.row(i).to_owned());
270 }
271 }
272
273 create_batch(all_states, config)
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_batch_creation() {
282 let batch = BatchStateVector::new(10, 3, BatchConfig::default()).unwrap();
283 assert_eq!(batch.batch_size(), 10);
284 assert_eq!(batch.n_qubits, 3);
285 assert_eq!(batch.states.ncols(), 8); for i in 0..10 {
289 let state = batch.get_state(i).unwrap();
290 assert_eq!(state[0], Complex64::new(1.0, 0.0));
291 for j in 1..8 {
292 assert_eq!(state[j], Complex64::new(0.0, 0.0));
293 }
294 }
295 }
296
297 #[test]
298 fn test_batch_from_states() {
299 let mut states = Array2::zeros((5, 4));
300 for i in 0..5 {
301 states[[i, i % 4]] = Complex64::new(1.0, 0.0);
302 }
303
304 let batch = BatchStateVector::from_states(states, BatchConfig::default()).unwrap();
305 assert_eq!(batch.batch_size(), 5);
306 assert_eq!(batch.n_qubits, 2); }
308
309 #[test]
310 fn test_create_batch_helper() {
311 let states = vec![
312 Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
313 Array1::from_vec(vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]),
314 Array1::from_vec(vec![Complex64::new(0.707, 0.0), Complex64::new(0.707, 0.0)]),
315 ];
316
317 let batch = create_batch(states, BatchConfig::default()).unwrap();
318 assert_eq!(batch.batch_size(), 3);
319 assert_eq!(batch.n_qubits, 1);
320 }
321
322 #[test]
323 fn test_split_batch() {
324 let batch = BatchStateVector::new(10, 2, BatchConfig::default()).unwrap();
325 let chunks = split_batch(&batch, 3);
326
327 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].batch_size(), 3);
329 assert_eq!(chunks[1].batch_size(), 3);
330 assert_eq!(chunks[2].batch_size(), 3);
331 assert_eq!(chunks[3].batch_size(), 1);
332 }
333
334 #[test]
335 fn test_merge_batches() {
336 let batch1 = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
337 let batch2 = BatchStateVector::new(2, 2, BatchConfig::default()).unwrap();
338
339 let merged = merge_batches(vec![batch1, batch2], BatchConfig::default()).unwrap();
340 assert_eq!(merged.batch_size(), 5);
341 assert_eq!(merged.n_qubits, 2);
342 }
343
344 #[test]
347 fn test_batch_memory_limit_enforcement() {
348 let mut config = BatchConfig::default();
349 config.memory_limit = Some(100);
351
352 let result = BatchStateVector::new(10, 5, config);
354 assert!(result.is_err());
355
356 if let Err(e) = result {
358 let msg = format!("{:?}", e);
359 assert!(msg.contains("bytes") || msg.contains("limit"));
360 }
361 }
362
363 #[test]
364 fn test_batch_state_normalization() {
365 let batch = BatchStateVector::new(5, 2, BatchConfig::default()).unwrap();
366
367 for i in 0..batch.batch_size() {
369 let state = batch.get_state(i).unwrap();
370 let norm: f64 = state.iter().map(|c| c.norm_sqr()).sum();
371 assert!(
372 (norm - 1.0).abs() < 1e-10,
373 "State {} not normalized: {}",
374 i,
375 norm
376 );
377 }
378 }
379
380 #[test]
381 fn test_batch_state_get_set_roundtrip() {
382 let mut batch = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
383
384 let custom_state = Array1::from_vec(vec![
386 Complex64::new(0.5, 0.0),
387 Complex64::new(0.5, 0.0),
388 Complex64::new(0.5, 0.0),
389 Complex64::new(0.5, 0.0),
390 ]);
391
392 batch.set_state(1, &custom_state).unwrap();
394 let retrieved = batch.get_state(1).unwrap();
395
396 for i in 0..4 {
398 assert!((retrieved[i] - custom_state[i]).norm() < 1e-10);
399 }
400 }
401
402 #[test]
403 fn test_batch_out_of_bounds_access() {
404 let batch = BatchStateVector::new(5, 2, BatchConfig::default()).unwrap();
405
406 assert!(batch.get_state(5).is_err());
408 assert!(batch.get_state(100).is_err());
409 }
410
411 #[test]
412 fn test_batch_set_wrong_size_state() {
413 let mut batch = BatchStateVector::new(5, 2, BatchConfig::default()).unwrap();
414
415 let wrong_state =
417 Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
418 assert!(batch.set_state(0, &wrong_state).is_err());
419 }
420
421 #[test]
422 fn test_empty_batch_creation_fails() {
423 let result = create_batch(Vec::<Array1<Complex64>>::new(), BatchConfig::default());
424 assert!(result.is_err());
425 }
426
427 #[test]
428 fn test_batch_mismatched_state_sizes() {
429 let states = vec![
430 Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
431 Array1::from_vec(vec![
432 Complex64::new(1.0, 0.0),
433 Complex64::new(0.0, 0.0),
434 Complex64::new(0.0, 0.0),
435 Complex64::new(0.0, 0.0),
436 ]),
437 ];
438
439 let result = create_batch(states, BatchConfig::default());
440 assert!(result.is_err());
441 }
442
443 #[test]
444 fn test_batch_invalid_state_size() {
445 let states = Array2::zeros((5, 3));
447 let result = BatchStateVector::from_states(states, BatchConfig::default());
448 assert!(result.is_err());
449 }
450
451 #[test]
452 fn test_split_batch_single_element() {
453 let batch = BatchStateVector::new(1, 2, BatchConfig::default()).unwrap();
454 let chunks = split_batch(&batch, 10);
455
456 assert_eq!(chunks.len(), 1);
457 assert_eq!(chunks[0].batch_size(), 1);
458 }
459
460 #[test]
461 fn test_split_batch_exact_division() {
462 let batch = BatchStateVector::new(9, 2, BatchConfig::default()).unwrap();
463 let chunks = split_batch(&batch, 3);
464
465 assert_eq!(chunks.len(), 3);
466 for chunk in &chunks {
467 assert_eq!(chunk.batch_size(), 3);
468 }
469 }
470
471 #[test]
472 fn test_merge_batches_empty() {
473 let result = merge_batches(Vec::new(), BatchConfig::default());
474 assert!(result.is_err());
475 }
476
477 #[test]
478 fn test_merge_batches_mismatched_qubits() {
479 let batch1 = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
480 let batch2 = BatchStateVector::new(2, 3, BatchConfig::default()).unwrap();
481
482 let result = merge_batches(vec![batch1, batch2], BatchConfig::default());
483 assert!(result.is_err());
484 }
485
486 #[test]
487 fn test_batch_config_defaults() {
488 let config = BatchConfig::default();
489 assert!(config.num_workers.is_none());
490 assert_eq!(config.max_batch_size, 1024);
491 assert!(config.use_gpu);
492 assert!(config.memory_limit.is_none());
493 assert!(config.enable_cache);
494 }
495
496 #[test]
497 fn test_large_batch_creation() {
498 let batch = BatchStateVector::new(100, 4, BatchConfig::default()).unwrap();
500 assert_eq!(batch.batch_size(), 100);
501 assert_eq!(batch.n_qubits, 4);
502 assert_eq!(batch.states.ncols(), 16); }
504
505 #[test]
506 fn test_batch_state_modification_isolation() {
507 let mut batch = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
508
509 let modified = Array1::from_vec(vec![
511 Complex64::new(0.0, 0.0),
512 Complex64::new(1.0, 0.0),
513 Complex64::new(0.0, 0.0),
514 Complex64::new(0.0, 0.0),
515 ]);
516 batch.set_state(1, &modified).unwrap();
517
518 let state0 = batch.get_state(0).unwrap();
520 let state2 = batch.get_state(2).unwrap();
521
522 assert_eq!(state0[0], Complex64::new(1.0, 0.0));
523 assert_eq!(state2[0], Complex64::new(1.0, 0.0));
524 }
525
526 #[test]
527 fn test_split_merge_roundtrip() {
528 let batch = BatchStateVector::new(10, 2, BatchConfig::default()).unwrap();
529 let original_states = batch.states.clone();
530
531 let chunks = split_batch(&batch, 3);
533 let merged = merge_batches(chunks, BatchConfig::default()).unwrap();
534
535 assert_eq!(merged.batch_size(), 10);
537 for i in 0..10 {
538 for j in 0..4 {
539 assert_eq!(merged.states[[i, j]], original_states[[i, j]]);
540 }
541 }
542 }
543
544 #[test]
545 fn test_batch_execution_result_fields() {
546 let result = BatchExecutionResult {
547 final_states: Array2::zeros((5, 4)),
548 execution_time_ms: 100.0,
549 gates_applied: 50,
550 used_gpu: false,
551 };
552
553 assert_eq!(result.execution_time_ms, 100.0);
554 assert_eq!(result.gates_applied, 50);
555 assert!(!result.used_gpu);
556 }
557
558 #[test]
559 fn test_batch_measurement_result_fields() {
560 use scirs2_core::ndarray::Array2;
561
562 let result = BatchMeasurementResult {
563 outcomes: Array2::zeros((5, 10)),
564 probabilities: Array2::zeros((5, 10)),
565 post_measurement_states: None,
566 };
567
568 assert_eq!(result.outcomes.dim(), (5, 10));
569 assert_eq!(result.probabilities.dim(), (5, 10));
570 assert!(result.post_measurement_states.is_none());
571 }
572}