quantrs2_core/batch/
mod.rs

1//! Batch operations for quantum circuits using SciRS2 parallel algorithms
2//!
3//! This module provides efficient batch processing for quantum operations,
4//! leveraging SciRS2's parallel computing capabilities for performance.
5
6pub mod execution;
7pub mod measurement;
8pub mod operations;
9pub mod optimization;
10
11use crate::{
12    error::{QuantRS2Error, QuantRS2Result},
13    gate::GateOp,
14    qubit::QubitId,
15};
16use ndarray::{Array1, Array2};
17use num_complex::Complex64;
18
19/// Configuration for batch operations
20#[derive(Debug, Clone)]
21pub struct BatchConfig {
22    /// Number of parallel workers
23    pub num_workers: Option<usize>,
24    /// Maximum batch size for processing
25    pub max_batch_size: usize,
26    /// Whether to use GPU acceleration if available
27    pub use_gpu: bool,
28    /// Memory limit in bytes
29    pub memory_limit: Option<usize>,
30    /// Enable cache for repeated operations
31    pub enable_cache: bool,
32}
33
34impl Default for BatchConfig {
35    fn default() -> Self {
36        Self {
37            num_workers: None, // Use system default
38            max_batch_size: 1024,
39            use_gpu: true,
40            memory_limit: None,
41            enable_cache: true,
42        }
43    }
44}
45
46/// Batch of quantum states for parallel processing
47#[derive(Clone)]
48pub struct BatchStateVector {
49    /// The batch of state vectors (batch_size, 2^n_qubits)
50    pub states: Array2<Complex64>,
51    /// Number of qubits
52    pub n_qubits: usize,
53    /// Batch configuration
54    pub config: BatchConfig,
55}
56
57impl BatchStateVector {
58    /// Create a new batch of quantum states
59    pub fn new(batch_size: usize, n_qubits: usize, config: BatchConfig) -> QuantRS2Result<Self> {
60        let state_size = 1 << n_qubits;
61
62        // Check memory constraints
63        if let Some(limit) = config.memory_limit {
64            let required_memory = batch_size * state_size * std::mem::size_of::<Complex64>();
65            if required_memory > limit {
66                return Err(QuantRS2Error::InvalidInput(format!(
67                    "Batch requires {} bytes, limit is {}",
68                    required_memory, limit
69                )));
70            }
71        }
72
73        // Initialize all states to |0...0>
74        let mut states = Array2::zeros((batch_size, state_size));
75        for i in 0..batch_size {
76            states[[i, 0]] = Complex64::new(1.0, 0.0);
77        }
78
79        Ok(Self {
80            states,
81            n_qubits,
82            config,
83        })
84    }
85
86    /// Create from existing state vectors
87    pub fn from_states(states: Array2<Complex64>, config: BatchConfig) -> QuantRS2Result<Self> {
88        let (_batch_size, state_size) = states.dim();
89
90        // Determine number of qubits
91        let n_qubits = (state_size as f64).log2().round() as usize;
92        if 1 << n_qubits != state_size {
93            return Err(QuantRS2Error::InvalidInput(
94                "State size must be a power of 2".to_string(),
95            ));
96        }
97
98        Ok(Self {
99            states,
100            n_qubits,
101            config,
102        })
103    }
104
105    /// Get batch size
106    pub fn batch_size(&self) -> usize {
107        self.states.nrows()
108    }
109
110    /// Get a specific state from the batch
111    pub fn get_state(&self, index: usize) -> QuantRS2Result<Array1<Complex64>> {
112        if index >= self.batch_size() {
113            return Err(QuantRS2Error::InvalidInput(format!(
114                "Index {} out of bounds for batch size {}",
115                index,
116                self.batch_size()
117            )));
118        }
119
120        Ok(self.states.row(index).to_owned())
121    }
122
123    /// Set a specific state in the batch
124    pub fn set_state(&mut self, index: usize, state: &Array1<Complex64>) -> QuantRS2Result<()> {
125        if index >= self.batch_size() {
126            return Err(QuantRS2Error::InvalidInput(format!(
127                "Index {} out of bounds for batch size {}",
128                index,
129                self.batch_size()
130            )));
131        }
132
133        if state.len() != self.states.ncols() {
134            return Err(QuantRS2Error::InvalidInput(format!(
135                "State size {} doesn't match expected {}",
136                state.len(),
137                self.states.ncols()
138            )));
139        }
140
141        self.states.row_mut(index).assign(state);
142        Ok(())
143    }
144}
145
146/// Batch circuit execution result
147#[derive(Debug, Clone)]
148pub struct BatchExecutionResult {
149    /// Final state vectors
150    pub final_states: Array2<Complex64>,
151    /// Execution time in milliseconds
152    pub execution_time_ms: f64,
153    /// Number of gates applied
154    pub gates_applied: usize,
155    /// Whether GPU was used
156    pub used_gpu: bool,
157}
158
159/// Batch measurement result
160#[derive(Debug, Clone)]
161pub struct BatchMeasurementResult {
162    /// Measurement outcomes for each state in the batch
163    /// Shape: (batch_size, num_measurements)
164    pub outcomes: Array2<u8>,
165    /// Probabilities for each outcome
166    /// Shape: (batch_size, num_measurements)
167    pub probabilities: Array2<f64>,
168    /// Post-measurement states (if requested)
169    pub post_measurement_states: Option<Array2<Complex64>>,
170}
171
172/// Trait for batch-optimized gates
173pub trait BatchGateOp: GateOp {
174    /// Apply this gate to a batch of states
175    fn apply_batch(
176        &self,
177        batch: &mut BatchStateVector,
178        target_qubits: &[QubitId],
179    ) -> QuantRS2Result<()>;
180
181    /// Check if this gate has batch optimization
182    fn has_batch_optimization(&self) -> bool {
183        true
184    }
185}
186
187/// Helper to create batches from a collection of states
188pub fn create_batch<I>(states: I, config: BatchConfig) -> QuantRS2Result<BatchStateVector>
189where
190    I: IntoIterator<Item = Array1<Complex64>>,
191{
192    let states_vec: Vec<_> = states.into_iter().collect();
193    if states_vec.is_empty() {
194        return Err(QuantRS2Error::InvalidInput(
195            "Cannot create empty batch".to_string(),
196        ));
197    }
198
199    let state_size = states_vec[0].len();
200    let batch_size = states_vec.len();
201
202    // Validate all states have same size
203    for (i, state) in states_vec.iter().enumerate() {
204        if state.len() != state_size {
205            return Err(QuantRS2Error::InvalidInput(format!(
206                "State {} has size {}, expected {}",
207                i,
208                state.len(),
209                state_size
210            )));
211        }
212    }
213
214    // Create 2D array
215    let mut batch_array = Array2::zeros((batch_size, state_size));
216    for (i, state) in states_vec.iter().enumerate() {
217        batch_array.row_mut(i).assign(state);
218    }
219
220    BatchStateVector::from_states(batch_array, config)
221}
222
223/// Helper to split a large batch into smaller chunks
224pub fn split_batch(batch: &BatchStateVector, chunk_size: usize) -> Vec<BatchStateVector> {
225    let mut chunks = Vec::new();
226    let batch_size = batch.batch_size();
227
228    for start in (0..batch_size).step_by(chunk_size) {
229        let end = (start + chunk_size).min(batch_size);
230        let chunk_states = batch.states.slice(ndarray::s![start..end, ..]).to_owned();
231
232        if let Ok(chunk) = BatchStateVector::from_states(chunk_states, batch.config.clone()) {
233            chunks.push(chunk);
234        }
235    }
236
237    chunks
238}
239
240/// Merge multiple batches into one
241pub fn merge_batches(
242    batches: Vec<BatchStateVector>,
243    config: BatchConfig,
244) -> QuantRS2Result<BatchStateVector> {
245    if batches.is_empty() {
246        return Err(QuantRS2Error::InvalidInput(
247            "Cannot merge empty batches".to_string(),
248        ));
249    }
250
251    // Validate all batches have same n_qubits
252    let n_qubits = batches[0].n_qubits;
253    for (i, batch) in batches.iter().enumerate() {
254        if batch.n_qubits != n_qubits {
255            return Err(QuantRS2Error::InvalidInput(format!(
256                "Batch {} has {} qubits, expected {}",
257                i, batch.n_qubits, n_qubits
258            )));
259        }
260    }
261
262    // Concatenate states
263    let mut all_states = Vec::new();
264    for batch in batches {
265        for i in 0..batch.batch_size() {
266            all_states.push(batch.states.row(i).to_owned());
267        }
268    }
269
270    create_batch(all_states, config)
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_batch_creation() {
279        let batch = BatchStateVector::new(10, 3, BatchConfig::default()).unwrap();
280        assert_eq!(batch.batch_size(), 10);
281        assert_eq!(batch.n_qubits, 3);
282        assert_eq!(batch.states.ncols(), 8); // 2^3
283
284        // Check initial state is |000>
285        for i in 0..10 {
286            let state = batch.get_state(i).unwrap();
287            assert_eq!(state[0], Complex64::new(1.0, 0.0));
288            for j in 1..8 {
289                assert_eq!(state[j], Complex64::new(0.0, 0.0));
290            }
291        }
292    }
293
294    #[test]
295    fn test_batch_from_states() {
296        let mut states = Array2::zeros((5, 4));
297        for i in 0..5 {
298            states[[i, i % 4]] = Complex64::new(1.0, 0.0);
299        }
300
301        let batch = BatchStateVector::from_states(states, BatchConfig::default()).unwrap();
302        assert_eq!(batch.batch_size(), 5);
303        assert_eq!(batch.n_qubits, 2); // 2^2 = 4
304    }
305
306    #[test]
307    fn test_create_batch_helper() {
308        let states = vec![
309            Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
310            Array1::from_vec(vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]),
311            Array1::from_vec(vec![Complex64::new(0.707, 0.0), Complex64::new(0.707, 0.0)]),
312        ];
313
314        let batch = create_batch(states, BatchConfig::default()).unwrap();
315        assert_eq!(batch.batch_size(), 3);
316        assert_eq!(batch.n_qubits, 1);
317    }
318
319    #[test]
320    fn test_split_batch() {
321        let batch = BatchStateVector::new(10, 2, BatchConfig::default()).unwrap();
322        let chunks = split_batch(&batch, 3);
323
324        assert_eq!(chunks.len(), 4); // 3, 3, 3, 1
325        assert_eq!(chunks[0].batch_size(), 3);
326        assert_eq!(chunks[1].batch_size(), 3);
327        assert_eq!(chunks[2].batch_size(), 3);
328        assert_eq!(chunks[3].batch_size(), 1);
329    }
330
331    #[test]
332    fn test_merge_batches() {
333        let batch1 = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
334        let batch2 = BatchStateVector::new(2, 2, BatchConfig::default()).unwrap();
335
336        let merged = merge_batches(vec![batch1, batch2], BatchConfig::default()).unwrap();
337        assert_eq!(merged.batch_size(), 5);
338        assert_eq!(merged.n_qubits, 2);
339    }
340}