use serde::{Deserialize, Serialize};
use crate::error::{Result, SimulatorError};
#[derive(Debug)]
pub struct MixedPrecisionContext;
#[derive(Debug)]
pub enum PrecisionLevel {
F16,
F32,
F64,
Adaptive,
}
#[derive(Debug)]
pub enum AdaptiveStrategy {
ErrorBased(f64),
Fixed(PrecisionLevel),
}
impl MixedPrecisionContext {
pub fn new(_strategy: AdaptiveStrategy) -> Result<Self> {
Err(SimulatorError::UnsupportedOperation(
"Mixed precision context not available without advanced_math feature".to_string(),
))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum QuantumPrecision {
Half,
BFloat16,
TF32,
Single,
Double,
Adaptive,
}
impl QuantumPrecision {
#[cfg(feature = "advanced_math")]
pub const fn to_scirs2_precision(&self) -> PrecisionLevel {
match self {
Self::Half | Self::BFloat16 => PrecisionLevel::F16,
Self::TF32 | Self::Single => PrecisionLevel::F32,
Self::Double => PrecisionLevel::F64,
Self::Adaptive => PrecisionLevel::Adaptive,
}
}
#[must_use]
pub const fn memory_factor(&self) -> f64 {
match self {
Self::Half => 0.25,
Self::BFloat16 => 0.25,
Self::TF32 => 0.5, Self::Single => 0.5,
Self::Double => 1.0,
Self::Adaptive => 0.75, }
}
#[must_use]
pub const fn computation_factor(&self) -> f64 {
match self {
Self::Half => 0.25, Self::BFloat16 => 0.25, Self::TF32 => 0.35, Self::Single => 0.7,
Self::Double => 1.0,
Self::Adaptive => 0.6, }
}
#[must_use]
pub const fn typical_error(&self) -> f64 {
match self {
Self::Half => 1e-3, Self::BFloat16 => 1e-2, Self::TF32 => 1e-4, Self::Single => 1e-6, Self::Double => 1e-15, Self::Adaptive => 1e-6, }
}
#[must_use]
pub const fn requires_tensor_cores(&self) -> bool {
matches!(self, Self::TF32 | Self::BFloat16)
}
#[must_use]
pub const fn is_reduced_precision(&self) -> bool {
matches!(self, Self::Half | Self::BFloat16 | Self::TF32)
}
#[must_use]
pub const fn bit_width(&self) -> usize {
match self {
Self::Half => 16,
Self::BFloat16 => 16,
Self::TF32 => 19, Self::Single => 32,
Self::Double => 64,
Self::Adaptive => 32, }
}
#[must_use]
pub const fn mantissa_bits(&self) -> usize {
match self {
Self::Half => 10,
Self::BFloat16 => 7,
Self::TF32 => 10,
Self::Single => 23,
Self::Double => 52,
Self::Adaptive => 23,
}
}
#[must_use]
pub const fn exponent_bits(&self) -> usize {
match self {
Self::Half => 5,
Self::BFloat16 => 8,
Self::TF32 => 8,
Self::Single => 8,
Self::Double => 11,
Self::Adaptive => 8,
}
}
#[must_use]
pub fn is_sufficient_for_tolerance(&self, tolerance: f64) -> bool {
self.typical_error() <= tolerance * 10.0 }
#[must_use]
pub const fn higher_precision(&self) -> Option<Self> {
match self {
Self::Half => Some(Self::BFloat16),
Self::BFloat16 => Some(Self::TF32),
Self::TF32 => Some(Self::Single),
Self::Single => Some(Self::Double),
Self::Double => None,
Self::Adaptive => Some(Self::Double),
}
}
#[must_use]
pub const fn lower_precision(&self) -> Option<Self> {
match self {
Self::Half => None,
Self::BFloat16 => Some(Self::Half),
Self::TF32 => Some(Self::BFloat16),
Self::Single => Some(Self::TF32),
Self::Double => Some(Self::Single),
Self::Adaptive => Some(Self::Single),
}
}
#[must_use]
pub fn select_for_accuracy_and_tensor_cores(tolerance: f64, has_tensor_cores: bool) -> Self {
if tolerance >= 1e-2 {
if has_tensor_cores {
Self::BFloat16
} else {
Self::Half
}
} else if tolerance >= 1e-4 {
if has_tensor_cores {
Self::TF32
} else {
Self::Single
}
} else if tolerance >= 1e-6 {
Self::Single
} else {
Self::Double
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MixedPrecisionConfig {
pub state_vector_precision: QuantumPrecision,
pub gate_precision: QuantumPrecision,
pub measurement_precision: QuantumPrecision,
pub error_tolerance: f64,
pub adaptive_precision: bool,
pub min_precision: QuantumPrecision,
pub max_precision: QuantumPrecision,
pub large_system_threshold: usize,
pub enable_analysis: bool,
}
impl Default for MixedPrecisionConfig {
fn default() -> Self {
Self {
state_vector_precision: QuantumPrecision::Single,
gate_precision: QuantumPrecision::Single,
measurement_precision: QuantumPrecision::Double,
error_tolerance: 1e-6,
adaptive_precision: true,
min_precision: QuantumPrecision::Half,
max_precision: QuantumPrecision::Double,
large_system_threshold: 20,
enable_analysis: true,
}
}
}
impl MixedPrecisionConfig {
#[must_use]
pub const fn for_accuracy() -> Self {
Self {
state_vector_precision: QuantumPrecision::Double,
gate_precision: QuantumPrecision::Double,
measurement_precision: QuantumPrecision::Double,
error_tolerance: 1e-12,
adaptive_precision: false,
min_precision: QuantumPrecision::Double,
max_precision: QuantumPrecision::Double,
large_system_threshold: 50,
enable_analysis: true,
}
}
#[must_use]
pub const fn for_performance() -> Self {
Self {
state_vector_precision: QuantumPrecision::Half,
gate_precision: QuantumPrecision::Single,
measurement_precision: QuantumPrecision::Single,
error_tolerance: 1e-3,
adaptive_precision: true,
min_precision: QuantumPrecision::Half,
max_precision: QuantumPrecision::Single,
large_system_threshold: 10,
enable_analysis: false,
}
}
#[must_use]
pub fn balanced() -> Self {
Self::default()
}
pub fn validate(&self) -> Result<()> {
if self.error_tolerance <= 0.0 {
return Err(SimulatorError::InvalidInput(
"Error tolerance must be positive".to_string(),
));
}
if self.large_system_threshold == 0 {
return Err(SimulatorError::InvalidInput(
"Large system threshold must be positive".to_string(),
));
}
if self.min_precision as u8 > self.max_precision as u8 {
return Err(SimulatorError::InvalidInput(
"Minimum precision cannot be higher than maximum precision".to_string(),
));
}
Ok(())
}
pub const fn adjust_for_qubits(&mut self, num_qubits: usize) {
if num_qubits >= self.large_system_threshold {
if self.adaptive_precision {
match self.state_vector_precision {
QuantumPrecision::Double => {
self.state_vector_precision = QuantumPrecision::Single;
}
QuantumPrecision::Single => {
self.state_vector_precision = QuantumPrecision::Half;
}
_ => {}
}
}
}
}
#[must_use]
pub fn estimate_memory_usage(&self, num_qubits: usize) -> usize {
let state_vector_size = 1 << num_qubits;
let base_memory = state_vector_size * 16;
let factor = self.state_vector_precision.memory_factor();
(f64::from(base_memory) * factor) as usize
}
#[must_use]
pub fn fits_in_memory(&self, num_qubits: usize, available_memory: usize) -> bool {
self.estimate_memory_usage(num_qubits) <= available_memory
}
}