quantrs2_sim/
optimized_chunked.rs

1//! Optimized quantum state vector simulation using chunked memory processing
2//!
3//! This module provides a memory-efficient implementation for large qubit counts (30+)
4//! by processing the state vector in manageable chunks to reduce memory pressure.
5
6use scirs2_core::Complex64;
7use std::cmp::min;
8
9// Use standard memory management since scirs2 memory module is not available
10// Placeholder for future integration with scirs2
11#[derive(Clone, Debug)]
12struct MemoryChunk<T> {
13    data: Vec<T>,
14    _capacity: usize,
15}
16
17impl<T: Clone + Default> MemoryChunk<T> {
18    fn new(capacity: usize) -> Self {
19        Self {
20            data: vec![T::default(); capacity],
21            _capacity: capacity,
22        }
23    }
24
25    fn get(&self, idx: usize) -> Option<&T> {
26        self.data.get(idx)
27    }
28
29    fn get_mut(&mut self, idx: usize) -> Option<&mut T> {
30        self.data.get_mut(idx)
31    }
32
33    fn as_slice(&self) -> &[T] {
34        &self.data
35    }
36
37    // 未使用のため_プレフィックスを追加
38    fn _as_mut_slice(&mut self) -> &mut [T] {
39        &mut self.data
40    }
41}
42
43use crate::utils::flip_bit;
44
45/// Size of chunks in elements for large state vector processing
46const DEFAULT_CHUNK_SIZE: usize = 1 << 20; // 1 million complex numbers per chunk (~16 MB)
47
48/// Represents a quantum state vector that uses chunked memory for large qubit counts
49pub struct ChunkedStateVector {
50    /// The full state vector stored as multiple chunks
51    chunks: Vec<MemoryChunk<Complex64>>,
52    /// Number of qubits represented
53    num_qubits: usize,
54    /// Size of each chunk (number of complex numbers)
55    chunk_size: usize,
56    /// Total dimension of the state vector (2^num_qubits)
57    dimension: usize,
58}
59
60impl ChunkedStateVector {
61    /// Create a new chunked state vector for given number of qubits
62    pub fn new(num_qubits: usize) -> Self {
63        let dimension = 1 << num_qubits;
64        let chunk_size = min(DEFAULT_CHUNK_SIZE, dimension);
65        let num_chunks = dimension.div_ceil(chunk_size);
66
67        // Create empty chunks
68        let mut chunks = Vec::with_capacity(num_chunks);
69        for i in 0..num_chunks {
70            let this_chunk_size = if i == num_chunks - 1 && dimension % chunk_size != 0 {
71                dimension % chunk_size
72            } else {
73                chunk_size
74            };
75
76            let mut chunk = MemoryChunk::new(this_chunk_size);
77            if i == 0 {
78                // Initialize to |0...0>
79                if let Some(first) = chunk.get_mut(0) {
80                    *first = Complex64::new(1.0, 0.0);
81                }
82            }
83            chunks.push(chunk);
84        }
85
86        Self {
87            chunks,
88            num_qubits,
89            chunk_size,
90            dimension,
91        }
92    }
93
94    /// Get the number of qubits
95    pub const fn num_qubits(&self) -> usize {
96        self.num_qubits
97    }
98
99    /// Get the dimension of the state vector
100    pub const fn dimension(&self) -> usize {
101        self.dimension
102    }
103
104    /// Access a specific amplitude by global index
105    pub fn get_amplitude(&self, idx: usize) -> Complex64 {
106        let chunk_idx = idx / self.chunk_size;
107        let local_idx = idx % self.chunk_size;
108
109        if chunk_idx >= self.chunks.len() {
110            return Complex64::new(0.0, 0.0);
111        }
112
113        match self.chunks[chunk_idx].get(local_idx) {
114            Some(val) => *val,
115            None => Complex64::new(0.0, 0.0),
116        }
117    }
118
119    /// Get all amplitudes as a flattened vector (for testing and conversion)
120    /// Warning: For large qubit counts, this will use a lot of memory
121    pub fn as_vec(&self) -> Vec<Complex64> {
122        let mut result = Vec::with_capacity(self.dimension);
123        for chunk in &self.chunks {
124            result.extend_from_slice(chunk.as_slice());
125        }
126        result
127    }
128
129    /// Apply a single-qubit gate to the state vector using chunked processing
130    ///
131    /// # Arguments
132    ///
133    /// * `matrix` - The 2x2 matrix representation of the gate
134    /// * `target` - The target qubit index
135    pub fn apply_single_qubit_gate(&mut self, matrix: &[Complex64], target: usize) {
136        assert!(
137            (target < self.num_qubits),
138            "Target qubit index out of range"
139        );
140
141        // Copy current state as we need to read from old state while writing to new
142        let old_chunks = self.chunks.clone();
143
144        // Reset all values to zero
145        for chunk in &mut self.chunks {
146            for idx in 0..chunk.as_slice().len() {
147                if let Some(val) = chunk.get_mut(idx) {
148                    *val = Complex64::new(0.0, 0.0);
149                }
150            }
151        }
152
153        // Process each chunk - iterate through old chunks for reading
154        for (chunk_idx, chunk) in old_chunks.iter().enumerate() {
155            let base_idx = chunk_idx * self.chunk_size;
156
157            // Process each amplitude in this chunk
158            for (local_idx, &amp) in chunk.as_slice().iter().enumerate() {
159                let global_idx = base_idx + local_idx;
160                if global_idx >= self.dimension {
161                    break;
162                }
163
164                // Skip over zero amplitudes for efficiency
165                if amp == Complex64::new(0.0, 0.0) {
166                    continue;
167                }
168
169                let bit_val = (global_idx >> target) & 1;
170
171                // Find the paired index
172                let paired_global_idx = flip_bit(global_idx, target);
173                let paired_chunk_idx = paired_global_idx / self.chunk_size;
174                let paired_local_idx = paired_global_idx % self.chunk_size;
175
176                // Get the amplitude of the paired index from old state
177                let paired_amp = if paired_chunk_idx < old_chunks.len() {
178                    if let Some(val) = old_chunks[paired_chunk_idx].get(paired_local_idx) {
179                        *val
180                    } else {
181                        Complex64::new(0.0, 0.0)
182                    }
183                } else {
184                    Complex64::new(0.0, 0.0)
185                };
186
187                // Calculate new amplitudes
188                let new_amp0 = matrix[0] * amp + matrix[1] * paired_amp;
189                let new_amp1 = matrix[2] * amp + matrix[3] * paired_amp;
190
191                // Determine current chunk/idx from global index
192                if bit_val == 0 {
193                    // Update both indices in one go
194                    if let Some(val) = self.chunks[chunk_idx].get_mut(local_idx) {
195                        *val += new_amp0;
196                    }
197
198                    if paired_chunk_idx < self.chunks.len() {
199                        if let Some(val) = self.chunks[paired_chunk_idx].get_mut(paired_local_idx) {
200                            *val += new_amp1;
201                        }
202                    }
203                }
204            }
205        }
206    }
207
208    /// Apply a controlled-NOT gate to the state vector
209    ///
210    /// # Arguments
211    ///
212    /// * `control` - The control qubit index
213    /// * `target` - The target qubit index
214    pub fn apply_cnot(&mut self, control: usize, target: usize) {
215        assert!(
216            !(control >= self.num_qubits || target >= self.num_qubits),
217            "Qubit indices out of range"
218        );
219
220        assert!(
221            (control != target),
222            "Control and target qubits must be different"
223        );
224
225        // We're using standard qubit ordering where the target/control parameters
226        // are used directly with bit operations
227
228        // Create new chunks to hold the result
229        let mut new_chunks = Vec::with_capacity(self.chunks.len());
230        for chunk in &self.chunks {
231            new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
232        }
233
234        // Process each chunk in parallel
235        for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
236            let base_idx = chunk_idx * self.chunk_size;
237
238            // Process this chunk
239            for (local_idx, &amp) in chunk.as_slice().iter().enumerate() {
240                let global_idx = base_idx + local_idx;
241                if global_idx >= self.dimension {
242                    break;
243                }
244
245                let control_bit = (global_idx >> control) & 1;
246
247                if control_bit == 0 {
248                    // Control bit is 0: state remains unchanged
249                    if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
250                        *val = amp;
251                    }
252                } else {
253                    // Control bit is 1: flip the target bit
254                    let flipped_idx = flip_bit(global_idx, target);
255                    let flipped_chunk_idx = flipped_idx / self.chunk_size;
256                    let flipped_local_idx = flipped_idx % self.chunk_size;
257
258                    // Get the amplitude from the flipped position
259                    let flipped_amp = self.get_amplitude(flipped_idx);
260
261                    // Update the current position with the flipped amplitude
262                    if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
263                        *val = flipped_amp;
264                    }
265
266                    // Update the flipped position with the current amplitude
267                    if flipped_chunk_idx < self.chunks.len() {
268                        if let Some(val) = new_chunks[flipped_chunk_idx].get_mut(flipped_local_idx)
269                        {
270                            *val = amp;
271                        }
272                    }
273                }
274            }
275        }
276
277        // Update the state
278        self.chunks = new_chunks;
279    }
280
281    /// Apply a two-qubit gate to the state vector
282    ///
283    /// # Arguments
284    ///
285    /// * `matrix` - The 4x4 matrix representation of the gate
286    /// * `qubit1` - The first qubit index
287    /// * `qubit2` - The second qubit index
288    pub fn apply_two_qubit_gate(&mut self, matrix: &[Complex64], qubit1: usize, qubit2: usize) {
289        assert!(
290            !(qubit1 >= self.num_qubits || qubit2 >= self.num_qubits),
291            "Qubit indices out of range"
292        );
293
294        assert!((qubit1 != qubit2), "Qubit indices must be different");
295
296        // Create new chunks to hold the result
297        let mut new_chunks = Vec::with_capacity(self.chunks.len());
298        for chunk in &self.chunks {
299            new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
300        }
301
302        // Process each chunk
303        for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
304            let base_idx = chunk_idx * self.chunk_size;
305
306            // Process this chunk
307            for (local_idx, &_) in chunk.as_slice().iter().enumerate() {
308                let global_idx = base_idx + local_idx;
309                if global_idx >= self.dimension {
310                    break;
311                }
312
313                // Determine which basis state this corresponds to in the 2-qubit subspace
314                let bit1 = (global_idx >> qubit1) & 1;
315                let bit2 = (global_idx >> qubit2) & 1;
316
317                // Calculate the indices of all four basis states in the 2-qubit subspace
318                let bits00 = global_idx & !(1 << qubit1) & !(1 << qubit2);
319                let bits01 = bits00 | (1 << qubit2);
320                let bits10 = bits00 | (1 << qubit1);
321                let bits11 = bits10 | (1 << qubit2);
322
323                // Get the amplitudes for all basis states
324                let amp00 = self.get_amplitude(bits00);
325                let amp01 = self.get_amplitude(bits01);
326                let amp10 = self.get_amplitude(bits10);
327                let amp11 = self.get_amplitude(bits11);
328
329                // Determine which amplitude to update
330                let subspace_idx = (bit1 << 1) | bit2;
331                let mut new_amp = Complex64::new(0.0, 0.0);
332
333                // Apply the 4x4 matrix to compute the new amplitude
334                new_amp += matrix[subspace_idx * 4] * amp00;
335                new_amp += matrix[subspace_idx * 4 + 1] * amp01;
336                new_amp += matrix[subspace_idx * 4 + 2] * amp10;
337                new_amp += matrix[subspace_idx * 4 + 3] * amp11;
338
339                // Update the amplitude in the result
340                if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
341                    *val = new_amp;
342                }
343            }
344        }
345
346        // Update the state
347        self.chunks = new_chunks;
348    }
349
350    /// Calculate probability of measuring a specific bit string
351    pub fn probability(&self, bit_string: &[u8]) -> f64 {
352        assert!(
353            (bit_string.len() == self.num_qubits),
354            "Bit string length must match number of qubits"
355        );
356
357        // Convert bit string to index
358        let mut idx = 0;
359        for (i, &bit) in bit_string.iter().enumerate() {
360            if bit != 0 {
361                idx |= 1 << i;
362            }
363        }
364
365        // Return probability
366        self.get_amplitude(idx).norm_sqr()
367    }
368
369    /// Calculate probabilities for all basis states
370    /// Warning: For large qubit counts, this will use a lot of memory
371    pub fn probabilities(&self) -> Vec<f64> {
372        self.chunks
373            .iter()
374            .flat_map(|chunk| chunk.as_slice().iter().map(|a| a.norm_sqr()))
375            .collect()
376    }
377
378    /// Calculate the probability of a specified range of states
379    /// More memory efficient for large qubit counts
380    pub fn probability_range(&self, start_idx: usize, end_idx: usize) -> Vec<f64> {
381        let real_end = std::cmp::min(end_idx, self.dimension);
382
383        (start_idx..real_end)
384            .map(|idx| self.get_amplitude(idx).norm_sqr())
385            .collect()
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use std::f64::consts::FRAC_1_SQRT_2;
393
394    #[test]
395    fn test_chunked_state_vector_init() {
396        let sv = ChunkedStateVector::new(2);
397        assert_eq!(sv.num_qubits(), 2);
398        assert_eq!(sv.dimension(), 4);
399
400        // Initial state should be |00>
401        assert_eq!(sv.get_amplitude(0), Complex64::new(1.0, 0.0));
402        assert_eq!(sv.get_amplitude(1), Complex64::new(0.0, 0.0));
403        assert_eq!(sv.get_amplitude(2), Complex64::new(0.0, 0.0));
404        assert_eq!(sv.get_amplitude(3), Complex64::new(0.0, 0.0));
405    }
406
407    #[test]
408    fn test_hadamard_gate_chunked() {
409        // Hadamard matrix
410        let h_matrix = [
411            Complex64::new(FRAC_1_SQRT_2, 0.0),
412            Complex64::new(FRAC_1_SQRT_2, 0.0),
413            Complex64::new(FRAC_1_SQRT_2, 0.0),
414            Complex64::new(-FRAC_1_SQRT_2, 0.0),
415        ];
416
417        // Apply H to the 0th qubit of |00>
418        let mut sv = ChunkedStateVector::new(2);
419        println!("Initial state: {:?}", sv.as_vec());
420        sv.apply_single_qubit_gate(&h_matrix, 1); // Changed from 0 to 1
421
422        // Print state for debugging
423        println!("After H on qubit 1:");
424        println!("amplitude[0] = {:?}", sv.get_amplitude(0));
425        println!("amplitude[1] = {:?}", sv.get_amplitude(1));
426        println!("amplitude[2] = {:?}", sv.get_amplitude(2));
427        println!("amplitude[3] = {:?}", sv.get_amplitude(3));
428
429        // Result should be |00> + |10> / sqrt(2)
430        assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
431        assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
432        assert!((sv.get_amplitude(2) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
433        assert!((sv.get_amplitude(3) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
434
435        // Apply H to the 1st qubit (actually 0th in our implementation)
436        sv.apply_single_qubit_gate(&h_matrix, 0);
437
438        // Result should be (|00> + |01> + |10> - |11>) / 2
439        // Add debug output
440        println!("After both H gates:");
441        println!("amplitude[0] = {:?}", sv.get_amplitude(0));
442        println!("amplitude[1] = {:?}", sv.get_amplitude(1));
443        println!("amplitude[2] = {:?}", sv.get_amplitude(2));
444        println!("amplitude[3] = {:?}", sv.get_amplitude(3));
445
446        assert!((sv.get_amplitude(0) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
447        assert!((sv.get_amplitude(1) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
448        assert!((sv.get_amplitude(2) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
449        assert!((sv.get_amplitude(3) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
450    }
451
452    #[test]
453    fn test_cnot_gate_chunked() {
454        // Set up state |+0> = (|00> + |10>) / sqrt(2)
455        let mut sv = ChunkedStateVector::new(2);
456
457        // Hadamard on qubit 0
458        let h_matrix = [
459            Complex64::new(FRAC_1_SQRT_2, 0.0),
460            Complex64::new(FRAC_1_SQRT_2, 0.0),
461            Complex64::new(FRAC_1_SQRT_2, 0.0),
462            Complex64::new(-FRAC_1_SQRT_2, 0.0),
463        ];
464        sv.apply_single_qubit_gate(&h_matrix, 0);
465
466        // Apply CNOT
467        sv.apply_cnot(0, 1);
468
469        // Result should be (|00> + |11>) / sqrt(2) = Bell state
470        assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
471        assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
472        assert!((sv.get_amplitude(2) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
473        assert!((sv.get_amplitude(3) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
474    }
475}