1pub mod execution;
7pub mod measurement;
8pub mod operations;
9pub mod optimization;
10
11use crate::{
12 error::{QuantRS2Error, QuantRS2Result},
13 gate::GateOp,
14 qubit::QubitId,
15};
16use scirs2_core::ndarray::{Array1, Array2};
17use scirs2_core::Complex64;
18
19#[derive(Debug, Clone)]
21pub struct BatchConfig {
22 pub num_workers: Option<usize>,
24 pub max_batch_size: usize,
26 pub use_gpu: bool,
28 pub memory_limit: Option<usize>,
30 pub enable_cache: bool,
32}
33
34impl Default for BatchConfig {
35 fn default() -> Self {
36 Self {
37 num_workers: None, max_batch_size: 1024,
39 use_gpu: true,
40 memory_limit: None,
41 enable_cache: true,
42 }
43 }
44}
45
46#[derive(Clone)]
48pub struct BatchStateVector {
49 pub states: Array2<Complex64>,
51 pub n_qubits: usize,
53 pub config: BatchConfig,
55}
56
57impl BatchStateVector {
58 pub fn new(batch_size: usize, n_qubits: usize, config: BatchConfig) -> QuantRS2Result<Self> {
60 let state_size = 1 << n_qubits;
61
62 if let Some(limit) = config.memory_limit {
64 let required_memory = batch_size * state_size * std::mem::size_of::<Complex64>();
65 if required_memory > limit {
66 return Err(QuantRS2Error::InvalidInput(format!(
67 "Batch requires {required_memory} bytes, limit is {limit}"
68 )));
69 }
70 }
71
72 let mut states = Array2::zeros((batch_size, state_size));
74 for i in 0..batch_size {
75 states[[i, 0]] = Complex64::new(1.0, 0.0);
76 }
77
78 Ok(Self {
79 states,
80 n_qubits,
81 config,
82 })
83 }
84
85 pub fn from_states(states: Array2<Complex64>, config: BatchConfig) -> QuantRS2Result<Self> {
87 let (_batch_size, state_size) = states.dim();
88
89 let n_qubits = (state_size as f64).log2().round() as usize;
91 if 1 << n_qubits != state_size {
92 return Err(QuantRS2Error::InvalidInput(
93 "State size must be a power of 2".to_string(),
94 ));
95 }
96
97 Ok(Self {
98 states,
99 n_qubits,
100 config,
101 })
102 }
103
104 pub fn batch_size(&self) -> usize {
106 self.states.nrows()
107 }
108
109 pub fn get_state(&self, index: usize) -> QuantRS2Result<Array1<Complex64>> {
111 if index >= self.batch_size() {
112 return Err(QuantRS2Error::InvalidInput(format!(
113 "Index {} out of bounds for batch size {}",
114 index,
115 self.batch_size()
116 )));
117 }
118
119 Ok(self.states.row(index).to_owned())
120 }
121
122 pub fn set_state(&mut self, index: usize, state: &Array1<Complex64>) -> QuantRS2Result<()> {
124 if index >= self.batch_size() {
125 return Err(QuantRS2Error::InvalidInput(format!(
126 "Index {} out of bounds for batch size {}",
127 index,
128 self.batch_size()
129 )));
130 }
131
132 if state.len() != self.states.ncols() {
133 return Err(QuantRS2Error::InvalidInput(format!(
134 "State size {} doesn't match expected {}",
135 state.len(),
136 self.states.ncols()
137 )));
138 }
139
140 self.states.row_mut(index).assign(state);
141 Ok(())
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct BatchExecutionResult {
148 pub final_states: Array2<Complex64>,
150 pub execution_time_ms: f64,
152 pub gates_applied: usize,
154 pub used_gpu: bool,
156}
157
158#[derive(Debug, Clone)]
160pub struct BatchMeasurementResult {
161 pub outcomes: Array2<u8>,
164 pub probabilities: Array2<f64>,
167 pub post_measurement_states: Option<Array2<Complex64>>,
169}
170
171pub trait BatchGateOp: GateOp {
173 fn apply_batch(
175 &self,
176 batch: &mut BatchStateVector,
177 target_qubits: &[QubitId],
178 ) -> QuantRS2Result<()>;
179
180 fn has_batch_optimization(&self) -> bool {
182 true
183 }
184}
185
186pub fn create_batch<I>(states: I, config: BatchConfig) -> QuantRS2Result<BatchStateVector>
188where
189 I: IntoIterator<Item = Array1<Complex64>>,
190{
191 let states_vec: Vec<_> = states.into_iter().collect();
192 if states_vec.is_empty() {
193 return Err(QuantRS2Error::InvalidInput(
194 "Cannot create empty batch".to_string(),
195 ));
196 }
197
198 let state_size = states_vec[0].len();
199 let batch_size = states_vec.len();
200
201 for (i, state) in states_vec.iter().enumerate() {
203 if state.len() != state_size {
204 return Err(QuantRS2Error::InvalidInput(format!(
205 "State {} has size {}, expected {}",
206 i,
207 state.len(),
208 state_size
209 )));
210 }
211 }
212
213 let mut batch_array = Array2::zeros((batch_size, state_size));
215 for (i, state) in states_vec.iter().enumerate() {
216 batch_array.row_mut(i).assign(state);
217 }
218
219 BatchStateVector::from_states(batch_array, config)
220}
221
222pub fn split_batch(batch: &BatchStateVector, chunk_size: usize) -> Vec<BatchStateVector> {
224 let mut chunks = Vec::new();
225 let batch_size = batch.batch_size();
226
227 for start in (0..batch_size).step_by(chunk_size) {
228 let end = (start + chunk_size).min(batch_size);
229 let chunk_states = batch
230 .states
231 .slice(scirs2_core::ndarray::s![start..end, ..])
232 .to_owned();
233
234 if let Ok(chunk) = BatchStateVector::from_states(chunk_states, batch.config.clone()) {
235 chunks.push(chunk);
236 }
237 }
238
239 chunks
240}
241
242pub fn merge_batches(
244 batches: Vec<BatchStateVector>,
245 config: BatchConfig,
246) -> QuantRS2Result<BatchStateVector> {
247 if batches.is_empty() {
248 return Err(QuantRS2Error::InvalidInput(
249 "Cannot merge empty batches".to_string(),
250 ));
251 }
252
253 let n_qubits = batches[0].n_qubits;
255 for (i, batch) in batches.iter().enumerate() {
256 if batch.n_qubits != n_qubits {
257 return Err(QuantRS2Error::InvalidInput(format!(
258 "Batch {} has {} qubits, expected {}",
259 i, batch.n_qubits, n_qubits
260 )));
261 }
262 }
263
264 let mut all_states = Vec::new();
266 for batch in batches {
267 for i in 0..batch.batch_size() {
268 all_states.push(batch.states.row(i).to_owned());
269 }
270 }
271
272 create_batch(all_states, config)
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_batch_creation() {
281 let batch = BatchStateVector::new(10, 3, BatchConfig::default())
282 .expect("Failed to create batch state vector");
283 assert_eq!(batch.batch_size(), 10);
284 assert_eq!(batch.n_qubits, 3);
285 assert_eq!(batch.states.ncols(), 8); for i in 0..10 {
289 let state = batch.get_state(i).expect("Failed to get state from batch");
290 assert_eq!(state[0], Complex64::new(1.0, 0.0));
291 for j in 1..8 {
292 assert_eq!(state[j], Complex64::new(0.0, 0.0));
293 }
294 }
295 }
296
297 #[test]
298 fn test_batch_from_states() {
299 let mut states = Array2::zeros((5, 4));
300 for i in 0..5 {
301 states[[i, i % 4]] = Complex64::new(1.0, 0.0);
302 }
303
304 let batch = BatchStateVector::from_states(states, BatchConfig::default())
305 .expect("Failed to create batch from states");
306 assert_eq!(batch.batch_size(), 5);
307 assert_eq!(batch.n_qubits, 2); }
309
310 #[test]
311 fn test_create_batch_helper() {
312 let states = vec![
313 Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
314 Array1::from_vec(vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]),
315 Array1::from_vec(vec![Complex64::new(0.707, 0.0), Complex64::new(0.707, 0.0)]),
316 ];
317
318 let batch = create_batch(states, BatchConfig::default())
319 .expect("Failed to create batch from state collection");
320 assert_eq!(batch.batch_size(), 3);
321 assert_eq!(batch.n_qubits, 1);
322 }
323
324 #[test]
325 fn test_split_batch() {
326 let batch = BatchStateVector::new(10, 2, BatchConfig::default())
327 .expect("Failed to create batch for split test");
328 let chunks = split_batch(&batch, 3);
329
330 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].batch_size(), 3);
332 assert_eq!(chunks[1].batch_size(), 3);
333 assert_eq!(chunks[2].batch_size(), 3);
334 assert_eq!(chunks[3].batch_size(), 1);
335 }
336
337 #[test]
338 fn test_merge_batches() {
339 let batch1 = BatchStateVector::new(3, 2, BatchConfig::default())
340 .expect("Failed to create first batch");
341 let batch2 = BatchStateVector::new(2, 2, BatchConfig::default())
342 .expect("Failed to create second batch");
343
344 let merged = merge_batches(vec![batch1, batch2], BatchConfig::default())
345 .expect("Failed to merge batches");
346 assert_eq!(merged.batch_size(), 5);
347 assert_eq!(merged.n_qubits, 2);
348 }
349
350 #[test]
353 fn test_batch_memory_limit_enforcement() {
354 let mut config = BatchConfig::default();
355 config.memory_limit = Some(100);
357
358 let result = BatchStateVector::new(10, 5, config);
360 assert!(result.is_err());
361
362 if let Err(e) = result {
364 let msg = format!("{:?}", e);
365 assert!(msg.contains("bytes") || msg.contains("limit"));
366 }
367 }
368
369 #[test]
370 fn test_batch_state_normalization() {
371 let batch = BatchStateVector::new(5, 2, BatchConfig::default())
372 .expect("Failed to create batch for normalization test");
373
374 for i in 0..batch.batch_size() {
376 let state = batch
377 .get_state(i)
378 .expect("Failed to get state for normalization check");
379 let norm: f64 = state.iter().map(|c| c.norm_sqr()).sum();
380 assert!(
381 (norm - 1.0).abs() < 1e-10,
382 "State {} not normalized: {}",
383 i,
384 norm
385 );
386 }
387 }
388
389 #[test]
390 fn test_batch_state_get_set_roundtrip() {
391 let mut batch = BatchStateVector::new(3, 2, BatchConfig::default())
392 .expect("Failed to create batch for get/set test");
393
394 let custom_state = Array1::from_vec(vec![
396 Complex64::new(0.5, 0.0),
397 Complex64::new(0.5, 0.0),
398 Complex64::new(0.5, 0.0),
399 Complex64::new(0.5, 0.0),
400 ]);
401
402 batch
404 .set_state(1, &custom_state)
405 .expect("Failed to set custom state");
406 let retrieved = batch
407 .get_state(1)
408 .expect("Failed to retrieve state after set");
409
410 for i in 0..4 {
412 assert!((retrieved[i] - custom_state[i]).norm() < 1e-10);
413 }
414 }
415
416 #[test]
417 fn test_batch_out_of_bounds_access() {
418 let batch = BatchStateVector::new(5, 2, BatchConfig::default())
419 .expect("Failed to create batch for bounds test");
420
421 assert!(batch.get_state(5).is_err());
423 assert!(batch.get_state(100).is_err());
424 }
425
426 #[test]
427 fn test_batch_set_wrong_size_state() {
428 let mut batch = BatchStateVector::new(5, 2, BatchConfig::default())
429 .expect("Failed to create batch for wrong size test");
430
431 let wrong_state =
433 Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
434 assert!(batch.set_state(0, &wrong_state).is_err());
435 }
436
437 #[test]
438 fn test_empty_batch_creation_fails() {
439 let result = create_batch(Vec::<Array1<Complex64>>::new(), BatchConfig::default());
440 assert!(result.is_err());
441 }
442
443 #[test]
444 fn test_batch_mismatched_state_sizes() {
445 let states = vec![
446 Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
447 Array1::from_vec(vec![
448 Complex64::new(1.0, 0.0),
449 Complex64::new(0.0, 0.0),
450 Complex64::new(0.0, 0.0),
451 Complex64::new(0.0, 0.0),
452 ]),
453 ];
454
455 let result = create_batch(states, BatchConfig::default());
456 assert!(result.is_err());
457 }
458
459 #[test]
460 fn test_batch_invalid_state_size() {
461 let states = Array2::zeros((5, 3));
463 let result = BatchStateVector::from_states(states, BatchConfig::default());
464 assert!(result.is_err());
465 }
466
467 #[test]
468 fn test_split_batch_single_element() {
469 let batch = BatchStateVector::new(1, 2, BatchConfig::default())
470 .expect("Failed to create single element batch");
471 let chunks = split_batch(&batch, 10);
472
473 assert_eq!(chunks.len(), 1);
474 assert_eq!(chunks[0].batch_size(), 1);
475 }
476
477 #[test]
478 fn test_split_batch_exact_division() {
479 let batch = BatchStateVector::new(9, 2, BatchConfig::default())
480 .expect("Failed to create batch for exact division test");
481 let chunks = split_batch(&batch, 3);
482
483 assert_eq!(chunks.len(), 3);
484 for chunk in &chunks {
485 assert_eq!(chunk.batch_size(), 3);
486 }
487 }
488
489 #[test]
490 fn test_merge_batches_empty() {
491 let result = merge_batches(Vec::new(), BatchConfig::default());
492 assert!(result.is_err());
493 }
494
495 #[test]
496 fn test_merge_batches_mismatched_qubits() {
497 let batch1 = BatchStateVector::new(3, 2, BatchConfig::default())
498 .expect("Failed to create first batch with 2 qubits");
499 let batch2 = BatchStateVector::new(2, 3, BatchConfig::default())
500 .expect("Failed to create second batch with 3 qubits");
501
502 let result = merge_batches(vec![batch1, batch2], BatchConfig::default());
503 assert!(result.is_err());
504 }
505
506 #[test]
507 fn test_batch_config_defaults() {
508 let config = BatchConfig::default();
509 assert!(config.num_workers.is_none());
510 assert_eq!(config.max_batch_size, 1024);
511 assert!(config.use_gpu);
512 assert!(config.memory_limit.is_none());
513 assert!(config.enable_cache);
514 }
515
516 #[test]
517 fn test_large_batch_creation() {
518 let batch = BatchStateVector::new(100, 4, BatchConfig::default())
520 .expect("Failed to create large batch");
521 assert_eq!(batch.batch_size(), 100);
522 assert_eq!(batch.n_qubits, 4);
523 assert_eq!(batch.states.ncols(), 16); }
525
526 #[test]
527 fn test_batch_state_modification_isolation() {
528 let mut batch = BatchStateVector::new(3, 2, BatchConfig::default())
529 .expect("Failed to create batch for isolation test");
530
531 let modified = Array1::from_vec(vec![
533 Complex64::new(0.0, 0.0),
534 Complex64::new(1.0, 0.0),
535 Complex64::new(0.0, 0.0),
536 Complex64::new(0.0, 0.0),
537 ]);
538 batch
539 .set_state(1, &modified)
540 .expect("Failed to set modified state");
541
542 let state0 = batch.get_state(0).expect("Failed to get state 0");
544 let state2 = batch.get_state(2).expect("Failed to get state 2");
545
546 assert_eq!(state0[0], Complex64::new(1.0, 0.0));
547 assert_eq!(state2[0], Complex64::new(1.0, 0.0));
548 }
549
550 #[test]
551 fn test_split_merge_roundtrip() {
552 let batch = BatchStateVector::new(10, 2, BatchConfig::default())
553 .expect("Failed to create batch for roundtrip test");
554 let original_states = batch.states.clone();
555
556 let chunks = split_batch(&batch, 3);
558 let merged = merge_batches(chunks, BatchConfig::default())
559 .expect("Failed to merge chunks in roundtrip test");
560
561 assert_eq!(merged.batch_size(), 10);
563 for i in 0..10 {
564 for j in 0..4 {
565 assert_eq!(merged.states[[i, j]], original_states[[i, j]]);
566 }
567 }
568 }
569
570 #[test]
571 fn test_batch_execution_result_fields() {
572 let result = BatchExecutionResult {
573 final_states: Array2::zeros((5, 4)),
574 execution_time_ms: 100.0,
575 gates_applied: 50,
576 used_gpu: false,
577 };
578
579 assert_eq!(result.execution_time_ms, 100.0);
580 assert_eq!(result.gates_applied, 50);
581 assert!(!result.used_gpu);
582 }
583
584 #[test]
585 fn test_batch_measurement_result_fields() {
586 use scirs2_core::ndarray::Array2;
587
588 let result = BatchMeasurementResult {
589 outcomes: Array2::zeros((5, 10)),
590 probabilities: Array2::zeros((5, 10)),
591 post_measurement_states: None,
592 };
593
594 assert_eq!(result.outcomes.dim(), (5, 10));
595 assert_eq!(result.probabilities.dim(), (5, 10));
596 assert!(result.post_measurement_states.is_none());
597 }
598}