1pub mod execution;
7pub mod measurement;
8pub mod operations;
9pub mod optimization;
10
11use crate::{
12 error::{QuantRS2Error, QuantRS2Result},
13 gate::GateOp,
14 qubit::QubitId,
15};
16use ndarray::{Array1, Array2};
17use num_complex::Complex64;
18
19#[derive(Debug, Clone)]
21pub struct BatchConfig {
22 pub num_workers: Option<usize>,
24 pub max_batch_size: usize,
26 pub use_gpu: bool,
28 pub memory_limit: Option<usize>,
30 pub enable_cache: bool,
32}
33
34impl Default for BatchConfig {
35 fn default() -> Self {
36 Self {
37 num_workers: None, max_batch_size: 1024,
39 use_gpu: true,
40 memory_limit: None,
41 enable_cache: true,
42 }
43 }
44}
45
46#[derive(Clone)]
48pub struct BatchStateVector {
49 pub states: Array2<Complex64>,
51 pub n_qubits: usize,
53 pub config: BatchConfig,
55}
56
57impl BatchStateVector {
58 pub fn new(batch_size: usize, n_qubits: usize, config: BatchConfig) -> QuantRS2Result<Self> {
60 let state_size = 1 << n_qubits;
61
62 if let Some(limit) = config.memory_limit {
64 let required_memory = batch_size * state_size * std::mem::size_of::<Complex64>();
65 if required_memory > limit {
66 return Err(QuantRS2Error::InvalidInput(format!(
67 "Batch requires {} bytes, limit is {}",
68 required_memory, limit
69 )));
70 }
71 }
72
73 let mut states = Array2::zeros((batch_size, state_size));
75 for i in 0..batch_size {
76 states[[i, 0]] = Complex64::new(1.0, 0.0);
77 }
78
79 Ok(Self {
80 states,
81 n_qubits,
82 config,
83 })
84 }
85
86 pub fn from_states(states: Array2<Complex64>, config: BatchConfig) -> QuantRS2Result<Self> {
88 let (_batch_size, state_size) = states.dim();
89
90 let n_qubits = (state_size as f64).log2().round() as usize;
92 if 1 << n_qubits != state_size {
93 return Err(QuantRS2Error::InvalidInput(
94 "State size must be a power of 2".to_string(),
95 ));
96 }
97
98 Ok(Self {
99 states,
100 n_qubits,
101 config,
102 })
103 }
104
105 pub fn batch_size(&self) -> usize {
107 self.states.nrows()
108 }
109
110 pub fn get_state(&self, index: usize) -> QuantRS2Result<Array1<Complex64>> {
112 if index >= self.batch_size() {
113 return Err(QuantRS2Error::InvalidInput(format!(
114 "Index {} out of bounds for batch size {}",
115 index,
116 self.batch_size()
117 )));
118 }
119
120 Ok(self.states.row(index).to_owned())
121 }
122
123 pub fn set_state(&mut self, index: usize, state: &Array1<Complex64>) -> QuantRS2Result<()> {
125 if index >= self.batch_size() {
126 return Err(QuantRS2Error::InvalidInput(format!(
127 "Index {} out of bounds for batch size {}",
128 index,
129 self.batch_size()
130 )));
131 }
132
133 if state.len() != self.states.ncols() {
134 return Err(QuantRS2Error::InvalidInput(format!(
135 "State size {} doesn't match expected {}",
136 state.len(),
137 self.states.ncols()
138 )));
139 }
140
141 self.states.row_mut(index).assign(state);
142 Ok(())
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct BatchExecutionResult {
149 pub final_states: Array2<Complex64>,
151 pub execution_time_ms: f64,
153 pub gates_applied: usize,
155 pub used_gpu: bool,
157}
158
159#[derive(Debug, Clone)]
161pub struct BatchMeasurementResult {
162 pub outcomes: Array2<u8>,
165 pub probabilities: Array2<f64>,
168 pub post_measurement_states: Option<Array2<Complex64>>,
170}
171
172pub trait BatchGateOp: GateOp {
174 fn apply_batch(
176 &self,
177 batch: &mut BatchStateVector,
178 target_qubits: &[QubitId],
179 ) -> QuantRS2Result<()>;
180
181 fn has_batch_optimization(&self) -> bool {
183 true
184 }
185}
186
187pub fn create_batch<I>(states: I, config: BatchConfig) -> QuantRS2Result<BatchStateVector>
189where
190 I: IntoIterator<Item = Array1<Complex64>>,
191{
192 let states_vec: Vec<_> = states.into_iter().collect();
193 if states_vec.is_empty() {
194 return Err(QuantRS2Error::InvalidInput(
195 "Cannot create empty batch".to_string(),
196 ));
197 }
198
199 let state_size = states_vec[0].len();
200 let batch_size = states_vec.len();
201
202 for (i, state) in states_vec.iter().enumerate() {
204 if state.len() != state_size {
205 return Err(QuantRS2Error::InvalidInput(format!(
206 "State {} has size {}, expected {}",
207 i,
208 state.len(),
209 state_size
210 )));
211 }
212 }
213
214 let mut batch_array = Array2::zeros((batch_size, state_size));
216 for (i, state) in states_vec.iter().enumerate() {
217 batch_array.row_mut(i).assign(state);
218 }
219
220 BatchStateVector::from_states(batch_array, config)
221}
222
223pub fn split_batch(batch: &BatchStateVector, chunk_size: usize) -> Vec<BatchStateVector> {
225 let mut chunks = Vec::new();
226 let batch_size = batch.batch_size();
227
228 for start in (0..batch_size).step_by(chunk_size) {
229 let end = (start + chunk_size).min(batch_size);
230 let chunk_states = batch.states.slice(ndarray::s![start..end, ..]).to_owned();
231
232 if let Ok(chunk) = BatchStateVector::from_states(chunk_states, batch.config.clone()) {
233 chunks.push(chunk);
234 }
235 }
236
237 chunks
238}
239
240pub fn merge_batches(
242 batches: Vec<BatchStateVector>,
243 config: BatchConfig,
244) -> QuantRS2Result<BatchStateVector> {
245 if batches.is_empty() {
246 return Err(QuantRS2Error::InvalidInput(
247 "Cannot merge empty batches".to_string(),
248 ));
249 }
250
251 let n_qubits = batches[0].n_qubits;
253 for (i, batch) in batches.iter().enumerate() {
254 if batch.n_qubits != n_qubits {
255 return Err(QuantRS2Error::InvalidInput(format!(
256 "Batch {} has {} qubits, expected {}",
257 i, batch.n_qubits, n_qubits
258 )));
259 }
260 }
261
262 let mut all_states = Vec::new();
264 for batch in batches {
265 for i in 0..batch.batch_size() {
266 all_states.push(batch.states.row(i).to_owned());
267 }
268 }
269
270 create_batch(all_states, config)
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_batch_creation() {
279 let batch = BatchStateVector::new(10, 3, BatchConfig::default()).unwrap();
280 assert_eq!(batch.batch_size(), 10);
281 assert_eq!(batch.n_qubits, 3);
282 assert_eq!(batch.states.ncols(), 8); for i in 0..10 {
286 let state = batch.get_state(i).unwrap();
287 assert_eq!(state[0], Complex64::new(1.0, 0.0));
288 for j in 1..8 {
289 assert_eq!(state[j], Complex64::new(0.0, 0.0));
290 }
291 }
292 }
293
294 #[test]
295 fn test_batch_from_states() {
296 let mut states = Array2::zeros((5, 4));
297 for i in 0..5 {
298 states[[i, i % 4]] = Complex64::new(1.0, 0.0);
299 }
300
301 let batch = BatchStateVector::from_states(states, BatchConfig::default()).unwrap();
302 assert_eq!(batch.batch_size(), 5);
303 assert_eq!(batch.n_qubits, 2); }
305
306 #[test]
307 fn test_create_batch_helper() {
308 let states = vec![
309 Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]),
310 Array1::from_vec(vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]),
311 Array1::from_vec(vec![Complex64::new(0.707, 0.0), Complex64::new(0.707, 0.0)]),
312 ];
313
314 let batch = create_batch(states, BatchConfig::default()).unwrap();
315 assert_eq!(batch.batch_size(), 3);
316 assert_eq!(batch.n_qubits, 1);
317 }
318
319 #[test]
320 fn test_split_batch() {
321 let batch = BatchStateVector::new(10, 2, BatchConfig::default()).unwrap();
322 let chunks = split_batch(&batch, 3);
323
324 assert_eq!(chunks.len(), 4); assert_eq!(chunks[0].batch_size(), 3);
326 assert_eq!(chunks[1].batch_size(), 3);
327 assert_eq!(chunks[2].batch_size(), 3);
328 assert_eq!(chunks[3].batch_size(), 1);
329 }
330
331 #[test]
332 fn test_merge_batches() {
333 let batch1 = BatchStateVector::new(3, 2, BatchConfig::default()).unwrap();
334 let batch2 = BatchStateVector::new(2, 2, BatchConfig::default()).unwrap();
335
336 let merged = merge_batches(vec![batch1, batch2], BatchConfig::default()).unwrap();
337 assert_eq!(merged.batch_size(), 5);
338 assert_eq!(merged.n_qubits, 2);
339 }
340}