1use scirs2_core::Complex64;
7use std::cmp::min;
8
9#[derive(Clone, Debug)]
12struct MemoryChunk<T> {
13 data: Vec<T>,
14 _capacity: usize,
15}
16
17impl<T: Clone + Default> MemoryChunk<T> {
18 fn new(capacity: usize) -> Self {
19 Self {
20 data: vec![T::default(); capacity],
21 _capacity: capacity,
22 }
23 }
24
25 fn get(&self, idx: usize) -> Option<&T> {
26 self.data.get(idx)
27 }
28
29 fn get_mut(&mut self, idx: usize) -> Option<&mut T> {
30 self.data.get_mut(idx)
31 }
32
33 fn as_slice(&self) -> &[T] {
34 &self.data
35 }
36
37 fn _as_mut_slice(&mut self) -> &mut [T] {
39 &mut self.data
40 }
41}
42
43use crate::utils::flip_bit;
44
45const DEFAULT_CHUNK_SIZE: usize = 1 << 20; pub struct ChunkedStateVector {
50 chunks: Vec<MemoryChunk<Complex64>>,
52 num_qubits: usize,
54 chunk_size: usize,
56 dimension: usize,
58}
59
60impl ChunkedStateVector {
61 pub fn new(num_qubits: usize) -> Self {
63 let dimension = 1 << num_qubits;
64 let chunk_size = min(DEFAULT_CHUNK_SIZE, dimension);
65 let num_chunks = dimension.div_ceil(chunk_size);
66
67 let mut chunks = Vec::with_capacity(num_chunks);
69 for i in 0..num_chunks {
70 let this_chunk_size = if i == num_chunks - 1 && dimension % chunk_size != 0 {
71 dimension % chunk_size
72 } else {
73 chunk_size
74 };
75
76 let mut chunk = MemoryChunk::new(this_chunk_size);
77 if i == 0 {
78 if let Some(first) = chunk.get_mut(0) {
80 *first = Complex64::new(1.0, 0.0);
81 }
82 }
83 chunks.push(chunk);
84 }
85
86 Self {
87 chunks,
88 num_qubits,
89 chunk_size,
90 dimension,
91 }
92 }
93
94 pub const fn num_qubits(&self) -> usize {
96 self.num_qubits
97 }
98
99 pub const fn dimension(&self) -> usize {
101 self.dimension
102 }
103
104 pub fn get_amplitude(&self, idx: usize) -> Complex64 {
106 let chunk_idx = idx / self.chunk_size;
107 let local_idx = idx % self.chunk_size;
108
109 if chunk_idx >= self.chunks.len() {
110 return Complex64::new(0.0, 0.0);
111 }
112
113 match self.chunks[chunk_idx].get(local_idx) {
114 Some(val) => *val,
115 None => Complex64::new(0.0, 0.0),
116 }
117 }
118
119 pub fn as_vec(&self) -> Vec<Complex64> {
122 let mut result = Vec::with_capacity(self.dimension);
123 for chunk in &self.chunks {
124 result.extend_from_slice(chunk.as_slice());
125 }
126 result
127 }
128
129 pub fn apply_single_qubit_gate(&mut self, matrix: &[Complex64], target: usize) {
136 assert!(
137 (target < self.num_qubits),
138 "Target qubit index out of range"
139 );
140
141 let old_chunks = self.chunks.clone();
143
144 for chunk in &mut self.chunks {
146 for idx in 0..chunk.as_slice().len() {
147 if let Some(val) = chunk.get_mut(idx) {
148 *val = Complex64::new(0.0, 0.0);
149 }
150 }
151 }
152
153 for (chunk_idx, chunk) in old_chunks.iter().enumerate() {
155 let base_idx = chunk_idx * self.chunk_size;
156
157 for (local_idx, &) in chunk.as_slice().iter().enumerate() {
159 let global_idx = base_idx + local_idx;
160 if global_idx >= self.dimension {
161 break;
162 }
163
164 if amp == Complex64::new(0.0, 0.0) {
166 continue;
167 }
168
169 let bit_val = (global_idx >> target) & 1;
170
171 let paired_global_idx = flip_bit(global_idx, target);
173 let paired_chunk_idx = paired_global_idx / self.chunk_size;
174 let paired_local_idx = paired_global_idx % self.chunk_size;
175
176 let paired_amp = if paired_chunk_idx < old_chunks.len() {
178 if let Some(val) = old_chunks[paired_chunk_idx].get(paired_local_idx) {
179 *val
180 } else {
181 Complex64::new(0.0, 0.0)
182 }
183 } else {
184 Complex64::new(0.0, 0.0)
185 };
186
187 let new_amp0 = matrix[0] * amp + matrix[1] * paired_amp;
189 let new_amp1 = matrix[2] * amp + matrix[3] * paired_amp;
190
191 if bit_val == 0 {
193 if let Some(val) = self.chunks[chunk_idx].get_mut(local_idx) {
195 *val += new_amp0;
196 }
197
198 if paired_chunk_idx < self.chunks.len() {
199 if let Some(val) = self.chunks[paired_chunk_idx].get_mut(paired_local_idx) {
200 *val += new_amp1;
201 }
202 }
203 }
204 }
205 }
206 }
207
208 pub fn apply_cnot(&mut self, control: usize, target: usize) {
215 assert!(
216 !(control >= self.num_qubits || target >= self.num_qubits),
217 "Qubit indices out of range"
218 );
219
220 assert!(
221 (control != target),
222 "Control and target qubits must be different"
223 );
224
225 let mut new_chunks = Vec::with_capacity(self.chunks.len());
230 for chunk in &self.chunks {
231 new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
232 }
233
234 for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
236 let base_idx = chunk_idx * self.chunk_size;
237
238 for (local_idx, &) in chunk.as_slice().iter().enumerate() {
240 let global_idx = base_idx + local_idx;
241 if global_idx >= self.dimension {
242 break;
243 }
244
245 let control_bit = (global_idx >> control) & 1;
246
247 if control_bit == 0 {
248 if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
250 *val = amp;
251 }
252 } else {
253 let flipped_idx = flip_bit(global_idx, target);
255 let flipped_chunk_idx = flipped_idx / self.chunk_size;
256 let flipped_local_idx = flipped_idx % self.chunk_size;
257
258 let flipped_amp = self.get_amplitude(flipped_idx);
260
261 if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
263 *val = flipped_amp;
264 }
265
266 if flipped_chunk_idx < self.chunks.len() {
268 if let Some(val) = new_chunks[flipped_chunk_idx].get_mut(flipped_local_idx)
269 {
270 *val = amp;
271 }
272 }
273 }
274 }
275 }
276
277 self.chunks = new_chunks;
279 }
280
281 pub fn apply_two_qubit_gate(&mut self, matrix: &[Complex64], qubit1: usize, qubit2: usize) {
289 assert!(
290 !(qubit1 >= self.num_qubits || qubit2 >= self.num_qubits),
291 "Qubit indices out of range"
292 );
293
294 assert!((qubit1 != qubit2), "Qubit indices must be different");
295
296 let mut new_chunks = Vec::with_capacity(self.chunks.len());
298 for chunk in &self.chunks {
299 new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
300 }
301
302 for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
304 let base_idx = chunk_idx * self.chunk_size;
305
306 for (local_idx, &_) in chunk.as_slice().iter().enumerate() {
308 let global_idx = base_idx + local_idx;
309 if global_idx >= self.dimension {
310 break;
311 }
312
313 let bit1 = (global_idx >> qubit1) & 1;
315 let bit2 = (global_idx >> qubit2) & 1;
316
317 let bits00 = global_idx & !(1 << qubit1) & !(1 << qubit2);
319 let bits01 = bits00 | (1 << qubit2);
320 let bits10 = bits00 | (1 << qubit1);
321 let bits11 = bits10 | (1 << qubit2);
322
323 let amp00 = self.get_amplitude(bits00);
325 let amp01 = self.get_amplitude(bits01);
326 let amp10 = self.get_amplitude(bits10);
327 let amp11 = self.get_amplitude(bits11);
328
329 let subspace_idx = (bit1 << 1) | bit2;
331 let mut new_amp = Complex64::new(0.0, 0.0);
332
333 new_amp += matrix[subspace_idx * 4] * amp00;
335 new_amp += matrix[subspace_idx * 4 + 1] * amp01;
336 new_amp += matrix[subspace_idx * 4 + 2] * amp10;
337 new_amp += matrix[subspace_idx * 4 + 3] * amp11;
338
339 if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
341 *val = new_amp;
342 }
343 }
344 }
345
346 self.chunks = new_chunks;
348 }
349
350 pub fn probability(&self, bit_string: &[u8]) -> f64 {
352 assert!(
353 (bit_string.len() == self.num_qubits),
354 "Bit string length must match number of qubits"
355 );
356
357 let mut idx = 0;
359 for (i, &bit) in bit_string.iter().enumerate() {
360 if bit != 0 {
361 idx |= 1 << i;
362 }
363 }
364
365 self.get_amplitude(idx).norm_sqr()
367 }
368
369 pub fn probabilities(&self) -> Vec<f64> {
372 self.chunks
373 .iter()
374 .flat_map(|chunk| chunk.as_slice().iter().map(|a| a.norm_sqr()))
375 .collect()
376 }
377
378 pub fn probability_range(&self, start_idx: usize, end_idx: usize) -> Vec<f64> {
381 let real_end = std::cmp::min(end_idx, self.dimension);
382
383 (start_idx..real_end)
384 .map(|idx| self.get_amplitude(idx).norm_sqr())
385 .collect()
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use std::f64::consts::FRAC_1_SQRT_2;
393
394 #[test]
395 fn test_chunked_state_vector_init() {
396 let sv = ChunkedStateVector::new(2);
397 assert_eq!(sv.num_qubits(), 2);
398 assert_eq!(sv.dimension(), 4);
399
400 assert_eq!(sv.get_amplitude(0), Complex64::new(1.0, 0.0));
402 assert_eq!(sv.get_amplitude(1), Complex64::new(0.0, 0.0));
403 assert_eq!(sv.get_amplitude(2), Complex64::new(0.0, 0.0));
404 assert_eq!(sv.get_amplitude(3), Complex64::new(0.0, 0.0));
405 }
406
407 #[test]
408 fn test_hadamard_gate_chunked() {
409 let h_matrix = [
411 Complex64::new(FRAC_1_SQRT_2, 0.0),
412 Complex64::new(FRAC_1_SQRT_2, 0.0),
413 Complex64::new(FRAC_1_SQRT_2, 0.0),
414 Complex64::new(-FRAC_1_SQRT_2, 0.0),
415 ];
416
417 let mut sv = ChunkedStateVector::new(2);
419 println!("Initial state: {:?}", sv.as_vec());
420 sv.apply_single_qubit_gate(&h_matrix, 1); println!("After H on qubit 1:");
424 println!("amplitude[0] = {:?}", sv.get_amplitude(0));
425 println!("amplitude[1] = {:?}", sv.get_amplitude(1));
426 println!("amplitude[2] = {:?}", sv.get_amplitude(2));
427 println!("amplitude[3] = {:?}", sv.get_amplitude(3));
428
429 assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
431 assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
432 assert!((sv.get_amplitude(2) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
433 assert!((sv.get_amplitude(3) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
434
435 sv.apply_single_qubit_gate(&h_matrix, 0);
437
438 println!("After both H gates:");
441 println!("amplitude[0] = {:?}", sv.get_amplitude(0));
442 println!("amplitude[1] = {:?}", sv.get_amplitude(1));
443 println!("amplitude[2] = {:?}", sv.get_amplitude(2));
444 println!("amplitude[3] = {:?}", sv.get_amplitude(3));
445
446 assert!((sv.get_amplitude(0) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
447 assert!((sv.get_amplitude(1) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
448 assert!((sv.get_amplitude(2) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
449 assert!((sv.get_amplitude(3) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
450 }
451
452 #[test]
453 fn test_cnot_gate_chunked() {
454 let mut sv = ChunkedStateVector::new(2);
456
457 let h_matrix = [
459 Complex64::new(FRAC_1_SQRT_2, 0.0),
460 Complex64::new(FRAC_1_SQRT_2, 0.0),
461 Complex64::new(FRAC_1_SQRT_2, 0.0),
462 Complex64::new(-FRAC_1_SQRT_2, 0.0),
463 ];
464 sv.apply_single_qubit_gate(&h_matrix, 0);
465
466 sv.apply_cnot(0, 1);
468
469 assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
471 assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
472 assert!((sv.get_amplitude(2) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
473 assert!((sv.get_amplitude(3) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
474 }
475}