use crate::error::{AprenderError, Result};
use std::collections::BTreeMap;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OnnxDataType {
Float,
Uint8,
Int8,
Uint16,
Int16,
Int32,
Int64,
String,
Bool,
Float16,
Double,
Uint32,
Uint64,
BFloat16,
Unknown(i32),
}
impl OnnxDataType {
fn from_i32(v: i32) -> Self {
match v {
1 => Self::Float,
2 => Self::Uint8,
3 => Self::Int8,
4 => Self::Uint16,
5 => Self::Int16,
6 => Self::Int32,
7 => Self::Int64,
8 => Self::String,
9 => Self::Bool,
10 => Self::Float16,
11 => Self::Double,
12 => Self::Uint32,
13 => Self::Uint64,
16 => Self::BFloat16,
other => Self::Unknown(other),
}
}
pub fn element_size(&self) -> usize {
match self {
Self::Float | Self::Int32 | Self::Uint32 => 4,
Self::Double | Self::Int64 | Self::Uint64 => 8,
Self::Float16 | Self::BFloat16 | Self::Int16 | Self::Uint16 => 2,
Self::Uint8 | Self::Int8 | Self::Bool => 1,
Self::String | Self::Unknown(_) => 0,
}
}
}
#[derive(Debug, Clone)]
pub struct OnnxTensor {
pub name: String,
pub shape: Vec<usize>,
pub data_type: OnnxDataType,
pub raw_data: Vec<u8>,
}
impl OnnxTensor {
pub fn to_f32(&self) -> Vec<f32> {
match self.data_type {
OnnxDataType::Float => self
.raw_data
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
OnnxDataType::Float16 => self
.raw_data
.chunks_exact(2)
.map(|b| {
let bits = u16::from_le_bytes([b[0], b[1]]);
f16_to_f32(bits)
})
.collect(),
OnnxDataType::Double => self
.raw_data
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect(),
OnnxDataType::Int8 => self.raw_data.iter().map(|&b| (b as i8) as f32).collect(),
OnnxDataType::Uint8 => self.raw_data.iter().map(|&b| b as f32).collect(),
OnnxDataType::Int32 => self
.raw_data
.chunks_exact(4)
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]) as f32)
.collect(),
OnnxDataType::Int64 => self
.raw_data
.chunks_exact(8)
.map(|b| {
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect(),
_ => Vec::new(),
}
}
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits >> 15) & 1) as u32;
let exponent = ((bits >> 10) & 0x1F) as u32;
let mantissa = (bits & 0x3FF) as u32;
if exponent == 0 {
if mantissa == 0 {
f32::from_bits(sign << 31)
} else {
let mut m = mantissa;
let mut e = 0u32;
while (m & 0x400) == 0 {
m <<= 1;
e += 1;
}
let f32_exp = 127 - 15 - e + 1;
let f32_mant = (m & 0x3FF) << 13;
f32::from_bits((sign << 31) | (f32_exp << 23) | f32_mant)
}
} else if exponent == 31 {
let f32_mant = mantissa << 13;
f32::from_bits((sign << 31) | (0xFF << 23) | f32_mant)
} else {
let f32_exp = exponent + 127 - 15;
let f32_mant = mantissa << 13;
f32::from_bits((sign << 31) | (f32_exp << 23) | f32_mant)
}
}
#[derive(Debug, Clone, Default)]
pub struct OnnxMetadata {
pub ir_version: i64,
pub producer_name: String,
pub producer_version: String,
pub domain: String,
pub model_version: i64,
pub doc_string: String,
pub opset_versions: Vec<(String, i64)>,
}
#[derive(Debug)]
pub struct OnnxReader {
tensors: Vec<OnnxTensor>,
metadata: OnnxMetadata,
}
include!("reader.rs");
#[cfg(test)]
mod f16_tests {
use super::f16_to_f32;
fn golden_f16_to_f32(bits: u16) -> f32 {
let sign = (bits >> 15) & 1;
let exp = i32::from((bits >> 10) & 0x1F);
let man = u32::from(bits & 0x3FF);
let s = (sign as f32).mul_add(-2.0, 1.0); if exp == 0 {
s * (man as f32) * 2f32.powi(-24)
} else if exp == 0x1F {
if man == 0 {
if sign == 1 {
f32::NEG_INFINITY
} else {
f32::INFINITY
}
} else {
f32::NAN
}
} else {
s * (1.0 + (man as f32) / 1024.0) * 2f32.powi(exp - 15)
}
}
#[test]
fn smallest_subnormal_not_halved() {
let got = f16_to_f32(0x0001).to_bits();
assert_eq!(
got, 0x3380_0000,
"f16_to_f32(0x0001) = {got:#010x}, expected 0x33800000 (5.9604645e-8)"
);
}
#[test]
fn all_bit_patterns_match_golden() {
let mut mismatches = 0u32;
for bits in 0..=u16::MAX {
let exp = (bits >> 10) & 0x1F;
let man = bits & 0x3FF;
if exp == 0x1F && man != 0 {
continue; }
if f16_to_f32(bits).to_bits() != golden_f16_to_f32(bits).to_bits() {
mismatches += 1;
}
}
assert_eq!(
mismatches, 0,
"{mismatches} f16->f32 conversions disagree with golden oracle"
);
}
}