#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OnnxDataType {
Float = 1,
Uint8 = 2,
Int8 = 3,
Uint16 = 4,
Int16 = 5,
Int32 = 6,
Int64 = 7,
String = 8,
Bool = 9,
Double = 11,
Uint32 = 12,
Uint64 = 13,
}
impl OnnxDataType {
pub(crate) fn from_i32(v: i32) -> Option<Self> {
match v {
1 => Some(Self::Float),
2 => Some(Self::Uint8),
3 => Some(Self::Int8),
4 => Some(Self::Uint16),
5 => Some(Self::Int16),
6 => Some(Self::Int32),
7 => Some(Self::Int64),
8 => Some(Self::String),
9 => Some(Self::Bool),
11 => Some(Self::Double),
12 => Some(Self::Uint32),
13 => Some(Self::Uint64),
_ => None,
}
}
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct OnnxTensor {
pub name: String,
pub data_type: OnnxDataType,
pub dims: Vec<i64>,
pub float_data: Vec<f32>,
pub double_data: Vec<f64>,
pub int32_data: Vec<i32>,
pub int64_data: Vec<i64>,
pub raw_data: Vec<u8>,
}
impl OnnxTensor {
pub fn new() -> Self {
Self {
name: String::new(),
data_type: OnnxDataType::Float,
dims: Vec::new(),
float_data: Vec::new(),
double_data: Vec::new(),
int32_data: Vec::new(),
int64_data: Vec::new(),
raw_data: Vec::new(),
}
}
pub fn to_f32_vec(&self) -> Vec<f32> {
if !self.float_data.is_empty() {
return self.float_data.clone();
}
if !self.raw_data.is_empty() {
match self.data_type {
OnnxDataType::Float => {
return self
.raw_data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
}
OnnxDataType::Double => {
return self
.raw_data
.chunks_exact(8)
.map(|c| {
#[allow(clippy::cast_possible_truncation)]
let v = f64::from_le_bytes([
c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7],
]) as f32;
v
})
.collect();
}
_ => {}
}
}
if !self.double_data.is_empty() {
#[allow(clippy::cast_possible_truncation)]
return self.double_data.iter().map(|&v| v as f32).collect();
}
if !self.int64_data.is_empty() {
#[allow(clippy::cast_possible_truncation)]
return self.int64_data.iter().map(|&v| v as f32).collect();
}
if !self.int32_data.is_empty() {
#[allow(clippy::cast_precision_loss)]
return self.int32_data.iter().map(|&v| v as f32).collect();
}
Vec::new()
}
pub fn dims_usize(&self) -> Vec<usize> {
#[allow(clippy::cast_sign_loss)]
self.dims.iter().map(|&d| d as usize).collect()
}
}
impl Default for OnnxTensor {
fn default() -> Self {
Self::new()
}
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub enum OnnxAttributeValue {
Float(f32),
Int(i64),
String(String),
Tensor(OnnxTensor),
Graph(OnnxGraph),
Floats(Vec<f32>),
Ints(Vec<i64>),
Strings(Vec<String>),
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct OnnxAttribute {
pub name: String,
pub value: OnnxAttributeValue,
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct OnnxNode {
pub op_type: String,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub name: String,
pub attributes: Vec<OnnxAttribute>,
}
impl OnnxNode {
pub fn new(op_type: &str) -> Self {
Self {
op_type: op_type.to_owned(),
inputs: Vec::new(),
outputs: Vec::new(),
name: String::new(),
attributes: Vec::new(),
}
}
pub fn get_attr(&self, name: &str) -> Option<&OnnxAttributeValue> {
self.attributes
.iter()
.find(|a| a.name == name)
.map(|a| &a.value)
}
pub fn get_int_attr(&self, name: &str, default: i64) -> i64 {
match self.get_attr(name) {
Some(OnnxAttributeValue::Int(v)) => *v,
_ => default,
}
}
pub fn get_float_attr(&self, name: &str, default: f32) -> f32 {
match self.get_attr(name) {
Some(OnnxAttributeValue::Float(v)) => *v,
_ => default,
}
}
pub fn get_ints_attr(&self, name: &str) -> Vec<i64> {
match self.get_attr(name) {
Some(OnnxAttributeValue::Ints(v)) => v.clone(),
_ => Vec::new(),
}
}
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct OnnxValueInfo {
pub name: String,
pub data_type: OnnxDataType,
pub shape: Vec<i64>,
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct OnnxGraph {
pub name: String,
pub nodes: Vec<OnnxNode>,
pub initializers: Vec<OnnxTensor>,
pub inputs: Vec<OnnxValueInfo>,
pub outputs: Vec<OnnxValueInfo>,
}
impl OnnxGraph {
pub fn new() -> Self {
Self {
name: String::new(),
nodes: Vec::new(),
initializers: Vec::new(),
inputs: Vec::new(),
outputs: Vec::new(),
}
}
}
impl Default for OnnxGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct OnnxOpsetImport {
pub domain: String,
pub version: i64,
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct OnnxModel {
pub ir_version: i64,
pub opset_imports: Vec<OnnxOpsetImport>,
pub graph: OnnxGraph,
pub producer_name: String,
pub model_version: i64,
}
impl OnnxModel {
pub fn new() -> Self {
Self {
ir_version: 0,
opset_imports: Vec::new(),
graph: OnnxGraph::new(),
producer_name: String::new(),
model_version: 0,
}
}
}
impl Default for OnnxModel {
fn default() -> Self {
Self::new()
}
}