use crate::error::AnamnesisError;
use crate::parse::safetensors::Dtype;
const BLOCK_SIZE: usize = 128;
#[allow(clippy::indexing_slicing)]
const SUBNORMAL_TABLE: [u32; 8] = [
0x0000_0000, 0x3B00_0000, 0x3B80_0000, 0x3BC0_0000, 0x3C00_0000, 0x3C20_0000, 0x3C40_0000, 0x3C60_0000, ];
#[must_use]
pub(crate) fn e4m3_to_f32_bits(byte: u8) -> u32 {
let b = u32::from(byte);
let sign = b >> 7;
let exp = (b >> 3) & 0xF;
let mant = b & 0x7;
let normal_bits = (sign << 31) | ((exp + 120) << 23) | (mant << 20);
#[allow(clippy::indexing_slicing, clippy::as_conversions)]
let sub_bits = SUBNORMAL_TABLE[mant as usize] | (sign << 31);
let sub_flag = exp.wrapping_sub(1) >> 31;
let sub_mask = 0u32.wrapping_sub(sub_flag);
let selected = (sub_bits & sub_mask) | (normal_bits & !sub_mask);
let nan_check = (b & 0x7F) ^ 0x7F; let nan_flag = nan_check.wrapping_sub(1) >> 31; let nan_mask = 0u32.wrapping_sub(nan_flag);
let nan_bits = (sign << 31) | 0x7FC0_0000;
(nan_bits & nan_mask) | (selected & !nan_mask)
}
#[must_use]
pub(crate) fn f32_bits_to_bf16_bits(bits: u32) -> u16 {
let lsb = (bits >> 16) & 1;
let rounding_bias = 0x7FFF_u32 + lsb;
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let bf16 = (bits.wrapping_add(rounding_bias) >> 16) as u16;
bf16
}
#[must_use]
fn e4m3_to_scaled_bf16(byte: u8, scale: f32) -> u16 {
let value_bits = e4m3_to_f32_bits(byte);
let scaled = f32::from_bits(value_bits) * scale;
f32_bits_to_bf16_bits(scaled.to_bits())
}
fn load_scale(
scale_data: &[u8],
block_row: usize,
block_col: usize,
scale_cols: usize,
scale_dtype: Dtype,
) -> crate::Result<f32> {
let bps = scale_dtype.byte_size();
let scale_idx = block_row
.checked_mul(scale_cols)
.and_then(|v| v.checked_add(block_col))
.ok_or_else(|| AnamnesisError::Parse {
reason: "scale index overflow".into(),
})?;
let byte_offset = scale_idx
.checked_mul(bps)
.ok_or_else(|| AnamnesisError::Parse {
reason: "scale byte offset overflow".into(),
})?;
let end = byte_offset
.checked_add(bps)
.ok_or_else(|| AnamnesisError::Parse {
reason: "scale byte range overflow".into(),
})?;
let slice = scale_data
.get(byte_offset..end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!(
"scale data too short: need bytes {byte_offset}..{end}, have {}",
scale_data.len()
),
})?;
read_scale_bytes(slice, scale_dtype)
}
fn read_scale_bytes(slice: &[u8], dtype: Dtype) -> crate::Result<f32> {
match dtype {
Dtype::F32 => {
let arr: [u8; 4] = slice.try_into().map_err(|_| AnamnesisError::Parse {
reason: "scale slice is not 4 bytes".into(),
})?;
Ok(f32::from_le_bytes(arr))
}
Dtype::BF16 => {
let arr: [u8; 2] = slice.try_into().map_err(|_| AnamnesisError::Parse {
reason: "scale slice is not 2 bytes".into(),
})?;
let f32_bits = u32::from(u16::from_le_bytes(arr)) << 16;
Ok(f32::from_bits(f32_bits))
}
Dtype::F16 => {
let arr: [u8; 2] = slice.try_into().map_err(|_| AnamnesisError::Parse {
reason: "scale slice is not 2 bytes".into(),
})?;
Ok(half::f16::from_le_bytes(arr).to_f32())
}
Dtype::F8E4M3
| Dtype::F8E5M2
| Dtype::F64
| Dtype::Bool
| Dtype::U8
| Dtype::I8
| Dtype::U16
| Dtype::I16
| Dtype::U32
| Dtype::I32
| Dtype::U64
| Dtype::I64 => Err(AnamnesisError::Parse {
reason: format!("unsupported scale dtype: {dtype}"),
}),
}
}
pub fn dequantize_fp8_to_bf16(
weight_data: &[u8],
scale_data: &[u8],
rows: usize,
cols: usize,
scale_dtype: Dtype,
) -> crate::Result<Vec<u8>> {
let bytes_per_scale = scale_dtype.byte_size();
if bytes_per_scale == 0 {
return Err(AnamnesisError::Parse {
reason: format!("unsupported scale dtype: {scale_dtype}"),
});
}
let expected_weight_len = rows
.checked_mul(cols)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("rows × cols overflow: {rows} × {cols}"),
})?;
if weight_data.len() != expected_weight_len {
return Err(AnamnesisError::Parse {
reason: format!(
"weight data length {} != rows × cols {expected_weight_len}",
weight_data.len()
),
});
}
if !scale_data.len().is_multiple_of(bytes_per_scale) {
return Err(AnamnesisError::Parse {
reason: format!(
"scale data length {} is not a multiple of {bytes_per_scale} ({scale_dtype})",
scale_data.len()
),
});
}
let scale_elements = scale_data.len() / bytes_per_scale;
let scale_rows = rows.div_ceil(BLOCK_SIZE);
if scale_rows == 0 {
return Err(AnamnesisError::Parse {
reason: "zero rows".into(),
});
}
if !scale_elements.is_multiple_of(scale_rows) {
return Err(AnamnesisError::Parse {
reason: format!(
"scale grid is not rectangular: {scale_elements} elements / {scale_rows} rows \
has remainder {}",
scale_elements % scale_rows
),
});
}
let scale_cols = scale_elements / scale_rows;
let col_blocks_needed = cols.div_ceil(BLOCK_SIZE);
if scale_cols < col_blocks_needed {
return Err(AnamnesisError::Parse {
reason: format!(
"scale has {scale_cols} column blocks but weight needs {col_blocks_needed} \
(cols={cols}, block_size={BLOCK_SIZE})"
),
});
}
let out_byte_len = expected_weight_len
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "output size overflow".into(),
})?;
let mut output = vec![0u8; out_byte_len];
for r in 0..rows {
let block_row = r / BLOCK_SIZE;
let row_offset = r.checked_mul(cols).ok_or_else(|| AnamnesisError::Parse {
reason: "row offset overflow".into(),
})?;
let row_w = weight_data
.get(row_offset..row_offset + cols)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("weight row {r} out of bounds"),
})?;
let out_row_offset = row_offset
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "output row offset overflow".into(),
})?;
let row_o = output
.get_mut(out_row_offset..out_row_offset + cols * 2)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output row {r} out of bounds"),
})?;
let full_blocks = row_w.chunks_exact(BLOCK_SIZE);
let remainder_w = full_blocks.remainder();
for (block_col, w_chunk) in full_blocks.enumerate() {
let scale = load_scale(scale_data, block_row, block_col, scale_cols, scale_dtype)?;
let o_start = block_col * BLOCK_SIZE * 2;
let o_chunk = row_o
.get_mut(o_start..o_start + BLOCK_SIZE * 2)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output chunk at row {r}, block_col {block_col} out of bounds"),
})?;
for (&byte, out_pair) in w_chunk.iter().zip(o_chunk.chunks_exact_mut(2)) {
let bf16 = e4m3_to_scaled_bf16(byte, scale);
out_pair.copy_from_slice(&bf16.to_le_bytes());
}
}
if !remainder_w.is_empty() {
let last_block_col = cols / BLOCK_SIZE;
let scale = load_scale(
scale_data,
block_row,
last_block_col,
scale_cols,
scale_dtype,
)?;
let o_start = last_block_col * BLOCK_SIZE * 2;
let o_chunk = row_o
.get_mut(o_start..o_start + remainder_w.len() * 2)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output remainder at row {r} out of bounds"),
})?;
for (&byte, out_pair) in remainder_w.iter().zip(o_chunk.chunks_exact_mut(2)) {
let bf16 = e4m3_to_scaled_bf16(byte, scale);
out_pair.copy_from_slice(&bf16.to_le_bytes());
}
}
}
Ok(output)
}
pub fn dequantize_per_tensor_fp8_to_bf16(weight_data: &[u8], scale: f32) -> crate::Result<Vec<u8>> {
let out_byte_len = weight_data
.len()
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "output size overflow".into(),
})?;
let mut output = vec![0u8; out_byte_len];
for (&byte, out_pair) in weight_data.iter().zip(output.chunks_exact_mut(2)) {
let bf16 = e4m3_to_scaled_bf16(byte, scale);
out_pair.copy_from_slice(&bf16.to_le_bytes());
}
Ok(output)
}
pub fn dequantize_per_channel_fp8_to_bf16(
weight_data: &[u8],
scale_data: &[u8],
rows: usize,
cols: usize,
scale_dtype: Dtype,
) -> crate::Result<Vec<u8>> {
let bps = scale_dtype.byte_size();
let expected_weight_len = rows
.checked_mul(cols)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("rows × cols overflow: {rows} × {cols}"),
})?;
if weight_data.len() != expected_weight_len {
return Err(AnamnesisError::Parse {
reason: format!(
"weight data length {} != rows × cols {expected_weight_len}",
weight_data.len()
),
});
}
let expected_scale_len = rows.checked_mul(bps).ok_or_else(|| AnamnesisError::Parse {
reason: "scale byte count overflow".into(),
})?;
if scale_data.len() != expected_scale_len {
return Err(AnamnesisError::Parse {
reason: format!(
"per-channel scale data length {} != expected {expected_scale_len} \
(rows={rows}, {bps} bytes per scale)",
scale_data.len()
),
});
}
let out_byte_len = expected_weight_len
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "output size overflow".into(),
})?;
let mut output = vec![0u8; out_byte_len];
for r in 0..rows {
let scale_offset = r * bps;
let scale_slice = scale_data
.get(scale_offset..scale_offset + bps)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("per-channel scale for row {r} out of bounds"),
})?;
let scale = read_scale_bytes(scale_slice, scale_dtype)?;
let row_start = r * cols;
let row_w = weight_data
.get(row_start..row_start + cols)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("weight row {r} out of bounds"),
})?;
let out_row_start = row_start * 2;
let row_o = output
.get_mut(out_row_start..out_row_start + cols * 2)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("output row {r} out of bounds"),
})?;
for (&byte, out_pair) in row_w.iter().zip(row_o.chunks_exact_mut(2)) {
let bf16 = e4m3_to_scaled_bf16(byte, scale);
out_pair.copy_from_slice(&bf16.to_le_bytes());
}
}
Ok(output)
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::float_cmp,
clippy::as_conversions,
clippy::cast_possible_truncation
)]
mod tests {
use super::*;
fn bits_to_f32(bits: u32) -> f32 {
f32::from_bits(bits)
}
#[test]
fn e4m3_zero() {
assert_eq!(bits_to_f32(e4m3_to_f32_bits(0x00)), 0.0);
}
#[test]
fn e4m3_negative_zero() {
let val = bits_to_f32(e4m3_to_f32_bits(0x80));
assert!(val.is_sign_negative());
assert_eq!(val, -0.0);
}
#[test]
fn e4m3_one() {
assert_eq!(bits_to_f32(e4m3_to_f32_bits(0x38)), 1.0);
}
#[test]
fn e4m3_negative_one() {
assert_eq!(bits_to_f32(e4m3_to_f32_bits(0xB8)), -1.0);
}
#[test]
fn e4m3_two() {
assert_eq!(bits_to_f32(e4m3_to_f32_bits(0x40)), 2.0);
}
#[test]
fn e4m3_half() {
assert_eq!(bits_to_f32(e4m3_to_f32_bits(0x30)), 0.5);
}
#[test]
fn e4m3_max_normal() {
assert_eq!(bits_to_f32(e4m3_to_f32_bits(0x7E)), 448.0);
}
#[test]
fn e4m3_min_positive_normal() {
assert_eq!(bits_to_f32(e4m3_to_f32_bits(0x08)), 0.015_625);
}
#[test]
fn e4m3_min_positive_subnormal() {
assert_eq!(bits_to_f32(e4m3_to_f32_bits(0x01)), 0.001_953_125);
}
#[test]
fn e4m3_max_subnormal() {
assert_eq!(bits_to_f32(e4m3_to_f32_bits(0x07)), 0.013_671_875);
}
#[test]
fn e4m3_nan_positive() {
assert!(bits_to_f32(e4m3_to_f32_bits(0x7F)).is_nan());
}
#[test]
fn e4m3_nan_negative() {
let val = bits_to_f32(e4m3_to_f32_bits(0xFF));
assert!(val.is_nan());
}
#[test]
fn exhaustive_cross_validation_with_float8() {
for byte_val in 0u16..=255 {
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let byte = byte_val as u8;
let our_f32 = bits_to_f32(e4m3_to_f32_bits(byte));
let ref_f32 = float8::F8E4M3::from_bits(byte).to_f32();
if ref_f32.is_nan() {
assert!(
our_f32.is_nan(),
"byte {byte:#04X}: expected NaN, got {our_f32}"
);
} else {
assert_eq!(
our_f32, ref_f32,
"byte {byte:#04X}: our={our_f32}, ref={ref_f32}"
);
}
}
}
#[test]
fn bf16_one() {
assert_eq!(f32_bits_to_bf16_bits(1.0_f32.to_bits()), 0x3F80);
}
#[test]
fn bf16_zero() {
assert_eq!(f32_bits_to_bf16_bits(0.0_f32.to_bits()), 0x0000);
}
#[test]
fn bf16_negative_one() {
assert_eq!(f32_bits_to_bf16_bits((-1.0_f32).to_bits()), 0xBF80);
}
#[test]
fn bf16_nan() {
let nan_bits = f32::NAN.to_bits();
let bf16 = f32_bits_to_bf16_bits(nan_bits);
let reconstructed = f32::from_bits(u32::from(bf16) << 16);
assert!(reconstructed.is_nan());
}
#[test]
fn bf16_round_to_nearest_even() {
assert_eq!(f32_bits_to_bf16_bits(0x3F80_8000), 0x3F80);
assert_eq!(f32_bits_to_bf16_bits(0x3F81_8000), 0x3F82);
}
#[test]
fn scaled_bf16_identity() {
let byte = 0x38; let bf16 = e4m3_to_scaled_bf16(byte, 1.0);
assert_eq!(bf16, 0x3F80); }
#[test]
fn scaled_bf16_by_two() {
let bf16 = e4m3_to_scaled_bf16(0x38, 2.0);
assert_eq!(bf16, 0x4000); }
#[test]
fn scaled_bf16_nan_times_scale() {
let bf16 = e4m3_to_scaled_bf16(0x7F, 42.0);
let f = f32::from_bits(u32::from(bf16) << 16);
assert!(f.is_nan());
}
#[test]
fn scaled_bf16_zero_times_scale() {
let bf16 = e4m3_to_scaled_bf16(0x00, 100.0);
assert_eq!(bf16, 0x0000);
}
fn make_scale_bytes(scales: &[f32]) -> Vec<u8> {
scales.iter().flat_map(|s| s.to_le_bytes()).collect()
}
#[test]
fn single_block_128x128() {
let rows = 128;
let cols = 128;
let weight_data = vec![0x38u8; rows * cols];
let scale_data = make_scale_bytes(&[2.0]);
let output =
dequantize_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::F32).unwrap();
assert_eq!(output.len(), rows * cols * 2);
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x00, 0x40], "expected BF16 2.0");
}
}
#[test]
fn multi_block_256x256() {
let rows = 256;
let cols = 256;
let weight_data = vec![0x38u8; rows * cols]; let scales = [1.0_f32, 2.0, 3.0, 4.0]; let scale_data = make_scale_bytes(&scales);
let output =
dequantize_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::F32).unwrap();
assert_eq!(&output[0..2], &[0x80, 0x3F]);
assert_eq!(&output[256..258], &[0x00, 0x40]);
let offset_10 = 128 * 256 * 2;
assert_eq!(&output[offset_10..offset_10 + 2], &[0x40, 0x40]);
let offset_11 = offset_10 + 128 * 2;
assert_eq!(&output[offset_11..offset_11 + 2], &[0x80, 0x40]);
}
#[test]
fn edge_block_130x130() {
let rows = 130;
let cols = 130;
let weight_data = vec![0x38u8; rows * cols]; let scales = [1.0_f32, 2.0, 3.0, 4.0];
let scale_data = make_scale_bytes(&scales);
let output =
dequantize_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::F32).unwrap();
assert_eq!(output.len(), rows * cols * 2);
assert_eq!(&output[0..2], &[0x80, 0x3F]);
assert_eq!(&output[256..258], &[0x00, 0x40]);
assert_eq!(&output[258..260], &[0x00, 0x40]);
}
#[test]
fn single_element_1x1() {
let weight_data = vec![0x38u8]; let scale_data = make_scale_bytes(&[3.0]);
let output = dequantize_fp8_to_bf16(&weight_data, &scale_data, 1, 1, Dtype::F32).unwrap();
assert_eq!(output.len(), 2);
assert_eq!(&output[..], &[0x40, 0x40]);
}
#[test]
fn single_row_1x128() {
let weight_data = vec![0x40u8; 128]; let scale_data = make_scale_bytes(&[0.5]);
let output = dequantize_fp8_to_bf16(&weight_data, &scale_data, 1, 128, Dtype::F32).unwrap();
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x80, 0x3F]);
}
}
#[test]
fn validation_weight_length_mismatch() {
let result = dequantize_fp8_to_bf16(&[0u8; 10], &[0u8; 4], 2, 6, Dtype::F32);
assert!(result.is_err());
}
#[test]
fn validation_scale_not_multiple_of_4() {
let result = dequantize_fp8_to_bf16(&[0u8; 4], &[0u8; 5], 2, 2, Dtype::F32);
assert!(result.is_err());
}
#[test]
fn validation_scale_too_small() {
let weight = vec![0u8; 256 * 256];
let scale = vec![0u8; 4]; let result = dequantize_fp8_to_bf16(&weight, &scale, 256, 256, Dtype::F32);
assert!(result.is_err());
}
#[test]
fn validation_zero_dimensions() {
let result = dequantize_fp8_to_bf16(&[], &[], 0, 0, Dtype::F32);
assert!(result.is_err());
}
#[test]
fn per_tensor_all_ones_scale_one() {
let weight = vec![0x38u8; 128];
let output = dequantize_per_tensor_fp8_to_bf16(&weight, 1.0).unwrap();
assert_eq!(output.len(), 256);
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x80, 0x3F]); }
}
#[test]
fn per_tensor_scale_two() {
let weight = vec![0x38u8; 64];
let output = dequantize_per_tensor_fp8_to_bf16(&weight, 2.0).unwrap();
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x00, 0x40]); }
}
#[test]
fn per_tensor_non_aligned_length() {
let weight = vec![0x40u8; 130]; let output = dequantize_per_tensor_fp8_to_bf16(&weight, 0.5).unwrap();
assert_eq!(output.len(), 260);
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x80, 0x3F]); }
}
#[test]
fn per_tensor_single_element() {
let output = dequantize_per_tensor_fp8_to_bf16(&[0x38], 3.0).unwrap();
assert_eq!(output.len(), 2);
assert_eq!(&output[..], &[0x40, 0x40]); }
#[test]
fn per_tensor_empty() {
let output = dequantize_per_tensor_fp8_to_bf16(&[], 1.0).unwrap();
assert!(output.is_empty());
}
#[test]
fn per_tensor_nan_preserved() {
let output = dequantize_per_tensor_fp8_to_bf16(&[0x7F], 42.0).unwrap();
let bf16_bits = u16::from_le_bytes([output[0], output[1]]);
let f = f32::from_bits(u32::from(bf16_bits) << 16);
assert!(f.is_nan());
}
fn make_bf16_scale_bytes(scales: &[f32]) -> Vec<u8> {
scales
.iter()
.flat_map(|s| ((s.to_bits() >> 16) as u16).to_le_bytes())
.collect()
}
fn make_f16_scale_bytes(scales: &[f32]) -> Vec<u8> {
scales
.iter()
.flat_map(|s| half::f16::from_f32(*s).to_le_bytes())
.collect()
}
#[test]
fn per_channel_basic_f32_scale() {
let rows = 2;
let cols = 4;
let weight_data = vec![0x38u8; rows * cols]; let scale_data = make_scale_bytes(&[2.0, 3.0]);
let output =
dequantize_per_channel_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::F32)
.unwrap();
assert_eq!(output.len(), rows * cols * 2);
for chunk in output[..cols * 2].chunks_exact(2) {
assert_eq!(chunk, &[0x00, 0x40], "row 0: expected BF16 2.0");
}
for chunk in output[cols * 2..].chunks_exact(2) {
assert_eq!(chunk, &[0x40, 0x40], "row 1: expected BF16 3.0");
}
}
#[test]
fn per_channel_bf16_scale() {
let rows = 2;
let cols = 4;
let weight_data = vec![0x38u8; rows * cols]; let scale_data = make_bf16_scale_bytes(&[2.0, 3.0]);
let output =
dequantize_per_channel_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::BF16)
.unwrap();
assert_eq!(output.len(), rows * cols * 2);
for chunk in output[..cols * 2].chunks_exact(2) {
assert_eq!(chunk, &[0x00, 0x40], "row 0: expected BF16 2.0");
}
for chunk in output[cols * 2..].chunks_exact(2) {
assert_eq!(chunk, &[0x40, 0x40], "row 1: expected BF16 3.0");
}
}
#[test]
fn per_channel_f16_scale() {
let rows = 2;
let cols = 4;
let weight_data = vec![0x38u8; rows * cols]; let scale_data = make_f16_scale_bytes(&[2.0, 3.0]);
let output =
dequantize_per_channel_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::F16)
.unwrap();
assert_eq!(output.len(), rows * cols * 2);
for chunk in output[..cols * 2].chunks_exact(2) {
assert_eq!(chunk, &[0x00, 0x40], "row 0: expected BF16 2.0");
}
for chunk in output[cols * 2..].chunks_exact(2) {
assert_eq!(chunk, &[0x40, 0x40], "row 1: expected BF16 3.0");
}
}
#[test]
fn per_channel_single_row() {
let weight_data = vec![0x40u8; 128]; let scale_data = make_scale_bytes(&[0.5]);
let output =
dequantize_per_channel_fp8_to_bf16(&weight_data, &scale_data, 1, 128, Dtype::F32)
.unwrap();
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x80, 0x3F]);
}
}
#[test]
fn per_channel_nan_preserved() {
let weight_data = vec![0x7F]; let scale_data = make_scale_bytes(&[42.0]);
let output =
dequantize_per_channel_fp8_to_bf16(&weight_data, &scale_data, 1, 1, Dtype::F32)
.unwrap();
let bf16_bits = u16::from_le_bytes([output[0], output[1]]);
let f = f32::from_bits(u32::from(bf16_bits) << 16);
assert!(f.is_nan());
}
#[test]
fn per_channel_validation_weight_mismatch() {
let result = dequantize_per_channel_fp8_to_bf16(&[0u8; 10], &[0u8; 8], 2, 6, Dtype::F32);
assert!(result.is_err());
}
#[test]
fn per_channel_validation_scale_mismatch() {
let result = dequantize_per_channel_fp8_to_bf16(&[0u8; 8], &[0u8; 2], 2, 4, Dtype::F32);
assert!(result.is_err());
}
#[test]
fn fine_grained_bf16_scale() {
let rows = 128;
let cols = 128;
let weight_data = vec![0x38u8; rows * cols]; let scale_data = make_bf16_scale_bytes(&[2.0]);
let output =
dequantize_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::BF16).unwrap();
assert_eq!(output.len(), rows * cols * 2);
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x00, 0x40], "expected BF16 2.0");
}
}
#[test]
fn fine_grained_f32_scale() {
let rows = 128;
let cols = 128;
let weight_data = vec![0x38u8; rows * cols]; let scale_data = make_scale_bytes(&[3.0]);
let output =
dequantize_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::F32).unwrap();
assert_eq!(output.len(), rows * cols * 2);
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x40, 0x40], "expected BF16 3.0");
}
}
#[test]
fn fine_grained_f16_scale() {
let rows = 128;
let cols = 128;
let weight_data = vec![0x38u8; rows * cols]; let scale_data = make_f16_scale_bytes(&[2.0]);
let output =
dequantize_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::F16).unwrap();
assert_eq!(output.len(), rows * cols * 2);
for chunk in output.chunks_exact(2) {
assert_eq!(chunk, &[0x00, 0x40], "expected BF16 2.0");
}
}
#[test]
fn fine_grained_f32_multi_block() {
let rows = 256;
let cols = 256;
let weight_data = vec![0x38u8; rows * cols]; let scales = [1.0_f32, 4.0, 2.0, 8.0];
let scale_data = make_scale_bytes(&scales);
let output =
dequantize_fp8_to_bf16(&weight_data, &scale_data, rows, cols, Dtype::F32).unwrap();
assert_eq!(&output[0..2], &[0x80, 0x3F]);
assert_eq!(&output[256..258], &[0x80, 0x40]);
let offset_10 = 128 * 256 * 2;
assert_eq!(&output[offset_10..offset_10 + 2], &[0x00, 0x40]);
let offset_11 = offset_10 + 128 * 2;
assert_eq!(&output[offset_11..offset_11 + 2], &[0x00, 0x41]);
}
}