quantrs2_core/
memory_efficient.rs1use crate::error::{QuantRS2Error, QuantRS2Result};
7use num_complex::Complex64;
8
9pub struct EfficientStateVector {
14 num_qubits: usize,
16 data: Vec<Complex64>,
18}
19
20impl EfficientStateVector {
21 pub fn new(num_qubits: usize) -> QuantRS2Result<Self> {
23 let size = 1 << num_qubits;
24 if size > 1 << 30 {
25 let mut data = vec![Complex64::new(0.0, 0.0); size];
28 data[0] = Complex64::new(1.0, 0.0); Ok(Self { num_qubits, data })
30 } else {
31 let mut data = vec![Complex64::new(0.0, 0.0); size];
33 data[0] = Complex64::new(1.0, 0.0); Ok(Self { num_qubits, data })
35 }
36 }
37
38 pub fn num_qubits(&self) -> usize {
40 self.num_qubits
41 }
42
43 pub fn size(&self) -> usize {
45 self.data.len()
46 }
47
48 pub fn data(&self) -> &[Complex64] {
50 &self.data
51 }
52
53 pub fn data_mut(&mut self) -> &mut [Complex64] {
55 &mut self.data
56 }
57
58 pub fn normalize(&mut self) -> QuantRS2Result<()> {
60 let norm_sqr: f64 = self.data.iter().map(|c| c.norm_sqr()).sum();
61 if norm_sqr == 0.0 {
62 return Err(QuantRS2Error::InvalidInput(
63 "Cannot normalize zero vector".to_string(),
64 ));
65 }
66 let norm = norm_sqr.sqrt();
67 for amplitude in &mut self.data {
68 *amplitude /= norm;
69 }
70 Ok(())
71 }
72
73 pub fn get_probability(&self, basis_state: usize) -> QuantRS2Result<f64> {
75 if basis_state >= self.data.len() {
76 return Err(QuantRS2Error::InvalidInput(format!(
77 "Basis state {} out of range for {} qubits",
78 basis_state, self.num_qubits
79 )));
80 }
81 Ok(self.data[basis_state].norm_sqr())
82 }
83
84 pub fn process_chunks<F>(&mut self, chunk_size: usize, mut f: F) -> QuantRS2Result<()>
89 where
90 F: FnMut(&mut [Complex64], usize),
91 {
92 if chunk_size == 0 || chunk_size > self.data.len() {
93 return Err(QuantRS2Error::InvalidInput(
94 "Invalid chunk size".to_string(),
95 ));
96 }
97
98 for (chunk_idx, chunk) in self.data.chunks_mut(chunk_size).enumerate() {
99 f(chunk, chunk_idx * chunk_size);
100 }
101 Ok(())
102 }
103}
104
105pub struct StateMemoryStats {
107 pub num_amplitudes: usize,
109 pub memory_bytes: usize,
111}
112
113impl EfficientStateVector {
114 pub fn memory_stats(&self) -> StateMemoryStats {
116 StateMemoryStats {
117 num_amplitudes: self.data.len(),
118 memory_bytes: self.data.len() * std::mem::size_of::<Complex64>(),
119 }
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_efficient_state_vector() {
129 let state = EfficientStateVector::new(3).unwrap();
130 assert_eq!(state.num_qubits(), 3);
131 assert_eq!(state.size(), 8);
132
133 assert_eq!(state.data()[0], Complex64::new(1.0, 0.0));
135 for i in 1..8 {
136 assert_eq!(state.data()[i], Complex64::new(0.0, 0.0));
137 }
138 }
139
140 #[test]
141 fn test_normalization() {
142 let mut state = EfficientStateVector::new(2).unwrap();
143 state.data_mut()[0] = Complex64::new(1.0, 0.0);
144 state.data_mut()[1] = Complex64::new(0.0, 1.0);
145 state.data_mut()[2] = Complex64::new(1.0, 0.0);
146 state.data_mut()[3] = Complex64::new(0.0, -1.0);
147
148 state.normalize().unwrap();
149
150 let norm_sqr: f64 = state.data().iter().map(|c| c.norm_sqr()).sum();
151 assert!((norm_sqr - 1.0).abs() < 1e-10);
152 }
153
154 #[test]
155 fn test_chunk_processing() {
156 let mut state = EfficientStateVector::new(3).unwrap();
157
158 state
160 .process_chunks(2, |chunk, start_idx| {
161 for (i, amp) in chunk.iter_mut().enumerate() {
162 *amp = Complex64::new((start_idx + i) as f64, 0.0);
163 }
164 })
165 .unwrap();
166
167 for i in 0..8 {
169 assert_eq!(state.data()[i], Complex64::new(i as f64, 0.0));
170 }
171 }
172}