quantrs2_sim/
precision.rs

1//! Adaptive precision control for quantum state vectors.
2//!
3//! This module provides mechanisms to dynamically adjust numerical precision
4//! based on the requirements of the simulation, enabling efficient memory usage
5//! and computation for large quantum systems.
6
7use crate::prelude::SimulatorError;
8use half::f16;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::{Complex32, Complex64};
11use std::fmt;
12
13use crate::error::Result;
14
15/// Precision level for state vector representation
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum Precision {
18    /// Half precision (16-bit)
19    Half,
20    /// Single precision (32-bit)
21    Single,
22    /// Double precision (64-bit)
23    Double,
24    /// Extended precision (128-bit) - future support
25    Extended,
26}
27
28impl Precision {
29    /// Get bytes per complex number
30    pub fn bytes_per_complex(&self) -> usize {
31        match self {
32            Precision::Half => 4,      // 2 * f16
33            Precision::Single => 8,    // 2 * f32
34            Precision::Double => 16,   // 2 * f64
35            Precision::Extended => 32, // 2 * f128 (future)
36        }
37    }
38
39    /// Get relative epsilon for this precision
40    pub fn epsilon(&self) -> f64 {
41        match self {
42            Precision::Half => 0.001,     // ~2^-10
43            Precision::Single => 1e-7,    // ~2^-23
44            Precision::Double => 1e-15,   // ~2^-52
45            Precision::Extended => 1e-30, // ~2^-100
46        }
47    }
48
49    /// Determine minimum precision needed for a given error tolerance
50    pub fn from_tolerance(tolerance: f64) -> Self {
51        if tolerance >= 0.001 {
52            Precision::Half
53        } else if tolerance >= 1e-7 {
54            Precision::Single
55        } else if tolerance >= 1e-15 {
56            Precision::Double
57        } else {
58            Precision::Extended
59        }
60    }
61}
62
63/// Trait for types that can represent complex amplitudes
64pub trait ComplexAmplitude: Clone + Send + Sync {
65    /// Convert to Complex64 for computation
66    fn to_complex64(&self) -> Complex64;
67
68    /// Create from Complex64
69    fn from_complex64(c: Complex64) -> Self;
70
71    /// Get norm squared
72    fn norm_sqr(&self) -> f64;
73
74    /// Multiply by scalar
75    fn scale(&mut self, factor: f64);
76}
77
78impl ComplexAmplitude for Complex64 {
79    fn to_complex64(&self) -> Complex64 {
80        *self
81    }
82
83    fn from_complex64(c: Complex64) -> Self {
84        c
85    }
86
87    fn norm_sqr(&self) -> f64 {
88        self.norm_sqr()
89    }
90
91    fn scale(&mut self, factor: f64) {
92        *self *= factor;
93    }
94}
95
96impl ComplexAmplitude for Complex32 {
97    fn to_complex64(&self) -> Complex64 {
98        Complex64::new(self.re as f64, self.im as f64)
99    }
100
101    fn from_complex64(c: Complex64) -> Self {
102        Complex32::new(c.re as f32, c.im as f32)
103    }
104
105    fn norm_sqr(&self) -> f64 {
106        (self.re * self.re + self.im * self.im) as f64
107    }
108
109    fn scale(&mut self, factor: f64) {
110        *self *= factor as f32;
111    }
112}
113
114/// Half-precision complex number
115#[derive(Debug, Clone, Copy)]
116pub struct ComplexF16 {
117    pub re: f16,
118    pub im: f16,
119}
120
121impl ComplexAmplitude for ComplexF16 {
122    fn to_complex64(&self) -> Complex64 {
123        Complex64::new(self.re.to_f64(), self.im.to_f64())
124    }
125
126    fn from_complex64(c: Complex64) -> Self {
127        ComplexF16 {
128            re: f16::from_f64(c.re),
129            im: f16::from_f64(c.im),
130        }
131    }
132
133    fn norm_sqr(&self) -> f64 {
134        let r = self.re.to_f64();
135        let i = self.im.to_f64();
136        r * r + i * i
137    }
138
139    fn scale(&mut self, factor: f64) {
140        self.re = f16::from_f64(self.re.to_f64() * factor);
141        self.im = f16::from_f64(self.im.to_f64() * factor);
142    }
143}
144
145/// Adaptive precision state vector
146pub enum AdaptiveStateVector {
147    Half(Array1<ComplexF16>),
148    Single(Array1<Complex32>),
149    Double(Array1<Complex64>),
150}
151
152impl AdaptiveStateVector {
153    /// Create a new state vector with specified precision
154    pub fn new(num_qubits: usize, precision: Precision) -> Result<Self> {
155        let size = 1 << num_qubits;
156
157        if num_qubits > 30 {
158            return Err(SimulatorError::InvalidQubits(num_qubits));
159        }
160
161        match precision {
162            Precision::Half => {
163                let mut state = Array1::from_elem(
164                    size,
165                    ComplexF16 {
166                        re: f16::from_f64(0.0),
167                        im: f16::from_f64(0.0),
168                    },
169                );
170                state[0] = ComplexF16 {
171                    re: f16::from_f64(1.0),
172                    im: f16::from_f64(0.0),
173                };
174                Ok(AdaptiveStateVector::Half(state))
175            }
176            Precision::Single => {
177                let mut state = Array1::zeros(size);
178                state[0] = Complex32::new(1.0, 0.0);
179                Ok(AdaptiveStateVector::Single(state))
180            }
181            Precision::Double => {
182                let mut state = Array1::zeros(size);
183                state[0] = Complex64::new(1.0, 0.0);
184                Ok(AdaptiveStateVector::Double(state))
185            }
186            Precision::Extended => Err(SimulatorError::InvalidConfiguration(
187                "Extended precision not yet supported".to_string(),
188            )),
189        }
190    }
191
192    /// Get current precision
193    pub fn precision(&self) -> Precision {
194        match self {
195            AdaptiveStateVector::Half(_) => Precision::Half,
196            AdaptiveStateVector::Single(_) => Precision::Single,
197            AdaptiveStateVector::Double(_) => Precision::Double,
198        }
199    }
200
201    /// Get number of qubits
202    pub fn num_qubits(&self) -> usize {
203        let size = match self {
204            AdaptiveStateVector::Half(v) => v.len(),
205            AdaptiveStateVector::Single(v) => v.len(),
206            AdaptiveStateVector::Double(v) => v.len(),
207        };
208        (size as f64).log2() as usize
209    }
210
211    /// Convert to double precision for computation
212    pub fn to_complex64(&self) -> Array1<Complex64> {
213        match self {
214            AdaptiveStateVector::Half(v) => v.map(|c| c.to_complex64()),
215            AdaptiveStateVector::Single(v) => v.map(|c| c.to_complex64()),
216            AdaptiveStateVector::Double(v) => v.clone(),
217        }
218    }
219
220    /// Update from double precision
221    pub fn from_complex64(&mut self, data: &Array1<Complex64>) -> Result<()> {
222        match self {
223            AdaptiveStateVector::Half(v) => {
224                if v.len() != data.len() {
225                    return Err(SimulatorError::DimensionMismatch(format!(
226                        "Size mismatch: {} vs {}",
227                        v.len(),
228                        data.len()
229                    )));
230                }
231                for (i, &c) in data.iter().enumerate() {
232                    v[i] = ComplexF16::from_complex64(c);
233                }
234            }
235            AdaptiveStateVector::Single(v) => {
236                if v.len() != data.len() {
237                    return Err(SimulatorError::DimensionMismatch(format!(
238                        "Size mismatch: {} vs {}",
239                        v.len(),
240                        data.len()
241                    )));
242                }
243                for (i, &c) in data.iter().enumerate() {
244                    v[i] = Complex32::from_complex64(c);
245                }
246            }
247            AdaptiveStateVector::Double(v) => {
248                v.assign(data);
249            }
250        }
251        Ok(())
252    }
253
254    /// Check if precision upgrade is needed
255    pub fn needs_precision_upgrade(&self, threshold: f64) -> bool {
256        // Check if small amplitudes might be lost
257        let min_amplitude = match self {
258            AdaptiveStateVector::Half(v) => v
259                .iter()
260                .map(|c| c.norm_sqr())
261                .filter(|&n| n > 0.0)
262                .fold(None, |acc, x| match acc {
263                    None => Some(x),
264                    Some(y) => Some(if x < y { x } else { y }),
265                }),
266            AdaptiveStateVector::Single(v) => v
267                .iter()
268                .map(|c| c.norm_sqr() as f64)
269                .filter(|&n| n > 0.0)
270                .fold(None, |acc, x| match acc {
271                    None => Some(x),
272                    Some(y) => Some(if x < y { x } else { y }),
273                }),
274            AdaptiveStateVector::Double(v) => v
275                .iter()
276                .map(|c| c.norm_sqr())
277                .filter(|&n| n > 0.0)
278                .fold(None, |acc, x| match acc {
279                    None => Some(x),
280                    Some(y) => Some(if x < y { x } else { y }),
281                }),
282        };
283
284        if let Some(min_amp) = min_amplitude {
285            min_amp < threshold * self.precision().epsilon()
286        } else {
287            false
288        }
289    }
290
291    /// Upgrade precision if necessary
292    pub fn upgrade_precision(&mut self) -> Result<()> {
293        let new_precision = match self.precision() {
294            Precision::Half => Precision::Single,
295            Precision::Single => Precision::Double,
296            Precision::Double => return Ok(()), // Already at max
297            Precision::Extended => unreachable!(),
298        };
299
300        let data = self.to_complex64();
301        *self = Self::new(self.num_qubits(), new_precision)?;
302        self.from_complex64(&data)?;
303
304        Ok(())
305    }
306
307    /// Downgrade precision if possible
308    pub fn downgrade_precision(&mut self, tolerance: f64) -> Result<()> {
309        let new_precision = match self.precision() {
310            Precision::Half => return Ok(()), // Already at min
311            Precision::Single => Precision::Half,
312            Precision::Double => Precision::Single,
313            Precision::Extended => Precision::Double,
314        };
315
316        // Check if downgrade would lose too much precision
317        let data = self.to_complex64();
318        let test_vec = Self::new(self.num_qubits(), new_precision)?;
319
320        // Compute error from downgrade
321        let mut max_error: f64 = 0.0;
322        match &test_vec {
323            AdaptiveStateVector::Half(_) => {
324                for &c in data.iter() {
325                    let converted = ComplexF16::from_complex64(c).to_complex64();
326                    let error = (c - converted).norm();
327                    max_error = max_error.max(error);
328                }
329            }
330            AdaptiveStateVector::Single(_) => {
331                for &c in data.iter() {
332                    let converted = Complex32::from_complex64(c).to_complex64();
333                    let error = (c - converted).norm();
334                    max_error = max_error.max(error);
335                }
336            }
337            _ => unreachable!(),
338        }
339
340        if max_error < tolerance {
341            *self = test_vec;
342            self.from_complex64(&data)?;
343        }
344
345        Ok(())
346    }
347
348    /// Memory usage in bytes
349    pub fn memory_usage(&self) -> usize {
350        let elements = match self {
351            AdaptiveStateVector::Half(v) => v.len(),
352            AdaptiveStateVector::Single(v) => v.len(),
353            AdaptiveStateVector::Double(v) => v.len(),
354        };
355        elements * self.precision().bytes_per_complex()
356    }
357}
358
359/// Adaptive precision simulator config
360#[derive(Debug, Clone)]
361pub struct AdaptivePrecisionConfig {
362    /// Initial precision
363    pub initial_precision: Precision,
364    /// Error tolerance for automatic precision adjustment
365    pub error_tolerance: f64,
366    /// Check precision every N gates
367    pub check_interval: usize,
368    /// Enable automatic precision upgrade
369    pub auto_upgrade: bool,
370    /// Enable automatic precision downgrade
371    pub auto_downgrade: bool,
372    /// Minimum amplitude threshold
373    pub min_amplitude: f64,
374}
375
376impl Default for AdaptivePrecisionConfig {
377    fn default() -> Self {
378        Self {
379            initial_precision: Precision::Single,
380            error_tolerance: 1e-10,
381            check_interval: 100,
382            auto_upgrade: true,
383            auto_downgrade: true,
384            min_amplitude: 1e-12,
385        }
386    }
387}
388
389/// Track precision changes during simulation
390#[derive(Debug)]
391pub struct PrecisionTracker {
392    /// History of precision changes
393    changes: Vec<(usize, Precision, Precision)>, // (gate_count, from, to)
394    /// Current gate count
395    gate_count: usize,
396    /// Config
397    config: AdaptivePrecisionConfig,
398}
399
400impl PrecisionTracker {
401    /// Create a new tracker
402    pub fn new(config: AdaptivePrecisionConfig) -> Self {
403        Self {
404            changes: Vec::new(),
405            gate_count: 0,
406            config,
407        }
408    }
409
410    /// Record a gate application
411    pub fn record_gate(&mut self) {
412        self.gate_count += 1;
413    }
414
415    /// Check if precision adjustment is needed
416    pub fn should_check_precision(&self) -> bool {
417        self.gate_count % self.config.check_interval == 0
418    }
419
420    /// Record precision change
421    pub fn record_change(&mut self, from: Precision, to: Precision) {
422        self.changes.push((self.gate_count, from, to));
423    }
424
425    /// Get precision history
426    pub fn history(&self) -> &[(usize, Precision, Precision)] {
427        &self.changes
428    }
429
430    /// Get statistics
431    pub fn stats(&self) -> PrecisionStats {
432        let mut upgrades = 0;
433        let mut downgrades = 0;
434
435        for (_, from, to) in &self.changes {
436            match (from, to) {
437                (Precision::Half, Precision::Single)
438                | (Precision::Single, Precision::Double)
439                | (Precision::Double, Precision::Extended) => upgrades += 1,
440                _ => downgrades += 1,
441            }
442        }
443
444        PrecisionStats {
445            total_gates: self.gate_count,
446            precision_changes: self.changes.len(),
447            upgrades,
448            downgrades,
449        }
450    }
451}
452
453/// Precision statistics
454#[derive(Debug)]
455pub struct PrecisionStats {
456    pub total_gates: usize,
457    pub precision_changes: usize,
458    pub upgrades: usize,
459    pub downgrades: usize,
460}
461
462impl fmt::Display for PrecisionStats {
463    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
464        write!(
465            f,
466            "Precision Stats: {} gates, {} changes ({} upgrades, {} downgrades)",
467            self.total_gates, self.precision_changes, self.upgrades, self.downgrades
468        )
469    }
470}
471
472/// Benchmark different precisions
473pub fn benchmark_precisions(num_qubits: usize) -> Result<()> {
474    println!("\nPrecision Benchmark for {} qubits:", num_qubits);
475    println!("{:-<60}", "");
476
477    for precision in [Precision::Half, Precision::Single, Precision::Double] {
478        let state = AdaptiveStateVector::new(num_qubits, precision)?;
479        let memory = state.memory_usage();
480        let memory_mb = memory as f64 / (1024.0 * 1024.0);
481
482        println!(
483            "{:?} precision: {:.2} MB ({} bytes per amplitude)",
484            precision,
485            memory_mb,
486            precision.bytes_per_complex()
487        );
488    }
489
490    Ok(())
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn test_precision_levels() {
499        assert_eq!(Precision::Half.bytes_per_complex(), 4);
500        assert_eq!(Precision::Single.bytes_per_complex(), 8);
501        assert_eq!(Precision::Double.bytes_per_complex(), 16);
502    }
503
504    #[test]
505    fn test_precision_from_tolerance() {
506        assert_eq!(Precision::from_tolerance(0.01), Precision::Half);
507        assert_eq!(Precision::from_tolerance(1e-8), Precision::Double); // 1e-8 < 1e-7, so Double
508        assert_eq!(Precision::from_tolerance(1e-16), Precision::Extended); // 1e-16 < 1e-15, so Extended
509    }
510
511    #[test]
512    fn test_complex_f16() {
513        let c = ComplexF16 {
514            re: f16::from_f64(0.5),
515            im: f16::from_f64(0.5),
516        };
517
518        let c64 = c.to_complex64();
519        assert!((c64.re - 0.5).abs() < 0.01);
520        assert!((c64.im - 0.5).abs() < 0.01);
521    }
522
523    #[test]
524    fn test_adaptive_state_vector() {
525        let mut state = AdaptiveStateVector::new(2, Precision::Single).unwrap();
526        assert_eq!(state.precision(), Precision::Single);
527        assert_eq!(state.num_qubits(), 2);
528
529        // Test conversion
530        let c64 = state.to_complex64();
531        assert_eq!(c64.len(), 4);
532        assert_eq!(c64[0], Complex64::new(1.0, 0.0));
533    }
534
535    #[test]
536    fn test_precision_upgrade() {
537        let mut state = AdaptiveStateVector::new(2, Precision::Half).unwrap();
538        state.upgrade_precision().unwrap();
539        assert_eq!(state.precision(), Precision::Single);
540    }
541
542    #[test]
543    fn test_precision_tracker() {
544        let config = AdaptivePrecisionConfig::default();
545        let mut tracker = PrecisionTracker::new(config);
546
547        // Record exactly 100 gates so gate_count % check_interval == 0
548        for _ in 0..100 {
549            tracker.record_gate();
550        }
551
552        assert!(tracker.should_check_precision());
553
554        tracker.record_change(Precision::Single, Precision::Double);
555        let stats = tracker.stats();
556        assert_eq!(stats.upgrades, 1);
557        assert_eq!(stats.downgrades, 0);
558    }
559
560    #[test]
561    fn test_memory_usage() {
562        let state = AdaptiveStateVector::new(10, Precision::Half).unwrap();
563        let memory = state.memory_usage();
564        assert_eq!(memory, 1024 * 4); // 2^10 * 4 bytes
565    }
566}