quantrs2_core/batch/
mod.rs

1//! Batch operations for quantum circuits using SciRS2 parallel algorithms
2//!
3//! This module provides efficient batch processing for quantum operations,
4//! leveraging SciRS2's parallel computing capabilities for performance.
5
6pub mod execution;
7pub mod measurement;
8pub mod operations;
9pub mod optimization;
10
11use crate::{
12    error::{QuantRS2Error, QuantRS2Result},
13    gate::GateOp,
14    qubit::QubitId,
15};
16use scirs2_core::ndarray::{Array1, Array2};
17use scirs2_core::Complex64;
18
19/// Configuration for batch operations
20#[derive(Debug, Clone)]
21pub struct BatchConfig {
22    /// Number of parallel workers
23    pub num_workers: Option<usize>,
24    /// Maximum batch size for processing
25    pub max_batch_size: usize,
26    /// Whether to use GPU acceleration if available
27    pub use_gpu: bool,
28    /// Memory limit in bytes
29    pub memory_limit: Option<usize>,
30    /// Enable cache for repeated operations
31    pub enable_cache: bool,
32}
33
34impl Default for BatchConfig {
35    fn default() -> Self {
36        Self {
37            num_workers: None, // Use system default
38            max_batch_size: 1024,
39            use_gpu: true,
40            memory_limit: None,
41            enable_cache: true,
42        }
43    }
44}
45
46/// Batch of quantum states for parallel processing
47#[derive(Clone)]
48pub struct BatchStateVector {
49    /// The batch of state vectors (batch_size, 2^n_qubits)
50    pub states: Array2<Complex64>,
51    /// Number of qubits
52    pub n_qubits: usize,
53    /// Batch configuration
54    pub config: BatchConfig,
55}
56
57impl BatchStateVector {
58    /// Create a new batch of quantum states
59    pub fn new(batch_size: usize, n_qubits: usize, config: BatchConfig) -> QuantRS2Result<Self> {
60        let state_size = 1 << n_qubits;
61
62        // Check memory constraints
63        if let Some(limit) = config.memory_limit {
64            let required_memory = batch_size * state_size * std::mem::size_of::<Complex64>();
65            if required_memory > limit {
66                return Err(QuantRS2Error::InvalidInput(format!(
67                    "Batch requires {} bytes, limit is {}",
68                    required_memory, limit
69                )));
70            }
71        }
72
73        // Initialize all states to |0...0>
74        let mut states = Array2::zeros((batch_size, state_size));
75        for i in 0..batch_size {
76            states[[i, 0]] = Complex64::new(1.0, 0.0);
77        }
78
79        Ok(Self {
80            states,
81            n_qubits,
82            config,
83        })
84    }
85
86    /// Create from existing state vectors
87    pub fn from_states(states: Array2<Complex64>, config: BatchConfig) -> QuantRS2Result<Self> {
88        let (_batch_size, state_size) = states.dim();
89
90        // Determine number of qubits
91        let n_qubits = (state_size as f64).log2().round() as usize;
92        if 1 << n_qubits != state_size {
93            return Err(QuantRS2Error::InvalidInput(
94                "State size must be a power of 2".to_string(),
95            ));
96        }
97
98        Ok(Self {
99            states,
100            n_qubits,
101            config,
102        })
103    }
104
105    /// Get batch size
106    pub fn batch_size(&self) -> usize {
107        self.states.nrows()
108    }
109
110    /// Get a specific state from the batch
111    pub fn get_state(&self, index: usize) -> QuantRS2Result<Array1<Complex64>> {
112        if index >= self.batch_size() {
113            return Err(QuantRS2Error::InvalidInput(format!(
114                "Index {} out of bounds for batch size {}",
115                index,
116                self.batch_size()
117            )));
118        }
119
120        Ok(self.states.row(index).to_owned())
121    }
122
123    /// Set a specific state in the batch
124    pub fn set_state(&mut self, index: usize, state: &Array1<Complex64>) -> QuantRS2Result<()> {
125        if index >= self.batch_size() {
126            return Err(QuantRS2Error::InvalidInput(format!(
127                "Index {} out of bounds for batch size {}",
128                index,
129                self.batch_size()
130            )));
131        }
132
133        if state.len() != self.states.ncols() {
134            return Err(QuantRS2Error::InvalidInput(format!(
135                "State size {} doesn't match expected {}",
136                state.len(),
137                self.states.ncols()
138            )));
139        }
140
141        self.states.row_mut(index).assign(state);
142        Ok(())
143    }
144}
145
146/// Batch circuit execution result
147#[derive(Debug, Clone)]
148pub struct BatchExecutionResult {
149    /// Final state vectors
150    pub final_states: Array2<Complex64>,
151    /// Execution time in milliseconds
152    pub execution_time_ms: f64,
153    /// Number of gates applied
154    pub gates_applied: usize,
155    /// Whether GPU was used
156    pub used_gpu: bool,
157}
158
159/// Batch measurement result
160#[derive(Debug, Clone)]
161pub struct BatchMeasurementResult {
162    /// Measurement outcomes for each state in the batch
163    /// Shape: (batch_size, num_measurements)
164    pub outcomes: Array2<u8>,
165    /// Probabilities for each outcome
166    /// Shape: (batch_size, num_measurements)
167    pub probabilities: Array2<f64>,
168    /// Post-measurement states (if requested)
169    pub post_measurement_states: Option<Array2<Complex64>>,
170}
171
172/// Trait for batch-optimized gates
173pub trait BatchGateOp: GateOp {
174    /// Apply this gate to a batch of states
175    fn apply_batch(
176        &self,
177        batch: &mut BatchStateVector,
178        target_qubits: &[QubitId],
179    ) -> QuantRS2Result<()>;
180
181    /// Check if this gate has batch optimization
182    fn has_batch_optimization(&self) -> bool {
183        true
184    }
185}
186
187/// Helper to create batches from a collection of states
188pub fn create_batch<I>(states: I, config: BatchConfig) -> QuantRS2Result<BatchStateVector>
189where
190    I: IntoIterator<Item = Array1<Complex64>>,
191{
192    let states_vec: Vec<_> = states.into_iter().collect();
193    if states_vec.is_empty() {
194        return Err(QuantRS2Error::InvalidInput(
195            "Cannot create empty batch".to_string(),
196        ));
197    }
198
199    let state_size = states_vec[0].len();
200    let batch_size = states_vec.len();
201
202    // Validate all states have same size
203    for (i, state) in states_vec.iter().enumerate() {
204        if state.len() != state_size {
205            return Err(QuantRS2Error::InvalidInput(format!(
206                "State {} has size {}, expected {}",
207                i,
208                state.len(),
209                state_size
210            )));
211        }
212    }
213
214    // Create 2D array
215    let mut batch_array = Array2::zeros((batch_size, state_size));
216    for (i, state) in states_vec.iter().enumerate() {
217        batch_array.row_mut(i).assign(state);
218    }
219
220    BatchStateVector::from_states(batch_array, config)
221}
222
223/// Helper to split a large batch into smaller chunks
224pub fn split_batch(batch: &BatchStateVector, chunk_size: usize) -> Vec<BatchStateVector> {
225    let mut chunks = Vec::new();
226    let batch_size = batch.batch_size();
227
228    for start in (0..batch_size).step_by(chunk_size) {
229        let end = (start + chunk_size).min(batch_size);
230        let chunk_states = batch
231            .states
232            .slice(scirs2_core::ndarray::s![start..end, ..])
233            .to_owned();
234
235        if let Ok(chunk) = BatchStateVector::from_states(chunk_states, batch.config.clone()) {
236            chunks.push(chunk);
237        }
238    }
239
240    chunks
241}
242
243/// Merge multiple batches into one
244pub fn merge_batches(
245    batches: Vec<BatchStateVector>,
246    config: BatchConfig,
247) -> QuantRS2Result<BatchStateVector> {
248    if batches.is_empty() {
249        return Err(QuantRS2Error::InvalidInput(
250            "Cannot merge empty batches".to_string(),
251        ));
252    }
253
254    // Validate all batches have same n_qubits
255    let n_qubits = batches[0].n_qubits;
256    for (i, batch) in batches.iter().enumerate() {
257        if batch.n_qubits != n_qubits {
258            return Err(QuantRS2Error::InvalidInput(format!(
259                "Batch {} has {} qubits, expected {}",
260                i, batch.n_qubits, n_qubits
261            )));
262        }
263    }
264
265    // Concatenate states
266    let mut all_states = Vec::new();
267    for batch in batches {
268        for i in 0..batch.batch_size() {
269            all_states.push(batch.states.row(i).to_owned());
270        }
271    }
272
273    create_batch(all_states, config)
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_batch_creation() {
282        let batch = BatchStateVector::new(10, 3, BatchConfig::default()).unwrap();
283        assert_eq!(batch.batch_size(), 10);
284        assert_eq!(batch.n_qubits, 3);
285        assert_eq!(batch.states.ncols(), 8); // 2^3
286
287        // Check initial state is |000>
288        for i in 0..10 {
289            let state = batch.get_state(i).unwrap();
290            assert_eq!(state[0], Complex64::new(1.0, 0.0));
291            for j in 1..8 {
292                assert_eq!(state[j], Complex64::new(0.0, 0.0));
293            }
294        }
295    }
296
297    #[test]
298    fn test_batch_from_states() {
299        let mut states = Array2::zeros((5, 4));
300        for i in 0..5 {
301            states[[i, i % 4]] = Complex64::new(1.0, 0.0);
302        }
303
304        let batch = BatchStateVector::from_states(states, BatchConfig::default()).unwrap();
305        assert_eq!(batch.batch_size(), 5);
306        assert_eq!(batch.n_qubits, 2); // 2^2 = 4
307    }
308
309    #[test]
310    fn test_create_batch_helper() {
311        let states = vec![
312            Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
313            Array1::from_vec(vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]),
314            Array1::from_vec(vec![Complex64::new(0.707, 0.0), Complex64::new(0.707, 0.0)]),
315        ];
316
317        let batch = create_batch(states, BatchConfig::default()).unwrap();
318        assert_eq!(batch.batch_size(), 3);
319        assert_eq!(batch.n_qubits, 1);
320    }
321
322    #[test]
323    fn test_split_batch() {
324        let batch = BatchStateVector::new(10, 2, BatchConfig::default()).unwrap();
325        let chunks = split_batch(&batch, 3);
326
327        assert_eq!(chunks.len(), 4); // 3, 3, 3, 1
328        assert_eq!(chunks[0].batch_size(), 3);
329        assert_eq!(chunks[1].batch_size(), 3);
330        assert_eq!(chunks[2].batch_size(), 3);
331        assert_eq!(chunks[3].batch_size(), 1);
332    }
333
334    #[test]
335    fn test_merge_batches() {
336        let batch1 = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
337        let batch2 = BatchStateVector::new(2, 2, BatchConfig::default()).unwrap();
338
339        let merged = merge_batches(vec![batch1, batch2], BatchConfig::default()).unwrap();
340        assert_eq!(merged.batch_size(), 5);
341        assert_eq!(merged.n_qubits, 2);
342    }
343
344    // === Comprehensive Batch Operation Tests ===
345
346    #[test]
347    fn test_batch_memory_limit_enforcement() {
348        let mut config = BatchConfig::default();
349        // Set a very small memory limit
350        config.memory_limit = Some(100);
351
352        // Try to create a batch that exceeds the limit
353        let result = BatchStateVector::new(10, 5, config);
354        assert!(result.is_err());
355
356        // Verify error message
357        if let Err(e) = result {
358            let msg = format!("{:?}", e);
359            assert!(msg.contains("bytes") || msg.contains("limit"));
360        }
361    }
362
363    #[test]
364    fn test_batch_state_normalization() {
365        let batch = BatchStateVector::new(5, 2, BatchConfig::default()).unwrap();
366
367        // Check that all states are normalized
368        for i in 0..batch.batch_size() {
369            let state = batch.get_state(i).unwrap();
370            let norm: f64 = state.iter().map(|c| c.norm_sqr()).sum();
371            assert!(
372                (norm - 1.0).abs() < 1e-10,
373                "State {} not normalized: {}",
374                i,
375                norm
376            );
377        }
378    }
379
380    #[test]
381    fn test_batch_state_get_set_roundtrip() {
382        let mut batch = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
383
384        // Create a custom state
385        let custom_state = Array1::from_vec(vec![
386            Complex64::new(0.5, 0.0),
387            Complex64::new(0.5, 0.0),
388            Complex64::new(0.5, 0.0),
389            Complex64::new(0.5, 0.0),
390        ]);
391
392        // Set and get
393        batch.set_state(1, &custom_state).unwrap();
394        let retrieved = batch.get_state(1).unwrap();
395
396        // Verify
397        for i in 0..4 {
398            assert!((retrieved[i] - custom_state[i]).norm() < 1e-10);
399        }
400    }
401
402    #[test]
403    fn test_batch_out_of_bounds_access() {
404        let batch = BatchStateVector::new(5, 2, BatchConfig::default()).unwrap();
405
406        // Get out of bounds
407        assert!(batch.get_state(5).is_err());
408        assert!(batch.get_state(100).is_err());
409    }
410
411    #[test]
412    fn test_batch_set_wrong_size_state() {
413        let mut batch = BatchStateVector::new(5, 2, BatchConfig::default()).unwrap();
414
415        // Try to set state with wrong size
416        let wrong_state =
417            Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
418        assert!(batch.set_state(0, &wrong_state).is_err());
419    }
420
421    #[test]
422    fn test_empty_batch_creation_fails() {
423        let result = create_batch(Vec::<Array1<Complex64>>::new(), BatchConfig::default());
424        assert!(result.is_err());
425    }
426
427    #[test]
428    fn test_batch_mismatched_state_sizes() {
429        let states = vec![
430            Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
431            Array1::from_vec(vec![
432                Complex64::new(1.0, 0.0),
433                Complex64::new(0.0, 0.0),
434                Complex64::new(0.0, 0.0),
435                Complex64::new(0.0, 0.0),
436            ]),
437        ];
438
439        let result = create_batch(states, BatchConfig::default());
440        assert!(result.is_err());
441    }
442
443    #[test]
444    fn test_batch_invalid_state_size() {
445        // State size not a power of 2
446        let states = Array2::zeros((5, 3));
447        let result = BatchStateVector::from_states(states, BatchConfig::default());
448        assert!(result.is_err());
449    }
450
451    #[test]
452    fn test_split_batch_single_element() {
453        let batch = BatchStateVector::new(1, 2, BatchConfig::default()).unwrap();
454        let chunks = split_batch(&batch, 10);
455
456        assert_eq!(chunks.len(), 1);
457        assert_eq!(chunks[0].batch_size(), 1);
458    }
459
460    #[test]
461    fn test_split_batch_exact_division() {
462        let batch = BatchStateVector::new(9, 2, BatchConfig::default()).unwrap();
463        let chunks = split_batch(&batch, 3);
464
465        assert_eq!(chunks.len(), 3);
466        for chunk in &chunks {
467            assert_eq!(chunk.batch_size(), 3);
468        }
469    }
470
471    #[test]
472    fn test_merge_batches_empty() {
473        let result = merge_batches(Vec::new(), BatchConfig::default());
474        assert!(result.is_err());
475    }
476
477    #[test]
478    fn test_merge_batches_mismatched_qubits() {
479        let batch1 = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
480        let batch2 = BatchStateVector::new(2, 3, BatchConfig::default()).unwrap();
481
482        let result = merge_batches(vec![batch1, batch2], BatchConfig::default());
483        assert!(result.is_err());
484    }
485
486    #[test]
487    fn test_batch_config_defaults() {
488        let config = BatchConfig::default();
489        assert!(config.num_workers.is_none());
490        assert_eq!(config.max_batch_size, 1024);
491        assert!(config.use_gpu);
492        assert!(config.memory_limit.is_none());
493        assert!(config.enable_cache);
494    }
495
496    #[test]
497    fn test_large_batch_creation() {
498        // Test with larger batch size
499        let batch = BatchStateVector::new(100, 4, BatchConfig::default()).unwrap();
500        assert_eq!(batch.batch_size(), 100);
501        assert_eq!(batch.n_qubits, 4);
502        assert_eq!(batch.states.ncols(), 16); // 2^4
503    }
504
505    #[test]
506    fn test_batch_state_modification_isolation() {
507        let mut batch = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
508
509        // Modify one state
510        let modified = Array1::from_vec(vec![
511            Complex64::new(0.0, 0.0),
512            Complex64::new(1.0, 0.0),
513            Complex64::new(0.0, 0.0),
514            Complex64::new(0.0, 0.0),
515        ]);
516        batch.set_state(1, &modified).unwrap();
517
518        // Check that other states are unchanged
519        let state0 = batch.get_state(0).unwrap();
520        let state2 = batch.get_state(2).unwrap();
521
522        assert_eq!(state0[0], Complex64::new(1.0, 0.0));
523        assert_eq!(state2[0], Complex64::new(1.0, 0.0));
524    }
525
526    #[test]
527    fn test_split_merge_roundtrip() {
528        let batch = BatchStateVector::new(10, 2, BatchConfig::default()).unwrap();
529        let original_states = batch.states.clone();
530
531        // Split and merge
532        let chunks = split_batch(&batch, 3);
533        let merged = merge_batches(chunks, BatchConfig::default()).unwrap();
534
535        // Verify states are preserved
536        assert_eq!(merged.batch_size(), 10);
537        for i in 0..10 {
538            for j in 0..4 {
539                assert_eq!(merged.states[[i, j]], original_states[[i, j]]);
540            }
541        }
542    }
543
544    #[test]
545    fn test_batch_execution_result_fields() {
546        let result = BatchExecutionResult {
547            final_states: Array2::zeros((5, 4)),
548            execution_time_ms: 100.0,
549            gates_applied: 50,
550            used_gpu: false,
551        };
552
553        assert_eq!(result.execution_time_ms, 100.0);
554        assert_eq!(result.gates_applied, 50);
555        assert!(!result.used_gpu);
556    }
557
558    #[test]
559    fn test_batch_measurement_result_fields() {
560        use scirs2_core::ndarray::Array2;
561
562        let result = BatchMeasurementResult {
563            outcomes: Array2::zeros((5, 10)),
564            probabilities: Array2::zeros((5, 10)),
565            post_measurement_states: None,
566        };
567
568        assert_eq!(result.outcomes.dim(), (5, 10));
569        assert_eq!(result.probabilities.dim(), (5, 10));
570        assert!(result.post_measurement_states.is_none());
571    }
572}