quantrs2_sim/
precision.rs

1//! Adaptive precision control for quantum state vectors.
2//!
3//! This module provides mechanisms to dynamically adjust numerical precision
4//! based on the requirements of the simulation, enabling efficient memory usage
5//! and computation for large quantum systems.
6
7use crate::prelude::SimulatorError;
8use half::f16;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::{Complex32, Complex64};
11use std::fmt;
12
13use crate::error::Result;
14
15/// Precision level for state vector representation
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum Precision {
18    /// Half precision (16-bit)
19    Half,
20    /// Single precision (32-bit)
21    Single,
22    /// Double precision (64-bit)
23    Double,
24    /// Extended precision (128-bit) - future support
25    Extended,
26}
27
28impl Precision {
29    /// Get bytes per complex number
30    #[must_use]
31    pub const fn bytes_per_complex(&self) -> usize {
32        match self {
33            Self::Half => 4,      // 2 * f16
34            Self::Single => 8,    // 2 * f32
35            Self::Double => 16,   // 2 * f64
36            Self::Extended => 32, // 2 * f128 (future)
37        }
38    }
39
40    /// Get relative epsilon for this precision
41    #[must_use]
42    pub const fn epsilon(&self) -> f64 {
43        match self {
44            Self::Half => 0.001,     // ~2^-10
45            Self::Single => 1e-7,    // ~2^-23
46            Self::Double => 1e-15,   // ~2^-52
47            Self::Extended => 1e-30, // ~2^-100
48        }
49    }
50
51    /// Determine minimum precision needed for a given error tolerance
52    #[must_use]
53    pub fn from_tolerance(tolerance: f64) -> Self {
54        if tolerance >= 0.001 {
55            Self::Half
56        } else if tolerance >= 1e-7 {
57            Self::Single
58        } else if tolerance >= 1e-15 {
59            Self::Double
60        } else {
61            Self::Extended
62        }
63    }
64}
65
66/// Trait for types that can represent complex amplitudes
67pub trait ComplexAmplitude: Clone + Send + Sync {
68    /// Convert to Complex64 for computation
69    fn to_complex64(&self) -> Complex64;
70
71    /// Create from Complex64
72    fn from_complex64(c: Complex64) -> Self;
73
74    /// Get norm squared
75    fn norm_sqr(&self) -> f64;
76
77    /// Multiply by scalar
78    fn scale(&mut self, factor: f64);
79}
80
81impl ComplexAmplitude for Complex64 {
82    fn to_complex64(&self) -> Complex64 {
83        *self
84    }
85
86    fn from_complex64(c: Complex64) -> Self {
87        c
88    }
89
90    fn norm_sqr(&self) -> f64 {
91        self.norm_sqr()
92    }
93
94    fn scale(&mut self, factor: f64) {
95        *self *= factor;
96    }
97}
98
99impl ComplexAmplitude for Complex32 {
100    fn to_complex64(&self) -> Complex64 {
101        Complex64::new(f64::from(self.re), f64::from(self.im))
102    }
103
104    fn from_complex64(c: Complex64) -> Self {
105        Self::new(c.re as f32, c.im as f32)
106    }
107
108    fn norm_sqr(&self) -> f64 {
109        f64::from(self.re.mul_add(self.re, self.im * self.im))
110    }
111
112    fn scale(&mut self, factor: f64) {
113        *self *= factor as f32;
114    }
115}
116
117/// Half-precision complex number
118#[derive(Debug, Clone, Copy)]
119pub struct ComplexF16 {
120    pub re: f16,
121    pub im: f16,
122}
123
124impl ComplexAmplitude for ComplexF16 {
125    fn to_complex64(&self) -> Complex64 {
126        Complex64::new(self.re.to_f64(), self.im.to_f64())
127    }
128
129    fn from_complex64(c: Complex64) -> Self {
130        Self {
131            re: f16::from_f64(c.re),
132            im: f16::from_f64(c.im),
133        }
134    }
135
136    fn norm_sqr(&self) -> f64 {
137        let r = self.re.to_f64();
138        let i = self.im.to_f64();
139        r.mul_add(r, i * i)
140    }
141
142    fn scale(&mut self, factor: f64) {
143        self.re = f16::from_f64(self.re.to_f64() * factor);
144        self.im = f16::from_f64(self.im.to_f64() * factor);
145    }
146}
147
148/// Adaptive precision state vector
149pub enum AdaptiveStateVector {
150    Half(Array1<ComplexF16>),
151    Single(Array1<Complex32>),
152    Double(Array1<Complex64>),
153}
154
155impl AdaptiveStateVector {
156    /// Create a new state vector with specified precision
157    pub fn new(num_qubits: usize, precision: Precision) -> Result<Self> {
158        let size = 1 << num_qubits;
159
160        if num_qubits > 30 {
161            return Err(SimulatorError::InvalidQubits(num_qubits));
162        }
163
164        match precision {
165            Precision::Half => {
166                let mut state = Array1::from_elem(
167                    size,
168                    ComplexF16 {
169                        re: f16::from_f64(0.0),
170                        im: f16::from_f64(0.0),
171                    },
172                );
173                state[0] = ComplexF16 {
174                    re: f16::from_f64(1.0),
175                    im: f16::from_f64(0.0),
176                };
177                Ok(Self::Half(state))
178            }
179            Precision::Single => {
180                let mut state = Array1::zeros(size);
181                state[0] = Complex32::new(1.0, 0.0);
182                Ok(Self::Single(state))
183            }
184            Precision::Double => {
185                let mut state = Array1::zeros(size);
186                state[0] = Complex64::new(1.0, 0.0);
187                Ok(Self::Double(state))
188            }
189            Precision::Extended => Err(SimulatorError::InvalidConfiguration(
190                "Extended precision not yet supported".to_string(),
191            )),
192        }
193    }
194
195    /// Get current precision
196    #[must_use]
197    pub const fn precision(&self) -> Precision {
198        match self {
199            Self::Half(_) => Precision::Half,
200            Self::Single(_) => Precision::Single,
201            Self::Double(_) => Precision::Double,
202        }
203    }
204
205    /// Get number of qubits
206    #[must_use]
207    pub fn num_qubits(&self) -> usize {
208        let size = match self {
209            Self::Half(v) => v.len(),
210            Self::Single(v) => v.len(),
211            Self::Double(v) => v.len(),
212        };
213        (size as f64).log2() as usize
214    }
215
216    /// Convert to double precision for computation
217    #[must_use]
218    pub fn to_complex64(&self) -> Array1<Complex64> {
219        match self {
220            Self::Half(v) => v.map(ComplexAmplitude::to_complex64),
221            Self::Single(v) => v.map(ComplexAmplitude::to_complex64),
222            Self::Double(v) => v.clone(),
223        }
224    }
225
226    /// Update from double precision
227    pub fn from_complex64(&mut self, data: &Array1<Complex64>) -> Result<()> {
228        match self {
229            Self::Half(v) => {
230                if v.len() != data.len() {
231                    return Err(SimulatorError::DimensionMismatch(format!(
232                        "Size mismatch: {} vs {}",
233                        v.len(),
234                        data.len()
235                    )));
236                }
237                for (i, &c) in data.iter().enumerate() {
238                    v[i] = ComplexF16::from_complex64(c);
239                }
240            }
241            Self::Single(v) => {
242                if v.len() != data.len() {
243                    return Err(SimulatorError::DimensionMismatch(format!(
244                        "Size mismatch: {} vs {}",
245                        v.len(),
246                        data.len()
247                    )));
248                }
249                for (i, &c) in data.iter().enumerate() {
250                    v[i] = Complex32::from_complex64(c);
251                }
252            }
253            Self::Double(v) => {
254                v.assign(data);
255            }
256        }
257        Ok(())
258    }
259
260    /// Check if precision upgrade is needed
261    #[must_use]
262    pub fn needs_precision_upgrade(&self, threshold: f64) -> bool {
263        // Check if small amplitudes might be lost
264        let min_amplitude = match self {
265            Self::Half(v) => v
266                .iter()
267                .map(ComplexAmplitude::norm_sqr)
268                .filter(|&n| n > 0.0)
269                .fold(None, |acc, x| match acc {
270                    None => Some(x),
271                    Some(y) => Some(if x < y { x } else { y }),
272                }),
273            Self::Single(v) => v
274                .iter()
275                .map(|c| f64::from(c.norm_sqr()))
276                .filter(|&n| n > 0.0)
277                .fold(None, |acc, x| match acc {
278                    None => Some(x),
279                    Some(y) => Some(if x < y { x } else { y }),
280                }),
281            Self::Double(v) => v
282                .iter()
283                .map(scirs2_core::Complex::norm_sqr)
284                .filter(|&n| n > 0.0)
285                .fold(None, |acc, x| match acc {
286                    None => Some(x),
287                    Some(y) => Some(if x < y { x } else { y }),
288                }),
289        };
290
291        if let Some(min_amp) = min_amplitude {
292            min_amp < threshold * self.precision().epsilon()
293        } else {
294            false
295        }
296    }
297
298    /// Upgrade precision if necessary
299    pub fn upgrade_precision(&mut self) -> Result<()> {
300        let new_precision = match self.precision() {
301            Precision::Half => Precision::Single,
302            Precision::Single => Precision::Double,
303            Precision::Double => return Ok(()), // Already at max
304            Precision::Extended => unreachable!(),
305        };
306
307        let data = self.to_complex64();
308        *self = Self::new(self.num_qubits(), new_precision)?;
309        self.from_complex64(&data)?;
310
311        Ok(())
312    }
313
314    /// Downgrade precision if possible
315    pub fn downgrade_precision(&mut self, tolerance: f64) -> Result<()> {
316        let new_precision = match self.precision() {
317            Precision::Half => return Ok(()), // Already at min
318            Precision::Single => Precision::Half,
319            Precision::Double => Precision::Single,
320            Precision::Extended => Precision::Double,
321        };
322
323        // Check if downgrade would lose too much precision
324        let data = self.to_complex64();
325        let test_vec = Self::new(self.num_qubits(), new_precision)?;
326
327        // Compute error from downgrade
328        let mut max_error: f64 = 0.0;
329        match &test_vec {
330            Self::Half(_) => {
331                for &c in &data {
332                    let converted = ComplexF16::from_complex64(c).to_complex64();
333                    let error = (c - converted).norm();
334                    max_error = max_error.max(error);
335                }
336            }
337            Self::Single(_) => {
338                for &c in &data {
339                    let converted = Complex32::from_complex64(c).to_complex64();
340                    let error = (c - converted).norm();
341                    max_error = max_error.max(error);
342                }
343            }
344            Self::Double(_) => unreachable!(),
345        }
346
347        if max_error < tolerance {
348            *self = test_vec;
349            self.from_complex64(&data)?;
350        }
351
352        Ok(())
353    }
354
355    /// Memory usage in bytes
356    #[must_use]
357    pub fn memory_usage(&self) -> usize {
358        let elements = match self {
359            Self::Half(v) => v.len(),
360            Self::Single(v) => v.len(),
361            Self::Double(v) => v.len(),
362        };
363        elements * self.precision().bytes_per_complex()
364    }
365}
366
367/// Adaptive precision simulator config
368#[derive(Debug, Clone)]
369pub struct AdaptivePrecisionConfig {
370    /// Initial precision
371    pub initial_precision: Precision,
372    /// Error tolerance for automatic precision adjustment
373    pub error_tolerance: f64,
374    /// Check precision every N gates
375    pub check_interval: usize,
376    /// Enable automatic precision upgrade
377    pub auto_upgrade: bool,
378    /// Enable automatic precision downgrade
379    pub auto_downgrade: bool,
380    /// Minimum amplitude threshold
381    pub min_amplitude: f64,
382}
383
384impl Default for AdaptivePrecisionConfig {
385    fn default() -> Self {
386        Self {
387            initial_precision: Precision::Single,
388            error_tolerance: 1e-10,
389            check_interval: 100,
390            auto_upgrade: true,
391            auto_downgrade: true,
392            min_amplitude: 1e-12,
393        }
394    }
395}
396
397/// Track precision changes during simulation
398#[derive(Debug)]
399pub struct PrecisionTracker {
400    /// History of precision changes
401    changes: Vec<(usize, Precision, Precision)>, // (gate_count, from, to)
402    /// Current gate count
403    gate_count: usize,
404    /// Config
405    config: AdaptivePrecisionConfig,
406}
407
408impl PrecisionTracker {
409    /// Create a new tracker
410    #[must_use]
411    pub const fn new(config: AdaptivePrecisionConfig) -> Self {
412        Self {
413            changes: Vec::new(),
414            gate_count: 0,
415            config,
416        }
417    }
418
419    /// Record a gate application
420    pub const fn record_gate(&mut self) {
421        self.gate_count += 1;
422    }
423
424    /// Check if precision adjustment is needed
425    #[must_use]
426    pub const fn should_check_precision(&self) -> bool {
427        self.gate_count % self.config.check_interval == 0
428    }
429
430    /// Record precision change
431    pub fn record_change(&mut self, from: Precision, to: Precision) {
432        self.changes.push((self.gate_count, from, to));
433    }
434
435    /// Get precision history
436    #[must_use]
437    pub fn history(&self) -> &[(usize, Precision, Precision)] {
438        &self.changes
439    }
440
441    /// Get statistics
442    #[must_use]
443    pub fn stats(&self) -> PrecisionStats {
444        let mut upgrades = 0;
445        let mut downgrades = 0;
446
447        for (_, from, to) in &self.changes {
448            match (from, to) {
449                (Precision::Half, Precision::Single)
450                | (Precision::Single, Precision::Double)
451                | (Precision::Double, Precision::Extended) => upgrades += 1,
452                _ => downgrades += 1,
453            }
454        }
455
456        PrecisionStats {
457            total_gates: self.gate_count,
458            precision_changes: self.changes.len(),
459            upgrades,
460            downgrades,
461        }
462    }
463}
464
465/// Precision statistics
466#[derive(Debug)]
467pub struct PrecisionStats {
468    pub total_gates: usize,
469    pub precision_changes: usize,
470    pub upgrades: usize,
471    pub downgrades: usize,
472}
473
474impl fmt::Display for PrecisionStats {
475    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476        write!(
477            f,
478            "Precision Stats: {} gates, {} changes ({} upgrades, {} downgrades)",
479            self.total_gates, self.precision_changes, self.upgrades, self.downgrades
480        )
481    }
482}
483
484/// Benchmark different precisions
485pub fn benchmark_precisions(num_qubits: usize) -> Result<()> {
486    println!("\nPrecision Benchmark for {num_qubits} qubits:");
487    println!("{:-<60}", "");
488
489    for precision in [Precision::Half, Precision::Single, Precision::Double] {
490        let state = AdaptiveStateVector::new(num_qubits, precision)?;
491        let memory = state.memory_usage();
492        let memory_mb = memory as f64 / (1024.0 * 1024.0);
493
494        println!(
495            "{:?} precision: {:.2} MB ({} bytes per amplitude)",
496            precision,
497            memory_mb,
498            precision.bytes_per_complex()
499        );
500    }
501
502    Ok(())
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_precision_levels() {
511        assert_eq!(Precision::Half.bytes_per_complex(), 4);
512        assert_eq!(Precision::Single.bytes_per_complex(), 8);
513        assert_eq!(Precision::Double.bytes_per_complex(), 16);
514    }
515
516    #[test]
517    fn test_precision_from_tolerance() {
518        assert_eq!(Precision::from_tolerance(0.01), Precision::Half);
519        assert_eq!(Precision::from_tolerance(1e-8), Precision::Double); // 1e-8 < 1e-7, so Double
520        assert_eq!(Precision::from_tolerance(1e-16), Precision::Extended); // 1e-16 < 1e-15, so Extended
521    }
522
523    #[test]
524    fn test_complex_f16() {
525        let c = ComplexF16 {
526            re: f16::from_f64(0.5),
527            im: f16::from_f64(0.5),
528        };
529
530        let c64 = c.to_complex64();
531        assert!((c64.re - 0.5).abs() < 0.01);
532        assert!((c64.im - 0.5).abs() < 0.01);
533    }
534
535    #[test]
536    fn test_adaptive_state_vector() {
537        let mut state = AdaptiveStateVector::new(2, Precision::Single)
538            .expect("Failed to create adaptive state vector");
539        assert_eq!(state.precision(), Precision::Single);
540        assert_eq!(state.num_qubits(), 2);
541
542        // Test conversion
543        let c64 = state.to_complex64();
544        assert_eq!(c64.len(), 4);
545        assert_eq!(c64[0], Complex64::new(1.0, 0.0));
546    }
547
548    #[test]
549    fn test_precision_upgrade() {
550        let mut state = AdaptiveStateVector::new(2, Precision::Half)
551            .expect("Failed to create half precision state");
552        state
553            .upgrade_precision()
554            .expect("Failed to upgrade precision");
555        assert_eq!(state.precision(), Precision::Single);
556    }
557
558    #[test]
559    fn test_precision_tracker() {
560        let config = AdaptivePrecisionConfig::default();
561        let mut tracker = PrecisionTracker::new(config);
562
563        // Record exactly 100 gates so gate_count % check_interval == 0
564        for _ in 0..100 {
565            tracker.record_gate();
566        }
567
568        assert!(tracker.should_check_precision());
569
570        tracker.record_change(Precision::Single, Precision::Double);
571        let stats = tracker.stats();
572        assert_eq!(stats.upgrades, 1);
573        assert_eq!(stats.downgrades, 0);
574    }
575
576    #[test]
577    fn test_memory_usage() {
578        let state = AdaptiveStateVector::new(10, Precision::Half)
579            .expect("Failed to create state for memory test");
580        let memory = state.memory_usage();
581        assert_eq!(memory, 1024 * 4); // 2^10 * 4 bytes
582    }
583}