quantrs2_sim/mixed_precision_impl/
state_vector.rs1use scirs2_core::ndarray::Array1;
8use scirs2_core::{Complex32, Complex64};
9
10use super::config::QuantumPrecision;
11use crate::error::{Result, SimulatorError};
12
13pub enum MixedPrecisionStateVector {
15 Half(Array1<Complex32>),
17 Single(Array1<Complex32>),
19 Double(Array1<Complex64>),
21 Adaptive {
23 primary: Box<MixedPrecisionStateVector>,
24 secondary: Option<Box<MixedPrecisionStateVector>>,
25 precision_map: Vec<QuantumPrecision>,
26 },
27}
28
29impl MixedPrecisionStateVector {
30 pub fn new(size: usize, precision: QuantumPrecision) -> Self {
32 match precision {
33 QuantumPrecision::Half => Self::Half(Array1::zeros(size)),
34 QuantumPrecision::Single => Self::Single(Array1::zeros(size)),
35 QuantumPrecision::Double => Self::Double(Array1::zeros(size)),
36 QuantumPrecision::Adaptive => {
37 let primary = Box::new(Self::Single(Array1::zeros(size)));
39 Self::Adaptive {
40 primary,
41 secondary: None,
42 precision_map: vec![QuantumPrecision::Single; size],
43 }
44 }
45 }
46 }
47
48 pub fn computational_basis(num_qubits: usize, precision: QuantumPrecision) -> Self {
50 let size = 1 << num_qubits;
51 let mut state = Self::new(size, precision);
52
53 match &mut state {
55 Self::Half(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
56 Self::Single(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
57 Self::Double(ref mut arr) => arr[0] = Complex64::new(1.0, 0.0),
58 Self::Adaptive {
59 ref mut primary, ..
60 } => {
61 *primary = Box::new(Self::computational_basis(
62 num_qubits,
63 QuantumPrecision::Single,
64 ));
65 }
66 }
67
68 state
69 }
70
71 pub fn len(&self) -> usize {
73 match self {
74 Self::Half(arr) => arr.len(),
75 Self::Single(arr) => arr.len(),
76 Self::Double(arr) => arr.len(),
77 Self::Adaptive { primary, .. } => primary.len(),
78 }
79 }
80
81 pub fn is_empty(&self) -> bool {
83 self.len() == 0
84 }
85
86 pub fn precision(&self) -> QuantumPrecision {
88 match self {
89 Self::Half(_) => QuantumPrecision::Half,
90 Self::Single(_) => QuantumPrecision::Single,
91 Self::Double(_) => QuantumPrecision::Double,
92 Self::Adaptive { .. } => QuantumPrecision::Adaptive,
93 }
94 }
95
96 pub fn to_precision(&self, target_precision: QuantumPrecision) -> Result<Self> {
98 if self.precision() == target_precision {
99 return Ok(self.clone());
100 }
101
102 let size = self.len();
103 let mut result = Self::new(size, target_precision);
104
105 match (self, &mut result) {
106 (Self::Single(src), Self::Double(dst)) => {
107 for (i, &val) in src.iter().enumerate() {
108 dst[i] = Complex64::new(val.re as f64, val.im as f64);
109 }
110 }
111 (Self::Double(src), Self::Single(dst)) => {
112 for (i, &val) in src.iter().enumerate() {
113 dst[i] = Complex32::new(val.re as f32, val.im as f32);
114 }
115 }
116 (Self::Half(src), Self::Single(dst)) => {
117 *dst = src.clone();
118 }
119 (Self::Single(src), Self::Half(dst)) => {
120 *dst = src.clone();
121 }
122 (Self::Half(src), Self::Double(dst)) => {
123 for (i, &val) in src.iter().enumerate() {
124 dst[i] = Complex64::new(val.re as f64, val.im as f64);
125 }
126 }
127 (Self::Double(src), Self::Half(dst)) => {
128 for (i, &val) in src.iter().enumerate() {
129 dst[i] = Complex32::new(val.re as f32, val.im as f32);
130 }
131 }
132 _ => {
133 return Err(SimulatorError::UnsupportedOperation(
134 "Complex precision conversion not supported".to_string(),
135 ));
136 }
137 }
138
139 Ok(result)
140 }
141
142 pub fn normalize(&mut self) -> Result<()> {
144 let norm = self.norm();
145 if norm == 0.0 {
146 return Err(SimulatorError::InvalidInput(
147 "Cannot normalize zero vector".to_string(),
148 ));
149 }
150
151 match self {
152 Self::Half(arr) => {
153 let norm_f32 = norm as f32;
154 for val in arr.iter_mut() {
155 *val /= norm_f32;
156 }
157 }
158 Self::Single(arr) => {
159 let norm_f32 = norm as f32;
160 for val in arr.iter_mut() {
161 *val /= norm_f32;
162 }
163 }
164 Self::Double(arr) => {
165 for val in arr.iter_mut() {
166 *val /= norm;
167 }
168 }
169 Self::Adaptive {
170 ref mut primary, ..
171 } => {
172 primary.normalize()?;
173 }
174 }
175
176 Ok(())
177 }
178
179 pub fn norm(&self) -> f64 {
181 match self {
182 Self::Half(arr) => arr.iter().map(|x| x.norm_sqr() as f64).sum::<f64>().sqrt(),
183 Self::Single(arr) => arr.iter().map(|x| x.norm_sqr() as f64).sum::<f64>().sqrt(),
184 Self::Double(arr) => arr.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt(),
185 Self::Adaptive { primary, .. } => primary.norm(),
186 }
187 }
188
189 pub fn probability(&self, index: usize) -> Result<f64> {
191 if index >= self.len() {
192 return Err(SimulatorError::InvalidInput(format!(
193 "Index {} out of bounds for state vector of length {}",
194 index,
195 self.len()
196 )));
197 }
198
199 let prob = match self {
200 Self::Half(arr) => arr[index].norm_sqr() as f64,
201 Self::Single(arr) => arr[index].norm_sqr() as f64,
202 Self::Double(arr) => arr[index].norm_sqr(),
203 Self::Adaptive { primary, .. } => primary.probability(index)?,
204 };
205
206 Ok(prob)
207 }
208
209 pub fn amplitude(&self, index: usize) -> Result<Complex64> {
211 if index >= self.len() {
212 return Err(SimulatorError::InvalidInput(format!(
213 "Index {} out of bounds for state vector of length {}",
214 index,
215 self.len()
216 )));
217 }
218
219 let amplitude = match self {
220 Self::Half(arr) => {
221 let val = arr[index];
222 Complex64::new(val.re as f64, val.im as f64)
223 }
224 Self::Single(arr) => {
225 let val = arr[index];
226 Complex64::new(val.re as f64, val.im as f64)
227 }
228 Self::Double(arr) => arr[index],
229 Self::Adaptive { primary, .. } => primary.amplitude(index)?,
230 };
231
232 Ok(amplitude)
233 }
234
235 pub fn set_amplitude(&mut self, index: usize, amplitude: Complex64) -> Result<()> {
237 if index >= self.len() {
238 return Err(SimulatorError::InvalidInput(format!(
239 "Index {} out of bounds for state vector of length {}",
240 index,
241 self.len()
242 )));
243 }
244
245 match self {
246 Self::Half(arr) => {
247 arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
248 }
249 Self::Single(arr) => {
250 arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
251 }
252 Self::Double(arr) => {
253 arr[index] = amplitude;
254 }
255 Self::Adaptive {
256 ref mut primary, ..
257 } => {
258 primary.set_amplitude(index, amplitude)?;
259 }
260 }
261
262 Ok(())
263 }
264
265 pub fn fidelity(&self, other: &Self) -> Result<f64> {
267 if self.len() != other.len() {
268 return Err(SimulatorError::InvalidInput(
269 "State vectors must have the same length for fidelity calculation".to_string(),
270 ));
271 }
272
273 let mut inner_product = Complex64::new(0.0, 0.0);
274
275 for i in 0..self.len() {
276 let amp1 = self.amplitude(i)?;
277 let amp2 = other.amplitude(i)?;
278 inner_product += amp1.conj() * amp2;
279 }
280
281 Ok(inner_product.norm_sqr())
282 }
283
284 pub fn clone_to_precision(&self, precision: QuantumPrecision) -> Result<Self> {
286 self.to_precision(precision)
287 }
288
289 pub fn memory_usage(&self) -> usize {
291 match self {
292 Self::Half(arr) => arr.len() * std::mem::size_of::<Complex32>(),
293 Self::Single(arr) => arr.len() * std::mem::size_of::<Complex32>(),
294 Self::Double(arr) => arr.len() * std::mem::size_of::<Complex64>(),
295 Self::Adaptive {
296 primary, secondary, ..
297 } => {
298 let mut usage = primary.memory_usage();
299 if let Some(sec) = secondary {
300 usage += sec.memory_usage();
301 }
302 usage += std::mem::size_of::<QuantumPrecision>() * primary.len(); usage
304 }
305 }
306 }
307
308 pub fn is_normalized(&self, tolerance: f64) -> bool {
310 (self.norm() - 1.0).abs() < tolerance
311 }
312
313 pub fn num_qubits(&self) -> usize {
315 (self.len() as f64).log2() as usize
316 }
317}
318
319impl Clone for MixedPrecisionStateVector {
320 fn clone(&self) -> Self {
321 match self {
322 Self::Half(arr) => Self::Half(arr.clone()),
323 Self::Single(arr) => Self::Single(arr.clone()),
324 Self::Double(arr) => Self::Double(arr.clone()),
325 Self::Adaptive {
326 primary,
327 secondary,
328 precision_map,
329 } => Self::Adaptive {
330 primary: primary.clone(),
331 secondary: secondary.clone(),
332 precision_map: precision_map.clone(),
333 },
334 }
335 }
336}
337
338impl std::fmt::Debug for MixedPrecisionStateVector {
339 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340 match self {
341 Self::Half(arr) => write!(f, "Half({} elements)", arr.len()),
342 Self::Single(arr) => write!(f, "Single({} elements)", arr.len()),
343 Self::Double(arr) => write!(f, "Double({} elements)", arr.len()),
344 Self::Adaptive {
345 primary, secondary, ..
346 } => {
347 write!(
348 f,
349 "Adaptive(primary: {:?}, secondary: {:?})",
350 primary, secondary
351 )
352 }
353 }
354 }
355}