use crate::error::{AprenderError, Result};
use std::collections::BTreeMap;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OnnxDataType {
Float,
Uint8,
Int8,
Uint16,
Int16,
Int32,
Int64,
String,
Bool,
Float16,
Double,
Uint32,
Uint64,
BFloat16,
Unknown(i32),
}
impl OnnxDataType {
fn from_i32(v: i32) -> Self {
match v {
1 => Self::Float,
2 => Self::Uint8,
3 => Self::Int8,
4 => Self::Uint16,
5 => Self::Int16,
6 => Self::Int32,
7 => Self::Int64,
8 => Self::String,
9 => Self::Bool,
10 => Self::Float16,
11 => Self::Double,
12 => Self::Uint32,
13 => Self::Uint64,
16 => Self::BFloat16,
other => Self::Unknown(other),
}
}
pub fn element_size(&self) -> usize {
match self {
Self::Float | Self::Int32 | Self::Uint32 => 4,
Self::Double | Self::Int64 | Self::Uint64 => 8,
Self::Float16 | Self::BFloat16 | Self::Int16 | Self::Uint16 => 2,
Self::Uint8 | Self::Int8 | Self::Bool => 1,
Self::String | Self::Unknown(_) => 0,
}
}
}
#[derive(Debug, Clone)]
pub struct OnnxTensor {
pub name: String,
pub shape: Vec<usize>,
pub data_type: OnnxDataType,
pub raw_data: Vec<u8>,
}
impl OnnxTensor {
pub fn to_f32(&self) -> Vec<f32> {
match self.data_type {
OnnxDataType::Float => self
.raw_data
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
OnnxDataType::Float16 => self
.raw_data
.chunks_exact(2)
.map(|b| {
let bits = u16::from_le_bytes([b[0], b[1]]);
f16_to_f32(bits)
})
.collect(),
OnnxDataType::Double => self
.raw_data
.chunks_exact(8)
.map(|b| {
f64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect(),
OnnxDataType::Int8 => self.raw_data.iter().map(|&b| (b as i8) as f32).collect(),
OnnxDataType::Uint8 => self.raw_data.iter().map(|&b| b as f32).collect(),
OnnxDataType::Int32 => self
.raw_data
.chunks_exact(4)
.map(|b| i32::from_le_bytes([b[0], b[1], b[2], b[3]]) as f32)
.collect(),
OnnxDataType::Int64 => self
.raw_data
.chunks_exact(8)
.map(|b| {
i64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]) as f32
})
.collect(),
_ => Vec::new(),
}
}
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits >> 15) & 1) as u32;
let exponent = ((bits >> 10) & 0x1F) as u32;
let mantissa = (bits & 0x3FF) as u32;
if exponent == 0 {
if mantissa == 0 {
f32::from_bits(sign << 31)
} else {
let mut m = mantissa;
let mut e = 0u32;
while (m & 0x400) == 0 {
m <<= 1;
e += 1;
}
let f32_exp = 127 - 15 - e;
let f32_mant = (m & 0x3FF) << 13;
f32::from_bits((sign << 31) | (f32_exp << 23) | f32_mant)
}
} else if exponent == 31 {
let f32_mant = mantissa << 13;
f32::from_bits((sign << 31) | (0xFF << 23) | f32_mant)
} else {
let f32_exp = exponent + 127 - 15;
let f32_mant = mantissa << 13;
f32::from_bits((sign << 31) | (f32_exp << 23) | f32_mant)
}
}
#[derive(Debug, Clone, Default)]
pub struct OnnxMetadata {
pub ir_version: i64,
pub producer_name: String,
pub producer_version: String,
pub domain: String,
pub model_version: i64,
pub doc_string: String,
pub opset_versions: Vec<(String, i64)>,
}
#[derive(Debug)]
pub struct OnnxReader {
tensors: Vec<OnnxTensor>,
metadata: OnnxMetadata,
}
include!("reader.rs");