1pub mod execution;
7pub mod measurement;
8pub mod operations;
9pub mod optimization;
10
11use crate::{
12 error::{QuantRS2Error, QuantRS2Result},
13 gate::GateOp,
14 qubit::QubitId,
15};
16use ndarray::{Array1, Array2, Array3};
17use num_complex::Complex64;
18use std::sync::Arc;
19
20#[derive(Debug, Clone)]
22pub struct BatchConfig {
23 pub num_workers: Option<usize>,
25 pub max_batch_size: usize,
27 pub use_gpu: bool,
29 pub memory_limit: Option<usize>,
31 pub enable_cache: bool,
33}
34
35impl Default for BatchConfig {
36 fn default() -> Self {
37 Self {
38 num_workers: None, max_batch_size: 1024,
40 use_gpu: true,
41 memory_limit: None,
42 enable_cache: true,
43 }
44 }
45}
46
47#[derive(Clone)]
49pub struct BatchStateVector {
50 pub states: Array2<Complex64>,
52 pub n_qubits: usize,
54 pub config: BatchConfig,
56}
57
58impl BatchStateVector {
59 pub fn new(batch_size: usize, n_qubits: usize, config: BatchConfig) -> QuantRS2Result<Self> {
61 let state_size = 1 << n_qubits;
62
63 if let Some(limit) = config.memory_limit {
65 let required_memory = batch_size * state_size * std::mem::size_of::<Complex64>();
66 if required_memory > limit {
67 return Err(QuantRS2Error::InvalidInput(format!(
68 "Batch requires {} bytes, limit is {}",
69 required_memory, limit
70 )));
71 }
72 }
73
74 let mut states = Array2::zeros((batch_size, state_size));
76 for i in 0..batch_size {
77 states[[i, 0]] = Complex64::new(1.0, 0.0);
78 }
79
80 Ok(Self {
81 states,
82 n_qubits,
83 config,
84 })
85 }
86
87 pub fn from_states(states: Array2<Complex64>, config: BatchConfig) -> QuantRS2Result<Self> {
89 let (batch_size, state_size) = states.dim();
90
91 let n_qubits = (state_size as f64).log2().round() as usize;
93 if 1 << n_qubits != state_size {
94 return Err(QuantRS2Error::InvalidInput(
95 "State size must be a power of 2".to_string(),
96 ));
97 }
98
99 Ok(Self {
100 states,
101 n_qubits,
102 config,
103 })
104 }
105
106 pub fn batch_size(&self) -> usize {
108 self.states.nrows()
109 }
110
111 pub fn get_state(&self, index: usize) -> QuantRS2Result<Array1<Complex64>> {
113 if index >= self.batch_size() {
114 return Err(QuantRS2Error::InvalidInput(format!(
115 "Index {} out of bounds for batch size {}",
116 index,
117 self.batch_size()
118 )));
119 }
120
121 Ok(self.states.row(index).to_owned())
122 }
123
124 pub fn set_state(&mut self, index: usize, state: &Array1<Complex64>) -> QuantRS2Result<()> {
126 if index >= self.batch_size() {
127 return Err(QuantRS2Error::InvalidInput(format!(
128 "Index {} out of bounds for batch size {}",
129 index,
130 self.batch_size()
131 )));
132 }
133
134 if state.len() != self.states.ncols() {
135 return Err(QuantRS2Error::InvalidInput(format!(
136 "State size {} doesn't match expected {}",
137 state.len(),
138 self.states.ncols()
139 )));
140 }
141
142 self.states.row_mut(index).assign(state);
143 Ok(())
144 }
145}
146
147#[derive(Debug, Clone)]
149pub struct BatchExecutionResult {
150 pub final_states: Array2<Complex64>,
152 pub execution_time_ms: f64,
154 pub gates_applied: usize,
156 pub used_gpu: bool,
158}
159
160#[derive(Debug, Clone)]
162pub struct BatchMeasurementResult {
163 pub outcomes: Array2<u8>,
166 pub probabilities: Array2<f64>,
169 pub post_measurement_states: Option<Array2<Complex64>>,
171}
172
173pub trait BatchGateOp: GateOp {
175 fn apply_batch(
177 &self,
178 batch: &mut BatchStateVector,
179 target_qubits: &[QubitId],
180 ) -> QuantRS2Result<()>;
181
182 fn has_batch_optimization(&self) -> bool {
184 true
185 }
186}
187
188pub fn create_batch<I>(states: I, config: BatchConfig) -> QuantRS2Result<BatchStateVector>
190where
191 I: IntoIterator<Item = Array1<Complex64>>,
192{
193 let states_vec: Vec<_> = states.into_iter().collect();
194 if states_vec.is_empty() {
195 return Err(QuantRS2Error::InvalidInput(
196 "Cannot create empty batch".to_string(),
197 ));
198 }
199
200 let state_size = states_vec[0].len();
201 let batch_size = states_vec.len();
202
203 for (i, state) in states_vec.iter().enumerate() {
205 if state.len() != state_size {
206 return Err(QuantRS2Error::InvalidInput(format!(
207 "State {} has size {}, expected {}",
208 i,
209 state.len(),
210 state_size
211 )));
212 }
213 }
214
215 let mut batch_array = Array2::zeros((batch_size, state_size));
217 for (i, state) in states_vec.iter().enumerate() {
218 batch_array.row_mut(i).assign(state);
219 }
220
221 BatchStateVector::from_states(batch_array, config)
222}
223
224pub fn split_batch(batch: &BatchStateVector, chunk_size: usize) -> Vec<BatchStateVector> {
226 let mut chunks = Vec::new();
227 let batch_size = batch.batch_size();
228
229 for start in (0..batch_size).step_by(chunk_size) {
230 let end = (start + chunk_size).min(batch_size);
231 let chunk_states = batch.states.slice(ndarray::s![start..end, ..]).to_owned();
232
233 if let Ok(chunk) = BatchStateVector::from_states(chunk_states, batch.config.clone()) {
234 chunks.push(chunk);
235 }
236 }
237
238 chunks
239}
240
241pub fn merge_batches(
243 batches: Vec<BatchStateVector>,
244 config: BatchConfig,
245) -> QuantRS2Result<BatchStateVector> {
246 if batches.is_empty() {
247 return Err(QuantRS2Error::InvalidInput(
248 "Cannot merge empty batches".to_string(),
249 ));
250 }
251
252 let n_qubits = batches[0].n_qubits;
254 for (i, batch) in batches.iter().enumerate() {
255 if batch.n_qubits != n_qubits {
256 return Err(QuantRS2Error::InvalidInput(format!(
257 "Batch {} has {} qubits, expected {}",
258 i, batch.n_qubits, n_qubits
259 )));
260 }
261 }
262
263 let mut all_states = Vec::new();
265 for batch in batches {
266 for i in 0..batch.batch_size() {
267 all_states.push(batch.states.row(i).to_owned());
268 }
269 }
270
271 create_batch(all_states, config)
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_batch_creation() {
280 let batch = BatchStateVector::new(10, 3, BatchConfig::default()).unwrap();
281 assert_eq!(batch.batch_size(), 10);
282 assert_eq!(batch.n_qubits, 3);
283 assert_eq!(batch.states.ncols(), 8); for i in 0..10 {
287 let state = batch.get_state(i).unwrap();
288 assert_eq!(state[0], Complex64::new(1.0, 0.0));
289 for j in 1..8 {
290 assert_eq!(state[j], Complex64::new(0.0, 0.0));
291 }
292 }
293 }
294
295 #[test]
296 fn test_batch_from_states() {
297 let mut states = Array2::zeros((5, 4));
298 for i in 0..5 {
299 states[[i, i % 4]] = Complex64::new(1.0, 0.0);
300 }
301
302 let batch = BatchStateVector::from_states(states, BatchConfig::default()).unwrap();
303 assert_eq!(batch.batch_size(), 5);
304 assert_eq!(batch.n_qubits, 2); }
306
307 #[test]
308 fn test_create_batch_helper() {
309 let states = vec![
310 Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
311 Array1::from_vec(vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]),
312 Array1::from_vec(vec![Complex64::new(0.707, 0.0), Complex64::new(0.707, 0.0)]),
313 ];
314
315 let batch = create_batch(states, BatchConfig::default()).unwrap();
316 assert_eq!(batch.batch_size(), 3);
317 assert_eq!(batch.n_qubits, 1);
318 }
319
320 #[test]
321 fn test_split_batch() {
322 let batch = BatchStateVector::new(10, 2, BatchConfig::default()).unwrap();
323 let chunks = split_batch(&batch, 3);
324
325 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].batch_size(), 3);
327 assert_eq!(chunks[1].batch_size(), 3);
328 assert_eq!(chunks[2].batch_size(), 3);
329 assert_eq!(chunks[3].batch_size(), 1);
330 }
331
332 #[test]
333 fn test_merge_batches() {
334 let batch1 = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
335 let batch2 = BatchStateVector::new(2, 2, BatchConfig::default()).unwrap();
336
337 let merged = merge_batches(vec![batch1, batch2], BatchConfig::default()).unwrap();
338 assert_eq!(merged.batch_size(), 5);
339 assert_eq!(merged.n_qubits, 2);
340 }
341}