quantrs2_sim/mixed_precision_impl/
state_vector.rs

1//! Mixed-precision state vector implementations for quantum simulation.
2//!
3//! This module provides state vector representations that can dynamically
4//! switch between different numerical precisions based on accuracy requirements
5//! and performance constraints.
6
7use scirs2_core::ndarray::Array1;
8use scirs2_core::{Complex32, Complex64};
9
10use super::config::QuantumPrecision;
11use crate::error::{Result, SimulatorError};
12
13/// Mixed-precision state vector
14pub enum MixedPrecisionStateVector {
15    /// Half precision state vector (using Complex32 as approximation)
16    Half(Array1<Complex32>),
17    /// BFloat16 precision state vector (using Complex32 as storage)
18    BFloat16(Array1<Complex32>),
19    /// TF32 precision state vector (stored as Complex32)
20    TF32(Array1<Complex32>),
21    /// Single precision state vector
22    Single(Array1<Complex32>),
23    /// Double precision state vector
24    Double(Array1<Complex64>),
25    /// Adaptive precision with multiple representations
26    Adaptive {
27        primary: Box<Self>,
28        secondary: Option<Box<Self>>,
29        precision_map: Vec<QuantumPrecision>,
30    },
31}
32
33impl MixedPrecisionStateVector {
34    /// Create a new state vector with the specified precision
35    #[must_use]
36    pub fn new(size: usize, precision: QuantumPrecision) -> Self {
37        match precision {
38            QuantumPrecision::Half => Self::Half(Array1::zeros(size)),
39            QuantumPrecision::BFloat16 => Self::BFloat16(Array1::zeros(size)),
40            QuantumPrecision::TF32 => Self::TF32(Array1::zeros(size)),
41            QuantumPrecision::Single => Self::Single(Array1::zeros(size)),
42            QuantumPrecision::Double => Self::Double(Array1::zeros(size)),
43            QuantumPrecision::Adaptive => {
44                // Start with single precision for adaptive
45                let primary = Box::new(Self::Single(Array1::zeros(size)));
46                Self::Adaptive {
47                    primary,
48                    secondary: None,
49                    precision_map: vec![QuantumPrecision::Single; size],
50                }
51            }
52        }
53    }
54
55    /// Create a computational basis state |0...0>
56    #[must_use]
57    pub fn computational_basis(num_qubits: usize, precision: QuantumPrecision) -> Self {
58        let size = 1 << num_qubits;
59        let mut state = Self::new(size, precision);
60
61        // Set |0...0> state
62        match &mut state {
63            Self::Half(ref mut arr)
64            | Self::BFloat16(ref mut arr)
65            | Self::TF32(ref mut arr)
66            | Self::Single(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
67            Self::Double(ref mut arr) => arr[0] = Complex64::new(1.0, 0.0),
68            Self::Adaptive {
69                ref mut primary, ..
70            } => {
71                **primary = Self::computational_basis(num_qubits, QuantumPrecision::Single);
72            }
73        }
74
75        state
76    }
77
78    /// Get the length of the state vector
79    #[must_use]
80    pub fn len(&self) -> usize {
81        match self {
82            Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
83                arr.len()
84            }
85            Self::Double(arr) => arr.len(),
86            Self::Adaptive { primary, .. } => primary.len(),
87        }
88    }
89
90    /// Check if the state vector is empty
91    #[must_use]
92    pub fn is_empty(&self) -> bool {
93        self.len() == 0
94    }
95
96    /// Get the current precision of the state vector
97    #[must_use]
98    pub const fn precision(&self) -> QuantumPrecision {
99        match self {
100            Self::Half(_) => QuantumPrecision::Half,
101            Self::BFloat16(_) => QuantumPrecision::BFloat16,
102            Self::TF32(_) => QuantumPrecision::TF32,
103            Self::Single(_) => QuantumPrecision::Single,
104            Self::Double(_) => QuantumPrecision::Double,
105            Self::Adaptive { .. } => QuantumPrecision::Adaptive,
106        }
107    }
108
109    /// Convert to a specific precision
110    pub fn to_precision(&self, target_precision: QuantumPrecision) -> Result<Self> {
111        if self.precision() == target_precision {
112            return Ok(self.clone());
113        }
114
115        let size = self.len();
116        let mut result = Self::new(size, target_precision);
117
118        match (self, &mut result) {
119            (Self::Single(src), Self::Double(dst)) => {
120                for (i, &val) in src.iter().enumerate() {
121                    dst[i] = Complex64::new(f64::from(val.re), f64::from(val.im));
122                }
123            }
124            (Self::Double(src), Self::Single(dst)) => {
125                for (i, &val) in src.iter().enumerate() {
126                    dst[i] = Complex32::new(val.re as f32, val.im as f32);
127                }
128            }
129            (Self::Half(src), Self::Single(dst)) => {
130                dst.clone_from(src);
131            }
132            (Self::Single(src), Self::Half(dst)) => {
133                dst.clone_from(src);
134            }
135            (Self::Half(src), Self::Double(dst)) => {
136                for (i, &val) in src.iter().enumerate() {
137                    dst[i] = Complex64::new(f64::from(val.re), f64::from(val.im));
138                }
139            }
140            (Self::Double(src), Self::Half(dst)) => {
141                for (i, &val) in src.iter().enumerate() {
142                    dst[i] = Complex32::new(val.re as f32, val.im as f32);
143                }
144            }
145            _ => {
146                return Err(SimulatorError::UnsupportedOperation(
147                    "Complex precision conversion not supported".to_string(),
148                ));
149            }
150        }
151
152        Ok(result)
153    }
154
155    /// Normalize the state vector
156    pub fn normalize(&mut self) -> Result<()> {
157        let norm = self.norm();
158        if norm == 0.0 {
159            return Err(SimulatorError::InvalidInput(
160                "Cannot normalize zero vector".to_string(),
161            ));
162        }
163
164        match self {
165            Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
166                let norm_f32 = norm as f32;
167                for val in arr.iter_mut() {
168                    *val /= norm_f32;
169                }
170            }
171            Self::Double(arr) => {
172                for val in arr.iter_mut() {
173                    *val /= norm;
174                }
175            }
176            Self::Adaptive {
177                ref mut primary, ..
178            } => {
179                primary.normalize()?;
180            }
181        }
182
183        Ok(())
184    }
185
186    /// Calculate the L2 norm of the state vector
187    #[must_use]
188    pub fn norm(&self) -> f64 {
189        match self {
190            Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => arr
191                .iter()
192                .map(|x| f64::from(x.norm_sqr()))
193                .sum::<f64>()
194                .sqrt(),
195            Self::Double(arr) => arr
196                .iter()
197                .map(scirs2_core::Complex::norm_sqr)
198                .sum::<f64>()
199                .sqrt(),
200            Self::Adaptive { primary, .. } => primary.norm(),
201        }
202    }
203
204    /// Calculate probability of measuring a specific state
205    pub fn probability(&self, index: usize) -> Result<f64> {
206        if index >= self.len() {
207            return Err(SimulatorError::InvalidInput(format!(
208                "Index {} out of bounds for state vector of length {}",
209                index,
210                self.len()
211            )));
212        }
213
214        let prob = match self {
215            Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
216                f64::from(arr[index].norm_sqr())
217            }
218            Self::Double(arr) => arr[index].norm_sqr(),
219            Self::Adaptive { primary, .. } => primary.probability(index)?,
220        };
221
222        Ok(prob)
223    }
224
225    /// Get amplitude at a specific index as Complex64
226    pub fn amplitude(&self, index: usize) -> Result<Complex64> {
227        if index >= self.len() {
228            return Err(SimulatorError::InvalidInput(format!(
229                "Index {} out of bounds for state vector of length {}",
230                index,
231                self.len()
232            )));
233        }
234
235        let amplitude = match self {
236            Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
237                let val = arr[index];
238                Complex64::new(f64::from(val.re), f64::from(val.im))
239            }
240            Self::Double(arr) => arr[index],
241            Self::Adaptive { primary, .. } => primary.amplitude(index)?,
242        };
243
244        Ok(amplitude)
245    }
246
247    /// Set amplitude at a specific index
248    pub fn set_amplitude(&mut self, index: usize, amplitude: Complex64) -> Result<()> {
249        if index >= self.len() {
250            return Err(SimulatorError::InvalidInput(format!(
251                "Index {} out of bounds for state vector of length {}",
252                index,
253                self.len()
254            )));
255        }
256
257        match self {
258            Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
259                arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
260            }
261            Self::Double(arr) => {
262                arr[index] = amplitude;
263            }
264            Self::Adaptive {
265                ref mut primary, ..
266            } => {
267                primary.set_amplitude(index, amplitude)?;
268            }
269        }
270
271        Ok(())
272    }
273
274    /// Calculate fidelity with another state vector
275    pub fn fidelity(&self, other: &Self) -> Result<f64> {
276        if self.len() != other.len() {
277            return Err(SimulatorError::InvalidInput(
278                "State vectors must have the same length for fidelity calculation".to_string(),
279            ));
280        }
281
282        let mut inner_product = Complex64::new(0.0, 0.0);
283
284        for i in 0..self.len() {
285            let amp1 = self.amplitude(i)?;
286            let amp2 = other.amplitude(i)?;
287            inner_product += amp1.conj() * amp2;
288        }
289
290        Ok(inner_product.norm_sqr())
291    }
292
293    /// Clone the state vector to a specific precision
294    pub fn clone_to_precision(&self, precision: QuantumPrecision) -> Result<Self> {
295        self.to_precision(precision)
296    }
297
298    /// Estimate memory usage in bytes
299    #[must_use]
300    pub fn memory_usage(&self) -> usize {
301        match self {
302            Self::Half(arr) | Self::BFloat16(arr) | Self::TF32(arr) | Self::Single(arr) => {
303                arr.len() * std::mem::size_of::<Complex32>()
304            }
305            Self::Double(arr) => arr.len() * std::mem::size_of::<Complex64>(),
306            Self::Adaptive {
307                primary, secondary, ..
308            } => {
309                let mut usage = primary.memory_usage();
310                if let Some(sec) = secondary {
311                    usage += sec.memory_usage();
312                }
313                usage += std::mem::size_of::<QuantumPrecision>() * primary.len(); // precision_map
314                usage
315            }
316        }
317    }
318
319    /// Check if the state vector is normalized (within tolerance)
320    #[must_use]
321    pub fn is_normalized(&self, tolerance: f64) -> bool {
322        (self.norm() - 1.0).abs() < tolerance
323    }
324
325    /// Get the number of qubits this state vector represents
326    #[must_use]
327    pub fn num_qubits(&self) -> usize {
328        (self.len() as f64).log2() as usize
329    }
330}
331
332impl Clone for MixedPrecisionStateVector {
333    fn clone(&self) -> Self {
334        match self {
335            Self::Half(arr) => Self::Half(arr.clone()),
336            Self::BFloat16(arr) => Self::BFloat16(arr.clone()),
337            Self::TF32(arr) => Self::TF32(arr.clone()),
338            Self::Single(arr) => Self::Single(arr.clone()),
339            Self::Double(arr) => Self::Double(arr.clone()),
340            Self::Adaptive {
341                primary,
342                secondary,
343                precision_map,
344            } => Self::Adaptive {
345                primary: primary.clone(),
346                secondary: secondary.clone(),
347                precision_map: precision_map.clone(),
348            },
349        }
350    }
351}
352
353impl std::fmt::Debug for MixedPrecisionStateVector {
354    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        match self {
356            Self::Half(arr) => write!(f, "Half({} elements)", arr.len()),
357            Self::BFloat16(arr) => write!(f, "BFloat16({} elements)", arr.len()),
358            Self::TF32(arr) => write!(f, "TF32({} elements)", arr.len()),
359            Self::Single(arr) => write!(f, "Single({} elements)", arr.len()),
360            Self::Double(arr) => write!(f, "Double({} elements)", arr.len()),
361            Self::Adaptive {
362                primary, secondary, ..
363            } => {
364                write!(
365                    f,
366                    "Adaptive(primary: {primary:?}, secondary: {secondary:?})"
367                )
368            }
369        }
370    }
371}