quantrs2_sim/mixed_precision_impl/
state_vector.rs

1//! Mixed-precision state vector implementations for quantum simulation.
2//!
3//! This module provides state vector representations that can dynamically
4//! switch between different numerical precisions based on accuracy requirements
5//! and performance constraints.
6
7use scirs2_core::ndarray::Array1;
8use scirs2_core::{Complex32, Complex64};
9
10use super::config::QuantumPrecision;
11use crate::error::{Result, SimulatorError};
12
13/// Mixed-precision state vector
14pub enum MixedPrecisionStateVector {
15    /// Half precision state vector (using Complex32 as approximation)
16    Half(Array1<Complex32>),
17    /// Single precision state vector
18    Single(Array1<Complex32>),
19    /// Double precision state vector
20    Double(Array1<Complex64>),
21    /// Adaptive precision with multiple representations
22    Adaptive {
23        primary: Box<Self>,
24        secondary: Option<Box<Self>>,
25        precision_map: Vec<QuantumPrecision>,
26    },
27}
28
29impl MixedPrecisionStateVector {
30    /// Create a new state vector with the specified precision
31    pub fn new(size: usize, precision: QuantumPrecision) -> Self {
32        match precision {
33            QuantumPrecision::Half => Self::Half(Array1::zeros(size)),
34            QuantumPrecision::Single => Self::Single(Array1::zeros(size)),
35            QuantumPrecision::Double => Self::Double(Array1::zeros(size)),
36            QuantumPrecision::Adaptive => {
37                // Start with single precision for adaptive
38                let primary = Box::new(Self::Single(Array1::zeros(size)));
39                Self::Adaptive {
40                    primary,
41                    secondary: None,
42                    precision_map: vec![QuantumPrecision::Single; size],
43                }
44            }
45        }
46    }
47
48    /// Create a computational basis state |0...0>
49    pub fn computational_basis(num_qubits: usize, precision: QuantumPrecision) -> Self {
50        let size = 1 << num_qubits;
51        let mut state = Self::new(size, precision);
52
53        // Set |0...0> state
54        match &mut state {
55            Self::Half(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
56            Self::Single(ref mut arr) => arr[0] = Complex32::new(1.0, 0.0),
57            Self::Double(ref mut arr) => arr[0] = Complex64::new(1.0, 0.0),
58            Self::Adaptive {
59                ref mut primary, ..
60            } => {
61                **primary = Self::computational_basis(num_qubits, QuantumPrecision::Single);
62            }
63        }
64
65        state
66    }
67
68    /// Get the length of the state vector
69    pub fn len(&self) -> usize {
70        match self {
71            Self::Half(arr) => arr.len(),
72            Self::Single(arr) => arr.len(),
73            Self::Double(arr) => arr.len(),
74            Self::Adaptive { primary, .. } => primary.len(),
75        }
76    }
77
78    /// Check if the state vector is empty
79    pub fn is_empty(&self) -> bool {
80        self.len() == 0
81    }
82
83    /// Get the current precision of the state vector
84    pub const fn precision(&self) -> QuantumPrecision {
85        match self {
86            Self::Half(_) => QuantumPrecision::Half,
87            Self::Single(_) => QuantumPrecision::Single,
88            Self::Double(_) => QuantumPrecision::Double,
89            Self::Adaptive { .. } => QuantumPrecision::Adaptive,
90        }
91    }
92
93    /// Convert to a specific precision
94    pub fn to_precision(&self, target_precision: QuantumPrecision) -> Result<Self> {
95        if self.precision() == target_precision {
96            return Ok(self.clone());
97        }
98
99        let size = self.len();
100        let mut result = Self::new(size, target_precision);
101
102        match (self, &mut result) {
103            (Self::Single(src), Self::Double(dst)) => {
104                for (i, &val) in src.iter().enumerate() {
105                    dst[i] = Complex64::new(val.re as f64, val.im as f64);
106                }
107            }
108            (Self::Double(src), Self::Single(dst)) => {
109                for (i, &val) in src.iter().enumerate() {
110                    dst[i] = Complex32::new(val.re as f32, val.im as f32);
111                }
112            }
113            (Self::Half(src), Self::Single(dst)) => {
114                *dst = src.clone();
115            }
116            (Self::Single(src), Self::Half(dst)) => {
117                *dst = src.clone();
118            }
119            (Self::Half(src), Self::Double(dst)) => {
120                for (i, &val) in src.iter().enumerate() {
121                    dst[i] = Complex64::new(val.re as f64, val.im as f64);
122                }
123            }
124            (Self::Double(src), Self::Half(dst)) => {
125                for (i, &val) in src.iter().enumerate() {
126                    dst[i] = Complex32::new(val.re as f32, val.im as f32);
127                }
128            }
129            _ => {
130                return Err(SimulatorError::UnsupportedOperation(
131                    "Complex precision conversion not supported".to_string(),
132                ));
133            }
134        }
135
136        Ok(result)
137    }
138
139    /// Normalize the state vector
140    pub fn normalize(&mut self) -> Result<()> {
141        let norm = self.norm();
142        if norm == 0.0 {
143            return Err(SimulatorError::InvalidInput(
144                "Cannot normalize zero vector".to_string(),
145            ));
146        }
147
148        match self {
149            Self::Half(arr) => {
150                let norm_f32 = norm as f32;
151                for val in arr.iter_mut() {
152                    *val /= norm_f32;
153                }
154            }
155            Self::Single(arr) => {
156                let norm_f32 = norm as f32;
157                for val in arr.iter_mut() {
158                    *val /= norm_f32;
159                }
160            }
161            Self::Double(arr) => {
162                for val in arr.iter_mut() {
163                    *val /= norm;
164                }
165            }
166            Self::Adaptive {
167                ref mut primary, ..
168            } => {
169                primary.normalize()?;
170            }
171        }
172
173        Ok(())
174    }
175
176    /// Calculate the L2 norm of the state vector
177    pub fn norm(&self) -> f64 {
178        match self {
179            Self::Half(arr) => arr.iter().map(|x| x.norm_sqr() as f64).sum::<f64>().sqrt(),
180            Self::Single(arr) => arr.iter().map(|x| x.norm_sqr() as f64).sum::<f64>().sqrt(),
181            Self::Double(arr) => arr.iter().map(|x| x.norm_sqr()).sum::<f64>().sqrt(),
182            Self::Adaptive { primary, .. } => primary.norm(),
183        }
184    }
185
186    /// Calculate probability of measuring a specific state
187    pub fn probability(&self, index: usize) -> Result<f64> {
188        if index >= self.len() {
189            return Err(SimulatorError::InvalidInput(format!(
190                "Index {} out of bounds for state vector of length {}",
191                index,
192                self.len()
193            )));
194        }
195
196        let prob = match self {
197            Self::Half(arr) => arr[index].norm_sqr() as f64,
198            Self::Single(arr) => arr[index].norm_sqr() as f64,
199            Self::Double(arr) => arr[index].norm_sqr(),
200            Self::Adaptive { primary, .. } => primary.probability(index)?,
201        };
202
203        Ok(prob)
204    }
205
206    /// Get amplitude at a specific index as Complex64
207    pub fn amplitude(&self, index: usize) -> Result<Complex64> {
208        if index >= self.len() {
209            return Err(SimulatorError::InvalidInput(format!(
210                "Index {} out of bounds for state vector of length {}",
211                index,
212                self.len()
213            )));
214        }
215
216        let amplitude = match self {
217            Self::Half(arr) => {
218                let val = arr[index];
219                Complex64::new(val.re as f64, val.im as f64)
220            }
221            Self::Single(arr) => {
222                let val = arr[index];
223                Complex64::new(val.re as f64, val.im as f64)
224            }
225            Self::Double(arr) => arr[index],
226            Self::Adaptive { primary, .. } => primary.amplitude(index)?,
227        };
228
229        Ok(amplitude)
230    }
231
232    /// Set amplitude at a specific index
233    pub fn set_amplitude(&mut self, index: usize, amplitude: Complex64) -> Result<()> {
234        if index >= self.len() {
235            return Err(SimulatorError::InvalidInput(format!(
236                "Index {} out of bounds for state vector of length {}",
237                index,
238                self.len()
239            )));
240        }
241
242        match self {
243            Self::Half(arr) => {
244                arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
245            }
246            Self::Single(arr) => {
247                arr[index] = Complex32::new(amplitude.re as f32, amplitude.im as f32);
248            }
249            Self::Double(arr) => {
250                arr[index] = amplitude;
251            }
252            Self::Adaptive {
253                ref mut primary, ..
254            } => {
255                primary.set_amplitude(index, amplitude)?;
256            }
257        }
258
259        Ok(())
260    }
261
262    /// Calculate fidelity with another state vector
263    pub fn fidelity(&self, other: &Self) -> Result<f64> {
264        if self.len() != other.len() {
265            return Err(SimulatorError::InvalidInput(
266                "State vectors must have the same length for fidelity calculation".to_string(),
267            ));
268        }
269
270        let mut inner_product = Complex64::new(0.0, 0.0);
271
272        for i in 0..self.len() {
273            let amp1 = self.amplitude(i)?;
274            let amp2 = other.amplitude(i)?;
275            inner_product += amp1.conj() * amp2;
276        }
277
278        Ok(inner_product.norm_sqr())
279    }
280
281    /// Clone the state vector to a specific precision
282    pub fn clone_to_precision(&self, precision: QuantumPrecision) -> Result<Self> {
283        self.to_precision(precision)
284    }
285
286    /// Estimate memory usage in bytes
287    pub fn memory_usage(&self) -> usize {
288        match self {
289            Self::Half(arr) => arr.len() * std::mem::size_of::<Complex32>(),
290            Self::Single(arr) => arr.len() * std::mem::size_of::<Complex32>(),
291            Self::Double(arr) => arr.len() * std::mem::size_of::<Complex64>(),
292            Self::Adaptive {
293                primary, secondary, ..
294            } => {
295                let mut usage = primary.memory_usage();
296                if let Some(sec) = secondary {
297                    usage += sec.memory_usage();
298                }
299                usage += std::mem::size_of::<QuantumPrecision>() * primary.len(); // precision_map
300                usage
301            }
302        }
303    }
304
305    /// Check if the state vector is normalized (within tolerance)
306    pub fn is_normalized(&self, tolerance: f64) -> bool {
307        (self.norm() - 1.0).abs() < tolerance
308    }
309
310    /// Get the number of qubits this state vector represents
311    pub fn num_qubits(&self) -> usize {
312        (self.len() as f64).log2() as usize
313    }
314}
315
316impl Clone for MixedPrecisionStateVector {
317    fn clone(&self) -> Self {
318        match self {
319            Self::Half(arr) => Self::Half(arr.clone()),
320            Self::Single(arr) => Self::Single(arr.clone()),
321            Self::Double(arr) => Self::Double(arr.clone()),
322            Self::Adaptive {
323                primary,
324                secondary,
325                precision_map,
326            } => Self::Adaptive {
327                primary: primary.clone(),
328                secondary: secondary.clone(),
329                precision_map: precision_map.clone(),
330            },
331        }
332    }
333}
334
335impl std::fmt::Debug for MixedPrecisionStateVector {
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        match self {
338            Self::Half(arr) => write!(f, "Half({} elements)", arr.len()),
339            Self::Single(arr) => write!(f, "Single({} elements)", arr.len()),
340            Self::Double(arr) => write!(f, "Double({} elements)", arr.len()),
341            Self::Adaptive {
342                primary, secondary, ..
343            } => {
344                write!(
345                    f,
346                    "Adaptive(primary: {primary:?}, secondary: {secondary:?})"
347                )
348            }
349        }
350    }
351}