quantrs2_sim/
precision.rs

1//! Adaptive precision control for quantum state vectors.
2//!
3//! This module provides mechanisms to dynamically adjust numerical precision
4//! based on the requirements of the simulation, enabling efficient memory usage
5//! and computation for large quantum systems.
6
7use crate::prelude::SimulatorError;
8use half::f16;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::{Complex32, Complex64};
11use std::fmt;
12
13use crate::error::Result;
14
15/// Precision level for state vector representation
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum Precision {
18    /// Half precision (16-bit)
19    Half,
20    /// Single precision (32-bit)
21    Single,
22    /// Double precision (64-bit)
23    Double,
24    /// Extended precision (128-bit) - future support
25    Extended,
26}
27
28impl Precision {
29    /// Get bytes per complex number
30    pub const fn bytes_per_complex(&self) -> usize {
31        match self {
32            Self::Half => 4,      // 2 * f16
33            Self::Single => 8,    // 2 * f32
34            Self::Double => 16,   // 2 * f64
35            Self::Extended => 32, // 2 * f128 (future)
36        }
37    }
38
39    /// Get relative epsilon for this precision
40    pub const fn epsilon(&self) -> f64 {
41        match self {
42            Self::Half => 0.001,     // ~2^-10
43            Self::Single => 1e-7,    // ~2^-23
44            Self::Double => 1e-15,   // ~2^-52
45            Self::Extended => 1e-30, // ~2^-100
46        }
47    }
48
49    /// Determine minimum precision needed for a given error tolerance
50    pub fn from_tolerance(tolerance: f64) -> Self {
51        if tolerance >= 0.001 {
52            Self::Half
53        } else if tolerance >= 1e-7 {
54            Self::Single
55        } else if tolerance >= 1e-15 {
56            Self::Double
57        } else {
58            Self::Extended
59        }
60    }
61}
62
63/// Trait for types that can represent complex amplitudes
64pub trait ComplexAmplitude: Clone + Send + Sync {
65    /// Convert to Complex64 for computation
66    fn to_complex64(&self) -> Complex64;
67
68    /// Create from Complex64
69    fn from_complex64(c: Complex64) -> Self;
70
71    /// Get norm squared
72    fn norm_sqr(&self) -> f64;
73
74    /// Multiply by scalar
75    fn scale(&mut self, factor: f64);
76}
77
78impl ComplexAmplitude for Complex64 {
79    fn to_complex64(&self) -> Complex64 {
80        *self
81    }
82
83    fn from_complex64(c: Complex64) -> Self {
84        c
85    }
86
87    fn norm_sqr(&self) -> f64 {
88        self.norm_sqr()
89    }
90
91    fn scale(&mut self, factor: f64) {
92        *self *= factor;
93    }
94}
95
96impl ComplexAmplitude for Complex32 {
97    fn to_complex64(&self) -> Complex64 {
98        Complex64::new(self.re as f64, self.im as f64)
99    }
100
101    fn from_complex64(c: Complex64) -> Self {
102        Self::new(c.re as f32, c.im as f32)
103    }
104
105    fn norm_sqr(&self) -> f64 {
106        self.re.mul_add(self.re, self.im * self.im) as f64
107    }
108
109    fn scale(&mut self, factor: f64) {
110        *self *= factor as f32;
111    }
112}
113
114/// Half-precision complex number
115#[derive(Debug, Clone, Copy)]
116pub struct ComplexF16 {
117    pub re: f16,
118    pub im: f16,
119}
120
121impl ComplexAmplitude for ComplexF16 {
122    fn to_complex64(&self) -> Complex64 {
123        Complex64::new(self.re.to_f64(), self.im.to_f64())
124    }
125
126    fn from_complex64(c: Complex64) -> Self {
127        Self {
128            re: f16::from_f64(c.re),
129            im: f16::from_f64(c.im),
130        }
131    }
132
133    fn norm_sqr(&self) -> f64 {
134        let r = self.re.to_f64();
135        let i = self.im.to_f64();
136        r.mul_add(r, i * i)
137    }
138
139    fn scale(&mut self, factor: f64) {
140        self.re = f16::from_f64(self.re.to_f64() * factor);
141        self.im = f16::from_f64(self.im.to_f64() * factor);
142    }
143}
144
145/// Adaptive precision state vector
146pub enum AdaptiveStateVector {
147    Half(Array1<ComplexF16>),
148    Single(Array1<Complex32>),
149    Double(Array1<Complex64>),
150}
151
152impl AdaptiveStateVector {
153    /// Create a new state vector with specified precision
154    pub fn new(num_qubits: usize, precision: Precision) -> Result<Self> {
155        let size = 1 << num_qubits;
156
157        if num_qubits > 30 {
158            return Err(SimulatorError::InvalidQubits(num_qubits));
159        }
160
161        match precision {
162            Precision::Half => {
163                let mut state = Array1::from_elem(
164                    size,
165                    ComplexF16 {
166                        re: f16::from_f64(0.0),
167                        im: f16::from_f64(0.0),
168                    },
169                );
170                state[0] = ComplexF16 {
171                    re: f16::from_f64(1.0),
172                    im: f16::from_f64(0.0),
173                };
174                Ok(Self::Half(state))
175            }
176            Precision::Single => {
177                let mut state = Array1::zeros(size);
178                state[0] = Complex32::new(1.0, 0.0);
179                Ok(Self::Single(state))
180            }
181            Precision::Double => {
182                let mut state = Array1::zeros(size);
183                state[0] = Complex64::new(1.0, 0.0);
184                Ok(Self::Double(state))
185            }
186            Precision::Extended => Err(SimulatorError::InvalidConfiguration(
187                "Extended precision not yet supported".to_string(),
188            )),
189        }
190    }
191
192    /// Get current precision
193    pub const fn precision(&self) -> Precision {
194        match self {
195            Self::Half(_) => Precision::Half,
196            Self::Single(_) => Precision::Single,
197            Self::Double(_) => Precision::Double,
198        }
199    }
200
201    /// Get number of qubits
202    pub fn num_qubits(&self) -> usize {
203        let size = match self {
204            Self::Half(v) => v.len(),
205            Self::Single(v) => v.len(),
206            Self::Double(v) => v.len(),
207        };
208        (size as f64).log2() as usize
209    }
210
211    /// Convert to double precision for computation
212    pub fn to_complex64(&self) -> Array1<Complex64> {
213        match self {
214            Self::Half(v) => v.map(|c| c.to_complex64()),
215            Self::Single(v) => v.map(|c| c.to_complex64()),
216            Self::Double(v) => v.clone(),
217        }
218    }
219
220    /// Update from double precision
221    pub fn from_complex64(&mut self, data: &Array1<Complex64>) -> Result<()> {
222        match self {
223            Self::Half(v) => {
224                if v.len() != data.len() {
225                    return Err(SimulatorError::DimensionMismatch(format!(
226                        "Size mismatch: {} vs {}",
227                        v.len(),
228                        data.len()
229                    )));
230                }
231                for (i, &c) in data.iter().enumerate() {
232                    v[i] = ComplexF16::from_complex64(c);
233                }
234            }
235            Self::Single(v) => {
236                if v.len() != data.len() {
237                    return Err(SimulatorError::DimensionMismatch(format!(
238                        "Size mismatch: {} vs {}",
239                        v.len(),
240                        data.len()
241                    )));
242                }
243                for (i, &c) in data.iter().enumerate() {
244                    v[i] = Complex32::from_complex64(c);
245                }
246            }
247            Self::Double(v) => {
248                v.assign(data);
249            }
250        }
251        Ok(())
252    }
253
254    /// Check if precision upgrade is needed
255    pub fn needs_precision_upgrade(&self, threshold: f64) -> bool {
256        // Check if small amplitudes might be lost
257        let min_amplitude = match self {
258            Self::Half(v) => {
259                v.iter()
260                    .map(|c| c.norm_sqr())
261                    .filter(|&n| n > 0.0)
262                    .fold(None, |acc, x| match acc {
263                        None => Some(x),
264                        Some(y) => Some(if x < y { x } else { y }),
265                    })
266            }
267            Self::Single(v) => v
268                .iter()
269                .map(|c| c.norm_sqr() as f64)
270                .filter(|&n| n > 0.0)
271                .fold(None, |acc, x| match acc {
272                    None => Some(x),
273                    Some(y) => Some(if x < y { x } else { y }),
274                }),
275            Self::Double(v) => {
276                v.iter()
277                    .map(|c| c.norm_sqr())
278                    .filter(|&n| n > 0.0)
279                    .fold(None, |acc, x| match acc {
280                        None => Some(x),
281                        Some(y) => Some(if x < y { x } else { y }),
282                    })
283            }
284        };
285
286        if let Some(min_amp) = min_amplitude {
287            min_amp < threshold * self.precision().epsilon()
288        } else {
289            false
290        }
291    }
292
293    /// Upgrade precision if necessary
294    pub fn upgrade_precision(&mut self) -> Result<()> {
295        let new_precision = match self.precision() {
296            Precision::Half => Precision::Single,
297            Precision::Single => Precision::Double,
298            Precision::Double => return Ok(()), // Already at max
299            Precision::Extended => unreachable!(),
300        };
301
302        let data = self.to_complex64();
303        *self = Self::new(self.num_qubits(), new_precision)?;
304        self.from_complex64(&data)?;
305
306        Ok(())
307    }
308
309    /// Downgrade precision if possible
310    pub fn downgrade_precision(&mut self, tolerance: f64) -> Result<()> {
311        let new_precision = match self.precision() {
312            Precision::Half => return Ok(()), // Already at min
313            Precision::Single => Precision::Half,
314            Precision::Double => Precision::Single,
315            Precision::Extended => Precision::Double,
316        };
317
318        // Check if downgrade would lose too much precision
319        let data = self.to_complex64();
320        let test_vec = Self::new(self.num_qubits(), new_precision)?;
321
322        // Compute error from downgrade
323        let mut max_error: f64 = 0.0;
324        match &test_vec {
325            Self::Half(_) => {
326                for &c in &data {
327                    let converted = ComplexF16::from_complex64(c).to_complex64();
328                    let error = (c - converted).norm();
329                    max_error = max_error.max(error);
330                }
331            }
332            Self::Single(_) => {
333                for &c in &data {
334                    let converted = Complex32::from_complex64(c).to_complex64();
335                    let error = (c - converted).norm();
336                    max_error = max_error.max(error);
337                }
338            }
339            _ => unreachable!(),
340        }
341
342        if max_error < tolerance {
343            *self = test_vec;
344            self.from_complex64(&data)?;
345        }
346
347        Ok(())
348    }
349
350    /// Memory usage in bytes
351    pub fn memory_usage(&self) -> usize {
352        let elements = match self {
353            Self::Half(v) => v.len(),
354            Self::Single(v) => v.len(),
355            Self::Double(v) => v.len(),
356        };
357        elements * self.precision().bytes_per_complex()
358    }
359}
360
361/// Adaptive precision simulator config
362#[derive(Debug, Clone)]
363pub struct AdaptivePrecisionConfig {
364    /// Initial precision
365    pub initial_precision: Precision,
366    /// Error tolerance for automatic precision adjustment
367    pub error_tolerance: f64,
368    /// Check precision every N gates
369    pub check_interval: usize,
370    /// Enable automatic precision upgrade
371    pub auto_upgrade: bool,
372    /// Enable automatic precision downgrade
373    pub auto_downgrade: bool,
374    /// Minimum amplitude threshold
375    pub min_amplitude: f64,
376}
377
378impl Default for AdaptivePrecisionConfig {
379    fn default() -> Self {
380        Self {
381            initial_precision: Precision::Single,
382            error_tolerance: 1e-10,
383            check_interval: 100,
384            auto_upgrade: true,
385            auto_downgrade: true,
386            min_amplitude: 1e-12,
387        }
388    }
389}
390
391/// Track precision changes during simulation
392#[derive(Debug)]
393pub struct PrecisionTracker {
394    /// History of precision changes
395    changes: Vec<(usize, Precision, Precision)>, // (gate_count, from, to)
396    /// Current gate count
397    gate_count: usize,
398    /// Config
399    config: AdaptivePrecisionConfig,
400}
401
402impl PrecisionTracker {
403    /// Create a new tracker
404    pub const fn new(config: AdaptivePrecisionConfig) -> Self {
405        Self {
406            changes: Vec::new(),
407            gate_count: 0,
408            config,
409        }
410    }
411
412    /// Record a gate application
413    pub const fn record_gate(&mut self) {
414        self.gate_count += 1;
415    }
416
417    /// Check if precision adjustment is needed
418    pub const fn should_check_precision(&self) -> bool {
419        self.gate_count % self.config.check_interval == 0
420    }
421
422    /// Record precision change
423    pub fn record_change(&mut self, from: Precision, to: Precision) {
424        self.changes.push((self.gate_count, from, to));
425    }
426
427    /// Get precision history
428    pub fn history(&self) -> &[(usize, Precision, Precision)] {
429        &self.changes
430    }
431
432    /// Get statistics
433    pub fn stats(&self) -> PrecisionStats {
434        let mut upgrades = 0;
435        let mut downgrades = 0;
436
437        for (_, from, to) in &self.changes {
438            match (from, to) {
439                (Precision::Half, Precision::Single)
440                | (Precision::Single, Precision::Double)
441                | (Precision::Double, Precision::Extended) => upgrades += 1,
442                _ => downgrades += 1,
443            }
444        }
445
446        PrecisionStats {
447            total_gates: self.gate_count,
448            precision_changes: self.changes.len(),
449            upgrades,
450            downgrades,
451        }
452    }
453}
454
455/// Precision statistics
456#[derive(Debug)]
457pub struct PrecisionStats {
458    pub total_gates: usize,
459    pub precision_changes: usize,
460    pub upgrades: usize,
461    pub downgrades: usize,
462}
463
464impl fmt::Display for PrecisionStats {
465    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
466        write!(
467            f,
468            "Precision Stats: {} gates, {} changes ({} upgrades, {} downgrades)",
469            self.total_gates, self.precision_changes, self.upgrades, self.downgrades
470        )
471    }
472}
473
474/// Benchmark different precisions
475pub fn benchmark_precisions(num_qubits: usize) -> Result<()> {
476    println!("\nPrecision Benchmark for {num_qubits} qubits:");
477    println!("{:-<60}", "");
478
479    for precision in [Precision::Half, Precision::Single, Precision::Double] {
480        let state = AdaptiveStateVector::new(num_qubits, precision)?;
481        let memory = state.memory_usage();
482        let memory_mb = memory as f64 / (1024.0 * 1024.0);
483
484        println!(
485            "{:?} precision: {:.2} MB ({} bytes per amplitude)",
486            precision,
487            memory_mb,
488            precision.bytes_per_complex()
489        );
490    }
491
492    Ok(())
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_precision_levels() {
501        assert_eq!(Precision::Half.bytes_per_complex(), 4);
502        assert_eq!(Precision::Single.bytes_per_complex(), 8);
503        assert_eq!(Precision::Double.bytes_per_complex(), 16);
504    }
505
506    #[test]
507    fn test_precision_from_tolerance() {
508        assert_eq!(Precision::from_tolerance(0.01), Precision::Half);
509        assert_eq!(Precision::from_tolerance(1e-8), Precision::Double); // 1e-8 < 1e-7, so Double
510        assert_eq!(Precision::from_tolerance(1e-16), Precision::Extended); // 1e-16 < 1e-15, so Extended
511    }
512
513    #[test]
514    fn test_complex_f16() {
515        let c = ComplexF16 {
516            re: f16::from_f64(0.5),
517            im: f16::from_f64(0.5),
518        };
519
520        let c64 = c.to_complex64();
521        assert!((c64.re - 0.5).abs() < 0.01);
522        assert!((c64.im - 0.5).abs() < 0.01);
523    }
524
525    #[test]
526    fn test_adaptive_state_vector() {
527        let mut state = AdaptiveStateVector::new(2, Precision::Single).unwrap();
528        assert_eq!(state.precision(), Precision::Single);
529        assert_eq!(state.num_qubits(), 2);
530
531        // Test conversion
532        let c64 = state.to_complex64();
533        assert_eq!(c64.len(), 4);
534        assert_eq!(c64[0], Complex64::new(1.0, 0.0));
535    }
536
537    #[test]
538    fn test_precision_upgrade() {
539        let mut state = AdaptiveStateVector::new(2, Precision::Half).unwrap();
540        state.upgrade_precision().unwrap();
541        assert_eq!(state.precision(), Precision::Single);
542    }
543
544    #[test]
545    fn test_precision_tracker() {
546        let config = AdaptivePrecisionConfig::default();
547        let mut tracker = PrecisionTracker::new(config);
548
549        // Record exactly 100 gates so gate_count % check_interval == 0
550        for _ in 0..100 {
551            tracker.record_gate();
552        }
553
554        assert!(tracker.should_check_precision());
555
556        tracker.record_change(Precision::Single, Precision::Double);
557        let stats = tracker.stats();
558        assert_eq!(stats.upgrades, 1);
559        assert_eq!(stats.downgrades, 0);
560    }
561
562    #[test]
563    fn test_memory_usage() {
564        let state = AdaptiveStateVector::new(10, Precision::Half).unwrap();
565        let memory = state.memory_usage();
566        assert_eq!(memory, 1024 * 4); // 2^10 * 4 bytes
567    }
568}