use crate::error::{Result, SimulatorError};
use crate::mixed_precision_impl::QuantumPrecision;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum TensorCoreGeneration {
None,
Volta,
Turing,
Ampere,
AdaLovelace,
Hopper,
Blackwell,
}
impl TensorCoreGeneration {
#[must_use]
pub fn from_compute_capability(major: i32, minor: i32) -> Self {
match (major, minor) {
(9, _) => Self::Hopper, (8, 9) => Self::AdaLovelace, (8, _) => Self::Ampere, (7, 5) => Self::Turing, (7, _) => Self::Volta, (10, _) => Self::Blackwell, _ => Self::None,
}
}
#[must_use]
pub const fn supports_tf32(&self) -> bool {
matches!(
self,
Self::Ampere | Self::AdaLovelace | Self::Hopper | Self::Blackwell
)
}
#[must_use]
pub const fn supports_bf16(&self) -> bool {
matches!(
self,
Self::Ampere | Self::AdaLovelace | Self::Hopper | Self::Blackwell
)
}
#[must_use]
pub const fn supports_fp16_tensor(&self) -> bool {
!matches!(self, Self::None)
}
#[must_use]
pub const fn supports_fp8(&self) -> bool {
matches!(self, Self::AdaLovelace | Self::Hopper | Self::Blackwell)
}
#[must_use]
pub const fn optimal_batch_size(&self) -> usize {
match self {
Self::None => 1,
Self::Volta | Self::Turing => 16,
Self::Ampere => 32,
Self::AdaLovelace | Self::Hopper | Self::Blackwell => 64,
}
}
}
#[derive(Debug, Clone)]
pub struct TensorCoreConfig {
pub enable_tf32: bool,
pub enable_fp16_accumulate: bool,
pub min_matrix_size_for_tc: usize,
pub accumulator_precision: AccumulatorPrecision,
pub enable_mixed_precision: bool,
pub generation: TensorCoreGeneration,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccumulatorPrecision {
Fp16,
Fp32,
Fp64,
}
impl Default for TensorCoreConfig {
fn default() -> Self {
Self {
enable_tf32: true,
enable_fp16_accumulate: false,
min_matrix_size_for_tc: 16,
accumulator_precision: AccumulatorPrecision::Fp32,
enable_mixed_precision: true,
generation: TensorCoreGeneration::None,
}
}
}
impl TensorCoreConfig {
#[must_use]
pub const fn for_accuracy() -> Self {
Self {
enable_tf32: false,
enable_fp16_accumulate: false,
min_matrix_size_for_tc: 64,
accumulator_precision: AccumulatorPrecision::Fp64,
enable_mixed_precision: false,
generation: TensorCoreGeneration::None,
}
}
#[must_use]
pub const fn for_performance() -> Self {
Self {
enable_tf32: true,
enable_fp16_accumulate: true,
min_matrix_size_for_tc: 8,
accumulator_precision: AccumulatorPrecision::Fp16,
enable_mixed_precision: true,
generation: TensorCoreGeneration::None,
}
}
pub fn detect_capabilities(&mut self, major: i32, minor: i32) {
self.generation = TensorCoreGeneration::from_compute_capability(major, minor);
if !self.generation.supports_tf32() {
self.enable_tf32 = false;
}
}
#[must_use]
pub const fn tensor_cores_available(&self) -> bool {
!matches!(self.generation, TensorCoreGeneration::None)
}
#[must_use]
pub fn best_precision_for_tolerance(&self, tolerance: f64) -> QuantumPrecision {
QuantumPrecision::select_for_accuracy_and_tensor_cores(
tolerance,
self.tensor_cores_available(),
)
}
}
pub trait TensorCoreOps {
fn supports_tensor_cores(&self) -> bool;
fn tensor_core_config(&self) -> &TensorCoreConfig;
fn matmul_tf32(
&self,
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> Result<()>;
fn matmul_fp16(
&self,
a: &[u16],
b: &[u16],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> Result<()>;
fn matmul_bf16(
&self,
a: &[u16],
b: &[u16],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> Result<()>;
fn apply_gate_mixed_precision(
&self,
state: &mut [f32],
gate: &[f32],
target_qubits: &[usize],
precision: QuantumPrecision,
) -> Result<()>;
}
pub mod fp16_utils {
#[must_use]
pub fn f32_to_fp16(value: f32) -> u16 {
let bits = value.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xFF) as i32 - 127;
let mantissa = bits & 0x7FFFFF;
if exp == 128 {
return ((sign << 15) | 0x7C00 | (mantissa >> 13)) as u16;
}
if exp < -14 {
if exp < -24 {
return (sign << 15) as u16; }
let mantissa = (mantissa | 0x800000) >> (-exp - 14 + 1);
return ((sign << 15) | (mantissa >> 13)) as u16;
}
if exp > 15 {
return ((sign << 15) | 0x7C00) as u16;
}
let fp16_exp = (exp + 15) as u32;
let fp16_mantissa = mantissa >> 13;
((sign << 15) | (fp16_exp << 10) | fp16_mantissa) as u16
}
#[must_use]
pub fn fp16_to_f32(value: u16) -> f32 {
let sign = ((value >> 15) & 1) as u32;
let exp = ((value >> 10) & 0x1F) as i32;
let mantissa = (value & 0x3FF) as u32;
if exp == 0 {
if mantissa == 0 {
return f32::from_bits(sign << 31);
}
let mut m = mantissa;
let mut e = -14i32;
while (m & 0x400) == 0 {
m <<= 1;
e -= 1;
}
let mantissa = (m & 0x3FF) << 13;
let exp = (e + 127) as u32;
return f32::from_bits((sign << 31) | (exp << 23) | mantissa);
}
if exp == 31 {
let bits = (sign << 31) | 0x7F800000 | (mantissa << 13);
return f32::from_bits(bits);
}
let fp32_exp = (exp - 15 + 127) as u32;
let fp32_mantissa = mantissa << 13;
f32::from_bits((sign << 31) | (fp32_exp << 23) | fp32_mantissa)
}
#[must_use]
pub fn f32_to_bf16(value: f32) -> u16 {
let bits = value.to_bits();
let rounding_bias = 0x7FFF + ((bits >> 16) & 1);
((bits + rounding_bias) >> 16) as u16
}
#[must_use]
pub fn bf16_to_f32(value: u16) -> f32 {
f32::from_bits((value as u32) << 16)
}
#[must_use]
pub fn f32_to_tf32(value: f32) -> f32 {
let bits = value.to_bits();
let mask = 0xFFFFE000u32;
f32::from_bits(bits & mask)
}
#[must_use]
pub fn f32_to_tf32_rounded(value: f32) -> f32 {
let bits = value.to_bits();
let rounding_bias = 0x1000; let rounded = bits.saturating_add(rounding_bias);
let mask = 0xFFFFE000u32;
f32::from_bits(rounded & mask)
}
}
pub struct TensorCoreKernels {
config: TensorCoreConfig,
cuda_available: bool,
}
impl TensorCoreKernels {
#[must_use]
pub fn new(config: TensorCoreConfig) -> Self {
Self {
config,
cuda_available: false, }
}
pub fn initialize(&mut self) -> Result<()> {
#[cfg(feature = "advanced_math")]
{
use super::context::CudaContext;
if let Ok(context) = CudaContext::new(0) {
let props = context.get_device_properties();
self.config
.detect_capabilities(props.compute_capability.0, props.compute_capability.1);
self.cuda_available = true;
}
}
Ok(())
}
pub fn apply_single_qubit_gate_tf32(
&self,
state: &mut [f32],
gate: &[[f32; 2]; 2],
target_qubit: usize,
num_qubits: usize,
) -> Result<()> {
if !self.config.enable_tf32 || !self.config.generation.supports_tf32() {
return Err(SimulatorError::UnsupportedOperation(
"TF32 not supported on this GPU".to_string(),
));
}
let state_size = 1 << num_qubits;
if state.len() != state_size * 2 {
return Err(SimulatorError::InvalidInput(
"State vector size mismatch".to_string(),
));
}
let gate_tf32 = [
[
fp16_utils::f32_to_tf32_rounded(gate[0][0]),
fp16_utils::f32_to_tf32_rounded(gate[0][1]),
],
[
fp16_utils::f32_to_tf32_rounded(gate[1][0]),
fp16_utils::f32_to_tf32_rounded(gate[1][1]),
],
];
let mask = 1 << target_qubit;
for i in 0..(state_size / 2) {
let i0 = (i & !(mask - 1)) << 1 | (i & (mask - 1));
let i1 = i0 | mask;
let idx0 = i0 * 2;
let idx1 = i1 * 2;
let (a_re, a_im) = (state[idx0], state[idx0 + 1]);
let (b_re, b_im) = (state[idx1], state[idx1 + 1]);
let new_a_re = gate_tf32[0][0] * a_re - gate_tf32[0][0] * a_im + gate_tf32[0][1] * b_re
- gate_tf32[0][1] * b_im;
let new_a_im = gate_tf32[0][0] * a_im
+ gate_tf32[0][0] * a_re
+ gate_tf32[0][1] * b_im
+ gate_tf32[0][1] * b_re;
let new_b_re = gate_tf32[1][0] * a_re - gate_tf32[1][0] * a_im + gate_tf32[1][1] * b_re
- gate_tf32[1][1] * b_im;
let new_b_im = gate_tf32[1][0] * a_im
+ gate_tf32[1][0] * a_re
+ gate_tf32[1][1] * b_im
+ gate_tf32[1][1] * b_re;
state[idx0] = new_a_re;
state[idx0 + 1] = new_a_im;
state[idx1] = new_b_re;
state[idx1 + 1] = new_b_im;
}
Ok(())
}
#[must_use]
pub fn estimated_tflops(&self, precision: QuantumPrecision) -> f64 {
match (self.config.generation, precision) {
(TensorCoreGeneration::Ampere, QuantumPrecision::TF32) => 156.0, (TensorCoreGeneration::Ampere, QuantumPrecision::Half) => 312.0, (TensorCoreGeneration::Ampere, QuantumPrecision::BFloat16) => 312.0,
(TensorCoreGeneration::Ampere, QuantumPrecision::Single) => 19.5,
(TensorCoreGeneration::Hopper, QuantumPrecision::TF32) => 495.0, (TensorCoreGeneration::Hopper, QuantumPrecision::Half) => 990.0,
(TensorCoreGeneration::Hopper, QuantumPrecision::BFloat16) => 990.0,
(TensorCoreGeneration::Volta, QuantumPrecision::Half) => 125.0, (TensorCoreGeneration::Turing, QuantumPrecision::Half) => 65.0, (TensorCoreGeneration::AdaLovelace, QuantumPrecision::TF32) => 82.6, (TensorCoreGeneration::AdaLovelace, QuantumPrecision::Half) => 165.2,
_ => 0.0, }
}
#[must_use]
pub fn supports_precision(&self, precision: QuantumPrecision) -> bool {
match precision {
QuantumPrecision::TF32 => self.config.generation.supports_tf32(),
QuantumPrecision::BFloat16 => self.config.generation.supports_bf16(),
QuantumPrecision::Half => self.config.generation.supports_fp16_tensor(),
QuantumPrecision::Single | QuantumPrecision::Double => true,
QuantumPrecision::Adaptive => true,
}
}
}
impl TensorCoreOps for TensorCoreKernels {
fn supports_tensor_cores(&self) -> bool {
self.config.tensor_cores_available() && self.cuda_available
}
fn tensor_core_config(&self) -> &TensorCoreConfig {
&self.config
}
fn matmul_tf32(
&self,
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> Result<()> {
if !self.config.generation.supports_tf32() {
return Err(SimulatorError::UnsupportedOperation(
"TF32 not supported".to_string(),
));
}
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(SimulatorError::InvalidInput(
"Matrix dimension mismatch".to_string(),
));
}
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for l in 0..k {
let a_tf32 = fp16_utils::f32_to_tf32(a[i * k + l]);
let b_tf32 = fp16_utils::f32_to_tf32(b[l * n + j]);
sum += a_tf32 * b_tf32;
}
c[i * n + j] = sum;
}
}
Ok(())
}
fn matmul_fp16(
&self,
a: &[u16],
b: &[u16],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> Result<()> {
if !self.config.generation.supports_fp16_tensor() {
return Err(SimulatorError::UnsupportedOperation(
"FP16 Tensor Cores not supported".to_string(),
));
}
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(SimulatorError::InvalidInput(
"Matrix dimension mismatch".to_string(),
));
}
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for l in 0..k {
let a_f32 = fp16_utils::fp16_to_f32(a[i * k + l]);
let b_f32 = fp16_utils::fp16_to_f32(b[l * n + j]);
sum += a_f32 * b_f32;
}
c[i * n + j] = sum;
}
}
Ok(())
}
fn matmul_bf16(
&self,
a: &[u16],
b: &[u16],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
) -> Result<()> {
if !self.config.generation.supports_bf16() {
return Err(SimulatorError::UnsupportedOperation(
"BF16 not supported".to_string(),
));
}
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(SimulatorError::InvalidInput(
"Matrix dimension mismatch".to_string(),
));
}
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for l in 0..k {
let a_f32 = fp16_utils::bf16_to_f32(a[i * k + l]);
let b_f32 = fp16_utils::bf16_to_f32(b[l * n + j]);
sum += a_f32 * b_f32;
}
c[i * n + j] = sum;
}
}
Ok(())
}
fn apply_gate_mixed_precision(
&self,
state: &mut [f32],
gate: &[f32],
_target_qubits: &[usize],
precision: QuantumPrecision,
) -> Result<()> {
if !self.supports_precision(precision) {
return Err(SimulatorError::UnsupportedOperation(format!(
"Precision {:?} not supported",
precision
)));
}
match precision {
QuantumPrecision::TF32 => {
let _gate_tf32: Vec<f32> =
gate.iter().map(|&v| fp16_utils::f32_to_tf32(v)).collect();
}
QuantumPrecision::Half | QuantumPrecision::BFloat16 => {
let _gate_fp16: Vec<u16> =
gate.iter().map(|&v| fp16_utils::f32_to_fp16(v)).collect();
}
_ => {
}
}
let _ = state;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_core_generation_detection() {
assert_eq!(
TensorCoreGeneration::from_compute_capability(9, 0),
TensorCoreGeneration::Hopper
);
assert_eq!(
TensorCoreGeneration::from_compute_capability(8, 9),
TensorCoreGeneration::AdaLovelace
);
assert_eq!(
TensorCoreGeneration::from_compute_capability(8, 0),
TensorCoreGeneration::Ampere
);
assert_eq!(
TensorCoreGeneration::from_compute_capability(7, 5),
TensorCoreGeneration::Turing
);
assert_eq!(
TensorCoreGeneration::from_compute_capability(7, 0),
TensorCoreGeneration::Volta
);
assert_eq!(
TensorCoreGeneration::from_compute_capability(6, 1),
TensorCoreGeneration::None
);
}
#[test]
fn test_tf32_support() {
assert!(TensorCoreGeneration::Ampere.supports_tf32());
assert!(TensorCoreGeneration::Hopper.supports_tf32());
assert!(!TensorCoreGeneration::Volta.supports_tf32());
assert!(!TensorCoreGeneration::Turing.supports_tf32());
}
#[test]
fn test_fp16_conversion() {
let one_f32 = 1.0f32;
let one_fp16 = fp16_utils::f32_to_fp16(one_f32);
let one_back = fp16_utils::fp16_to_f32(one_fp16);
assert!((one_f32 - one_back).abs() < 1e-6);
let half_f32 = 0.5f32;
let half_fp16 = fp16_utils::f32_to_fp16(half_f32);
let half_back = fp16_utils::fp16_to_f32(half_fp16);
assert!((half_f32 - half_back).abs() < 1e-6);
}
#[test]
fn test_bf16_conversion() {
let value = std::f32::consts::PI;
let bf16 = fp16_utils::f32_to_bf16(value);
let back = fp16_utils::bf16_to_f32(bf16);
assert!((value - back).abs() < 0.01);
}
#[test]
fn test_tf32_truncation() {
let value = 1.234_568_f32;
let tf32 = fp16_utils::f32_to_tf32(value);
assert!((value - tf32).abs() < 0.001);
}
#[test]
fn test_tensor_core_config_default() {
let config = TensorCoreConfig::default();
assert!(config.enable_tf32);
assert!(!config.enable_fp16_accumulate);
assert_eq!(config.accumulator_precision, AccumulatorPrecision::Fp32);
}
#[test]
fn test_tensor_core_config_performance() {
let config = TensorCoreConfig::for_performance();
assert!(config.enable_tf32);
assert!(config.enable_fp16_accumulate);
assert_eq!(config.accumulator_precision, AccumulatorPrecision::Fp16);
}
#[test]
fn test_tensor_core_config_accuracy() {
let config = TensorCoreConfig::for_accuracy();
assert!(!config.enable_tf32);
assert!(!config.enable_fp16_accumulate);
assert_eq!(config.accumulator_precision, AccumulatorPrecision::Fp64);
}
#[test]
fn test_precision_selection() {
let config_no_tc = TensorCoreConfig::default();
assert!(!config_no_tc.tensor_cores_available());
let precision_no_tc = config_no_tc.best_precision_for_tolerance(1e-3);
assert_eq!(precision_no_tc, QuantumPrecision::Single);
let config_with_tc = TensorCoreConfig {
generation: TensorCoreGeneration::Ampere,
..TensorCoreConfig::default()
};
assert!(config_with_tc.tensor_cores_available());
let precision_with_tc = config_with_tc.best_precision_for_tolerance(1e-3);
assert_eq!(precision_with_tc, QuantumPrecision::TF32);
let precision_low_tolerance = config_with_tc.best_precision_for_tolerance(1e-2);
assert_eq!(precision_low_tolerance, QuantumPrecision::BFloat16);
let precision_high = config_no_tc.best_precision_for_tolerance(1e-8);
assert_eq!(precision_high, QuantumPrecision::Double);
}
#[test]
fn test_matmul_tf32_dimensions() {
let config = TensorCoreConfig {
generation: TensorCoreGeneration::Ampere,
..TensorCoreConfig::default()
};
let kernels = TensorCoreKernels::new(config);
let a = vec![1.0f32; 4 * 4];
let b = vec![1.0f32; 4 * 4];
let mut c = vec![0.0f32; 4 * 4];
let result = kernels.matmul_tf32(&a, &b, &mut c, 4, 4, 4);
assert!(result.is_ok());
for val in c {
assert!((val - 4.0).abs() < 0.01);
}
}
#[test]
fn test_estimated_tflops() {
let config = TensorCoreConfig {
generation: TensorCoreGeneration::Ampere,
..TensorCoreConfig::default()
};
let kernels = TensorCoreKernels::new(config);
let tflops_tf32 = kernels.estimated_tflops(QuantumPrecision::TF32);
let tflops_fp16 = kernels.estimated_tflops(QuantumPrecision::Half);
assert!(tflops_tf32 > 0.0);
assert!(tflops_fp16 > tflops_tf32); }
#[test]
fn test_quantum_precision_tensor_core_requirements() {
assert!(QuantumPrecision::TF32.requires_tensor_cores());
assert!(QuantumPrecision::BFloat16.requires_tensor_cores());
assert!(!QuantumPrecision::Half.requires_tensor_cores());
assert!(!QuantumPrecision::Single.requires_tensor_cores());
assert!(!QuantumPrecision::Double.requires_tensor_cores());
}
#[test]
fn test_quantum_precision_bit_width() {
assert_eq!(QuantumPrecision::Half.bit_width(), 16);
assert_eq!(QuantumPrecision::BFloat16.bit_width(), 16);
assert_eq!(QuantumPrecision::TF32.bit_width(), 19);
assert_eq!(QuantumPrecision::Single.bit_width(), 32);
assert_eq!(QuantumPrecision::Double.bit_width(), 64);
}
#[test]
fn test_quantum_precision_mantissa_bits() {
assert_eq!(QuantumPrecision::Half.mantissa_bits(), 10);
assert_eq!(QuantumPrecision::BFloat16.mantissa_bits(), 7);
assert_eq!(QuantumPrecision::TF32.mantissa_bits(), 10);
assert_eq!(QuantumPrecision::Single.mantissa_bits(), 23);
assert_eq!(QuantumPrecision::Double.mantissa_bits(), 52);
}
}