quantrs2_core/batch/
mod.rs

1//! Batch operations for quantum circuits using SciRS2 parallel algorithms
2//!
3//! This module provides efficient batch processing for quantum operations,
4//! leveraging SciRS2's parallel computing capabilities for performance.
5
6pub mod execution;
7pub mod measurement;
8pub mod operations;
9pub mod optimization;
10
11use crate::{
12    error::{QuantRS2Error, QuantRS2Result},
13    gate::GateOp,
14    qubit::QubitId,
15};
16use scirs2_core::ndarray::{Array1, Array2};
17use scirs2_core::Complex64;
18
19/// Configuration for batch operations
20#[derive(Debug, Clone)]
21pub struct BatchConfig {
22    /// Number of parallel workers
23    pub num_workers: Option<usize>,
24    /// Maximum batch size for processing
25    pub max_batch_size: usize,
26    /// Whether to use GPU acceleration if available
27    pub use_gpu: bool,
28    /// Memory limit in bytes
29    pub memory_limit: Option<usize>,
30    /// Enable cache for repeated operations
31    pub enable_cache: bool,
32}
33
34impl Default for BatchConfig {
35    fn default() -> Self {
36        Self {
37            num_workers: None, // Use system default
38            max_batch_size: 1024,
39            use_gpu: true,
40            memory_limit: None,
41            enable_cache: true,
42        }
43    }
44}
45
46/// Batch of quantum states for parallel processing
47#[derive(Clone)]
48pub struct BatchStateVector {
49    /// The batch of state vectors (batch_size, 2^n_qubits)
50    pub states: Array2<Complex64>,
51    /// Number of qubits
52    pub n_qubits: usize,
53    /// Batch configuration
54    pub config: BatchConfig,
55}
56
57impl BatchStateVector {
58    /// Create a new batch of quantum states
59    pub fn new(batch_size: usize, n_qubits: usize, config: BatchConfig) -> QuantRS2Result<Self> {
60        let state_size = 1 << n_qubits;
61
62        // Check memory constraints
63        if let Some(limit) = config.memory_limit {
64            let required_memory = batch_size * state_size * std::mem::size_of::<Complex64>();
65            if required_memory > limit {
66                return Err(QuantRS2Error::InvalidInput(format!(
67                    "Batch requires {required_memory} bytes, limit is {limit}"
68                )));
69            }
70        }
71
72        // Initialize all states to |0...0>
73        let mut states = Array2::zeros((batch_size, state_size));
74        for i in 0..batch_size {
75            states[[i, 0]] = Complex64::new(1.0, 0.0);
76        }
77
78        Ok(Self {
79            states,
80            n_qubits,
81            config,
82        })
83    }
84
85    /// Create from existing state vectors
86    pub fn from_states(states: Array2<Complex64>, config: BatchConfig) -> QuantRS2Result<Self> {
87        let (_batch_size, state_size) = states.dim();
88
89        // Determine number of qubits
90        let n_qubits = (state_size as f64).log2().round() as usize;
91        if 1 << n_qubits != state_size {
92            return Err(QuantRS2Error::InvalidInput(
93                "State size must be a power of 2".to_string(),
94            ));
95        }
96
97        Ok(Self {
98            states,
99            n_qubits,
100            config,
101        })
102    }
103
104    /// Get batch size
105    pub fn batch_size(&self) -> usize {
106        self.states.nrows()
107    }
108
109    /// Get a specific state from the batch
110    pub fn get_state(&self, index: usize) -> QuantRS2Result<Array1<Complex64>> {
111        if index >= self.batch_size() {
112            return Err(QuantRS2Error::InvalidInput(format!(
113                "Index {} out of bounds for batch size {}",
114                index,
115                self.batch_size()
116            )));
117        }
118
119        Ok(self.states.row(index).to_owned())
120    }
121
122    /// Set a specific state in the batch
123    pub fn set_state(&mut self, index: usize, state: &Array1<Complex64>) -> QuantRS2Result<()> {
124        if index >= self.batch_size() {
125            return Err(QuantRS2Error::InvalidInput(format!(
126                "Index {} out of bounds for batch size {}",
127                index,
128                self.batch_size()
129            )));
130        }
131
132        if state.len() != self.states.ncols() {
133            return Err(QuantRS2Error::InvalidInput(format!(
134                "State size {} doesn't match expected {}",
135                state.len(),
136                self.states.ncols()
137            )));
138        }
139
140        self.states.row_mut(index).assign(state);
141        Ok(())
142    }
143}
144
145/// Batch circuit execution result
146#[derive(Debug, Clone)]
147pub struct BatchExecutionResult {
148    /// Final state vectors
149    pub final_states: Array2<Complex64>,
150    /// Execution time in milliseconds
151    pub execution_time_ms: f64,
152    /// Number of gates applied
153    pub gates_applied: usize,
154    /// Whether GPU was used
155    pub used_gpu: bool,
156}
157
158/// Batch measurement result
159#[derive(Debug, Clone)]
160pub struct BatchMeasurementResult {
161    /// Measurement outcomes for each state in the batch
162    /// Shape: (batch_size, num_measurements)
163    pub outcomes: Array2<u8>,
164    /// Probabilities for each outcome
165    /// Shape: (batch_size, num_measurements)
166    pub probabilities: Array2<f64>,
167    /// Post-measurement states (if requested)
168    pub post_measurement_states: Option<Array2<Complex64>>,
169}
170
171/// Trait for batch-optimized gates
172pub trait BatchGateOp: GateOp {
173    /// Apply this gate to a batch of states
174    fn apply_batch(
175        &self,
176        batch: &mut BatchStateVector,
177        target_qubits: &[QubitId],
178    ) -> QuantRS2Result<()>;
179
180    /// Check if this gate has batch optimization
181    fn has_batch_optimization(&self) -> bool {
182        true
183    }
184}
185
186/// Helper to create batches from a collection of states
187pub fn create_batch<I>(states: I, config: BatchConfig) -> QuantRS2Result<BatchStateVector>
188where
189    I: IntoIterator<Item = Array1<Complex64>>,
190{
191    let states_vec: Vec<_> = states.into_iter().collect();
192    if states_vec.is_empty() {
193        return Err(QuantRS2Error::InvalidInput(
194            "Cannot create empty batch".to_string(),
195        ));
196    }
197
198    let state_size = states_vec[0].len();
199    let batch_size = states_vec.len();
200
201    // Validate all states have same size
202    for (i, state) in states_vec.iter().enumerate() {
203        if state.len() != state_size {
204            return Err(QuantRS2Error::InvalidInput(format!(
205                "State {} has size {}, expected {}",
206                i,
207                state.len(),
208                state_size
209            )));
210        }
211    }
212
213    // Create 2D array
214    let mut batch_array = Array2::zeros((batch_size, state_size));
215    for (i, state) in states_vec.iter().enumerate() {
216        batch_array.row_mut(i).assign(state);
217    }
218
219    BatchStateVector::from_states(batch_array, config)
220}
221
222/// Helper to split a large batch into smaller chunks
223pub fn split_batch(batch: &BatchStateVector, chunk_size: usize) -> Vec<BatchStateVector> {
224    let mut chunks = Vec::new();
225    let batch_size = batch.batch_size();
226
227    for start in (0..batch_size).step_by(chunk_size) {
228        let end = (start + chunk_size).min(batch_size);
229        let chunk_states = batch
230            .states
231            .slice(scirs2_core::ndarray::s![start..end, ..])
232            .to_owned();
233
234        if let Ok(chunk) = BatchStateVector::from_states(chunk_states, batch.config.clone()) {
235            chunks.push(chunk);
236        }
237    }
238
239    chunks
240}
241
242/// Merge multiple batches into one
243pub fn merge_batches(
244    batches: Vec<BatchStateVector>,
245    config: BatchConfig,
246) -> QuantRS2Result<BatchStateVector> {
247    if batches.is_empty() {
248        return Err(QuantRS2Error::InvalidInput(
249            "Cannot merge empty batches".to_string(),
250        ));
251    }
252
253    // Validate all batches have same n_qubits
254    let n_qubits = batches[0].n_qubits;
255    for (i, batch) in batches.iter().enumerate() {
256        if batch.n_qubits != n_qubits {
257            return Err(QuantRS2Error::InvalidInput(format!(
258                "Batch {} has {} qubits, expected {}",
259                i, batch.n_qubits, n_qubits
260            )));
261        }
262    }
263
264    // Concatenate states
265    let mut all_states = Vec::new();
266    for batch in batches {
267        for i in 0..batch.batch_size() {
268            all_states.push(batch.states.row(i).to_owned());
269        }
270    }
271
272    create_batch(all_states, config)
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_batch_creation() {
281        let batch = BatchStateVector::new(10, 3, BatchConfig::default())
282            .expect("Failed to create batch state vector");
283        assert_eq!(batch.batch_size(), 10);
284        assert_eq!(batch.n_qubits, 3);
285        assert_eq!(batch.states.ncols(), 8); // 2^3
286
287        // Check initial state is |000>
288        for i in 0..10 {
289            let state = batch.get_state(i).expect("Failed to get state from batch");
290            assert_eq!(state[0], Complex64::new(1.0, 0.0));
291            for j in 1..8 {
292                assert_eq!(state[j], Complex64::new(0.0, 0.0));
293            }
294        }
295    }
296
297    #[test]
298    fn test_batch_from_states() {
299        let mut states = Array2::zeros((5, 4));
300        for i in 0..5 {
301            states[[i, i % 4]] = Complex64::new(1.0, 0.0);
302        }
303
304        let batch = BatchStateVector::from_states(states, BatchConfig::default())
305            .expect("Failed to create batch from states");
306        assert_eq!(batch.batch_size(), 5);
307        assert_eq!(batch.n_qubits, 2); // 2^2 = 4
308    }
309
310    #[test]
311    fn test_create_batch_helper() {
312        let states = vec![
313            Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
314            Array1::from_vec(vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]),
315            Array1::from_vec(vec![Complex64::new(0.707, 0.0), Complex64::new(0.707, 0.0)]),
316        ];
317
318        let batch = create_batch(states, BatchConfig::default())
319            .expect("Failed to create batch from state collection");
320        assert_eq!(batch.batch_size(), 3);
321        assert_eq!(batch.n_qubits, 1);
322    }
323
324    #[test]
325    fn test_split_batch() {
326        let batch = BatchStateVector::new(10, 2, BatchConfig::default())
327            .expect("Failed to create batch for split test");
328        let chunks = split_batch(&batch, 3);
329
330        assert_eq!(chunks.len(), 4); // 3, 3, 3, 1
331        assert_eq!(chunks[0].batch_size(), 3);
332        assert_eq!(chunks[1].batch_size(), 3);
333        assert_eq!(chunks[2].batch_size(), 3);
334        assert_eq!(chunks[3].batch_size(), 1);
335    }
336
337    #[test]
338    fn test_merge_batches() {
339        let batch1 = BatchStateVector::new(3, 2, BatchConfig::default())
340            .expect("Failed to create first batch");
341        let batch2 = BatchStateVector::new(2, 2, BatchConfig::default())
342            .expect("Failed to create second batch");
343
344        let merged = merge_batches(vec![batch1, batch2], BatchConfig::default())
345            .expect("Failed to merge batches");
346        assert_eq!(merged.batch_size(), 5);
347        assert_eq!(merged.n_qubits, 2);
348    }
349
350    // === Comprehensive Batch Operation Tests ===
351
352    #[test]
353    fn test_batch_memory_limit_enforcement() {
354        let mut config = BatchConfig::default();
355        // Set a very small memory limit
356        config.memory_limit = Some(100);
357
358        // Try to create a batch that exceeds the limit
359        let result = BatchStateVector::new(10, 5, config);
360        assert!(result.is_err());
361
362        // Verify error message
363        if let Err(e) = result {
364            let msg = format!("{:?}", e);
365            assert!(msg.contains("bytes") || msg.contains("limit"));
366        }
367    }
368
369    #[test]
370    fn test_batch_state_normalization() {
371        let batch = BatchStateVector::new(5, 2, BatchConfig::default())
372            .expect("Failed to create batch for normalization test");
373
374        // Check that all states are normalized
375        for i in 0..batch.batch_size() {
376            let state = batch
377                .get_state(i)
378                .expect("Failed to get state for normalization check");
379            let norm: f64 = state.iter().map(|c| c.norm_sqr()).sum();
380            assert!(
381                (norm - 1.0).abs() < 1e-10,
382                "State {} not normalized: {}",
383                i,
384                norm
385            );
386        }
387    }
388
389    #[test]
390    fn test_batch_state_get_set_roundtrip() {
391        let mut batch = BatchStateVector::new(3, 2, BatchConfig::default())
392            .expect("Failed to create batch for get/set test");
393
394        // Create a custom state
395        let custom_state = Array1::from_vec(vec![
396            Complex64::new(0.5, 0.0),
397            Complex64::new(0.5, 0.0),
398            Complex64::new(0.5, 0.0),
399            Complex64::new(0.5, 0.0),
400        ]);
401
402        // Set and get
403        batch
404            .set_state(1, &custom_state)
405            .expect("Failed to set custom state");
406        let retrieved = batch
407            .get_state(1)
408            .expect("Failed to retrieve state after set");
409
410        // Verify
411        for i in 0..4 {
412            assert!((retrieved[i] - custom_state[i]).norm() < 1e-10);
413        }
414    }
415
416    #[test]
417    fn test_batch_out_of_bounds_access() {
418        let batch = BatchStateVector::new(5, 2, BatchConfig::default())
419            .expect("Failed to create batch for bounds test");
420
421        // Get out of bounds
422        assert!(batch.get_state(5).is_err());
423        assert!(batch.get_state(100).is_err());
424    }
425
426    #[test]
427    fn test_batch_set_wrong_size_state() {
428        let mut batch = BatchStateVector::new(5, 2, BatchConfig::default())
429            .expect("Failed to create batch for wrong size test");
430
431        // Try to set state with wrong size
432        let wrong_state =
433            Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
434        assert!(batch.set_state(0, &wrong_state).is_err());
435    }
436
437    #[test]
438    fn test_empty_batch_creation_fails() {
439        let result = create_batch(Vec::<Array1<Complex64>>::new(), BatchConfig::default());
440        assert!(result.is_err());
441    }
442
443    #[test]
444    fn test_batch_mismatched_state_sizes() {
445        let states = vec![
446            Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
447            Array1::from_vec(vec![
448                Complex64::new(1.0, 0.0),
449                Complex64::new(0.0, 0.0),
450                Complex64::new(0.0, 0.0),
451                Complex64::new(0.0, 0.0),
452            ]),
453        ];
454
455        let result = create_batch(states, BatchConfig::default());
456        assert!(result.is_err());
457    }
458
459    #[test]
460    fn test_batch_invalid_state_size() {
461        // State size not a power of 2
462        let states = Array2::zeros((5, 3));
463        let result = BatchStateVector::from_states(states, BatchConfig::default());
464        assert!(result.is_err());
465    }
466
467    #[test]
468    fn test_split_batch_single_element() {
469        let batch = BatchStateVector::new(1, 2, BatchConfig::default())
470            .expect("Failed to create single element batch");
471        let chunks = split_batch(&batch, 10);
472
473        assert_eq!(chunks.len(), 1);
474        assert_eq!(chunks[0].batch_size(), 1);
475    }
476
477    #[test]
478    fn test_split_batch_exact_division() {
479        let batch = BatchStateVector::new(9, 2, BatchConfig::default())
480            .expect("Failed to create batch for exact division test");
481        let chunks = split_batch(&batch, 3);
482
483        assert_eq!(chunks.len(), 3);
484        for chunk in &chunks {
485            assert_eq!(chunk.batch_size(), 3);
486        }
487    }
488
489    #[test]
490    fn test_merge_batches_empty() {
491        let result = merge_batches(Vec::new(), BatchConfig::default());
492        assert!(result.is_err());
493    }
494
495    #[test]
496    fn test_merge_batches_mismatched_qubits() {
497        let batch1 = BatchStateVector::new(3, 2, BatchConfig::default())
498            .expect("Failed to create first batch with 2 qubits");
499        let batch2 = BatchStateVector::new(2, 3, BatchConfig::default())
500            .expect("Failed to create second batch with 3 qubits");
501
502        let result = merge_batches(vec![batch1, batch2], BatchConfig::default());
503        assert!(result.is_err());
504    }
505
506    #[test]
507    fn test_batch_config_defaults() {
508        let config = BatchConfig::default();
509        assert!(config.num_workers.is_none());
510        assert_eq!(config.max_batch_size, 1024);
511        assert!(config.use_gpu);
512        assert!(config.memory_limit.is_none());
513        assert!(config.enable_cache);
514    }
515
516    #[test]
517    fn test_large_batch_creation() {
518        // Test with larger batch size
519        let batch = BatchStateVector::new(100, 4, BatchConfig::default())
520            .expect("Failed to create large batch");
521        assert_eq!(batch.batch_size(), 100);
522        assert_eq!(batch.n_qubits, 4);
523        assert_eq!(batch.states.ncols(), 16); // 2^4
524    }
525
526    #[test]
527    fn test_batch_state_modification_isolation() {
528        let mut batch = BatchStateVector::new(3, 2, BatchConfig::default())
529            .expect("Failed to create batch for isolation test");
530
531        // Modify one state
532        let modified = Array1::from_vec(vec![
533            Complex64::new(0.0, 0.0),
534            Complex64::new(1.0, 0.0),
535            Complex64::new(0.0, 0.0),
536            Complex64::new(0.0, 0.0),
537        ]);
538        batch
539            .set_state(1, &modified)
540            .expect("Failed to set modified state");
541
542        // Check that other states are unchanged
543        let state0 = batch.get_state(0).expect("Failed to get state 0");
544        let state2 = batch.get_state(2).expect("Failed to get state 2");
545
546        assert_eq!(state0[0], Complex64::new(1.0, 0.0));
547        assert_eq!(state2[0], Complex64::new(1.0, 0.0));
548    }
549
550    #[test]
551    fn test_split_merge_roundtrip() {
552        let batch = BatchStateVector::new(10, 2, BatchConfig::default())
553            .expect("Failed to create batch for roundtrip test");
554        let original_states = batch.states.clone();
555
556        // Split and merge
557        let chunks = split_batch(&batch, 3);
558        let merged = merge_batches(chunks, BatchConfig::default())
559            .expect("Failed to merge chunks in roundtrip test");
560
561        // Verify states are preserved
562        assert_eq!(merged.batch_size(), 10);
563        for i in 0..10 {
564            for j in 0..4 {
565                assert_eq!(merged.states[[i, j]], original_states[[i, j]]);
566            }
567        }
568    }
569
570    #[test]
571    fn test_batch_execution_result_fields() {
572        let result = BatchExecutionResult {
573            final_states: Array2::zeros((5, 4)),
574            execution_time_ms: 100.0,
575            gates_applied: 50,
576            used_gpu: false,
577        };
578
579        assert_eq!(result.execution_time_ms, 100.0);
580        assert_eq!(result.gates_applied, 50);
581        assert!(!result.used_gpu);
582    }
583
584    #[test]
585    fn test_batch_measurement_result_fields() {
586        use scirs2_core::ndarray::Array2;
587
588        let result = BatchMeasurementResult {
589            outcomes: Array2::zeros((5, 10)),
590            probabilities: Array2::zeros((5, 10)),
591            post_measurement_states: None,
592        };
593
594        assert_eq!(result.outcomes.dim(), (5, 10));
595        assert_eq!(result.probabilities.dim(), (5, 10));
596        assert!(result.post_measurement_states.is_none());
597    }
598}