use crate::prelude::SimulatorError;
use half::f16;
use scirs2_core::ndarray::Array1;
use scirs2_core::{Complex32, Complex64};
use std::fmt;
use crate::error::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Precision {
Half,
Single,
Double,
Extended,
}
impl Precision {
#[must_use]
pub const fn bytes_per_complex(&self) -> usize {
match self {
Self::Half => 4, Self::Single => 8, Self::Double => 16, Self::Extended => 32, }
}
#[must_use]
pub const fn epsilon(&self) -> f64 {
match self {
Self::Half => 0.001, Self::Single => 1e-7, Self::Double => 1e-15, Self::Extended => 1e-30, }
}
#[must_use]
pub fn from_tolerance(tolerance: f64) -> Self {
if tolerance >= 0.001 {
Self::Half
} else if tolerance >= 1e-7 {
Self::Single
} else if tolerance >= 1e-15 {
Self::Double
} else {
Self::Extended
}
}
}
pub trait ComplexAmplitude: Clone + Send + Sync {
fn to_complex64(&self) -> Complex64;
fn from_complex64(c: Complex64) -> Self;
fn norm_sqr(&self) -> f64;
fn scale(&mut self, factor: f64);
}
impl ComplexAmplitude for Complex64 {
fn to_complex64(&self) -> Complex64 {
*self
}
fn from_complex64(c: Complex64) -> Self {
c
}
fn norm_sqr(&self) -> f64 {
self.norm_sqr()
}
fn scale(&mut self, factor: f64) {
*self *= factor;
}
}
impl ComplexAmplitude for Complex32 {
fn to_complex64(&self) -> Complex64 {
Complex64::new(f64::from(self.re), f64::from(self.im))
}
fn from_complex64(c: Complex64) -> Self {
Self::new(c.re as f32, c.im as f32)
}
fn norm_sqr(&self) -> f64 {
f64::from(self.re.mul_add(self.re, self.im * self.im))
}
fn scale(&mut self, factor: f64) {
*self *= factor as f32;
}
}
#[derive(Debug, Clone, Copy)]
pub struct ComplexF16 {
pub re: f16,
pub im: f16,
}
impl ComplexAmplitude for ComplexF16 {
fn to_complex64(&self) -> Complex64 {
Complex64::new(self.re.to_f64(), self.im.to_f64())
}
fn from_complex64(c: Complex64) -> Self {
Self {
re: f16::from_f64(c.re),
im: f16::from_f64(c.im),
}
}
fn norm_sqr(&self) -> f64 {
let r = self.re.to_f64();
let i = self.im.to_f64();
r.mul_add(r, i * i)
}
fn scale(&mut self, factor: f64) {
self.re = f16::from_f64(self.re.to_f64() * factor);
self.im = f16::from_f64(self.im.to_f64() * factor);
}
}
pub enum AdaptiveStateVector {
Half(Array1<ComplexF16>),
Single(Array1<Complex32>),
Double(Array1<Complex64>),
}
impl AdaptiveStateVector {
pub fn new(num_qubits: usize, precision: Precision) -> Result<Self> {
let size = 1 << num_qubits;
if num_qubits > 30 {
return Err(SimulatorError::InvalidQubits(num_qubits));
}
match precision {
Precision::Half => {
let mut state = Array1::from_elem(
size,
ComplexF16 {
re: f16::from_f64(0.0),
im: f16::from_f64(0.0),
},
);
state[0] = ComplexF16 {
re: f16::from_f64(1.0),
im: f16::from_f64(0.0),
};
Ok(Self::Half(state))
}
Precision::Single => {
let mut state = Array1::zeros(size);
state[0] = Complex32::new(1.0, 0.0);
Ok(Self::Single(state))
}
Precision::Double => {
let mut state = Array1::zeros(size);
state[0] = Complex64::new(1.0, 0.0);
Ok(Self::Double(state))
}
Precision::Extended => Err(SimulatorError::InvalidConfiguration(
"Extended precision not yet supported".to_string(),
)),
}
}
#[must_use]
pub const fn precision(&self) -> Precision {
match self {
Self::Half(_) => Precision::Half,
Self::Single(_) => Precision::Single,
Self::Double(_) => Precision::Double,
}
}
#[must_use]
pub fn num_qubits(&self) -> usize {
let size = match self {
Self::Half(v) => v.len(),
Self::Single(v) => v.len(),
Self::Double(v) => v.len(),
};
(size as f64).log2() as usize
}
#[must_use]
pub fn to_complex64(&self) -> Array1<Complex64> {
match self {
Self::Half(v) => v.map(ComplexAmplitude::to_complex64),
Self::Single(v) => v.map(ComplexAmplitude::to_complex64),
Self::Double(v) => v.clone(),
}
}
pub fn from_complex64(&mut self, data: &Array1<Complex64>) -> Result<()> {
match self {
Self::Half(v) => {
if v.len() != data.len() {
return Err(SimulatorError::DimensionMismatch(format!(
"Size mismatch: {} vs {}",
v.len(),
data.len()
)));
}
for (i, &c) in data.iter().enumerate() {
v[i] = ComplexF16::from_complex64(c);
}
}
Self::Single(v) => {
if v.len() != data.len() {
return Err(SimulatorError::DimensionMismatch(format!(
"Size mismatch: {} vs {}",
v.len(),
data.len()
)));
}
for (i, &c) in data.iter().enumerate() {
v[i] = Complex32::from_complex64(c);
}
}
Self::Double(v) => {
v.assign(data);
}
}
Ok(())
}
#[must_use]
pub fn needs_precision_upgrade(&self, threshold: f64) -> bool {
let min_amplitude = match self {
Self::Half(v) => v
.iter()
.map(ComplexAmplitude::norm_sqr)
.filter(|&n| n > 0.0)
.fold(None, |acc, x| match acc {
None => Some(x),
Some(y) => Some(if x < y { x } else { y }),
}),
Self::Single(v) => v
.iter()
.map(|c| f64::from(c.norm_sqr()))
.filter(|&n| n > 0.0)
.fold(None, |acc, x| match acc {
None => Some(x),
Some(y) => Some(if x < y { x } else { y }),
}),
Self::Double(v) => v
.iter()
.map(scirs2_core::Complex::norm_sqr)
.filter(|&n| n > 0.0)
.fold(None, |acc, x| match acc {
None => Some(x),
Some(y) => Some(if x < y { x } else { y }),
}),
};
if let Some(min_amp) = min_amplitude {
min_amp < threshold * self.precision().epsilon()
} else {
false
}
}
pub fn upgrade_precision(&mut self) -> Result<()> {
let new_precision = match self.precision() {
Precision::Half => Precision::Single,
Precision::Single => Precision::Double,
Precision::Double => return Ok(()), Precision::Extended => unreachable!(),
};
let data = self.to_complex64();
*self = Self::new(self.num_qubits(), new_precision)?;
self.from_complex64(&data)?;
Ok(())
}
pub fn downgrade_precision(&mut self, tolerance: f64) -> Result<()> {
let new_precision = match self.precision() {
Precision::Half => return Ok(()), Precision::Single => Precision::Half,
Precision::Double => Precision::Single,
Precision::Extended => Precision::Double,
};
let data = self.to_complex64();
let test_vec = Self::new(self.num_qubits(), new_precision)?;
let mut max_error: f64 = 0.0;
match &test_vec {
Self::Half(_) => {
for &c in &data {
let converted = ComplexF16::from_complex64(c).to_complex64();
let error = (c - converted).norm();
max_error = max_error.max(error);
}
}
Self::Single(_) => {
for &c in &data {
let converted = Complex32::from_complex64(c).to_complex64();
let error = (c - converted).norm();
max_error = max_error.max(error);
}
}
Self::Double(_) => unreachable!(),
}
if max_error < tolerance {
*self = test_vec;
self.from_complex64(&data)?;
}
Ok(())
}
#[must_use]
pub fn memory_usage(&self) -> usize {
let elements = match self {
Self::Half(v) => v.len(),
Self::Single(v) => v.len(),
Self::Double(v) => v.len(),
};
elements * self.precision().bytes_per_complex()
}
}
#[derive(Debug, Clone)]
pub struct AdaptivePrecisionConfig {
pub initial_precision: Precision,
pub error_tolerance: f64,
pub check_interval: usize,
pub auto_upgrade: bool,
pub auto_downgrade: bool,
pub min_amplitude: f64,
}
impl Default for AdaptivePrecisionConfig {
fn default() -> Self {
Self {
initial_precision: Precision::Single,
error_tolerance: 1e-10,
check_interval: 100,
auto_upgrade: true,
auto_downgrade: true,
min_amplitude: 1e-12,
}
}
}
#[derive(Debug)]
pub struct PrecisionTracker {
changes: Vec<(usize, Precision, Precision)>, gate_count: usize,
config: AdaptivePrecisionConfig,
}
impl PrecisionTracker {
#[must_use]
pub const fn new(config: AdaptivePrecisionConfig) -> Self {
Self {
changes: Vec::new(),
gate_count: 0,
config,
}
}
pub const fn record_gate(&mut self) {
self.gate_count += 1;
}
#[must_use]
pub const fn should_check_precision(&self) -> bool {
self.gate_count % self.config.check_interval == 0
}
pub fn record_change(&mut self, from: Precision, to: Precision) {
self.changes.push((self.gate_count, from, to));
}
#[must_use]
pub fn history(&self) -> &[(usize, Precision, Precision)] {
&self.changes
}
#[must_use]
pub fn stats(&self) -> PrecisionStats {
let mut upgrades = 0;
let mut downgrades = 0;
for (_, from, to) in &self.changes {
match (from, to) {
(Precision::Half, Precision::Single)
| (Precision::Single, Precision::Double)
| (Precision::Double, Precision::Extended) => upgrades += 1,
_ => downgrades += 1,
}
}
PrecisionStats {
total_gates: self.gate_count,
precision_changes: self.changes.len(),
upgrades,
downgrades,
}
}
}
#[derive(Debug)]
pub struct PrecisionStats {
pub total_gates: usize,
pub precision_changes: usize,
pub upgrades: usize,
pub downgrades: usize,
}
impl fmt::Display for PrecisionStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Precision Stats: {} gates, {} changes ({} upgrades, {} downgrades)",
self.total_gates, self.precision_changes, self.upgrades, self.downgrades
)
}
}
pub fn benchmark_precisions(num_qubits: usize) -> Result<()> {
println!("\nPrecision Benchmark for {num_qubits} qubits:");
println!("{:-<60}", "");
for precision in [Precision::Half, Precision::Single, Precision::Double] {
let state = AdaptiveStateVector::new(num_qubits, precision)?;
let memory = state.memory_usage();
let memory_mb = memory as f64 / (1024.0 * 1024.0);
println!(
"{:?} precision: {:.2} MB ({} bytes per amplitude)",
precision,
memory_mb,
precision.bytes_per_complex()
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_precision_levels() {
assert_eq!(Precision::Half.bytes_per_complex(), 4);
assert_eq!(Precision::Single.bytes_per_complex(), 8);
assert_eq!(Precision::Double.bytes_per_complex(), 16);
}
#[test]
fn test_precision_from_tolerance() {
assert_eq!(Precision::from_tolerance(0.01), Precision::Half);
assert_eq!(Precision::from_tolerance(1e-8), Precision::Double); assert_eq!(Precision::from_tolerance(1e-16), Precision::Extended); }
#[test]
fn test_complex_f16() {
let c = ComplexF16 {
re: f16::from_f64(0.5),
im: f16::from_f64(0.5),
};
let c64 = c.to_complex64();
assert!((c64.re - 0.5).abs() < 0.01);
assert!((c64.im - 0.5).abs() < 0.01);
}
#[test]
fn test_adaptive_state_vector() {
let mut state = AdaptiveStateVector::new(2, Precision::Single)
.expect("Failed to create adaptive state vector");
assert_eq!(state.precision(), Precision::Single);
assert_eq!(state.num_qubits(), 2);
let c64 = state.to_complex64();
assert_eq!(c64.len(), 4);
assert_eq!(c64[0], Complex64::new(1.0, 0.0));
}
#[test]
fn test_precision_upgrade() {
let mut state = AdaptiveStateVector::new(2, Precision::Half)
.expect("Failed to create half precision state");
state
.upgrade_precision()
.expect("Failed to upgrade precision");
assert_eq!(state.precision(), Precision::Single);
}
#[test]
fn test_precision_tracker() {
let config = AdaptivePrecisionConfig::default();
let mut tracker = PrecisionTracker::new(config);
for _ in 0..100 {
tracker.record_gate();
}
assert!(tracker.should_check_precision());
tracker.record_change(Precision::Single, Precision::Double);
let stats = tracker.stats();
assert_eq!(stats.upgrades, 1);
assert_eq!(stats.downgrades, 0);
}
#[test]
fn test_memory_usage() {
let state = AdaptiveStateVector::new(10, Precision::Half)
.expect("Failed to create state for memory test");
let memory = state.memory_usage();
assert_eq!(memory, 1024 * 4); }
}