1use scirs2_core::Complex64;
7use std::cmp::min;
8
9#[derive(Clone, Debug)]
12struct MemoryChunk<T> {
13 data: Vec<T>,
14 _capacity: usize,
15}
16
17impl<T: Clone + Default> MemoryChunk<T> {
18 fn new(capacity: usize) -> Self {
19 Self {
20 data: vec![T::default(); capacity],
21 _capacity: capacity,
22 }
23 }
24
25 fn get(&self, idx: usize) -> Option<&T> {
26 self.data.get(idx)
27 }
28
29 fn get_mut(&mut self, idx: usize) -> Option<&mut T> {
30 self.data.get_mut(idx)
31 }
32
33 fn as_slice(&self) -> &[T] {
34 &self.data
35 }
36
37 fn _as_mut_slice(&mut self) -> &mut [T] {
39 &mut self.data
40 }
41}
42
43use crate::utils::flip_bit;
44
45const DEFAULT_CHUNK_SIZE: usize = 1 << 20; pub struct ChunkedStateVector {
50 chunks: Vec<MemoryChunk<Complex64>>,
52 num_qubits: usize,
54 chunk_size: usize,
56 dimension: usize,
58}
59
60impl ChunkedStateVector {
61 pub fn new(num_qubits: usize) -> Self {
63 let dimension = 1 << num_qubits;
64 let chunk_size = min(DEFAULT_CHUNK_SIZE, dimension);
65 let num_chunks = dimension.div_ceil(chunk_size);
66
67 let mut chunks = Vec::with_capacity(num_chunks);
69 for i in 0..num_chunks {
70 let this_chunk_size = if i == num_chunks - 1 && dimension % chunk_size != 0 {
71 dimension % chunk_size
72 } else {
73 chunk_size
74 };
75
76 let mut chunk = MemoryChunk::new(this_chunk_size);
77 if i == 0 {
78 if let Some(first) = chunk.get_mut(0) {
80 *first = Complex64::new(1.0, 0.0);
81 }
82 }
83 chunks.push(chunk);
84 }
85
86 Self {
87 chunks,
88 num_qubits,
89 chunk_size,
90 dimension,
91 }
92 }
93
94 pub fn num_qubits(&self) -> usize {
96 self.num_qubits
97 }
98
99 pub fn dimension(&self) -> usize {
101 self.dimension
102 }
103
104 pub fn get_amplitude(&self, idx: usize) -> Complex64 {
106 let chunk_idx = idx / self.chunk_size;
107 let local_idx = idx % self.chunk_size;
108
109 if chunk_idx >= self.chunks.len() {
110 return Complex64::new(0.0, 0.0);
111 }
112
113 match self.chunks[chunk_idx].get(local_idx) {
114 Some(val) => *val,
115 None => Complex64::new(0.0, 0.0),
116 }
117 }
118
119 pub fn as_vec(&self) -> Vec<Complex64> {
122 let mut result = Vec::with_capacity(self.dimension);
123 for chunk in &self.chunks {
124 result.extend_from_slice(chunk.as_slice());
125 }
126 result
127 }
128
129 pub fn apply_single_qubit_gate(&mut self, matrix: &[Complex64], target: usize) {
136 if target >= self.num_qubits {
137 panic!("Target qubit index out of range");
138 }
139
140 let old_chunks = self.chunks.clone();
142
143 for chunk in &mut self.chunks {
145 for idx in 0..chunk.as_slice().len() {
146 if let Some(val) = chunk.get_mut(idx) {
147 *val = Complex64::new(0.0, 0.0);
148 }
149 }
150 }
151
152 for (chunk_idx, chunk) in old_chunks.iter().enumerate() {
154 let base_idx = chunk_idx * self.chunk_size;
155
156 for (local_idx, &) in chunk.as_slice().iter().enumerate() {
158 let global_idx = base_idx + local_idx;
159 if global_idx >= self.dimension {
160 break;
161 }
162
163 if amp == Complex64::new(0.0, 0.0) {
165 continue;
166 }
167
168 let bit_val = (global_idx >> target) & 1;
169
170 let paired_global_idx = flip_bit(global_idx, target);
172 let paired_chunk_idx = paired_global_idx / self.chunk_size;
173 let paired_local_idx = paired_global_idx % self.chunk_size;
174
175 let paired_amp = if paired_chunk_idx < old_chunks.len() {
177 if let Some(val) = old_chunks[paired_chunk_idx].get(paired_local_idx) {
178 *val
179 } else {
180 Complex64::new(0.0, 0.0)
181 }
182 } else {
183 Complex64::new(0.0, 0.0)
184 };
185
186 let new_amp0 = matrix[0] * amp + matrix[1] * paired_amp;
188 let new_amp1 = matrix[2] * amp + matrix[3] * paired_amp;
189
190 if bit_val == 0 {
192 if let Some(val) = self.chunks[chunk_idx].get_mut(local_idx) {
194 *val += new_amp0;
195 }
196
197 if paired_chunk_idx < self.chunks.len() {
198 if let Some(val) = self.chunks[paired_chunk_idx].get_mut(paired_local_idx) {
199 *val += new_amp1;
200 }
201 }
202 }
203 }
204 }
205 }
206
207 pub fn apply_cnot(&mut self, control: usize, target: usize) {
214 if control >= self.num_qubits || target >= self.num_qubits {
215 panic!("Qubit indices out of range");
216 }
217
218 if control == target {
219 panic!("Control and target qubits must be different");
220 }
221
222 let mut new_chunks = Vec::with_capacity(self.chunks.len());
227 for chunk in &self.chunks {
228 new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
229 }
230
231 for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
233 let base_idx = chunk_idx * self.chunk_size;
234
235 for (local_idx, &) in chunk.as_slice().iter().enumerate() {
237 let global_idx = base_idx + local_idx;
238 if global_idx >= self.dimension {
239 break;
240 }
241
242 let control_bit = (global_idx >> control) & 1;
243
244 if control_bit == 0 {
245 if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
247 *val = amp;
248 }
249 } else {
250 let flipped_idx = flip_bit(global_idx, target);
252 let flipped_chunk_idx = flipped_idx / self.chunk_size;
253 let flipped_local_idx = flipped_idx % self.chunk_size;
254
255 let flipped_amp = self.get_amplitude(flipped_idx);
257
258 if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
260 *val = flipped_amp;
261 }
262
263 if flipped_chunk_idx < self.chunks.len() {
265 if let Some(val) = new_chunks[flipped_chunk_idx].get_mut(flipped_local_idx)
266 {
267 *val = amp;
268 }
269 }
270 }
271 }
272 }
273
274 self.chunks = new_chunks;
276 }
277
278 pub fn apply_two_qubit_gate(&mut self, matrix: &[Complex64], qubit1: usize, qubit2: usize) {
286 if qubit1 >= self.num_qubits || qubit2 >= self.num_qubits {
287 panic!("Qubit indices out of range");
288 }
289
290 if qubit1 == qubit2 {
291 panic!("Qubit indices must be different");
292 }
293
294 let mut new_chunks = Vec::with_capacity(self.chunks.len());
296 for chunk in &self.chunks {
297 new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
298 }
299
300 for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
302 let base_idx = chunk_idx * self.chunk_size;
303
304 for (local_idx, &_) in chunk.as_slice().iter().enumerate() {
306 let global_idx = base_idx + local_idx;
307 if global_idx >= self.dimension {
308 break;
309 }
310
311 let bit1 = (global_idx >> qubit1) & 1;
313 let bit2 = (global_idx >> qubit2) & 1;
314
315 let bits00 = global_idx & !(1 << qubit1) & !(1 << qubit2);
317 let bits01 = bits00 | (1 << qubit2);
318 let bits10 = bits00 | (1 << qubit1);
319 let bits11 = bits10 | (1 << qubit2);
320
321 let amp00 = self.get_amplitude(bits00);
323 let amp01 = self.get_amplitude(bits01);
324 let amp10 = self.get_amplitude(bits10);
325 let amp11 = self.get_amplitude(bits11);
326
327 let subspace_idx = (bit1 << 1) | bit2;
329 let mut new_amp = Complex64::new(0.0, 0.0);
330
331 new_amp += matrix[subspace_idx * 4] * amp00;
333 new_amp += matrix[subspace_idx * 4 + 1] * amp01;
334 new_amp += matrix[subspace_idx * 4 + 2] * amp10;
335 new_amp += matrix[subspace_idx * 4 + 3] * amp11;
336
337 if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
339 *val = new_amp;
340 }
341 }
342 }
343
344 self.chunks = new_chunks;
346 }
347
348 pub fn probability(&self, bit_string: &[u8]) -> f64 {
350 if bit_string.len() != self.num_qubits {
351 panic!("Bit string length must match number of qubits");
352 }
353
354 let mut idx = 0;
356 for (i, &bit) in bit_string.iter().enumerate() {
357 if bit != 0 {
358 idx |= 1 << i;
359 }
360 }
361
362 self.get_amplitude(idx).norm_sqr()
364 }
365
366 pub fn probabilities(&self) -> Vec<f64> {
369 self.chunks
370 .iter()
371 .flat_map(|chunk| chunk.as_slice().iter().map(|a| a.norm_sqr()))
372 .collect()
373 }
374
375 pub fn probability_range(&self, start_idx: usize, end_idx: usize) -> Vec<f64> {
378 let real_end = std::cmp::min(end_idx, self.dimension);
379
380 (start_idx..real_end)
381 .map(|idx| self.get_amplitude(idx).norm_sqr())
382 .collect()
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use std::f64::consts::FRAC_1_SQRT_2;
390
391 #[test]
392 fn test_chunked_state_vector_init() {
393 let sv = ChunkedStateVector::new(2);
394 assert_eq!(sv.num_qubits(), 2);
395 assert_eq!(sv.dimension(), 4);
396
397 assert_eq!(sv.get_amplitude(0), Complex64::new(1.0, 0.0));
399 assert_eq!(sv.get_amplitude(1), Complex64::new(0.0, 0.0));
400 assert_eq!(sv.get_amplitude(2), Complex64::new(0.0, 0.0));
401 assert_eq!(sv.get_amplitude(3), Complex64::new(0.0, 0.0));
402 }
403
404 #[test]
405 fn test_hadamard_gate_chunked() {
406 let h_matrix = [
408 Complex64::new(FRAC_1_SQRT_2, 0.0),
409 Complex64::new(FRAC_1_SQRT_2, 0.0),
410 Complex64::new(FRAC_1_SQRT_2, 0.0),
411 Complex64::new(-FRAC_1_SQRT_2, 0.0),
412 ];
413
414 let mut sv = ChunkedStateVector::new(2);
416 println!("Initial state: {:?}", sv.as_vec());
417 sv.apply_single_qubit_gate(&h_matrix, 1); println!("After H on qubit 1:");
421 println!("amplitude[0] = {:?}", sv.get_amplitude(0));
422 println!("amplitude[1] = {:?}", sv.get_amplitude(1));
423 println!("amplitude[2] = {:?}", sv.get_amplitude(2));
424 println!("amplitude[3] = {:?}", sv.get_amplitude(3));
425
426 assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
428 assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
429 assert!((sv.get_amplitude(2) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
430 assert!((sv.get_amplitude(3) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
431
432 sv.apply_single_qubit_gate(&h_matrix, 0);
434
435 println!("After both H gates:");
438 println!("amplitude[0] = {:?}", sv.get_amplitude(0));
439 println!("amplitude[1] = {:?}", sv.get_amplitude(1));
440 println!("amplitude[2] = {:?}", sv.get_amplitude(2));
441 println!("amplitude[3] = {:?}", sv.get_amplitude(3));
442
443 assert!((sv.get_amplitude(0) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
444 assert!((sv.get_amplitude(1) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
445 assert!((sv.get_amplitude(2) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
446 assert!((sv.get_amplitude(3) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
447 }
448
449 #[test]
450 fn test_cnot_gate_chunked() {
451 let mut sv = ChunkedStateVector::new(2);
453
454 let h_matrix = [
456 Complex64::new(FRAC_1_SQRT_2, 0.0),
457 Complex64::new(FRAC_1_SQRT_2, 0.0),
458 Complex64::new(FRAC_1_SQRT_2, 0.0),
459 Complex64::new(-FRAC_1_SQRT_2, 0.0),
460 ];
461 sv.apply_single_qubit_gate(&h_matrix, 0);
462
463 sv.apply_cnot(0, 1);
465
466 assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
468 assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
469 assert!((sv.get_amplitude(2) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
470 assert!((sv.get_amplitude(3) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
471 }
472}