use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const ONNX_IR_VERSION: i64 = 9;
pub const ONNX_OPSET_VERSION: i64 = 20;
pub const ONNX_PRODUCER_NAME: &str = "scirs2-autograd";
pub const ONNX_PRODUCER_VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(u32)]
pub enum OnnxDataType {
Float32 = 1,
Uint8 = 2,
Int8 = 3,
Int16 = 5,
Int32 = 6,
Int64 = 7,
Bool = 9,
Float16 = 10,
Float64 = 11,
Uint32 = 12,
Uint64 = 13,
}
impl OnnxDataType {
pub fn from_code(code: u32) -> Option<Self> {
match code {
1 => Some(OnnxDataType::Float32),
2 => Some(OnnxDataType::Uint8),
3 => Some(OnnxDataType::Int8),
5 => Some(OnnxDataType::Int16),
6 => Some(OnnxDataType::Int32),
7 => Some(OnnxDataType::Int64),
9 => Some(OnnxDataType::Bool),
10 => Some(OnnxDataType::Float16),
11 => Some(OnnxDataType::Float64),
12 => Some(OnnxDataType::Uint32),
13 => Some(OnnxDataType::Uint64),
_ => None,
}
}
pub fn code(&self) -> u32 {
*self as u32
}
pub fn element_size(&self) -> usize {
match self {
OnnxDataType::Bool | OnnxDataType::Uint8 | OnnxDataType::Int8 => 1,
OnnxDataType::Float16 | OnnxDataType::Int16 => 2,
OnnxDataType::Float32 | OnnxDataType::Int32 | OnnxDataType::Uint32 => 4,
OnnxDataType::Float64 | OnnxDataType::Int64 | OnnxDataType::Uint64 => 8,
}
}
pub fn name(&self) -> &'static str {
match self {
OnnxDataType::Float32 => "float32",
OnnxDataType::Uint8 => "uint8",
OnnxDataType::Int8 => "int8",
OnnxDataType::Int16 => "int16",
OnnxDataType::Int32 => "int32",
OnnxDataType::Int64 => "int64",
OnnxDataType::Bool => "bool",
OnnxDataType::Float16 => "float16",
OnnxDataType::Float64 => "float64",
OnnxDataType::Uint32 => "uint32",
OnnxDataType::Uint64 => "uint64",
}
}
}
impl std::fmt::Display for OnnxDataType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", content = "value")]
pub enum OnnxAttribute {
Int(i64),
Float(f64),
String(String),
Tensor(OnnxTensor),
Ints(Vec<i64>),
Floats(Vec<f64>),
Strings(Vec<String>),
}
impl OnnxAttribute {
pub fn as_int(&self) -> Option<i64> {
match self {
OnnxAttribute::Int(v) => Some(*v),
_ => None,
}
}
pub fn as_float(&self) -> Option<f64> {
match self {
OnnxAttribute::Float(v) => Some(*v),
_ => None,
}
}
pub fn as_string(&self) -> Option<&str> {
match self {
OnnxAttribute::String(v) => Some(v),
_ => None,
}
}
pub fn as_ints(&self) -> Option<&[i64]> {
match self {
OnnxAttribute::Ints(v) => Some(v),
_ => None,
}
}
pub fn as_floats(&self) -> Option<&[f64]> {
match self {
OnnxAttribute::Floats(v) => Some(v),
_ => None,
}
}
pub fn as_strings(&self) -> Option<&[String]> {
match self {
OnnxAttribute::Strings(v) => Some(v),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OnnxTensor {
pub name: String,
pub data_type: OnnxDataType,
pub dims: Vec<i64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub float_data: Vec<f32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub double_data: Vec<f64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub int32_data: Vec<i32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub int64_data: Vec<i64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub raw_data: Vec<u8>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub doc_string: Option<String>,
}
impl OnnxTensor {
pub fn spec(name: &str, dims: &[i64], data_type: OnnxDataType) -> Self {
OnnxTensor {
name: name.to_string(),
data_type,
dims: dims.to_vec(),
float_data: Vec::new(),
double_data: Vec::new(),
int32_data: Vec::new(),
int64_data: Vec::new(),
raw_data: Vec::new(),
doc_string: None,
}
}
pub fn from_f32(name: &str, dims: &[i64], data: Vec<f32>) -> Self {
OnnxTensor {
name: name.to_string(),
data_type: OnnxDataType::Float32,
dims: dims.to_vec(),
float_data: data,
double_data: Vec::new(),
int32_data: Vec::new(),
int64_data: Vec::new(),
raw_data: Vec::new(),
doc_string: None,
}
}
pub fn from_f64(name: &str, dims: &[i64], data: Vec<f64>) -> Self {
OnnxTensor {
name: name.to_string(),
data_type: OnnxDataType::Float64,
dims: dims.to_vec(),
float_data: Vec::new(),
double_data: data,
int32_data: Vec::new(),
int64_data: Vec::new(),
raw_data: Vec::new(),
doc_string: None,
}
}
pub fn from_i32(name: &str, dims: &[i64], data: Vec<i32>) -> Self {
OnnxTensor {
name: name.to_string(),
data_type: OnnxDataType::Int32,
dims: dims.to_vec(),
float_data: Vec::new(),
double_data: Vec::new(),
int32_data: data,
int64_data: Vec::new(),
raw_data: Vec::new(),
doc_string: None,
}
}
pub fn from_i64(name: &str, dims: &[i64], data: Vec<i64>) -> Self {
OnnxTensor {
name: name.to_string(),
data_type: OnnxDataType::Int64,
dims: dims.to_vec(),
float_data: Vec::new(),
double_data: Vec::new(),
int32_data: Vec::new(),
int64_data: data,
raw_data: Vec::new(),
doc_string: None,
}
}
pub fn num_elements(&self) -> Option<usize> {
let mut count: usize = 1;
for &d in &self.dims {
if d < 0 {
return None; }
count = count.checked_mul(d as usize)?;
}
Some(count)
}
pub fn has_data(&self) -> bool {
!self.float_data.is_empty()
|| !self.double_data.is_empty()
|| !self.int32_data.is_empty()
|| !self.int64_data.is_empty()
|| !self.raw_data.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OnnxNode {
pub op_type: String,
pub name: String,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub attributes: HashMap<String, OnnxAttribute>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub doc_string: Option<String>,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub domain: String,
}
impl OnnxNode {
pub fn new(op_type: &str, name: &str, inputs: Vec<String>, outputs: Vec<String>) -> Self {
OnnxNode {
op_type: op_type.to_string(),
name: name.to_string(),
inputs,
outputs,
attributes: HashMap::new(),
doc_string: None,
domain: String::new(),
}
}
pub fn with_attribute(mut self, key: &str, value: OnnxAttribute) -> Self {
self.attributes.insert(key.to_string(), value);
self
}
pub fn get_attribute(&self, key: &str) -> Option<&OnnxAttribute> {
self.attributes.get(key)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OnnxOpsetImport {
#[serde(default)]
pub domain: String,
pub version: i64,
}
impl Default for OnnxOpsetImport {
fn default() -> Self {
OnnxOpsetImport {
domain: String::new(),
version: ONNX_OPSET_VERSION,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OnnxGraph {
pub name: String,
pub nodes: Vec<OnnxNode>,
pub inputs: Vec<OnnxTensor>,
pub outputs: Vec<OnnxTensor>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub initializers: Vec<OnnxTensor>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub doc_string: Option<String>,
}
impl OnnxGraph {
pub fn new(name: &str) -> Self {
OnnxGraph {
name: name.to_string(),
nodes: Vec::new(),
inputs: Vec::new(),
outputs: Vec::new(),
initializers: Vec::new(),
doc_string: None,
}
}
pub fn get_node(&self, name: &str) -> Option<&OnnxNode> {
self.nodes.iter().find(|n| n.name == name)
}
pub fn get_initializer(&self, name: &str) -> Option<&OnnxTensor> {
self.initializers.iter().find(|t| t.name == name)
}
pub fn total_parameters(&self) -> usize {
self.initializers
.iter()
.filter_map(|t| t.num_elements())
.sum()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OnnxModel {
pub ir_version: i64,
pub opset_imports: Vec<OnnxOpsetImport>,
pub producer_name: String,
pub producer_version: String,
#[serde(default)]
pub domain: String,
#[serde(default)]
pub model_version: i64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub doc_string: Option<String>,
pub graph: OnnxGraph,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, String>,
}
impl OnnxModel {
pub fn new(graph: OnnxGraph) -> Self {
OnnxModel {
ir_version: ONNX_IR_VERSION,
opset_imports: vec![OnnxOpsetImport::default()],
producer_name: ONNX_PRODUCER_NAME.to_string(),
producer_version: ONNX_PRODUCER_VERSION.to_string(),
domain: String::new(),
model_version: 1,
doc_string: None,
graph,
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
self.metadata.insert(key.to_string(), value.to_string());
self
}
pub fn with_doc_string(mut self, doc: &str) -> Self {
self.doc_string = Some(doc.to_string());
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ConvAttributes {
pub kernel_shape: Vec<i64>,
#[serde(default = "ConvAttributes::default_strides")]
pub strides: Vec<i64>,
#[serde(default)]
pub pads: Vec<i64>,
#[serde(default = "ConvAttributes::default_dilations")]
pub dilations: Vec<i64>,
#[serde(default = "ConvAttributes::default_group")]
pub group: i64,
#[serde(default = "ConvAttributes::default_auto_pad")]
pub auto_pad: String,
}
impl ConvAttributes {
fn default_strides() -> Vec<i64> {
vec![1, 1]
}
fn default_dilations() -> Vec<i64> {
vec![1, 1]
}
fn default_group() -> i64 {
1
}
fn default_auto_pad() -> String {
"NOTSET".to_string()
}
pub fn new(kernel_shape: Vec<i64>) -> Self {
ConvAttributes {
kernel_shape,
strides: vec![1, 1],
pads: vec![0, 0, 0, 0],
dilations: vec![1, 1],
group: 1,
auto_pad: "NOTSET".to_string(),
}
}
pub fn with_strides(mut self, strides: Vec<i64>) -> Self {
self.strides = strides;
self
}
pub fn with_pads(mut self, pads: Vec<i64>) -> Self {
self.pads = pads;
self
}
pub fn with_dilations(mut self, dilations: Vec<i64>) -> Self {
self.dilations = dilations;
self
}
pub fn with_group(mut self, group: i64) -> Self {
self.group = group;
self
}
pub fn to_attributes(&self) -> HashMap<String, OnnxAttribute> {
let mut attrs = HashMap::new();
attrs.insert(
"kernel_shape".to_string(),
OnnxAttribute::Ints(self.kernel_shape.clone()),
);
attrs.insert(
"strides".to_string(),
OnnxAttribute::Ints(self.strides.clone()),
);
attrs.insert("pads".to_string(), OnnxAttribute::Ints(self.pads.clone()));
attrs.insert(
"dilations".to_string(),
OnnxAttribute::Ints(self.dilations.clone()),
);
attrs.insert("group".to_string(), OnnxAttribute::Int(self.group));
if self.auto_pad != "NOTSET" {
attrs.insert(
"auto_pad".to_string(),
OnnxAttribute::String(self.auto_pad.clone()),
);
}
attrs
}
pub fn from_attributes(attrs: &HashMap<String, OnnxAttribute>) -> Option<Self> {
let kernel_shape = attrs.get("kernel_shape")?.as_ints()?.to_vec();
let strides = attrs
.get("strides")
.and_then(|a| a.as_ints())
.map(|s| s.to_vec())
.unwrap_or_else(Self::default_strides);
let pads = attrs
.get("pads")
.and_then(|a| a.as_ints())
.map(|s| s.to_vec())
.unwrap_or_default();
let dilations = attrs
.get("dilations")
.and_then(|a| a.as_ints())
.map(|s| s.to_vec())
.unwrap_or_else(Self::default_dilations);
let group = attrs.get("group").and_then(|a| a.as_int()).unwrap_or(1);
let auto_pad = attrs
.get("auto_pad")
.and_then(|a| a.as_string())
.map(|s| s.to_string())
.unwrap_or_else(Self::default_auto_pad);
Some(ConvAttributes {
kernel_shape,
strides,
pads,
dilations,
group,
auto_pad,
})
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PoolAttributes {
pub kernel_shape: Vec<i64>,
#[serde(default = "PoolAttributes::default_strides")]
pub strides: Vec<i64>,
#[serde(default)]
pub pads: Vec<i64>,
#[serde(default)]
pub ceil_mode: i64,
}
impl PoolAttributes {
fn default_strides() -> Vec<i64> {
vec![1, 1]
}
pub fn new(kernel_shape: Vec<i64>) -> Self {
PoolAttributes {
kernel_shape,
strides: vec![1, 1],
pads: vec![0, 0, 0, 0],
ceil_mode: 0,
}
}
pub fn with_strides(mut self, strides: Vec<i64>) -> Self {
self.strides = strides;
self
}
pub fn with_pads(mut self, pads: Vec<i64>) -> Self {
self.pads = pads;
self
}
pub fn to_attributes(&self) -> HashMap<String, OnnxAttribute> {
let mut attrs = HashMap::new();
attrs.insert(
"kernel_shape".to_string(),
OnnxAttribute::Ints(self.kernel_shape.clone()),
);
attrs.insert(
"strides".to_string(),
OnnxAttribute::Ints(self.strides.clone()),
);
attrs.insert("pads".to_string(), OnnxAttribute::Ints(self.pads.clone()));
if self.ceil_mode != 0 {
attrs.insert("ceil_mode".to_string(), OnnxAttribute::Int(self.ceil_mode));
}
attrs
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GemmAttributes {
#[serde(default = "GemmAttributes::default_alpha")]
pub alpha: f64,
#[serde(default = "GemmAttributes::default_beta")]
pub beta: f64,
#[serde(default)]
pub trans_a: i64,
#[serde(default)]
pub trans_b: i64,
}
impl GemmAttributes {
fn default_alpha() -> f64 {
1.0
}
fn default_beta() -> f64 {
1.0
}
pub fn new() -> Self {
GemmAttributes {
alpha: 1.0,
beta: 1.0,
trans_a: 0,
trans_b: 0,
}
}
pub fn to_attributes(&self) -> HashMap<String, OnnxAttribute> {
let mut attrs = HashMap::new();
if (self.alpha - 1.0).abs() > f64::EPSILON {
attrs.insert("alpha".to_string(), OnnxAttribute::Float(self.alpha));
}
if (self.beta - 1.0).abs() > f64::EPSILON {
attrs.insert("beta".to_string(), OnnxAttribute::Float(self.beta));
}
if self.trans_a != 0 {
attrs.insert("transA".to_string(), OnnxAttribute::Int(self.trans_a));
}
if self.trans_b != 0 {
attrs.insert("transB".to_string(), OnnxAttribute::Int(self.trans_b));
}
attrs
}
}
impl Default for GemmAttributes {
fn default() -> Self {
Self::new()
}
}