use super::Precision;
pub fn f32_to_bf16(value: f32) -> u16 {
let bits = value.to_bits();
(bits >> 16) as u16
}
pub fn bf16_to_f32(value: u16) -> f32 {
let bits = u32::from(value) << 16;
f32::from_bits(bits)
}
pub fn f32_to_fp16(value: f32) -> u16 {
trueno::f32_to_f16(value)
}
pub fn fp16_to_f32(value: u16) -> f32 {
trueno::f16_to_f32(value)
}
#[inline]
pub fn bf16_truncate(val: f32) -> f32 {
f32::from_bits(val.to_bits() & 0xFFFF_0000)
}
pub fn gemm_bf16_reference(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
let mut c = vec![0.0f32; m * n];
for row in 0..m {
for col in 0..n {
let mut acc = 0.0f32;
for i in 0..k {
let a_val = bf16_truncate(a[row * k + i]);
let b_val = bf16_truncate(b[i * n + col]);
acc = a_val.mul_add(b_val, acc);
}
c[row * n + col] = acc;
}
}
c
}
pub fn estimate_memory_savings(
num_params: usize,
batch_size: usize,
seq_len: usize,
hidden_size: usize,
precision: Precision,
) -> (usize, usize, f32) {
let param_bytes_fp32 = num_params * 4;
let activation_bytes_fp32 = batch_size * seq_len * hidden_size * 4;
let grad_bytes_fp32 = num_params * 4;
let total_fp32 = param_bytes_fp32 + activation_bytes_fp32 + grad_bytes_fp32;
let param_bytes_mixed = num_params * 4; let activation_bytes_mixed = batch_size * seq_len * hidden_size * precision.size_bytes();
let grad_bytes_mixed = num_params * precision.size_bytes();
let total_mixed = param_bytes_mixed + activation_bytes_mixed + grad_bytes_mixed;
let savings = 1.0 - (total_mixed as f32 / total_fp32 as f32);
(total_fp32, total_mixed, savings)
}