quantrs2_core/batch/
mod.rs

1//! Batch operations for quantum circuits using SciRS2 parallel algorithms
2//!
3//! This module provides efficient batch processing for quantum operations,
4//! leveraging SciRS2's parallel computing capabilities for performance.
5
6pub mod execution;
7pub mod measurement;
8pub mod operations;
9pub mod optimization;
10
11use crate::{
12    error::{QuantRS2Error, QuantRS2Result},
13    gate::GateOp,
14    qubit::QubitId,
15};
16use ndarray::{Array1, Array2, Array3};
17use num_complex::Complex64;
18use std::sync::Arc;
19
20/// Configuration for batch operations
21#[derive(Debug, Clone)]
22pub struct BatchConfig {
23    /// Number of parallel workers
24    pub num_workers: Option<usize>,
25    /// Maximum batch size for processing
26    pub max_batch_size: usize,
27    /// Whether to use GPU acceleration if available
28    pub use_gpu: bool,
29    /// Memory limit in bytes
30    pub memory_limit: Option<usize>,
31    /// Enable cache for repeated operations
32    pub enable_cache: bool,
33}
34
35impl Default for BatchConfig {
36    fn default() -> Self {
37        Self {
38            num_workers: None, // Use system default
39            max_batch_size: 1024,
40            use_gpu: true,
41            memory_limit: None,
42            enable_cache: true,
43        }
44    }
45}
46
47/// Batch of quantum states for parallel processing
48#[derive(Clone)]
49pub struct BatchStateVector {
50    /// The batch of state vectors (batch_size, 2^n_qubits)
51    pub states: Array2<Complex64>,
52    /// Number of qubits
53    pub n_qubits: usize,
54    /// Batch configuration
55    pub config: BatchConfig,
56}
57
58impl BatchStateVector {
59    /// Create a new batch of quantum states
60    pub fn new(batch_size: usize, n_qubits: usize, config: BatchConfig) -> QuantRS2Result<Self> {
61        let state_size = 1 << n_qubits;
62
63        // Check memory constraints
64        if let Some(limit) = config.memory_limit {
65            let required_memory = batch_size * state_size * std::mem::size_of::<Complex64>();
66            if required_memory > limit {
67                return Err(QuantRS2Error::InvalidInput(format!(
68                    "Batch requires {} bytes, limit is {}",
69                    required_memory, limit
70                )));
71            }
72        }
73
74        // Initialize all states to |0...0>
75        let mut states = Array2::zeros((batch_size, state_size));
76        for i in 0..batch_size {
77            states[[i, 0]] = Complex64::new(1.0, 0.0);
78        }
79
80        Ok(Self {
81            states,
82            n_qubits,
83            config,
84        })
85    }
86
87    /// Create from existing state vectors
88    pub fn from_states(states: Array2<Complex64>, config: BatchConfig) -> QuantRS2Result<Self> {
89        let (batch_size, state_size) = states.dim();
90
91        // Determine number of qubits
92        let n_qubits = (state_size as f64).log2().round() as usize;
93        if 1 << n_qubits != state_size {
94            return Err(QuantRS2Error::InvalidInput(
95                "State size must be a power of 2".to_string(),
96            ));
97        }
98
99        Ok(Self {
100            states,
101            n_qubits,
102            config,
103        })
104    }
105
106    /// Get batch size
107    pub fn batch_size(&self) -> usize {
108        self.states.nrows()
109    }
110
111    /// Get a specific state from the batch
112    pub fn get_state(&self, index: usize) -> QuantRS2Result<Array1<Complex64>> {
113        if index >= self.batch_size() {
114            return Err(QuantRS2Error::InvalidInput(format!(
115                "Index {} out of bounds for batch size {}",
116                index,
117                self.batch_size()
118            )));
119        }
120
121        Ok(self.states.row(index).to_owned())
122    }
123
124    /// Set a specific state in the batch
125    pub fn set_state(&mut self, index: usize, state: &Array1<Complex64>) -> QuantRS2Result<()> {
126        if index >= self.batch_size() {
127            return Err(QuantRS2Error::InvalidInput(format!(
128                "Index {} out of bounds for batch size {}",
129                index,
130                self.batch_size()
131            )));
132        }
133
134        if state.len() != self.states.ncols() {
135            return Err(QuantRS2Error::InvalidInput(format!(
136                "State size {} doesn't match expected {}",
137                state.len(),
138                self.states.ncols()
139            )));
140        }
141
142        self.states.row_mut(index).assign(state);
143        Ok(())
144    }
145}
146
147/// Batch circuit execution result
148#[derive(Debug, Clone)]
149pub struct BatchExecutionResult {
150    /// Final state vectors
151    pub final_states: Array2<Complex64>,
152    /// Execution time in milliseconds
153    pub execution_time_ms: f64,
154    /// Number of gates applied
155    pub gates_applied: usize,
156    /// Whether GPU was used
157    pub used_gpu: bool,
158}
159
160/// Batch measurement result
161#[derive(Debug, Clone)]
162pub struct BatchMeasurementResult {
163    /// Measurement outcomes for each state in the batch
164    /// Shape: (batch_size, num_measurements)
165    pub outcomes: Array2<u8>,
166    /// Probabilities for each outcome
167    /// Shape: (batch_size, num_measurements)
168    pub probabilities: Array2<f64>,
169    /// Post-measurement states (if requested)
170    pub post_measurement_states: Option<Array2<Complex64>>,
171}
172
173/// Trait for batch-optimized gates
174pub trait BatchGateOp: GateOp {
175    /// Apply this gate to a batch of states
176    fn apply_batch(
177        &self,
178        batch: &mut BatchStateVector,
179        target_qubits: &[QubitId],
180    ) -> QuantRS2Result<()>;
181
182    /// Check if this gate has batch optimization
183    fn has_batch_optimization(&self) -> bool {
184        true
185    }
186}
187
188/// Helper to create batches from a collection of states
189pub fn create_batch<I>(states: I, config: BatchConfig) -> QuantRS2Result<BatchStateVector>
190where
191    I: IntoIterator<Item = Array1<Complex64>>,
192{
193    let states_vec: Vec<_> = states.into_iter().collect();
194    if states_vec.is_empty() {
195        return Err(QuantRS2Error::InvalidInput(
196            "Cannot create empty batch".to_string(),
197        ));
198    }
199
200    let state_size = states_vec[0].len();
201    let batch_size = states_vec.len();
202
203    // Validate all states have same size
204    for (i, state) in states_vec.iter().enumerate() {
205        if state.len() != state_size {
206            return Err(QuantRS2Error::InvalidInput(format!(
207                "State {} has size {}, expected {}",
208                i,
209                state.len(),
210                state_size
211            )));
212        }
213    }
214
215    // Create 2D array
216    let mut batch_array = Array2::zeros((batch_size, state_size));
217    for (i, state) in states_vec.iter().enumerate() {
218        batch_array.row_mut(i).assign(state);
219    }
220
221    BatchStateVector::from_states(batch_array, config)
222}
223
224/// Helper to split a large batch into smaller chunks
225pub fn split_batch(batch: &BatchStateVector, chunk_size: usize) -> Vec<BatchStateVector> {
226    let mut chunks = Vec::new();
227    let batch_size = batch.batch_size();
228
229    for start in (0..batch_size).step_by(chunk_size) {
230        let end = (start + chunk_size).min(batch_size);
231        let chunk_states = batch.states.slice(ndarray::s![start..end, ..]).to_owned();
232
233        if let Ok(chunk) = BatchStateVector::from_states(chunk_states, batch.config.clone()) {
234            chunks.push(chunk);
235        }
236    }
237
238    chunks
239}
240
241/// Merge multiple batches into one
242pub fn merge_batches(
243    batches: Vec<BatchStateVector>,
244    config: BatchConfig,
245) -> QuantRS2Result<BatchStateVector> {
246    if batches.is_empty() {
247        return Err(QuantRS2Error::InvalidInput(
248            "Cannot merge empty batches".to_string(),
249        ));
250    }
251
252    // Validate all batches have same n_qubits
253    let n_qubits = batches[0].n_qubits;
254    for (i, batch) in batches.iter().enumerate() {
255        if batch.n_qubits != n_qubits {
256            return Err(QuantRS2Error::InvalidInput(format!(
257                "Batch {} has {} qubits, expected {}",
258                i, batch.n_qubits, n_qubits
259            )));
260        }
261    }
262
263    // Concatenate states
264    let mut all_states = Vec::new();
265    for batch in batches {
266        for i in 0..batch.batch_size() {
267            all_states.push(batch.states.row(i).to_owned());
268        }
269    }
270
271    create_batch(all_states, config)
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_batch_creation() {
280        let batch = BatchStateVector::new(10, 3, BatchConfig::default()).unwrap();
281        assert_eq!(batch.batch_size(), 10);
282        assert_eq!(batch.n_qubits, 3);
283        assert_eq!(batch.states.ncols(), 8); // 2^3
284
285        // Check initial state is |000>
286        for i in 0..10 {
287            let state = batch.get_state(i).unwrap();
288            assert_eq!(state[0], Complex64::new(1.0, 0.0));
289            for j in 1..8 {
290                assert_eq!(state[j], Complex64::new(0.0, 0.0));
291            }
292        }
293    }
294
295    #[test]
296    fn test_batch_from_states() {
297        let mut states = Array2::zeros((5, 4));
298        for i in 0..5 {
299            states[[i, i % 4]] = Complex64::new(1.0, 0.0);
300        }
301
302        let batch = BatchStateVector::from_states(states, BatchConfig::default()).unwrap();
303        assert_eq!(batch.batch_size(), 5);
304        assert_eq!(batch.n_qubits, 2); // 2^2 = 4
305    }
306
307    #[test]
308    fn test_create_batch_helper() {
309        let states = vec![
310            Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
311            Array1::from_vec(vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]),
312            Array1::from_vec(vec![Complex64::new(0.707, 0.0), Complex64::new(0.707, 0.0)]),
313        ];
314
315        let batch = create_batch(states, BatchConfig::default()).unwrap();
316        assert_eq!(batch.batch_size(), 3);
317        assert_eq!(batch.n_qubits, 1);
318    }
319
320    #[test]
321    fn test_split_batch() {
322        let batch = BatchStateVector::new(10, 2, BatchConfig::default()).unwrap();
323        let chunks = split_batch(&batch, 3);
324
325        assert_eq!(chunks.len(), 4); // 3, 3, 3, 1
326        assert_eq!(chunks[0].batch_size(), 3);
327        assert_eq!(chunks[1].batch_size(), 3);
328        assert_eq!(chunks[2].batch_size(), 3);
329        assert_eq!(chunks[3].batch_size(), 1);
330    }
331
332    #[test]
333    fn test_merge_batches() {
334        let batch1 = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
335        let batch2 = BatchStateVector::new(2, 2, BatchConfig::default()).unwrap();
336
337        let merged = merge_batches(vec![batch1, batch2], BatchConfig::default()).unwrap();
338        assert_eq!(merged.batch_size(), 5);
339        assert_eq!(merged.n_qubits, 2);
340    }
341}