quantrs2_sim/mixed_precision_impl/
state_vector.rs

1//! Mixed-precision state vector implementations for quantum simulation.
2//!
3//! This module provides state vector representations that can dynamically
4//! switch between different numerical precisions based on accuracy requirements
5//! and performance constraints.
6
7use scirs2_core::ndarray::Array1;
8use scirs2_core::{Complex32, Complex64};
9
10use super::config::QuantumPrecision;
11use crate::error::{Result, SimulatorError};
12
13/// Mixed-precision state vector
14pub enum MixedPrecisionStateVector {
15    /// Half precision state vector (using Complex32 as approximation)
16    Half(Array1<Complex32>),
17    /// Single precision state vector
18    Single(Array1<Complex32>),
19    /// Double precision state vector
20    Double(Array1<Complex64>),
21    /// Adaptive precision with multiple representations
22    Adaptive {
23        primary: Box<Self>,
24        secondary: Option<Box<Self>>,
25        precision_map: Vec<QuantumPrecision>,
26    },
27}
28
29impl MixedPrecisionStateVector {
30    /// Create a new state vector with the specified precision
31    #[must_use]
32    pub fn new(size: usize, precision: QuantumPrecision) -> Self {
33        match precision {
34            QuantumPrecision::Half => Self::Half(Array1::zeros(size)),
35            QuantumPrecision::Single => Self::Single(Array1::zeros(size)),
36            QuantumPrecision::Double => Self::Double(Array1::zeros(size)),
37            QuantumPrecision::Adaptive => {
38                // Start with single precision for adaptive
39                let primary = Box::new(Self::Single(Array1::zeros(size)));
40                Self::Adaptive {
41                    primary,
42                    secondary: None,
43                    precision_map: vec![QuantumPrecision::Single; size],
44                }
45            }
46        }
47    }
48
49    /// Create a computational basis state |0...0>
50    #[must_use]
51    pub fn computational_basis(num_qubits: usize, precision: QuantumPrecision) -> Self {
52        let size = 1 << num_qubits;
53        let mut state = Self::new(size, precision);
54
55        // Set |0...0> state
56        match &mut state {
57            Self::Half(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
58            Self::Single(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
59            Self::Double(ref mut arr) => arr[0] = Complex64::new(1.0, 0.0),
60            Self::Adaptive {
61                ref mut primary, ..
62            } => {
63                **primary = Self::computational_basis(num_qubits, QuantumPrecision::Single);
64            }
65        }
66
67        state
68    }
69
70    /// Get the length of the state vector
71    #[must_use]
72    pub fn len(&self) -> usize {
73        match self {
74            Self::Half(arr) => arr.len(),
75            Self::Single(arr) => arr.len(),
76            Self::Double(arr) => arr.len(),
77            Self::Adaptive { primary, .. } => primary.len(),
78        }
79    }
80
81    /// Check if the state vector is empty
82    #[must_use]
83    pub fn is_empty(&self) -> bool {
84        self.len() == 0
85    }
86
87    /// Get the current precision of the state vector
88    #[must_use]
89    pub const fn precision(&self) -> QuantumPrecision {
90        match self {
91            Self::Half(_) => QuantumPrecision::Half,
92            Self::Single(_) => QuantumPrecision::Single,
93            Self::Double(_) => QuantumPrecision::Double,
94            Self::Adaptive { .. } => QuantumPrecision::Adaptive,
95        }
96    }
97
98    /// Convert to a specific precision
99    pub fn to_precision(&self, target_precision: QuantumPrecision) -> Result<Self> {
100        if self.precision() == target_precision {
101            return Ok(self.clone());
102        }
103
104        let size = self.len();
105        let mut result = Self::new(size, target_precision);
106
107        match (self, &mut result) {
108            (Self::Single(src), Self::Double(dst)) => {
109                for (i, &val) in src.iter().enumerate() {
110                    dst[i] = Complex64::new(f64::from(val.re), f64::from(val.im));
111                }
112            }
113            (Self::Double(src), Self::Single(dst)) => {
114                for (i, &val) in src.iter().enumerate() {
115                    dst[i] = Complex32::new(val.re as f32, val.im as f32);
116                }
117            }
118            (Self::Half(src), Self::Single(dst)) => {
119                dst.clone_from(src);
120            }
121            (Self::Single(src), Self::Half(dst)) => {
122                dst.clone_from(src);
123            }
124            (Self::Half(src), Self::Double(dst)) => {
125                for (i, &val) in src.iter().enumerate() {
126                    dst[i] = Complex64::new(f64::from(val.re), f64::from(val.im));
127                }
128            }
129            (Self::Double(src), Self::Half(dst)) => {
130                for (i, &val) in src.iter().enumerate() {
131                    dst[i] = Complex32::new(val.re as f32, val.im as f32);
132                }
133            }
134            _ => {
135                return Err(SimulatorError::UnsupportedOperation(
136                    "Complex precision conversion not supported".to_string(),
137                ));
138            }
139        }
140
141        Ok(result)
142    }
143
144    /// Normalize the state vector
145    pub fn normalize(&mut self) -> Result<()> {
146        let norm = self.norm();
147        if norm == 0.0 {
148            return Err(SimulatorError::InvalidInput(
149                "Cannot normalize zero vector".to_string(),
150            ));
151        }
152
153        match self {
154            Self::Half(arr) => {
155                let norm_f32 = norm as f32;
156                for val in arr.iter_mut() {
157                    *val /= norm_f32;
158                }
159            }
160            Self::Single(arr) => {
161                let norm_f32 = norm as f32;
162                for val in arr.iter_mut() {
163                    *val /= norm_f32;
164                }
165            }
166            Self::Double(arr) => {
167                for val in arr.iter_mut() {
168                    *val /= norm;
169                }
170            }
171            Self::Adaptive {
172                ref mut primary, ..
173            } => {
174                primary.normalize()?;
175            }
176        }
177
178        Ok(())
179    }
180
181    /// Calculate the L2 norm of the state vector
182    #[must_use]
183    pub fn norm(&self) -> f64 {
184        match self {
185            Self::Half(arr) => arr
186                .iter()
187                .map(|x| f64::from(x.norm_sqr()))
188                .sum::<f64>()
189                .sqrt(),
190            Self::Single(arr) => arr
191                .iter()
192                .map(|x| f64::from(x.norm_sqr()))
193                .sum::<f64>()
194                .sqrt(),
195            Self::Double(arr) => arr
196                .iter()
197                .map(scirs2_core::Complex::norm_sqr)
198                .sum::<f64>()
199                .sqrt(),
200            Self::Adaptive { primary, .. } => primary.norm(),
201        }
202    }
203
204    /// Calculate probability of measuring a specific state
205    pub fn probability(&self, index: usize) -> Result<f64> {
206        if index >= self.len() {
207            return Err(SimulatorError::InvalidInput(format!(
208                "Index {} out of bounds for state vector of length {}",
209                index,
210                self.len()
211            )));
212        }
213
214        let prob = match self {
215            Self::Half(arr) => f64::from(arr[index].norm_sqr()),
216            Self::Single(arr) => f64::from(arr[index].norm_sqr()),
217            Self::Double(arr) => arr[index].norm_sqr(),
218            Self::Adaptive { primary, .. } => primary.probability(index)?,
219        };
220
221        Ok(prob)
222    }
223
224    /// Get amplitude at a specific index as Complex64
225    pub fn amplitude(&self, index: usize) -> Result<Complex64> {
226        if index >= self.len() {
227            return Err(SimulatorError::InvalidInput(format!(
228                "Index {} out of bounds for state vector of length {}",
229                index,
230                self.len()
231            )));
232        }
233
234        let amplitude = match self {
235            Self::Half(arr) => {
236                let val = arr[index];
237                Complex64::new(f64::from(val.re), f64::from(val.im))
238            }
239            Self::Single(arr) => {
240                let val = arr[index];
241                Complex64::new(f64::from(val.re), f64::from(val.im))
242            }
243            Self::Double(arr) => arr[index],
244            Self::Adaptive { primary, .. } => primary.amplitude(index)?,
245        };
246
247        Ok(amplitude)
248    }
249
250    /// Set amplitude at a specific index
251    pub fn set_amplitude(&mut self, index: usize, amplitude: Complex64) -> Result<()> {
252        if index >= self.len() {
253            return Err(SimulatorError::InvalidInput(format!(
254                "Index {} out of bounds for state vector of length {}",
255                index,
256                self.len()
257            )));
258        }
259
260        match self {
261            Self::Half(arr) => {
262                arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
263            }
264            Self::Single(arr) => {
265                arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
266            }
267            Self::Double(arr) => {
268                arr[index] = amplitude;
269            }
270            Self::Adaptive {
271                ref mut primary, ..
272            } => {
273                primary.set_amplitude(index, amplitude)?;
274            }
275        }
276
277        Ok(())
278    }
279
280    /// Calculate fidelity with another state vector
281    pub fn fidelity(&self, other: &Self) -> Result<f64> {
282        if self.len() != other.len() {
283            return Err(SimulatorError::InvalidInput(
284                "State vectors must have the same length for fidelity calculation".to_string(),
285            ));
286        }
287
288        let mut inner_product = Complex64::new(0.0, 0.0);
289
290        for i in 0..self.len() {
291            let amp1 = self.amplitude(i)?;
292            let amp2 = other.amplitude(i)?;
293            inner_product += amp1.conj() * amp2;
294        }
295
296        Ok(inner_product.norm_sqr())
297    }
298
299    /// Clone the state vector to a specific precision
300    pub fn clone_to_precision(&self, precision: QuantumPrecision) -> Result<Self> {
301        self.to_precision(precision)
302    }
303
304    /// Estimate memory usage in bytes
305    #[must_use]
306    pub fn memory_usage(&self) -> usize {
307        match self {
308            Self::Half(arr) => arr.len() * std::mem::size_of::<Complex32>(),
309            Self::Single(arr) => arr.len() * std::mem::size_of::<Complex32>(),
310            Self::Double(arr) => arr.len() * std::mem::size_of::<Complex64>(),
311            Self::Adaptive {
312                primary, secondary, ..
313            } => {
314                let mut usage = primary.memory_usage();
315                if let Some(sec) = secondary {
316                    usage += sec.memory_usage();
317                }
318                usage += std::mem::size_of::<QuantumPrecision>() * primary.len(); // precision_map
319                usage
320            }
321        }
322    }
323
324    /// Check if the state vector is normalized (within tolerance)
325    #[must_use]
326    pub fn is_normalized(&self, tolerance: f64) -> bool {
327        (self.norm() - 1.0).abs() < tolerance
328    }
329
330    /// Get the number of qubits this state vector represents
331    #[must_use]
332    pub fn num_qubits(&self) -> usize {
333        (self.len() as f64).log2() as usize
334    }
335}
336
337impl Clone for MixedPrecisionStateVector {
338    fn clone(&self) -> Self {
339        match self {
340            Self::Half(arr) => Self::Half(arr.clone()),
341            Self::Single(arr) => Self::Single(arr.clone()),
342            Self::Double(arr) => Self::Double(arr.clone()),
343            Self::Adaptive {
344                primary,
345                secondary,
346                precision_map,
347            } => Self::Adaptive {
348                primary: primary.clone(),
349                secondary: secondary.clone(),
350                precision_map: precision_map.clone(),
351            },
352        }
353    }
354}
355
356impl std::fmt::Debug for MixedPrecisionStateVector {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        match self {
359            Self::Half(arr) => write!(f, "Half({} elements)", arr.len()),
360            Self::Single(arr) => write!(f, "Single({} elements)", arr.len()),
361            Self::Double(arr) => write!(f, "Double({} elements)", arr.len()),
362            Self::Adaptive {
363                primary, secondary, ..
364            } => {
365                write!(
366                    f,
367                    "Adaptive(primary: {primary:?}, secondary: {secondary:?})"
368                )
369            }
370        }
371    }
372}