1use scirs2_core::Complex64;
7use std::cmp::min;
8
9#[derive(Clone, Debug)]
12struct MemoryChunk<T> {
13 data: Vec<T>,
14 _capacity: usize,
15}
16
17impl<T: Clone + Default> MemoryChunk<T> {
18 fn new(capacity: usize) -> Self {
19 Self {
20 data: vec![T::default(); capacity],
21 _capacity: capacity,
22 }
23 }
24
25 fn get(&self, idx: usize) -> Option<&T> {
26 self.data.get(idx)
27 }
28
29 fn get_mut(&mut self, idx: usize) -> Option<&mut T> {
30 self.data.get_mut(idx)
31 }
32
33 fn as_slice(&self) -> &[T] {
34 &self.data
35 }
36
37 fn _as_mut_slice(&mut self) -> &mut [T] {
39 &mut self.data
40 }
41}
42
43use crate::utils::flip_bit;
44
45const DEFAULT_CHUNK_SIZE: usize = 1 << 20; pub struct ChunkedStateVector {
50 chunks: Vec<MemoryChunk<Complex64>>,
52 num_qubits: usize,
54 chunk_size: usize,
56 dimension: usize,
58}
59
60impl ChunkedStateVector {
61 #[must_use]
63 pub fn new(num_qubits: usize) -> Self {
64 let dimension = 1 << num_qubits;
65 let chunk_size = min(DEFAULT_CHUNK_SIZE, dimension);
66 let num_chunks = dimension.div_ceil(chunk_size);
67
68 let mut chunks = Vec::with_capacity(num_chunks);
70 for i in 0..num_chunks {
71 let this_chunk_size = if i == num_chunks - 1 && dimension % chunk_size != 0 {
72 dimension % chunk_size
73 } else {
74 chunk_size
75 };
76
77 let mut chunk = MemoryChunk::new(this_chunk_size);
78 if i == 0 {
79 if let Some(first) = chunk.get_mut(0) {
81 *first = Complex64::new(1.0, 0.0);
82 }
83 }
84 chunks.push(chunk);
85 }
86
87 Self {
88 chunks,
89 num_qubits,
90 chunk_size,
91 dimension,
92 }
93 }
94
95 #[must_use]
97 pub const fn num_qubits(&self) -> usize {
98 self.num_qubits
99 }
100
101 #[must_use]
103 pub const fn dimension(&self) -> usize {
104 self.dimension
105 }
106
107 #[must_use]
109 pub fn get_amplitude(&self, idx: usize) -> Complex64 {
110 let chunk_idx = idx / self.chunk_size;
111 let local_idx = idx % self.chunk_size;
112
113 if chunk_idx >= self.chunks.len() {
114 return Complex64::new(0.0, 0.0);
115 }
116
117 match self.chunks[chunk_idx].get(local_idx) {
118 Some(val) => *val,
119 None => Complex64::new(0.0, 0.0),
120 }
121 }
122
123 #[must_use]
126 pub fn as_vec(&self) -> Vec<Complex64> {
127 let mut result = Vec::with_capacity(self.dimension);
128 for chunk in &self.chunks {
129 result.extend_from_slice(chunk.as_slice());
130 }
131 result
132 }
133
134 pub fn apply_single_qubit_gate(&mut self, matrix: &[Complex64], target: usize) {
141 assert!(
142 (target < self.num_qubits),
143 "Target qubit index out of range"
144 );
145
146 let old_chunks = self.chunks.clone();
148
149 for chunk in &mut self.chunks {
151 for idx in 0..chunk.as_slice().len() {
152 if let Some(val) = chunk.get_mut(idx) {
153 *val = Complex64::new(0.0, 0.0);
154 }
155 }
156 }
157
158 for (chunk_idx, chunk) in old_chunks.iter().enumerate() {
160 let base_idx = chunk_idx * self.chunk_size;
161
162 for (local_idx, &) in chunk.as_slice().iter().enumerate() {
164 let global_idx = base_idx + local_idx;
165 if global_idx >= self.dimension {
166 break;
167 }
168
169 if amp == Complex64::new(0.0, 0.0) {
171 continue;
172 }
173
174 let bit_val = (global_idx >> target) & 1;
175
176 let paired_global_idx = flip_bit(global_idx, target);
178 let paired_chunk_idx = paired_global_idx / self.chunk_size;
179 let paired_local_idx = paired_global_idx % self.chunk_size;
180
181 let paired_amp = if paired_chunk_idx < old_chunks.len() {
183 if let Some(val) = old_chunks[paired_chunk_idx].get(paired_local_idx) {
184 *val
185 } else {
186 Complex64::new(0.0, 0.0)
187 }
188 } else {
189 Complex64::new(0.0, 0.0)
190 };
191
192 let new_amp0 = matrix[0] * amp + matrix[1] * paired_amp;
194 let new_amp1 = matrix[2] * amp + matrix[3] * paired_amp;
195
196 if bit_val == 0 {
198 if let Some(val) = self.chunks[chunk_idx].get_mut(local_idx) {
200 *val += new_amp0;
201 }
202
203 if paired_chunk_idx < self.chunks.len() {
204 if let Some(val) = self.chunks[paired_chunk_idx].get_mut(paired_local_idx) {
205 *val += new_amp1;
206 }
207 }
208 }
209 }
210 }
211 }
212
213 pub fn apply_cnot(&mut self, control: usize, target: usize) {
220 assert!(
221 !(control >= self.num_qubits || target >= self.num_qubits),
222 "Qubit indices out of range"
223 );
224
225 assert!(
226 (control != target),
227 "Control and target qubits must be different"
228 );
229
230 let mut new_chunks = Vec::with_capacity(self.chunks.len());
235 for chunk in &self.chunks {
236 new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
237 }
238
239 for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
241 let base_idx = chunk_idx * self.chunk_size;
242
243 for (local_idx, &) in chunk.as_slice().iter().enumerate() {
245 let global_idx = base_idx + local_idx;
246 if global_idx >= self.dimension {
247 break;
248 }
249
250 let control_bit = (global_idx >> control) & 1;
251
252 if control_bit == 0 {
253 if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
255 *val = amp;
256 }
257 } else {
258 let flipped_idx = flip_bit(global_idx, target);
260 let flipped_chunk_idx = flipped_idx / self.chunk_size;
261 let flipped_local_idx = flipped_idx % self.chunk_size;
262
263 let flipped_amp = self.get_amplitude(flipped_idx);
265
266 if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
268 *val = flipped_amp;
269 }
270
271 if flipped_chunk_idx < self.chunks.len() {
273 if let Some(val) = new_chunks[flipped_chunk_idx].get_mut(flipped_local_idx)
274 {
275 *val = amp;
276 }
277 }
278 }
279 }
280 }
281
282 self.chunks = new_chunks;
284 }
285
286 pub fn apply_two_qubit_gate(&mut self, matrix: &[Complex64], qubit1: usize, qubit2: usize) {
294 assert!(
295 !(qubit1 >= self.num_qubits || qubit2 >= self.num_qubits),
296 "Qubit indices out of range"
297 );
298
299 assert!((qubit1 != qubit2), "Qubit indices must be different");
300
301 let mut new_chunks = Vec::with_capacity(self.chunks.len());
303 for chunk in &self.chunks {
304 new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
305 }
306
307 for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
309 let base_idx = chunk_idx * self.chunk_size;
310
311 for (local_idx, &_) in chunk.as_slice().iter().enumerate() {
313 let global_idx = base_idx + local_idx;
314 if global_idx >= self.dimension {
315 break;
316 }
317
318 let bit1 = (global_idx >> qubit1) & 1;
320 let bit2 = (global_idx >> qubit2) & 1;
321
322 let bits00 = global_idx & !(1 << qubit1) & !(1 << qubit2);
324 let bits01 = bits00 | (1 << qubit2);
325 let bits10 = bits00 | (1 << qubit1);
326 let bits11 = bits10 | (1 << qubit2);
327
328 let amp00 = self.get_amplitude(bits00);
330 let amp01 = self.get_amplitude(bits01);
331 let amp10 = self.get_amplitude(bits10);
332 let amp11 = self.get_amplitude(bits11);
333
334 let subspace_idx = (bit1 << 1) | bit2;
336 let mut new_amp = Complex64::new(0.0, 0.0);
337
338 new_amp += matrix[subspace_idx * 4] * amp00;
340 new_amp += matrix[subspace_idx * 4 + 1] * amp01;
341 new_amp += matrix[subspace_idx * 4 + 2] * amp10;
342 new_amp += matrix[subspace_idx * 4 + 3] * amp11;
343
344 if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
346 *val = new_amp;
347 }
348 }
349 }
350
351 self.chunks = new_chunks;
353 }
354
355 #[must_use]
357 pub fn probability(&self, bit_string: &[u8]) -> f64 {
358 assert!(
359 (bit_string.len() == self.num_qubits),
360 "Bit string length must match number of qubits"
361 );
362
363 let mut idx = 0;
365 for (i, &bit) in bit_string.iter().enumerate() {
366 if bit != 0 {
367 idx |= 1 << i;
368 }
369 }
370
371 self.get_amplitude(idx).norm_sqr()
373 }
374
375 #[must_use]
378 pub fn probabilities(&self) -> Vec<f64> {
379 self.chunks
380 .iter()
381 .flat_map(|chunk| chunk.as_slice().iter().map(scirs2_core::Complex::norm_sqr))
382 .collect()
383 }
384
385 #[must_use]
388 pub fn probability_range(&self, start_idx: usize, end_idx: usize) -> Vec<f64> {
389 let real_end = std::cmp::min(end_idx, self.dimension);
390
391 (start_idx..real_end)
392 .map(|idx| self.get_amplitude(idx).norm_sqr())
393 .collect()
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use std::f64::consts::FRAC_1_SQRT_2;
401
402 #[test]
403 fn test_chunked_state_vector_init() {
404 let sv = ChunkedStateVector::new(2);
405 assert_eq!(sv.num_qubits(), 2);
406 assert_eq!(sv.dimension(), 4);
407
408 assert_eq!(sv.get_amplitude(0), Complex64::new(1.0, 0.0));
410 assert_eq!(sv.get_amplitude(1), Complex64::new(0.0, 0.0));
411 assert_eq!(sv.get_amplitude(2), Complex64::new(0.0, 0.0));
412 assert_eq!(sv.get_amplitude(3), Complex64::new(0.0, 0.0));
413 }
414
415 #[test]
416 fn test_hadamard_gate_chunked() {
417 let h_matrix = [
419 Complex64::new(FRAC_1_SQRT_2, 0.0),
420 Complex64::new(FRAC_1_SQRT_2, 0.0),
421 Complex64::new(FRAC_1_SQRT_2, 0.0),
422 Complex64::new(-FRAC_1_SQRT_2, 0.0),
423 ];
424
425 let mut sv = ChunkedStateVector::new(2);
427 println!("Initial state: {:?}", sv.as_vec());
428 sv.apply_single_qubit_gate(&h_matrix, 1); println!("After H on qubit 1:");
432 println!("amplitude[0] = {:?}", sv.get_amplitude(0));
433 println!("amplitude[1] = {:?}", sv.get_amplitude(1));
434 println!("amplitude[2] = {:?}", sv.get_amplitude(2));
435 println!("amplitude[3] = {:?}", sv.get_amplitude(3));
436
437 assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
439 assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
440 assert!((sv.get_amplitude(2) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
441 assert!((sv.get_amplitude(3) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
442
443 sv.apply_single_qubit_gate(&h_matrix, 0);
445
446 println!("After both H gates:");
449 println!("amplitude[0] = {:?}", sv.get_amplitude(0));
450 println!("amplitude[1] = {:?}", sv.get_amplitude(1));
451 println!("amplitude[2] = {:?}", sv.get_amplitude(2));
452 println!("amplitude[3] = {:?}", sv.get_amplitude(3));
453
454 assert!((sv.get_amplitude(0) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
455 assert!((sv.get_amplitude(1) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
456 assert!((sv.get_amplitude(2) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
457 assert!((sv.get_amplitude(3) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
458 }
459
460 #[test]
461 fn test_cnot_gate_chunked() {
462 let mut sv = ChunkedStateVector::new(2);
464
465 let h_matrix = [
467 Complex64::new(FRAC_1_SQRT_2, 0.0),
468 Complex64::new(FRAC_1_SQRT_2, 0.0),
469 Complex64::new(FRAC_1_SQRT_2, 0.0),
470 Complex64::new(-FRAC_1_SQRT_2, 0.0),
471 ];
472 sv.apply_single_qubit_gate(&h_matrix, 0);
473
474 sv.apply_cnot(0, 1);
476
477 assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
479 assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
480 assert!((sv.get_amplitude(2) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
481 assert!((sv.get_amplitude(3) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
482 }
483}