use crate::errors::{Result, TrustformersError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MxFormat {
Mxfp8,
Mxfp6,
Mxfp4,
Mxint8,
}
impl MxFormat {
pub fn element_bits(&self) -> u8 {
match self {
MxFormat::Mxfp8 => 8,
MxFormat::Mxfp6 => 6,
MxFormat::Mxfp4 => 4,
MxFormat::Mxint8 => 8,
}
}
pub fn mantissa_bits(&self) -> u8 {
match self {
MxFormat::Mxfp8 => 3,
MxFormat::Mxfp6 => 2,
MxFormat::Mxfp4 => 1,
MxFormat::Mxint8 => 7, }
}
pub fn element_exponent_bits(&self) -> u8 {
match self {
MxFormat::Mxfp8 => 4,
MxFormat::Mxfp6 => 3,
MxFormat::Mxfp4 => 2,
MxFormat::Mxint8 => 0,
}
}
pub fn element_exponent_bias(&self) -> i32 {
match self {
MxFormat::Mxfp8 => 7, MxFormat::Mxfp6 => 3, MxFormat::Mxfp4 => 1, MxFormat::Mxint8 => 0, }
}
pub fn is_float_format(&self) -> bool {
!matches!(self, MxFormat::Mxint8)
}
pub fn max_element_value(&self) -> f32 {
match self {
MxFormat::Mxfp8 => {
let max_exp = (1 << self.element_exponent_bits()) - 2; let bias = self.element_exponent_bias(); let mantissa_frac = 1.0 - 2.0f32.powi(-(self.mantissa_bits() as i32));
(1.0 + mantissa_frac) * 2.0f32.powi(max_exp - bias)
},
MxFormat::Mxfp6 => {
let max_exp = (1 << self.element_exponent_bits()) - 2; let bias = self.element_exponent_bias(); let mantissa_frac = 1.0 - 2.0f32.powi(-(self.mantissa_bits() as i32));
(1.0 + mantissa_frac) * 2.0f32.powi(max_exp - bias)
},
MxFormat::Mxfp4 => {
let max_exp = (1 << self.element_exponent_bits()) - 2; let bias = self.element_exponent_bias(); let mantissa_frac = 1.0 - 2.0f32.powi(-(self.mantissa_bits() as i32));
(1.0 + mantissa_frac) * 2.0f32.powi(max_exp - bias)
},
MxFormat::Mxint8 => 127.0,
}
}
}
impl std::fmt::Display for MxFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MxFormat::Mxfp8 => write!(f, "MXFP8"),
MxFormat::Mxfp6 => write!(f, "MXFP6"),
MxFormat::Mxfp4 => write!(f, "MXFP4"),
MxFormat::Mxint8 => write!(f, "MXINT8"),
}
}
}
const VALID_BLOCK_SIZES: [usize; 5] = [2, 4, 8, 16, 32];
#[derive(Debug, Clone)]
pub struct MxQuantConfig {
pub format: MxFormat,
pub block_size: usize,
}
impl MxQuantConfig {
pub fn new(format: MxFormat, block_size: usize) -> Result<Self> {
if !VALID_BLOCK_SIZES.contains(&block_size) {
return Err(TrustformersError::quantization_error(format!(
"Invalid MX block size: {}. Must be one of: 2, 4, 8, 16, 32",
block_size
)));
}
Ok(Self { format, block_size })
}
}
#[derive(Debug, Clone)]
pub struct MxQuantized {
shared_exponents: Vec<u8>,
mantissa_data: Vec<u8>,
config: MxQuantConfig,
shape: Vec<usize>,
num_elements: usize,
}
impl MxQuantized {
pub fn shared_exponents(&self) -> &[u8] {
&self.shared_exponents
}
pub fn mantissa_data(&self) -> &[u8] {
&self.mantissa_data
}
pub fn config(&self) -> &MxQuantConfig {
&self.config
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn num_elements(&self) -> usize {
self.num_elements
}
pub fn size_bytes(&self) -> usize {
self.shared_exponents.len() + self.mantissa_data.len()
}
}
pub fn compression_ratio(format: MxFormat, block_size: usize) -> f32 {
let original_bits_per_element = 32.0f32;
let quantized_bits_per_element = format.element_bits() as f32 + 8.0 / block_size as f32;
original_bits_per_element / quantized_bits_per_element
}
fn extract_f32_exponent(value: f32) -> u8 {
let bits = value.to_bits();
((bits >> 23) & 0xFF) as u8
}
fn floor_log2_abs(value: f32) -> i32 {
let abs_val = value.abs();
if abs_val == 0.0 || abs_val.is_subnormal() {
return -127; }
let biased_exp = extract_f32_exponent(abs_val);
biased_exp as i32 - 127 }
fn compute_shared_exponent(block: &[f32]) -> u8 {
let mut max_exp = -127i32;
for &val in block {
let exp = floor_log2_abs(val);
if exp > max_exp {
max_exp = exp;
}
}
let biased = max_exp + 127;
if biased < 0 {
0u8
} else if biased > 254 {
254u8
} else {
biased as u8
}
}
fn quantize_element_fp(value: f32, shared_exp_unbiased: i32, format: MxFormat) -> u16 {
let sign_bit: u16 = if value < 0.0 { 1 } else { 0 };
let abs_val = value.abs();
if abs_val == 0.0 {
return 0;
}
let exp_bits = format.element_exponent_bits() as u32;
let mant_bits = format.mantissa_bits() as u32;
let elem_bias = format.element_exponent_bias();
let max_elem_exp = ((1i32 << exp_bits) - 2) - elem_bias;
let scaled = abs_val * 2.0f32.powi(-shared_exp_unbiased);
if scaled < f32::MIN_POSITIVE {
return 0;
}
let raw_local_exp = floor_log2_abs(scaled);
if raw_local_exp > max_elem_exp {
let biased_exp = (max_elem_exp + elem_bias) as u16;
let mantissa = (1u32 << mant_bits) - 1;
return (sign_bit << (exp_bits + mant_bits) as u16)
| (biased_exp << mant_bits as u16)
| (mantissa as u16);
}
if raw_local_exp >= -elem_bias {
let clamped_exp = raw_local_exp;
let significand = scaled / 2.0f32.powi(clamped_exp);
let frac = (significand - 1.0).max(0.0);
let mantissa_max = (1u32 << mant_bits) - 1;
let mantissa_raw = (frac * (1u32 << mant_bits) as f32 + 0.5) as u32;
let mantissa = mantissa_raw.min(mantissa_max);
let biased_exp = (clamped_exp + elem_bias) as u16;
(sign_bit << (exp_bits + mant_bits) as u16)
| (biased_exp << mant_bits as u16)
| (mantissa as u16)
} else {
let mantissa_f = scaled * 2.0f32.powi(elem_bias) * (1u32 << mant_bits) as f32;
let mantissa_max = (1u32 << mant_bits) - 1;
let mantissa = (mantissa_f + 0.5).min(mantissa_max as f32) as u32;
if mantissa == 0 {
return 0;
}
(sign_bit << (exp_bits + mant_bits) as u16) | (mantissa as u16)
}
}
fn quantize_element_int(value: f32, shared_exp_unbiased: i32) -> u8 {
let sign_bit: u8 = if value < 0.0 { 1 } else { 0 };
let abs_val = value.abs();
if abs_val == 0.0 {
return 0;
}
let scale = 2.0f32.powi(shared_exp_unbiased - 7);
let quantized = if scale > 0.0 { (abs_val / scale + 0.5) as u32 } else { 0 };
let clamped = quantized.min(127) as u8;
(sign_bit << 7) | clamped
}
fn dequantize_element_fp(packed: u16, shared_exp_unbiased: i32, format: MxFormat) -> f32 {
let exp_bits = format.element_exponent_bits() as u32;
let mant_bits = format.mantissa_bits() as u32;
let elem_bias = format.element_exponent_bias();
let sign_bit = (packed >> (exp_bits + mant_bits)) & 1;
let biased_exp = ((packed >> mant_bits) & ((1 << exp_bits) - 1)) as i32;
let mantissa = packed & ((1 << mant_bits) - 1);
if biased_exp == 0 && mantissa == 0 {
return if sign_bit != 0 { -0.0 } else { 0.0 };
}
let local_exp_unbiased = biased_exp - elem_bias;
let significand = if biased_exp == 0 {
mantissa as f32 / (1u32 << mant_bits) as f32
} else {
1.0 + mantissa as f32 / (1u32 << mant_bits) as f32
};
let value = significand * 2.0f32.powi(local_exp_unbiased + shared_exp_unbiased);
if sign_bit != 0 {
-value
} else {
value
}
}
fn dequantize_element_int(packed: u8, shared_exp_unbiased: i32) -> f32 {
let sign_bit = (packed >> 7) & 1;
let magnitude = packed & 0x7F;
let scale = 2.0f32.powi(shared_exp_unbiased - 7);
let value = magnitude as f32 * scale;
if sign_bit != 0 {
-value
} else {
value
}
}
fn pack_bits(values: &[u16], bits_per_element: u8) -> Vec<u8> {
if values.is_empty() {
return Vec::new();
}
let total_bits = values.len() * bits_per_element as usize;
let num_bytes = total_bits.div_ceil(8);
let mut packed = vec![0u8; num_bytes];
let mut bit_offset = 0usize;
for &val in values {
let bpe = bits_per_element as usize;
for i in 0..bpe {
let bit = (val >> (bpe - 1 - i)) & 1;
if bit != 0 {
let byte_idx = bit_offset / 8;
let bit_idx = 7 - (bit_offset % 8);
packed[byte_idx] |= 1 << bit_idx;
}
bit_offset += 1;
}
}
packed
}
fn unpack_bits(packed: &[u8], bits_per_element: u8, count: usize) -> Vec<u16> {
let mut values = Vec::with_capacity(count);
let bpe = bits_per_element as usize;
let mut bit_offset = 0usize;
for _ in 0..count {
let mut val: u16 = 0;
for i in 0..bpe {
let byte_idx = bit_offset / 8;
let bit_idx = 7 - (bit_offset % 8);
if byte_idx < packed.len() {
let bit = (packed[byte_idx] >> bit_idx) & 1;
val |= (bit as u16) << (bpe - 1 - i);
}
bit_offset += 1;
}
values.push(val);
}
values
}
pub fn quantize_mx(data: &[f32], config: &MxQuantConfig) -> Result<MxQuantized> {
if data.is_empty() {
return Err(TrustformersError::quantization_error(
"MX quantize: input data is empty".to_string(),
));
}
for (i, &val) in data.iter().enumerate() {
if val.is_nan() {
return Err(TrustformersError::quantization_error(format!(
"MX quantize: NaN value at index {}",
i
)));
}
}
let block_size = config.block_size;
let num_elements = data.len();
let num_blocks = num_elements.div_ceil(block_size);
let mut shared_exponents = Vec::with_capacity(num_blocks);
let mut all_element_values: Vec<u16> = Vec::with_capacity(num_elements);
for block_idx in 0..num_blocks {
let start = block_idx * block_size;
let end = (start + block_size).min(num_elements);
let block = &data[start..end];
let shared_exp = compute_shared_exponent(block);
shared_exponents.push(shared_exp);
let shared_exp_unbiased = shared_exp as i32 - 127;
for &val in block {
let quantized = if config.format.is_float_format() {
quantize_element_fp(val, shared_exp_unbiased, config.format)
} else {
quantize_element_int(val, shared_exp_unbiased) as u16
};
all_element_values.push(quantized);
}
let pad_count = (start + block_size).saturating_sub(end);
all_element_values.extend(std::iter::repeat_n(0u16, pad_count));
}
let bits_per_element = config.format.element_bits();
let mantissa_data = pack_bits(&all_element_values, bits_per_element);
let shape = vec![num_elements];
Ok(MxQuantized {
shared_exponents,
mantissa_data,
config: config.clone(),
shape,
num_elements,
})
}
pub fn quantize_mx_with_shape(
data: &[f32],
config: &MxQuantConfig,
shape: &[usize],
) -> Result<MxQuantized> {
let shape_elements: usize = shape.iter().product();
if shape_elements != data.len() {
return Err(TrustformersError::quantization_error(format!(
"MX quantize: shape {:?} implies {} elements but data has {} elements",
shape,
shape_elements,
data.len()
)));
}
let mut quantized = quantize_mx(data, config)?;
quantized.shape = shape.to_vec();
Ok(quantized)
}
pub fn dequantize_mx(quantized: &MxQuantized) -> Vec<f32> {
let block_size = quantized.config.block_size;
let bits_per_element = quantized.config.format.element_bits();
let num_elements = quantized.num_elements;
let num_blocks = quantized.shared_exponents.len();
let total_packed = num_blocks * block_size;
let all_values = unpack_bits(&quantized.mantissa_data, bits_per_element, total_packed);
let mut result = Vec::with_capacity(num_elements);
for block_idx in 0..num_blocks {
let shared_exp = quantized.shared_exponents[block_idx];
let shared_exp_unbiased = shared_exp as i32 - 127;
let start = block_idx * block_size;
let end = (start + block_size).min(num_elements);
for &packed in &all_values[start..end] {
let value = if quantized.config.format.is_float_format() {
dequantize_element_fp(packed, shared_exp_unbiased, quantized.config.format)
} else {
dequantize_element_int(packed as u8, shared_exp_unbiased)
};
result.push(value);
}
}
result
}
pub struct MxErrorStats {
pub mae: f32,
pub rmse: f32,
pub max_error: f32,
pub snr_db: f32,
}
pub fn compute_mx_error(original: &[f32], dequantized: &[f32]) -> Result<MxErrorStats> {
if original.len() != dequantized.len() {
return Err(TrustformersError::quantization_error(format!(
"MX error computation: length mismatch ({} vs {})",
original.len(),
dequantized.len()
)));
}
if original.is_empty() {
return Err(TrustformersError::quantization_error(
"MX error computation: empty data".to_string(),
));
}
let n = original.len() as f32;
let mut sum_abs_error = 0.0f64;
let mut sum_sq_error = 0.0f64;
let mut max_error = 0.0f64;
let mut signal_power = 0.0f64;
for (o, d) in original.iter().zip(dequantized.iter()) {
let error = (*o as f64) - (*d as f64);
let abs_error = error.abs();
sum_abs_error += abs_error;
sum_sq_error += error * error;
if abs_error > max_error {
max_error = abs_error;
}
signal_power += (*o as f64) * (*o as f64);
}
let mae = (sum_abs_error / n as f64) as f32;
let rmse = ((sum_sq_error / n as f64).sqrt()) as f32;
let max_error = max_error as f32;
let snr_db = if sum_sq_error > 0.0 {
(10.0 * (signal_power / sum_sq_error).log10()) as f32
} else {
f32::INFINITY
};
Ok(MxErrorStats {
mae,
rmse,
max_error,
snr_db,
})
}
#[cfg(test)]
struct LcgRng {
state: u64,
}
#[cfg(test)]
impl LcgRng {
fn new(seed: u64) -> Self {
Self { state: seed }
}
fn next_f32(&mut self, range: f32) -> f32 {
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let normalized = (self.state >> 33) as f32 / (u32::MAX >> 1) as f32;
(normalized * 2.0 - 1.0) * range
}
fn next_f32_positive(&mut self, range: f32) -> f32 {
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let normalized = (self.state >> 33) as f32 / (u32::MAX >> 1) as f32;
normalized * range
}
}
#[cfg(test)]
mod tests {
use super::*;
fn generate_test_data(seed: u64, count: usize, range: f32) -> Vec<f32> {
let mut rng = LcgRng::new(seed);
(0..count).map(|_| rng.next_f32(range)).collect()
}
fn generate_positive_test_data(seed: u64, count: usize, range: f32) -> Vec<f32> {
let mut rng = LcgRng::new(seed);
(0..count).map(|_| rng.next_f32_positive(range)).collect()
}
#[test]
fn test_config_valid_block_sizes() {
for &bs in &[2, 4, 8, 16, 32] {
let config = MxQuantConfig::new(MxFormat::Mxfp8, bs);
assert!(config.is_ok(), "Block size {} should be valid", bs);
}
}
#[test]
fn test_config_invalid_block_sizes() {
for &bs in &[0, 1, 3, 5, 6, 7, 9, 15, 17, 31, 33, 64, 128] {
let config = MxQuantConfig::new(MxFormat::Mxfp8, bs);
assert!(config.is_err(), "Block size {} should be invalid", bs);
}
}
#[test]
fn test_format_element_bits() {
assert_eq!(MxFormat::Mxfp8.element_bits(), 8);
assert_eq!(MxFormat::Mxfp6.element_bits(), 6);
assert_eq!(MxFormat::Mxfp4.element_bits(), 4);
assert_eq!(MxFormat::Mxint8.element_bits(), 8);
}
#[test]
fn test_format_mantissa_bits() {
assert_eq!(MxFormat::Mxfp8.mantissa_bits(), 3);
assert_eq!(MxFormat::Mxfp6.mantissa_bits(), 2);
assert_eq!(MxFormat::Mxfp4.mantissa_bits(), 1);
assert_eq!(MxFormat::Mxint8.mantissa_bits(), 7);
}
#[test]
fn test_format_exponent_bits() {
assert_eq!(MxFormat::Mxfp8.element_exponent_bits(), 4);
assert_eq!(MxFormat::Mxfp6.element_exponent_bits(), 3);
assert_eq!(MxFormat::Mxfp4.element_exponent_bits(), 2);
assert_eq!(MxFormat::Mxint8.element_exponent_bits(), 0);
}
#[test]
fn test_format_is_float() {
assert!(MxFormat::Mxfp8.is_float_format());
assert!(MxFormat::Mxfp6.is_float_format());
assert!(MxFormat::Mxfp4.is_float_format());
assert!(!MxFormat::Mxint8.is_float_format());
}
#[test]
fn test_format_display() {
assert_eq!(format!("{}", MxFormat::Mxfp8), "MXFP8");
assert_eq!(format!("{}", MxFormat::Mxfp6), "MXFP6");
assert_eq!(format!("{}", MxFormat::Mxfp4), "MXFP4");
assert_eq!(format!("{}", MxFormat::Mxint8), "MXINT8");
}
#[test]
fn test_format_max_element_value_positive() {
for fmt in &[
MxFormat::Mxfp8,
MxFormat::Mxfp6,
MxFormat::Mxfp4,
MxFormat::Mxint8,
] {
assert!(fmt.max_element_value() > 0.0, "{:?} max must be > 0", fmt);
}
}
#[test]
fn test_compression_ratio_mxfp8_block32() {
let ratio = compression_ratio(MxFormat::Mxfp8, 32);
assert!(ratio > 3.8 && ratio < 4.0, "MXFP8/32 ratio = {}", ratio);
}
#[test]
fn test_compression_ratio_mxfp4_block32() {
let ratio = compression_ratio(MxFormat::Mxfp4, 32);
assert!(ratio > 7.0 && ratio < 8.0, "MXFP4/32 ratio = {}", ratio);
}
#[test]
fn test_compression_ratio_mxfp6_block16() {
let ratio = compression_ratio(MxFormat::Mxfp6, 16);
assert!(ratio > 4.5 && ratio < 5.5, "MXFP6/16 ratio = {}", ratio);
}
#[test]
fn test_compression_ratio_increases_with_block_size() {
let fmt = MxFormat::Mxfp8;
let r2 = compression_ratio(fmt, 2);
let r4 = compression_ratio(fmt, 4);
let r8 = compression_ratio(fmt, 8);
let r16 = compression_ratio(fmt, 16);
let r32 = compression_ratio(fmt, 32);
assert!(r2 < r4, "ratio should increase: {} < {}", r2, r4);
assert!(r4 < r8, "ratio should increase: {} < {}", r4, r8);
assert!(r8 < r16, "ratio should increase: {} < {}", r8, r16);
assert!(r16 < r32, "ratio should increase: {} < {}", r16, r32);
}
#[test]
fn test_pack_unpack_8bit_roundtrip() {
let values: Vec<u16> = vec![0xFF, 0x00, 0xAB, 0x55, 0x01, 0xFE];
let packed = pack_bits(&values, 8);
let unpacked = unpack_bits(&packed, 8, values.len());
assert_eq!(values, unpacked);
}
#[test]
fn test_pack_unpack_6bit_roundtrip() {
let values: Vec<u16> = vec![0x3F, 0x00, 0x15, 0x2A, 0x01, 0x3E];
let packed = pack_bits(&values, 6);
let unpacked = unpack_bits(&packed, 6, values.len());
assert_eq!(values, unpacked);
}
#[test]
fn test_pack_unpack_4bit_roundtrip() {
let values: Vec<u16> = vec![0x0F, 0x00, 0x05, 0x0A, 0x01, 0x0E, 0x07, 0x03];
let packed = pack_bits(&values, 4);
let unpacked = unpack_bits(&packed, 4, values.len());
assert_eq!(values, unpacked);
}
#[test]
fn test_pack_empty() {
let packed = pack_bits(&[], 8);
assert!(packed.is_empty());
let unpacked = unpack_bits(&packed, 8, 0);
assert!(unpacked.is_empty());
}
#[test]
fn test_mxfp8_roundtrip_zeros() {
let data = vec![0.0f32; 16];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 8).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), data.len());
for val in &dequantized {
assert!(val.abs() < 1e-10, "expected zero, got {}", val);
}
}
#[test]
fn test_mxfp8_roundtrip_small_block() {
let data = vec![1.0, -1.0, 0.5, -0.5];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 2).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), 4);
for (o, d) in data.iter().zip(dequantized.iter()) {
let error = (o - d).abs();
assert!(error < 0.5, "error too large: orig={}, deq={}", o, d);
}
}
#[test]
fn test_mxfp8_roundtrip_random_data() {
let data = generate_test_data(42, 128, 10.0);
let config = MxQuantConfig::new(MxFormat::Mxfp8, 16).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), data.len());
let stats = compute_mx_error(&data, &dequantized).expect("error stats");
assert!(stats.rmse < 5.0, "MXFP8 RMSE too high: {}", stats.rmse);
}
#[test]
fn test_mxfp6_roundtrip_random_data() {
let data = generate_test_data(123, 64, 5.0);
let config = MxQuantConfig::new(MxFormat::Mxfp6, 8).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), data.len());
}
#[test]
fn test_mxfp4_roundtrip_random_data() {
let data = generate_test_data(456, 32, 2.0);
let config = MxQuantConfig::new(MxFormat::Mxfp4, 4).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), data.len());
}
#[test]
fn test_mxint8_roundtrip_random_data() {
let data = generate_test_data(789, 64, 100.0);
let config = MxQuantConfig::new(MxFormat::Mxint8, 16).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), data.len());
}
#[test]
fn test_mxint8_roundtrip_zeros() {
let data = vec![0.0f32; 32];
let config = MxQuantConfig::new(MxFormat::Mxint8, 8).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
for val in &dequantized {
assert!(val.abs() < 1e-10, "expected zero, got {}", val);
}
}
#[test]
fn test_quantize_empty_data() {
let config = MxQuantConfig::new(MxFormat::Mxfp8, 8).expect("valid config");
let result = quantize_mx(&[], &config);
assert!(result.is_err());
}
#[test]
fn test_quantize_nan_data() {
let data = vec![1.0, f32::NAN, 3.0];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 4).expect("valid config");
let result = quantize_mx(&data, &config);
assert!(result.is_err());
}
#[test]
fn test_quantize_partial_block() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 4).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
assert_eq!(quantized.num_elements(), 5);
assert_eq!(quantized.shared_exponents().len(), 2);
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), 5);
}
#[test]
fn test_quantize_single_element() {
let data = vec![42.0];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 2).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
assert_eq!(quantized.num_elements(), 1);
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), 1);
}
#[test]
fn test_quantize_exactly_one_block() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 4).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
assert_eq!(quantized.shared_exponents().len(), 1);
}
#[test]
fn test_quantize_large_values() {
let data = vec![1e6, -1e6, 5e5, -5e5, 1e3, -1e3, 1.0, -1.0];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 8).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), 8);
for i in 0..data.len() {
if data[i] != 0.0 && dequantized[i] != 0.0 {
assert_eq!(
data[i].is_sign_positive(),
dequantized[i].is_sign_positive(),
"Sign mismatch at index {}: orig={}, deq={}",
i,
data[i],
dequantized[i]
);
}
}
}
#[test]
fn test_quantize_very_small_values() {
let data = vec![1e-20, -1e-20, 1e-30, -1e-30];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 4).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), 4);
}
#[test]
fn test_quantize_mixed_magnitudes() {
let data = vec![1000.0, 0.001, -500.0, 0.0005];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 4).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), 4);
assert!((dequantized[0] - 1000.0).abs() < 200.0);
}
#[test]
fn test_quantize_with_shape() {
let data = generate_test_data(100, 24, 1.0);
let config = MxQuantConfig::new(MxFormat::Mxfp8, 8).expect("valid config");
let quantized =
quantize_mx_with_shape(&data, &config, &[2, 3, 4]).expect("quantize with shape");
assert_eq!(quantized.shape(), &[2, 3, 4]);
assert_eq!(quantized.num_elements(), 24);
}
#[test]
fn test_quantize_with_shape_mismatch() {
let data = vec![1.0; 24];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 8).expect("valid config");
let result = quantize_mx_with_shape(&data, &config, &[2, 3, 5]);
assert!(result.is_err());
}
#[test]
fn test_error_stats_identical() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let stats = compute_mx_error(&data, &data).expect("error stats");
assert!(stats.mae < 1e-10);
assert!(stats.rmse < 1e-10);
assert!(stats.max_error < 1e-10);
assert!(stats.snr_db.is_infinite());
}
#[test]
fn test_error_stats_length_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0];
let result = compute_mx_error(&a, &b);
assert!(result.is_err());
}
#[test]
fn test_error_stats_empty() {
let result = compute_mx_error(&[], &[]);
assert!(result.is_err());
}
#[test]
fn test_quantized_accessors() {
let data = generate_test_data(999, 32, 5.0);
let config = MxQuantConfig::new(MxFormat::Mxfp8, 8).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
assert_eq!(quantized.num_elements(), 32);
assert_eq!(quantized.shape(), &[32]);
assert_eq!(quantized.shared_exponents().len(), 4); assert!(!quantized.mantissa_data().is_empty());
assert!(quantized.size_bytes() > 0);
}
#[test]
fn test_quantized_size_bytes() {
let data = vec![1.0f32; 32];
let config_fp8 = MxQuantConfig::new(MxFormat::Mxfp8, 8).expect("valid config");
let q_fp8 = quantize_mx(&data, &config_fp8).expect("quantize");
let config_fp4 = MxQuantConfig::new(MxFormat::Mxfp4, 8).expect("valid config");
let q_fp4 = quantize_mx(&data, &config_fp4).expect("quantize");
assert!(
q_fp4.mantissa_data().len() < q_fp8.mantissa_data().len(),
"FP4 mantissa ({}) should be smaller than FP8 ({})",
q_fp4.mantissa_data().len(),
q_fp8.mantissa_data().len()
);
}
#[test]
fn test_quality_ordering_fp4_worst() {
let mut data = Vec::with_capacity(128);
let mut rng = LcgRng::new(555);
for _ in 0..128 {
data.push(0.5 + rng.next_f32_positive(1.5));
}
let config_fp8 = MxQuantConfig::new(MxFormat::Mxfp8, 32).expect("valid config");
let config_fp4 = MxQuantConfig::new(MxFormat::Mxfp4, 32).expect("valid config");
let d_fp8 = dequantize_mx(&quantize_mx(&data, &config_fp8).expect("q"));
let d_fp4 = dequantize_mx(&quantize_mx(&data, &config_fp4).expect("q"));
let e_fp8 = compute_mx_error(&data, &d_fp8).expect("err");
let e_fp4 = compute_mx_error(&data, &d_fp4).expect("err");
assert!(
e_fp8.rmse <= e_fp4.rmse,
"FP8 RMSE ({}) should be <= FP4 ({})",
e_fp8.rmse,
e_fp4.rmse
);
}
#[test]
fn test_shared_exponent_all_zeros() {
let exp = compute_shared_exponent(&[0.0, 0.0, 0.0, 0.0]);
assert_eq!(exp, 0); }
#[test]
fn test_shared_exponent_powers_of_two() {
let exp = compute_shared_exponent(&[1.0, 2.0, 4.0]);
let unbiased = exp as i32 - 127;
assert_eq!(unbiased, 2, "shared exp should be 2 for max=4.0");
}
#[test]
fn test_shared_exponent_negative_values() {
let exp = compute_shared_exponent(&[-8.0, 1.0, -0.5]);
let unbiased = exp as i32 - 127;
assert_eq!(unbiased, 3, "shared exp should be 3 for |max|=8.0");
}
#[test]
fn test_all_format_block_size_combinations() {
let formats = [
MxFormat::Mxfp8,
MxFormat::Mxfp6,
MxFormat::Mxfp4,
MxFormat::Mxint8,
];
let block_sizes = [2, 4, 8, 16, 32];
let data = generate_test_data(777, 64, 10.0);
for fmt in &formats {
for &bs in &block_sizes {
let config = MxQuantConfig::new(*fmt, bs).expect("valid config");
let quantized = quantize_mx(&data, &config);
assert!(
quantized.is_ok(),
"Failed for format {:?} block_size {}",
fmt,
bs
);
let deq = dequantize_mx(&quantized.expect("already checked"));
assert_eq!(deq.len(), data.len());
}
}
}
#[test]
fn test_mxfp8_positive_only_data() {
let data = generate_positive_test_data(333, 64, 20.0);
let config = MxQuantConfig::new(MxFormat::Mxfp8, 8).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), data.len());
for (i, val) in dequantized.iter().enumerate() {
assert!(
*val >= 0.0,
"Expected non-negative at index {}, got {}",
i,
val
);
}
}
#[test]
fn test_quantize_inf_values() {
let data = vec![f32::INFINITY, f32::NEG_INFINITY, 1.0, -1.0];
let config = MxQuantConfig::new(MxFormat::Mxfp8, 4).expect("valid config");
let quantized = quantize_mx(&data, &config).expect("quantize");
let dequantized = dequantize_mx(&quantized);
assert_eq!(dequantized.len(), 4);
}
#[test]
fn test_quantization_deterministic() {
let data = generate_test_data(12345, 64, 8.0);
let config = MxQuantConfig::new(MxFormat::Mxfp8, 8).expect("valid config");
let q1 = quantize_mx(&data, &config).expect("q1");
let q2 = quantize_mx(&data, &config).expect("q2");
assert_eq!(q1.shared_exponents(), q2.shared_exponents());
assert_eq!(q1.mantissa_data(), q2.mantissa_data());
}
}