use crate::error::AnamnesisError;
use crate::parse::safetensors::Dtype;
pub(crate) fn read_u32_le(data: &[u8], byte_offset: usize) -> crate::Result<u32> {
let end = byte_offset
.checked_add(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "u32 byte offset overflow".into(),
})?;
let slice = data
.get(byte_offset..end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!(
"u32 read out of bounds: need bytes {byte_offset}..{end}, have {}",
data.len()
),
})?;
let arr: [u8; 4] = slice.try_into().map_err(|_| AnamnesisError::Parse {
reason: "u32 slice is not 4 bytes".into(),
})?;
Ok(u32::from_le_bytes(arr))
}
pub(crate) fn read_scale_f32(data: &[u8], byte_offset: usize, dtype: Dtype) -> crate::Result<f32> {
match dtype {
Dtype::F16 => {
let end = byte_offset
.checked_add(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "F16 scale byte offset overflow".into(),
})?;
let slice = data
.get(byte_offset..end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("F16 scale read out of bounds at offset {byte_offset}"),
})?;
let arr: [u8; 2] = slice.try_into().map_err(|_| AnamnesisError::Parse {
reason: "F16 scale slice is not 2 bytes".into(),
})?;
Ok(half::f16::from_le_bytes(arr).to_f32())
}
Dtype::BF16 => {
let end = byte_offset
.checked_add(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BF16 scale byte offset overflow".into(),
})?;
let slice = data
.get(byte_offset..end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BF16 scale read out of bounds at offset {byte_offset}"),
})?;
let arr: [u8; 2] = slice.try_into().map_err(|_| AnamnesisError::Parse {
reason: "BF16 scale slice is not 2 bytes".into(),
})?;
Ok(f32::from_bits(u32::from(u16::from_le_bytes(arr)) << 16))
}
Dtype::F32 => {
let end = byte_offset
.checked_add(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "F32 scale byte offset overflow".into(),
})?;
let slice = data
.get(byte_offset..end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("F32 scale read out of bounds at offset {byte_offset}"),
})?;
let arr: [u8; 4] = slice.try_into().map_err(|_| AnamnesisError::Parse {
reason: "F32 scale slice is not 4 bytes".into(),
})?;
Ok(f32::from_le_bytes(arr))
}
Dtype::F8E4M3
| Dtype::F8E5M2
| Dtype::F64
| Dtype::Bool
| Dtype::U8
| Dtype::I8
| Dtype::U16
| Dtype::I16
| Dtype::U32
| Dtype::I32
| Dtype::U64
| Dtype::I64 => Err(AnamnesisError::Unsupported {
format: dtype.to_string(),
detail: "scale dtype must be F16, BF16, or F32".into(),
}),
}
}
const TRANSPOSE_TILE: usize = 32;
pub(crate) fn transpose_bf16(data: &[u8], rows: usize, cols: usize) -> crate::Result<Vec<u8>> {
let n_elements = rows
.checked_mul(cols)
.ok_or_else(|| AnamnesisError::Parse {
reason: "transpose element count overflow".into(),
})?;
let byte_len = n_elements
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "transpose byte count overflow".into(),
})?;
if data.len() != byte_len {
return Err(AnamnesisError::Parse {
reason: format!(
"transpose input length {} != rows × cols × 2 = {byte_len} \
(rows={rows}, cols={cols})",
data.len()
),
});
}
let mut output = vec![0u8; byte_len];
#[allow(clippy::indexing_slicing)]
for row_tile in (0..rows).step_by(TRANSPOSE_TILE) {
for col_tile in (0..cols).step_by(TRANSPOSE_TILE) {
let row_end = (row_tile + TRANSPOSE_TILE).min(rows);
let col_end = (col_tile + TRANSPOSE_TILE).min(cols);
for r in row_tile..row_end {
for c in col_tile..col_end {
let src = (r * cols + c) * 2;
let dst = (c * rows + r) * 2;
output[dst] = data[src];
output[dst + 1] = data[src + 1];
}
}
}
}
Ok(output)
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::float_cmp
)]
mod tests {
use super::*;
#[test]
fn read_scale_f16() {
let data = 0x3C00u16.to_le_bytes();
let val = read_scale_f32(&data, 0, Dtype::F16).unwrap();
assert_eq!(val, 1.0);
}
#[test]
fn read_scale_bf16() {
let data = 0x3F80u16.to_le_bytes();
let val = read_scale_f32(&data, 0, Dtype::BF16).unwrap();
assert_eq!(val, 1.0);
}
#[test]
fn read_scale_f32_dtype() {
let data = 2.0_f32.to_le_bytes();
let val = read_scale_f32(&data, 0, Dtype::F32).unwrap();
assert_eq!(val, 2.0);
}
fn bf16_ramp(n: usize) -> Vec<u8> {
(0..n)
.flat_map(|k| {
let bits = ((k as f32).to_bits() >> 16) as u16;
bits.to_le_bytes()
})
.collect()
}
fn bf16_value(data: &[u8], idx: usize) -> f32 {
let bits = u16::from_le_bytes([data[idx * 2], data[idx * 2 + 1]]);
f32::from_bits(u32::from(bits) << 16)
}
#[test]
fn transpose_maps_elements_exactly() {
let data = bf16_ramp(6);
let out = transpose_bf16(&data, 2, 3).unwrap();
let expected = [0.0, 3.0, 1.0, 4.0, 2.0, 5.0];
for (idx, &exp) in expected.iter().enumerate() {
assert_eq!(bf16_value(&out, idx), exp, "element {idx}");
}
}
#[test]
fn transpose_round_trips_non_square() {
for &(rows, cols) in &[(1usize, 7usize), (5, 33), (33, 5), (64, 96), (40, 40)] {
let data = bf16_ramp(rows * cols);
let once = transpose_bf16(&data, rows, cols).unwrap();
let twice = transpose_bf16(&once, cols, rows).unwrap();
assert_eq!(twice, data, "transpose² != identity for {rows}×{cols}");
}
}
#[test]
fn transpose_rejects_length_mismatch() {
let data = bf16_ramp(6);
assert!(transpose_bf16(&data, 2, 4).is_err());
assert!(transpose_bf16(&data, 7, 1).is_err());
}
}