quantrs2_sim/mixed_precision_impl/
state_vector.rs

1//! Mixed-precision state vector implementations for quantum simulation.
2//!
3//! This module provides state vector representations that can dynamically
4//! switch between different numerical precisions based on accuracy requirements
5//! and performance constraints.
6
7use scirs2_core::ndarray::Array1;
8use scirs2_core::{Complex32, Complex64};
9
10use super::config::QuantumPrecision;
11use crate::error::{Result, SimulatorError};
12
13/// Mixed-precision state vector
14pub enum MixedPrecisionStateVector {
15    /// Half precision state vector (using Complex32 as approximation)
16    Half(Array1<Complex32>),
17    /// Single precision state vector
18    Single(Array1<Complex32>),
19    /// Double precision state vector
20    Double(Array1<Complex64>),
21    /// Adaptive precision with multiple representations
22    Adaptive {
23        primary: Box<MixedPrecisionStateVector>,
24        secondary: Option<Box<MixedPrecisionStateVector>>,
25        precision_map: Vec<QuantumPrecision>,
26    },
27}
28
29impl MixedPrecisionStateVector {
30    /// Create a new state vector with the specified precision
31    pub fn new(size: usize, precision: QuantumPrecision) -> Self {
32        match precision {
33            QuantumPrecision::Half => Self::Half(Array1::zeros(size)),
34            QuantumPrecision::Single => Self::Single(Array1::zeros(size)),
35            QuantumPrecision::Double => Self::Double(Array1::zeros(size)),
36            QuantumPrecision::Adaptive => {
37                // Start with single precision for adaptive
38                let primary = Box::new(Self::Single(Array1::zeros(size)));
39                Self::Adaptive {
40                    primary,
41                    secondary: None,
42                    precision_map: vec![QuantumPrecision::Single; size],
43                }
44            }
45        }
46    }
47
48    /// Create a computational basis state |0...0>
49    pub fn computational_basis(num_qubits: usize, precision: QuantumPrecision) -> Self {
50        let size = 1 << num_qubits;
51        let mut state = Self::new(size, precision);
52
53        // Set |0...0> state
54        match &mut state {
55            Self::Half(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
56            Self::Single(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
57            Self::Double(ref mut arr) => arr[0] = Complex64::new(1.0, 0.0),
58            Self::Adaptive {
59                ref mut primary, ..
60            } => {
61                *primary = Box::new(Self::computational_basis(
62                    num_qubits,
63                    QuantumPrecision::Single,
64                ));
65            }
66        }
67
68        state
69    }
70
71    /// Get the length of the state vector
72    pub fn len(&self) -> usize {
73        match self {
74            Self::Half(arr) => arr.len(),
75            Self::Single(arr) => arr.len(),
76            Self::Double(arr) => arr.len(),
77            Self::Adaptive { primary, .. } => primary.len(),
78        }
79    }
80
81    /// Check if the state vector is empty
82    pub fn is_empty(&self) -> bool {
83        self.len() == 0
84    }
85
86    /// Get the current precision of the state vector
87    pub fn precision(&self) -> QuantumPrecision {
88        match self {
89            Self::Half(_) => QuantumPrecision::Half,
90            Self::Single(_) => QuantumPrecision::Single,
91            Self::Double(_) => QuantumPrecision::Double,
92            Self::Adaptive { .. } => QuantumPrecision::Adaptive,
93        }
94    }
95
96    /// Convert to a specific precision
97    pub fn to_precision(&self, target_precision: QuantumPrecision) -> Result<Self> {
98        if self.precision() == target_precision {
99            return Ok(self.clone());
100        }
101
102        let size = self.len();
103        let mut result = Self::new(size, target_precision);
104
105        match (self, &mut result) {
106            (Self::Single(src), Self::Double(dst)) => {
107                for (i, &val) in src.iter().enumerate() {
108                    dst[i] = Complex64::new(val.re as f64, val.im as f64);
109                }
110            }
111            (Self::Double(src), Self::Single(dst)) => {
112                for (i, &val) in src.iter().enumerate() {
113                    dst[i] = Complex32::new(val.re as f32, val.im as f32);
114                }
115            }
116            (Self::Half(src), Self::Single(dst)) => {
117                *dst = src.clone();
118            }
119            (Self::Single(src), Self::Half(dst)) => {
120                *dst = src.clone();
121            }
122            (Self::Half(src), Self::Double(dst)) => {
123                for (i, &val) in src.iter().enumerate() {
124                    dst[i] = Complex64::new(val.re as f64, val.im as f64);
125                }
126            }
127            (Self::Double(src), Self::Half(dst)) => {
128                for (i, &val) in src.iter().enumerate() {
129                    dst[i] = Complex32::new(val.re as f32, val.im as f32);
130                }
131            }
132            _ => {
133                return Err(SimulatorError::UnsupportedOperation(
134                    "Complex precision conversion not supported".to_string(),
135                ));
136            }
137        }
138
139        Ok(result)
140    }
141
142    /// Normalize the state vector
143    pub fn normalize(&mut self) -> Result<()> {
144        let norm = self.norm();
145        if norm == 0.0 {
146            return Err(SimulatorError::InvalidInput(
147                "Cannot normalize zero vector".to_string(),
148            ));
149        }
150
151        match self {
152            Self::Half(arr) => {
153                let norm_f32 = norm as f32;
154                for val in arr.iter_mut() {
155                    *val /= norm_f32;
156                }
157            }
158            Self::Single(arr) => {
159                let norm_f32 = norm as f32;
160                for val in arr.iter_mut() {
161                    *val /= norm_f32;
162                }
163            }
164            Self::Double(arr) => {
165                for val in arr.iter_mut() {
166                    *val /= norm;
167                }
168            }
169            Self::Adaptive {
170                ref mut primary, ..
171            } => {
172                primary.normalize()?;
173            }
174        }
175
176        Ok(())
177    }
178
179    /// Calculate the L2 norm of the state vector
180    pub fn norm(&self) -> f64 {
181        match self {
182            Self::Half(arr) => arr.iter().map(|x| x.norm_sqr() as f64).sum::<f64>().sqrt(),
183            Self::Single(arr) => arr.iter().map(|x| x.norm_sqr() as f64).sum::<f64>().sqrt(),
184            Self::Double(arr) => arr.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt(),
185            Self::Adaptive { primary, .. } => primary.norm(),
186        }
187    }
188
189    /// Calculate probability of measuring a specific state
190    pub fn probability(&self, index: usize) -> Result<f64> {
191        if index >= self.len() {
192            return Err(SimulatorError::InvalidInput(format!(
193                "Index {} out of bounds for state vector of length {}",
194                index,
195                self.len()
196            )));
197        }
198
199        let prob = match self {
200            Self::Half(arr) => arr[index].norm_sqr() as f64,
201            Self::Single(arr) => arr[index].norm_sqr() as f64,
202            Self::Double(arr) => arr[index].norm_sqr(),
203            Self::Adaptive { primary, .. } => primary.probability(index)?,
204        };
205
206        Ok(prob)
207    }
208
209    /// Get amplitude at a specific index as Complex64
210    pub fn amplitude(&self, index: usize) -> Result<Complex64> {
211        if index >= self.len() {
212            return Err(SimulatorError::InvalidInput(format!(
213                "Index {} out of bounds for state vector of length {}",
214                index,
215                self.len()
216            )));
217        }
218
219        let amplitude = match self {
220            Self::Half(arr) => {
221                let val = arr[index];
222                Complex64::new(val.re as f64, val.im as f64)
223            }
224            Self::Single(arr) => {
225                let val = arr[index];
226                Complex64::new(val.re as f64, val.im as f64)
227            }
228            Self::Double(arr) => arr[index],
229            Self::Adaptive { primary, .. } => primary.amplitude(index)?,
230        };
231
232        Ok(amplitude)
233    }
234
235    /// Set amplitude at a specific index
236    pub fn set_amplitude(&mut self, index: usize, amplitude: Complex64) -> Result<()> {
237        if index >= self.len() {
238            return Err(SimulatorError::InvalidInput(format!(
239                "Index {} out of bounds for state vector of length {}",
240                index,
241                self.len()
242            )));
243        }
244
245        match self {
246            Self::Half(arr) => {
247                arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
248            }
249            Self::Single(arr) => {
250                arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
251            }
252            Self::Double(arr) => {
253                arr[index] = amplitude;
254            }
255            Self::Adaptive {
256                ref mut primary, ..
257            } => {
258                primary.set_amplitude(index, amplitude)?;
259            }
260        }
261
262        Ok(())
263    }
264
265    /// Calculate fidelity with another state vector
266    pub fn fidelity(&self, other: &Self) -> Result<f64> {
267        if self.len() != other.len() {
268            return Err(SimulatorError::InvalidInput(
269                "State vectors must have the same length for fidelity calculation".to_string(),
270            ));
271        }
272
273        let mut inner_product = Complex64::new(0.0, 0.0);
274
275        for i in 0..self.len() {
276            let amp1 = self.amplitude(i)?;
277            let amp2 = other.amplitude(i)?;
278            inner_product += amp1.conj() * amp2;
279        }
280
281        Ok(inner_product.norm_sqr())
282    }
283
284    /// Clone the state vector to a specific precision
285    pub fn clone_to_precision(&self, precision: QuantumPrecision) -> Result<Self> {
286        self.to_precision(precision)
287    }
288
289    /// Estimate memory usage in bytes
290    pub fn memory_usage(&self) -> usize {
291        match self {
292            Self::Half(arr) => arr.len() * std::mem::size_of::<Complex32>(),
293            Self::Single(arr) => arr.len() * std::mem::size_of::<Complex32>(),
294            Self::Double(arr) => arr.len() * std::mem::size_of::<Complex64>(),
295            Self::Adaptive {
296                primary, secondary, ..
297            } => {
298                let mut usage = primary.memory_usage();
299                if let Some(sec) = secondary {
300                    usage += sec.memory_usage();
301                }
302                usage += std::mem::size_of::<QuantumPrecision>() * primary.len(); // precision_map
303                usage
304            }
305        }
306    }
307
308    /// Check if the state vector is normalized (within tolerance)
309    pub fn is_normalized(&self, tolerance: f64) -> bool {
310        (self.norm() - 1.0).abs() < tolerance
311    }
312
313    /// Get the number of qubits this state vector represents
314    pub fn num_qubits(&self) -> usize {
315        (self.len() as f64).log2() as usize
316    }
317}
318
319impl Clone for MixedPrecisionStateVector {
320    fn clone(&self) -> Self {
321        match self {
322            Self::Half(arr) => Self::Half(arr.clone()),
323            Self::Single(arr) => Self::Single(arr.clone()),
324            Self::Double(arr) => Self::Double(arr.clone()),
325            Self::Adaptive {
326                primary,
327                secondary,
328                precision_map,
329            } => Self::Adaptive {
330                primary: primary.clone(),
331                secondary: secondary.clone(),
332                precision_map: precision_map.clone(),
333            },
334        }
335    }
336}
337
338impl std::fmt::Debug for MixedPrecisionStateVector {
339    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340        match self {
341            Self::Half(arr) => write!(f, "Half({} elements)", arr.len()),
342            Self::Single(arr) => write!(f, "Single({} elements)", arr.len()),
343            Self::Double(arr) => write!(f, "Double({} elements)", arr.len()),
344            Self::Adaptive {
345                primary, secondary, ..
346            } => {
347                write!(
348                    f,
349                    "Adaptive(primary: {:?}, secondary: {:?})",
350                    primary, secondary
351                )
352            }
353        }
354    }
355}