quantrs2_sim/
optimized_chunked.rs

1//! Optimized quantum state vector simulation using chunked memory processing
2//!
3//! This module provides a memory-efficient implementation for large qubit counts (30+)
4//! by processing the state vector in manageable chunks to reduce memory pressure.
5
6use scirs2_core::Complex64;
7use std::cmp::min;
8
9// Use standard memory management since scirs2 memory module is not available
10// Placeholder for future integration with scirs2
11#[derive(Clone, Debug)]
12struct MemoryChunk<T> {
13    data: Vec<T>,
14    _capacity: usize,
15}
16
17impl<T: Clone + Default> MemoryChunk<T> {
18    fn new(capacity: usize) -> Self {
19        Self {
20            data: vec![T::default(); capacity],
21            _capacity: capacity,
22        }
23    }
24
25    fn get(&self, idx: usize) -> Option<&T> {
26        self.data.get(idx)
27    }
28
29    fn get_mut(&mut self, idx: usize) -> Option<&mut T> {
30        self.data.get_mut(idx)
31    }
32
33    fn as_slice(&self) -> &[T] {
34        &self.data
35    }
36
37    // 未使用のため_プレフィックスを追加
38    fn _as_mut_slice(&mut self) -> &mut [T] {
39        &mut self.data
40    }
41}
42
43use crate::utils::flip_bit;
44
45/// Size of chunks in elements for large state vector processing
46const DEFAULT_CHUNK_SIZE: usize = 1 << 20; // 1 million complex numbers per chunk (~16 MB)
47
48/// Represents a quantum state vector that uses chunked memory for large qubit counts
49pub struct ChunkedStateVector {
50    /// The full state vector stored as multiple chunks
51    chunks: Vec<MemoryChunk<Complex64>>,
52    /// Number of qubits represented
53    num_qubits: usize,
54    /// Size of each chunk (number of complex numbers)
55    chunk_size: usize,
56    /// Total dimension of the state vector (2^num_qubits)
57    dimension: usize,
58}
59
60impl ChunkedStateVector {
61    /// Create a new chunked state vector for given number of qubits
62    pub fn new(num_qubits: usize) -> Self {
63        let dimension = 1 << num_qubits;
64        let chunk_size = min(DEFAULT_CHUNK_SIZE, dimension);
65        let num_chunks = dimension.div_ceil(chunk_size);
66
67        // Create empty chunks
68        let mut chunks = Vec::with_capacity(num_chunks);
69        for i in 0..num_chunks {
70            let this_chunk_size = if i == num_chunks - 1 && dimension % chunk_size != 0 {
71                dimension % chunk_size
72            } else {
73                chunk_size
74            };
75
76            let mut chunk = MemoryChunk::new(this_chunk_size);
77            if i == 0 {
78                // Initialize to |0...0>
79                if let Some(first) = chunk.get_mut(0) {
80                    *first = Complex64::new(1.0, 0.0);
81                }
82            }
83            chunks.push(chunk);
84        }
85
86        Self {
87            chunks,
88            num_qubits,
89            chunk_size,
90            dimension,
91        }
92    }
93
94    /// Get the number of qubits
95    pub fn num_qubits(&self) -> usize {
96        self.num_qubits
97    }
98
99    /// Get the dimension of the state vector
100    pub fn dimension(&self) -> usize {
101        self.dimension
102    }
103
104    /// Access a specific amplitude by global index
105    pub fn get_amplitude(&self, idx: usize) -> Complex64 {
106        let chunk_idx = idx / self.chunk_size;
107        let local_idx = idx % self.chunk_size;
108
109        if chunk_idx >= self.chunks.len() {
110            return Complex64::new(0.0, 0.0);
111        }
112
113        match self.chunks[chunk_idx].get(local_idx) {
114            Some(val) => *val,
115            None => Complex64::new(0.0, 0.0),
116        }
117    }
118
119    /// Get all amplitudes as a flattened vector (for testing and conversion)
120    /// Warning: For large qubit counts, this will use a lot of memory
121    pub fn as_vec(&self) -> Vec<Complex64> {
122        let mut result = Vec::with_capacity(self.dimension);
123        for chunk in &self.chunks {
124            result.extend_from_slice(chunk.as_slice());
125        }
126        result
127    }
128
129    /// Apply a single-qubit gate to the state vector using chunked processing
130    ///
131    /// # Arguments
132    ///
133    /// * `matrix` - The 2x2 matrix representation of the gate
134    /// * `target` - The target qubit index
135    pub fn apply_single_qubit_gate(&mut self, matrix: &[Complex64], target: usize) {
136        if target >= self.num_qubits {
137            panic!("Target qubit index out of range");
138        }
139
140        // Copy current state as we need to read from old state while writing to new
141        let old_chunks = self.chunks.clone();
142
143        // Reset all values to zero
144        for chunk in &mut self.chunks {
145            for idx in 0..chunk.as_slice().len() {
146                if let Some(val) = chunk.get_mut(idx) {
147                    *val = Complex64::new(0.0, 0.0);
148                }
149            }
150        }
151
152        // Process each chunk - iterate through old chunks for reading
153        for (chunk_idx, chunk) in old_chunks.iter().enumerate() {
154            let base_idx = chunk_idx * self.chunk_size;
155
156            // Process each amplitude in this chunk
157            for (local_idx, &amp) in chunk.as_slice().iter().enumerate() {
158                let global_idx = base_idx + local_idx;
159                if global_idx >= self.dimension {
160                    break;
161                }
162
163                // Skip over zero amplitudes for efficiency
164                if amp == Complex64::new(0.0, 0.0) {
165                    continue;
166                }
167
168                let bit_val = (global_idx >> target) & 1;
169
170                // Find the paired index
171                let paired_global_idx = flip_bit(global_idx, target);
172                let paired_chunk_idx = paired_global_idx / self.chunk_size;
173                let paired_local_idx = paired_global_idx % self.chunk_size;
174
175                // Get the amplitude of the paired index from old state
176                let paired_amp = if paired_chunk_idx < old_chunks.len() {
177                    if let Some(val) = old_chunks[paired_chunk_idx].get(paired_local_idx) {
178                        *val
179                    } else {
180                        Complex64::new(0.0, 0.0)
181                    }
182                } else {
183                    Complex64::new(0.0, 0.0)
184                };
185
186                // Calculate new amplitudes
187                let new_amp0 = matrix[0] * amp + matrix[1] * paired_amp;
188                let new_amp1 = matrix[2] * amp + matrix[3] * paired_amp;
189
190                // Determine current chunk/idx from global index
191                if bit_val == 0 {
192                    // Update both indices in one go
193                    if let Some(val) = self.chunks[chunk_idx].get_mut(local_idx) {
194                        *val += new_amp0;
195                    }
196
197                    if paired_chunk_idx < self.chunks.len() {
198                        if let Some(val) = self.chunks[paired_chunk_idx].get_mut(paired_local_idx) {
199                            *val += new_amp1;
200                        }
201                    }
202                }
203            }
204        }
205    }
206
207    /// Apply a controlled-NOT gate to the state vector
208    ///
209    /// # Arguments
210    ///
211    /// * `control` - The control qubit index
212    /// * `target` - The target qubit index
213    pub fn apply_cnot(&mut self, control: usize, target: usize) {
214        if control >= self.num_qubits || target >= self.num_qubits {
215            panic!("Qubit indices out of range");
216        }
217
218        if control == target {
219            panic!("Control and target qubits must be different");
220        }
221
222        // We're using standard qubit ordering where the target/control parameters
223        // are used directly with bit operations
224
225        // Create new chunks to hold the result
226        let mut new_chunks = Vec::with_capacity(self.chunks.len());
227        for chunk in &self.chunks {
228            new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
229        }
230
231        // Process each chunk in parallel
232        for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
233            let base_idx = chunk_idx * self.chunk_size;
234
235            // Process this chunk
236            for (local_idx, &amp) in chunk.as_slice().iter().enumerate() {
237                let global_idx = base_idx + local_idx;
238                if global_idx >= self.dimension {
239                    break;
240                }
241
242                let control_bit = (global_idx >> control) & 1;
243
244                if control_bit == 0 {
245                    // Control bit is 0: state remains unchanged
246                    if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
247                        *val = amp;
248                    }
249                } else {
250                    // Control bit is 1: flip the target bit
251                    let flipped_idx = flip_bit(global_idx, target);
252                    let flipped_chunk_idx = flipped_idx / self.chunk_size;
253                    let flipped_local_idx = flipped_idx % self.chunk_size;
254
255                    // Get the amplitude from the flipped position
256                    let flipped_amp = self.get_amplitude(flipped_idx);
257
258                    // Update the current position with the flipped amplitude
259                    if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
260                        *val = flipped_amp;
261                    }
262
263                    // Update the flipped position with the current amplitude
264                    if flipped_chunk_idx < self.chunks.len() {
265                        if let Some(val) = new_chunks[flipped_chunk_idx].get_mut(flipped_local_idx)
266                        {
267                            *val = amp;
268                        }
269                    }
270                }
271            }
272        }
273
274        // Update the state
275        self.chunks = new_chunks;
276    }
277
278    /// Apply a two-qubit gate to the state vector
279    ///
280    /// # Arguments
281    ///
282    /// * `matrix` - The 4x4 matrix representation of the gate
283    /// * `qubit1` - The first qubit index
284    /// * `qubit2` - The second qubit index
285    pub fn apply_two_qubit_gate(&mut self, matrix: &[Complex64], qubit1: usize, qubit2: usize) {
286        if qubit1 >= self.num_qubits || qubit2 >= self.num_qubits {
287            panic!("Qubit indices out of range");
288        }
289
290        if qubit1 == qubit2 {
291            panic!("Qubit indices must be different");
292        }
293
294        // Create new chunks to hold the result
295        let mut new_chunks = Vec::with_capacity(self.chunks.len());
296        for chunk in &self.chunks {
297            new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
298        }
299
300        // Process each chunk
301        for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
302            let base_idx = chunk_idx * self.chunk_size;
303
304            // Process this chunk
305            for (local_idx, &_) in chunk.as_slice().iter().enumerate() {
306                let global_idx = base_idx + local_idx;
307                if global_idx >= self.dimension {
308                    break;
309                }
310
311                // Determine which basis state this corresponds to in the 2-qubit subspace
312                let bit1 = (global_idx >> qubit1) & 1;
313                let bit2 = (global_idx >> qubit2) & 1;
314
315                // Calculate the indices of all four basis states in the 2-qubit subspace
316                let bits00 = global_idx & !(1 << qubit1) & !(1 << qubit2);
317                let bits01 = bits00 | (1 << qubit2);
318                let bits10 = bits00 | (1 << qubit1);
319                let bits11 = bits10 | (1 << qubit2);
320
321                // Get the amplitudes for all basis states
322                let amp00 = self.get_amplitude(bits00);
323                let amp01 = self.get_amplitude(bits01);
324                let amp10 = self.get_amplitude(bits10);
325                let amp11 = self.get_amplitude(bits11);
326
327                // Determine which amplitude to update
328                let subspace_idx = (bit1 << 1) | bit2;
329                let mut new_amp = Complex64::new(0.0, 0.0);
330
331                // Apply the 4x4 matrix to compute the new amplitude
332                new_amp += matrix[subspace_idx * 4] * amp00;
333                new_amp += matrix[subspace_idx * 4 + 1] * amp01;
334                new_amp += matrix[subspace_idx * 4 + 2] * amp10;
335                new_amp += matrix[subspace_idx * 4 + 3] * amp11;
336
337                // Update the amplitude in the result
338                if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
339                    *val = new_amp;
340                }
341            }
342        }
343
344        // Update the state
345        self.chunks = new_chunks;
346    }
347
348    /// Calculate probability of measuring a specific bit string
349    pub fn probability(&self, bit_string: &[u8]) -> f64 {
350        if bit_string.len() != self.num_qubits {
351            panic!("Bit string length must match number of qubits");
352        }
353
354        // Convert bit string to index
355        let mut idx = 0;
356        for (i, &bit) in bit_string.iter().enumerate() {
357            if bit != 0 {
358                idx |= 1 << i;
359            }
360        }
361
362        // Return probability
363        self.get_amplitude(idx).norm_sqr()
364    }
365
366    /// Calculate probabilities for all basis states
367    /// Warning: For large qubit counts, this will use a lot of memory
368    pub fn probabilities(&self) -> Vec<f64> {
369        self.chunks
370            .iter()
371            .flat_map(|chunk| chunk.as_slice().iter().map(|a| a.norm_sqr()))
372            .collect()
373    }
374
375    /// Calculate the probability of a specified range of states
376    /// More memory efficient for large qubit counts
377    pub fn probability_range(&self, start_idx: usize, end_idx: usize) -> Vec<f64> {
378        let real_end = std::cmp::min(end_idx, self.dimension);
379
380        (start_idx..real_end)
381            .map(|idx| self.get_amplitude(idx).norm_sqr())
382            .collect()
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use std::f64::consts::FRAC_1_SQRT_2;
390
391    #[test]
392    fn test_chunked_state_vector_init() {
393        let sv = ChunkedStateVector::new(2);
394        assert_eq!(sv.num_qubits(), 2);
395        assert_eq!(sv.dimension(), 4);
396
397        // Initial state should be |00>
398        assert_eq!(sv.get_amplitude(0), Complex64::new(1.0, 0.0));
399        assert_eq!(sv.get_amplitude(1), Complex64::new(0.0, 0.0));
400        assert_eq!(sv.get_amplitude(2), Complex64::new(0.0, 0.0));
401        assert_eq!(sv.get_amplitude(3), Complex64::new(0.0, 0.0));
402    }
403
404    #[test]
405    fn test_hadamard_gate_chunked() {
406        // Hadamard matrix
407        let h_matrix = [
408            Complex64::new(FRAC_1_SQRT_2, 0.0),
409            Complex64::new(FRAC_1_SQRT_2, 0.0),
410            Complex64::new(FRAC_1_SQRT_2, 0.0),
411            Complex64::new(-FRAC_1_SQRT_2, 0.0),
412        ];
413
414        // Apply H to the 0th qubit of |00>
415        let mut sv = ChunkedStateVector::new(2);
416        println!("Initial state: {:?}", sv.as_vec());
417        sv.apply_single_qubit_gate(&h_matrix, 1); // Changed from 0 to 1
418
419        // Print state for debugging
420        println!("After H on qubit 1:");
421        println!("amplitude[0] = {:?}", sv.get_amplitude(0));
422        println!("amplitude[1] = {:?}", sv.get_amplitude(1));
423        println!("amplitude[2] = {:?}", sv.get_amplitude(2));
424        println!("amplitude[3] = {:?}", sv.get_amplitude(3));
425
426        // Result should be |00> + |10> / sqrt(2)
427        assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
428        assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
429        assert!((sv.get_amplitude(2) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
430        assert!((sv.get_amplitude(3) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
431
432        // Apply H to the 1st qubit (actually 0th in our implementation)
433        sv.apply_single_qubit_gate(&h_matrix, 0);
434
435        // Result should be (|00> + |01> + |10> - |11>) / 2
436        // Add debug output
437        println!("After both H gates:");
438        println!("amplitude[0] = {:?}", sv.get_amplitude(0));
439        println!("amplitude[1] = {:?}", sv.get_amplitude(1));
440        println!("amplitude[2] = {:?}", sv.get_amplitude(2));
441        println!("amplitude[3] = {:?}", sv.get_amplitude(3));
442
443        assert!((sv.get_amplitude(0) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
444        assert!((sv.get_amplitude(1) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
445        assert!((sv.get_amplitude(2) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
446        assert!((sv.get_amplitude(3) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
447    }
448
449    #[test]
450    fn test_cnot_gate_chunked() {
451        // Set up state |+0> = (|00> + |10>) / sqrt(2)
452        let mut sv = ChunkedStateVector::new(2);
453
454        // Hadamard on qubit 0
455        let h_matrix = [
456            Complex64::new(FRAC_1_SQRT_2, 0.0),
457            Complex64::new(FRAC_1_SQRT_2, 0.0),
458            Complex64::new(FRAC_1_SQRT_2, 0.0),
459            Complex64::new(-FRAC_1_SQRT_2, 0.0),
460        ];
461        sv.apply_single_qubit_gate(&h_matrix, 0);
462
463        // Apply CNOT
464        sv.apply_cnot(0, 1);
465
466        // Result should be (|00> + |11>) / sqrt(2) = Bell state
467        assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
468        assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
469        assert!((sv.get_amplitude(2) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
470        assert!((sv.get_amplitude(3) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
471    }
472}