quantrs2_sim/
optimized_chunked.rs

1//! Optimized quantum state vector simulation using chunked memory processing
2//!
3//! This module provides a memory-efficient implementation for large qubit counts (30+)
4//! by processing the state vector in manageable chunks to reduce memory pressure.
5
6use scirs2_core::Complex64;
7use std::cmp::min;
8
9// Use standard memory management since scirs2 memory module is not available
10// Placeholder for future integration with scirs2
11#[derive(Clone, Debug)]
12struct MemoryChunk<T> {
13    data: Vec<T>,
14    _capacity: usize,
15}
16
17impl<T: Clone + Default> MemoryChunk<T> {
18    fn new(capacity: usize) -> Self {
19        Self {
20            data: vec![T::default(); capacity],
21            _capacity: capacity,
22        }
23    }
24
25    fn get(&self, idx: usize) -> Option<&T> {
26        self.data.get(idx)
27    }
28
29    fn get_mut(&mut self, idx: usize) -> Option<&mut T> {
30        self.data.get_mut(idx)
31    }
32
33    fn as_slice(&self) -> &[T] {
34        &self.data
35    }
36
37    // 未使用のため_プレフィックスを追加
38    fn _as_mut_slice(&mut self) -> &mut [T] {
39        &mut self.data
40    }
41}
42
43use crate::utils::flip_bit;
44
45/// Size of chunks in elements for large state vector processing
46const DEFAULT_CHUNK_SIZE: usize = 1 << 20; // 1 million complex numbers per chunk (~16 MB)
47
48/// Represents a quantum state vector that uses chunked memory for large qubit counts
49pub struct ChunkedStateVector {
50    /// The full state vector stored as multiple chunks
51    chunks: Vec<MemoryChunk<Complex64>>,
52    /// Number of qubits represented
53    num_qubits: usize,
54    /// Size of each chunk (number of complex numbers)
55    chunk_size: usize,
56    /// Total dimension of the state vector (`2^num_qubits`)
57    dimension: usize,
58}
59
60impl ChunkedStateVector {
61    /// Create a new chunked state vector for given number of qubits
62    #[must_use]
63    pub fn new(num_qubits: usize) -> Self {
64        let dimension = 1 << num_qubits;
65        let chunk_size = min(DEFAULT_CHUNK_SIZE, dimension);
66        let num_chunks = dimension.div_ceil(chunk_size);
67
68        // Create empty chunks
69        let mut chunks = Vec::with_capacity(num_chunks);
70        for i in 0..num_chunks {
71            let this_chunk_size = if i == num_chunks - 1 && dimension % chunk_size != 0 {
72                dimension % chunk_size
73            } else {
74                chunk_size
75            };
76
77            let mut chunk = MemoryChunk::new(this_chunk_size);
78            if i == 0 {
79                // Initialize to |0...0>
80                if let Some(first) = chunk.get_mut(0) {
81                    *first = Complex64::new(1.0, 0.0);
82                }
83            }
84            chunks.push(chunk);
85        }
86
87        Self {
88            chunks,
89            num_qubits,
90            chunk_size,
91            dimension,
92        }
93    }
94
95    /// Get the number of qubits
96    #[must_use]
97    pub const fn num_qubits(&self) -> usize {
98        self.num_qubits
99    }
100
101    /// Get the dimension of the state vector
102    #[must_use]
103    pub const fn dimension(&self) -> usize {
104        self.dimension
105    }
106
107    /// Access a specific amplitude by global index
108    #[must_use]
109    pub fn get_amplitude(&self, idx: usize) -> Complex64 {
110        let chunk_idx = idx / self.chunk_size;
111        let local_idx = idx % self.chunk_size;
112
113        if chunk_idx >= self.chunks.len() {
114            return Complex64::new(0.0, 0.0);
115        }
116
117        match self.chunks[chunk_idx].get(local_idx) {
118            Some(val) => *val,
119            None => Complex64::new(0.0, 0.0),
120        }
121    }
122
123    /// Get all amplitudes as a flattened vector (for testing and conversion)
124    /// Warning: For large qubit counts, this will use a lot of memory
125    #[must_use]
126    pub fn as_vec(&self) -> Vec<Complex64> {
127        let mut result = Vec::with_capacity(self.dimension);
128        for chunk in &self.chunks {
129            result.extend_from_slice(chunk.as_slice());
130        }
131        result
132    }
133
134    /// Apply a single-qubit gate to the state vector using chunked processing
135    ///
136    /// # Arguments
137    ///
138    /// * `matrix` - The 2x2 matrix representation of the gate
139    /// * `target` - The target qubit index
140    pub fn apply_single_qubit_gate(&mut self, matrix: &[Complex64], target: usize) {
141        assert!(
142            (target < self.num_qubits),
143            "Target qubit index out of range"
144        );
145
146        // Copy current state as we need to read from old state while writing to new
147        let old_chunks = self.chunks.clone();
148
149        // Reset all values to zero
150        for chunk in &mut self.chunks {
151            for idx in 0..chunk.as_slice().len() {
152                if let Some(val) = chunk.get_mut(idx) {
153                    *val = Complex64::new(0.0, 0.0);
154                }
155            }
156        }
157
158        // Process each chunk - iterate through old chunks for reading
159        for (chunk_idx, chunk) in old_chunks.iter().enumerate() {
160            let base_idx = chunk_idx * self.chunk_size;
161
162            // Process each amplitude in this chunk
163            for (local_idx, &amp) in chunk.as_slice().iter().enumerate() {
164                let global_idx = base_idx + local_idx;
165                if global_idx >= self.dimension {
166                    break;
167                }
168
169                // Skip over zero amplitudes for efficiency
170                if amp == Complex64::new(0.0, 0.0) {
171                    continue;
172                }
173
174                let bit_val = (global_idx >> target) & 1;
175
176                // Find the paired index
177                let paired_global_idx = flip_bit(global_idx, target);
178                let paired_chunk_idx = paired_global_idx / self.chunk_size;
179                let paired_local_idx = paired_global_idx % self.chunk_size;
180
181                // Get the amplitude of the paired index from old state
182                let paired_amp = if paired_chunk_idx < old_chunks.len() {
183                    if let Some(val) = old_chunks[paired_chunk_idx].get(paired_local_idx) {
184                        *val
185                    } else {
186                        Complex64::new(0.0, 0.0)
187                    }
188                } else {
189                    Complex64::new(0.0, 0.0)
190                };
191
192                // Calculate new amplitudes
193                let new_amp0 = matrix[0] * amp + matrix[1] * paired_amp;
194                let new_amp1 = matrix[2] * amp + matrix[3] * paired_amp;
195
196                // Determine current chunk/idx from global index
197                if bit_val == 0 {
198                    // Update both indices in one go
199                    if let Some(val) = self.chunks[chunk_idx].get_mut(local_idx) {
200                        *val += new_amp0;
201                    }
202
203                    if paired_chunk_idx < self.chunks.len() {
204                        if let Some(val) = self.chunks[paired_chunk_idx].get_mut(paired_local_idx) {
205                            *val += new_amp1;
206                        }
207                    }
208                }
209            }
210        }
211    }
212
213    /// Apply a controlled-NOT gate to the state vector
214    ///
215    /// # Arguments
216    ///
217    /// * `control` - The control qubit index
218    /// * `target` - The target qubit index
219    pub fn apply_cnot(&mut self, control: usize, target: usize) {
220        assert!(
221            !(control >= self.num_qubits || target >= self.num_qubits),
222            "Qubit indices out of range"
223        );
224
225        assert!(
226            (control != target),
227            "Control and target qubits must be different"
228        );
229
230        // We're using standard qubit ordering where the target/control parameters
231        // are used directly with bit operations
232
233        // Create new chunks to hold the result
234        let mut new_chunks = Vec::with_capacity(self.chunks.len());
235        for chunk in &self.chunks {
236            new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
237        }
238
239        // Process each chunk in parallel
240        for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
241            let base_idx = chunk_idx * self.chunk_size;
242
243            // Process this chunk
244            for (local_idx, &amp) in chunk.as_slice().iter().enumerate() {
245                let global_idx = base_idx + local_idx;
246                if global_idx >= self.dimension {
247                    break;
248                }
249
250                let control_bit = (global_idx >> control) & 1;
251
252                if control_bit == 0 {
253                    // Control bit is 0: state remains unchanged
254                    if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
255                        *val = amp;
256                    }
257                } else {
258                    // Control bit is 1: flip the target bit
259                    let flipped_idx = flip_bit(global_idx, target);
260                    let flipped_chunk_idx = flipped_idx / self.chunk_size;
261                    let flipped_local_idx = flipped_idx % self.chunk_size;
262
263                    // Get the amplitude from the flipped position
264                    let flipped_amp = self.get_amplitude(flipped_idx);
265
266                    // Update the current position with the flipped amplitude
267                    if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
268                        *val = flipped_amp;
269                    }
270
271                    // Update the flipped position with the current amplitude
272                    if flipped_chunk_idx < self.chunks.len() {
273                        if let Some(val) = new_chunks[flipped_chunk_idx].get_mut(flipped_local_idx)
274                        {
275                            *val = amp;
276                        }
277                    }
278                }
279            }
280        }
281
282        // Update the state
283        self.chunks = new_chunks;
284    }
285
286    /// Apply a two-qubit gate to the state vector
287    ///
288    /// # Arguments
289    ///
290    /// * `matrix` - The 4x4 matrix representation of the gate
291    /// * `qubit1` - The first qubit index
292    /// * `qubit2` - The second qubit index
293    pub fn apply_two_qubit_gate(&mut self, matrix: &[Complex64], qubit1: usize, qubit2: usize) {
294        assert!(
295            !(qubit1 >= self.num_qubits || qubit2 >= self.num_qubits),
296            "Qubit indices out of range"
297        );
298
299        assert!((qubit1 != qubit2), "Qubit indices must be different");
300
301        // Create new chunks to hold the result
302        let mut new_chunks = Vec::with_capacity(self.chunks.len());
303        for chunk in &self.chunks {
304            new_chunks.push(MemoryChunk::new(chunk.as_slice().len()));
305        }
306
307        // Process each chunk
308        for (chunk_idx, chunk) in self.chunks.iter().enumerate() {
309            let base_idx = chunk_idx * self.chunk_size;
310
311            // Process this chunk
312            for (local_idx, &_) in chunk.as_slice().iter().enumerate() {
313                let global_idx = base_idx + local_idx;
314                if global_idx >= self.dimension {
315                    break;
316                }
317
318                // Determine which basis state this corresponds to in the 2-qubit subspace
319                let bit1 = (global_idx >> qubit1) & 1;
320                let bit2 = (global_idx >> qubit2) & 1;
321
322                // Calculate the indices of all four basis states in the 2-qubit subspace
323                let bits00 = global_idx & !(1 << qubit1) & !(1 << qubit2);
324                let bits01 = bits00 | (1 << qubit2);
325                let bits10 = bits00 | (1 << qubit1);
326                let bits11 = bits10 | (1 << qubit2);
327
328                // Get the amplitudes for all basis states
329                let amp00 = self.get_amplitude(bits00);
330                let amp01 = self.get_amplitude(bits01);
331                let amp10 = self.get_amplitude(bits10);
332                let amp11 = self.get_amplitude(bits11);
333
334                // Determine which amplitude to update
335                let subspace_idx = (bit1 << 1) | bit2;
336                let mut new_amp = Complex64::new(0.0, 0.0);
337
338                // Apply the 4x4 matrix to compute the new amplitude
339                new_amp += matrix[subspace_idx * 4] * amp00;
340                new_amp += matrix[subspace_idx * 4 + 1] * amp01;
341                new_amp += matrix[subspace_idx * 4 + 2] * amp10;
342                new_amp += matrix[subspace_idx * 4 + 3] * amp11;
343
344                // Update the amplitude in the result
345                if let Some(val) = new_chunks[chunk_idx].get_mut(local_idx) {
346                    *val = new_amp;
347                }
348            }
349        }
350
351        // Update the state
352        self.chunks = new_chunks;
353    }
354
355    /// Calculate probability of measuring a specific bit string
356    #[must_use]
357    pub fn probability(&self, bit_string: &[u8]) -> f64 {
358        assert!(
359            (bit_string.len() == self.num_qubits),
360            "Bit string length must match number of qubits"
361        );
362
363        // Convert bit string to index
364        let mut idx = 0;
365        for (i, &bit) in bit_string.iter().enumerate() {
366            if bit != 0 {
367                idx |= 1 << i;
368            }
369        }
370
371        // Return probability
372        self.get_amplitude(idx).norm_sqr()
373    }
374
375    /// Calculate probabilities for all basis states
376    /// Warning: For large qubit counts, this will use a lot of memory
377    #[must_use]
378    pub fn probabilities(&self) -> Vec<f64> {
379        self.chunks
380            .iter()
381            .flat_map(|chunk| chunk.as_slice().iter().map(scirs2_core::Complex::norm_sqr))
382            .collect()
383    }
384
385    /// Calculate the probability of a specified range of states
386    /// More memory efficient for large qubit counts
387    #[must_use]
388    pub fn probability_range(&self, start_idx: usize, end_idx: usize) -> Vec<f64> {
389        let real_end = std::cmp::min(end_idx, self.dimension);
390
391        (start_idx..real_end)
392            .map(|idx| self.get_amplitude(idx).norm_sqr())
393            .collect()
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use std::f64::consts::FRAC_1_SQRT_2;
401
402    #[test]
403    fn test_chunked_state_vector_init() {
404        let sv = ChunkedStateVector::new(2);
405        assert_eq!(sv.num_qubits(), 2);
406        assert_eq!(sv.dimension(), 4);
407
408        // Initial state should be |00>
409        assert_eq!(sv.get_amplitude(0), Complex64::new(1.0, 0.0));
410        assert_eq!(sv.get_amplitude(1), Complex64::new(0.0, 0.0));
411        assert_eq!(sv.get_amplitude(2), Complex64::new(0.0, 0.0));
412        assert_eq!(sv.get_amplitude(3), Complex64::new(0.0, 0.0));
413    }
414
415    #[test]
416    fn test_hadamard_gate_chunked() {
417        // Hadamard matrix
418        let h_matrix = [
419            Complex64::new(FRAC_1_SQRT_2, 0.0),
420            Complex64::new(FRAC_1_SQRT_2, 0.0),
421            Complex64::new(FRAC_1_SQRT_2, 0.0),
422            Complex64::new(-FRAC_1_SQRT_2, 0.0),
423        ];
424
425        // Apply H to the 0th qubit of |00>
426        let mut sv = ChunkedStateVector::new(2);
427        println!("Initial state: {:?}", sv.as_vec());
428        sv.apply_single_qubit_gate(&h_matrix, 1); // Changed from 0 to 1
429
430        // Print state for debugging
431        println!("After H on qubit 1:");
432        println!("amplitude[0] = {:?}", sv.get_amplitude(0));
433        println!("amplitude[1] = {:?}", sv.get_amplitude(1));
434        println!("amplitude[2] = {:?}", sv.get_amplitude(2));
435        println!("amplitude[3] = {:?}", sv.get_amplitude(3));
436
437        // Result should be |00> + |10> / sqrt(2)
438        assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
439        assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
440        assert!((sv.get_amplitude(2) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
441        assert!((sv.get_amplitude(3) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
442
443        // Apply H to the 1st qubit (actually 0th in our implementation)
444        sv.apply_single_qubit_gate(&h_matrix, 0);
445
446        // Result should be (|00> + |01> + |10> - |11>) / 2
447        // Add debug output
448        println!("After both H gates:");
449        println!("amplitude[0] = {:?}", sv.get_amplitude(0));
450        println!("amplitude[1] = {:?}", sv.get_amplitude(1));
451        println!("amplitude[2] = {:?}", sv.get_amplitude(2));
452        println!("amplitude[3] = {:?}", sv.get_amplitude(3));
453
454        assert!((sv.get_amplitude(0) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
455        assert!((sv.get_amplitude(1) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
456        assert!((sv.get_amplitude(2) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
457        assert!((sv.get_amplitude(3) - Complex64::new(0.5, 0.0)).norm() < 1e-10);
458    }
459
460    #[test]
461    fn test_cnot_gate_chunked() {
462        // Set up state |+0> = (|00> + |10>) / sqrt(2)
463        let mut sv = ChunkedStateVector::new(2);
464
465        // Hadamard on qubit 0
466        let h_matrix = [
467            Complex64::new(FRAC_1_SQRT_2, 0.0),
468            Complex64::new(FRAC_1_SQRT_2, 0.0),
469            Complex64::new(FRAC_1_SQRT_2, 0.0),
470            Complex64::new(-FRAC_1_SQRT_2, 0.0),
471        ];
472        sv.apply_single_qubit_gate(&h_matrix, 0);
473
474        // Apply CNOT
475        sv.apply_cnot(0, 1);
476
477        // Result should be (|00> + |11>) / sqrt(2) = Bell state
478        assert!((sv.get_amplitude(0) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
479        assert!((sv.get_amplitude(1) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
480        assert!((sv.get_amplitude(2) - Complex64::new(0.0, 0.0)).norm() < 1e-10);
481        assert!((sv.get_amplitude(3) - Complex64::new(FRAC_1_SQRT_2, 0.0)).norm() < 1e-10);
482    }
483}