use crate::{DType, Device, Result, Shape, Tensor, TensorError};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(i32)]
pub enum OnnxDataType {
Undefined = 0,
Float = 1, UInt8 = 2, Int8 = 3, UInt16 = 4, Int16 = 5, Int32 = 6, Int64 = 7, String = 8, Bool = 9, Float16 = 10, Double = 11, UInt32 = 12, UInt64 = 13, Complex64 = 14, Complex128 = 15, BFloat16 = 16, }
impl OnnxDataType {
pub fn from_dtype(dtype: DType) -> Result<Self> {
match dtype {
DType::Float32 => Ok(Self::Float),
DType::Float64 => Ok(Self::Double),
DType::Int8 => Ok(Self::Int8),
DType::UInt8 => Ok(Self::UInt8),
DType::Int16 => Ok(Self::Int16),
DType::UInt16 => Ok(Self::UInt16),
DType::Int32 => Ok(Self::Int32),
DType::UInt32 => Ok(Self::UInt32),
DType::Int64 => Ok(Self::Int64),
DType::UInt64 => Ok(Self::UInt64),
DType::Bool => Ok(Self::Bool),
DType::Float16 => Ok(Self::Float16),
DType::BFloat16 => Ok(Self::BFloat16),
DType::Complex64 => Ok(Self::Complex64),
DType::String => Ok(Self::String),
_ => Err(TensorError::unsupported_operation_simple(format!(
"ONNX does not support dtype: {:?}",
dtype
))),
}
}
pub fn to_dtype(&self) -> Result<DType> {
match self {
Self::Float => Ok(DType::Float32),
Self::Double => Ok(DType::Float64),
Self::Int8 => Ok(DType::Int8),
Self::UInt8 => Ok(DType::UInt8),
Self::Int16 => Ok(DType::Int16),
Self::UInt16 => Ok(DType::UInt16),
Self::Int32 => Ok(DType::Int32),
Self::UInt32 => Ok(DType::UInt32),
Self::Int64 => Ok(DType::Int64),
Self::UInt64 => Ok(DType::UInt64),
Self::Bool => Ok(DType::Bool),
Self::Float16 => Ok(DType::Float16),
Self::BFloat16 => Ok(DType::BFloat16),
Self::Complex64 => Ok(DType::Complex64),
Self::String => Ok(DType::String),
Self::Undefined => Err(TensorError::invalid_argument(
"Cannot convert undefined ONNX type to DType".to_string(),
)),
Self::Complex128 => Err(TensorError::unsupported_operation_simple(
"Complex128 not yet supported in TenfloweRS".to_string(),
)),
}
}
pub fn element_size(&self) -> usize {
match self {
Self::Float => 4,
Self::Double => 8,
Self::Int8 => 1,
Self::UInt8 => 1,
Self::Int16 => 2,
Self::UInt16 => 2,
Self::Int32 => 4,
Self::UInt32 => 4,
Self::Int64 => 8,
Self::UInt64 => 8,
Self::Bool => 1,
Self::Float16 => 2,
Self::BFloat16 => 2,
Self::Complex64 => 8,
Self::Complex128 => 16,
Self::String => 0, Self::Undefined => 0,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct OnnxTensorProto {
pub dims: Vec<i64>,
pub data_type: i32,
pub segment: Option<OnnxSegment>,
pub float_data: Vec<f32>,
pub int32_data: Vec<i32>,
pub string_data: Vec<Vec<u8>>,
pub int64_data: Vec<i64>,
pub name: String,
pub doc_string: String,
pub raw_data: Vec<u8>,
pub external_data: Vec<OnnxExternalData>,
}
impl Default for OnnxTensorProto {
fn default() -> Self {
Self {
dims: Vec::new(),
data_type: OnnxDataType::Undefined as i32,
segment: None,
float_data: Vec::new(),
int32_data: Vec::new(),
string_data: Vec::new(),
int64_data: Vec::new(),
name: String::new(),
doc_string: String::new(),
raw_data: Vec::new(),
external_data: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct OnnxSegment {
pub begin: i64,
pub end: i64,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
pub struct OnnxExternalData {
pub key: String,
pub value: String,
}
impl OnnxTensorProto {
pub fn new() -> Self {
Self::default()
}
pub fn with_name(mut self, name: String) -> Self {
self.name = name;
self
}
pub fn with_doc(mut self, doc: String) -> Self {
self.doc_string = doc;
self
}
pub fn num_elements(&self) -> i64 {
if self.dims.is_empty() {
0
} else {
self.dims.iter().product()
}
}
}
pub fn serialize_tensor_onnx<T>(tensor: &Tensor<T>, name: Option<String>) -> Result<OnnxTensorProto>
where
T: bytemuck::Pod + scirs2_core::num_traits::Float + Default + 'static,
{
let dtype = tensor.dtype();
let onnx_dtype = OnnxDataType::from_dtype(dtype)?;
let mut proto = OnnxTensorProto::new();
proto.dims = tensor.shape().dims().iter().map(|&d| d as i64).collect();
proto.data_type = onnx_dtype as i32;
if let Some(n) = name {
proto.name = n;
}
let data_slice = tensor.data();
let data_bytes: &[u8] = bytemuck::cast_slice(data_slice);
proto.raw_data = data_bytes.to_vec();
Ok(proto)
}
pub fn deserialize_tensor_onnx<T>(proto: &OnnxTensorProto) -> Result<Tensor<T>>
where
T: bytemuck::Pod + scirs2_core::num_traits::Float + Default + 'static,
{
let onnx_dtype = match proto.data_type {
x if x == OnnxDataType::Float as i32 => OnnxDataType::Float,
x if x == OnnxDataType::Double as i32 => OnnxDataType::Double,
x if x == OnnxDataType::Int32 as i32 => OnnxDataType::Int32,
x if x == OnnxDataType::Int64 as i32 => OnnxDataType::Int64,
x if x == OnnxDataType::Float16 as i32 => OnnxDataType::Float16,
x if x == OnnxDataType::BFloat16 as i32 => OnnxDataType::BFloat16,
_ => {
return Err(TensorError::unsupported_operation_simple(format!(
"Unsupported ONNX data type: {}",
proto.data_type
)))
}
};
let _dtype = onnx_dtype.to_dtype()?;
if proto.dims.is_empty() {
return Err(TensorError::invalid_shape_simple(
"ONNX tensor has empty dimensions".to_string(),
));
}
let shape_vec: Vec<usize> = proto.dims.iter().map(|&d| d as usize).collect();
let shape = Shape::from_slice(&shape_vec);
if !proto.raw_data.is_empty() {
let data_slice: &[T] = bytemuck::cast_slice(&proto.raw_data);
let data_vec = data_slice.to_vec();
use scirs2_core::ndarray::ArrayD;
let array = ArrayD::from_shape_vec(shape.dims(), data_vec).map_err(|e| {
TensorError::invalid_shape_simple(format!("Failed to create array from ONNX: {}", e))
})?;
return Ok(Tensor::from_array(array));
}
if !proto.float_data.is_empty() {
let data_slice: &[T] = bytemuck::cast_slice(&proto.float_data);
let data_vec = data_slice.to_vec();
use scirs2_core::ndarray::ArrayD;
let array = ArrayD::from_shape_vec(shape.dims(), data_vec).map_err(|e| {
TensorError::invalid_shape_simple(format!("Failed to create array from ONNX: {}", e))
})?;
return Ok(Tensor::from_array(array));
}
Err(TensorError::serialization_error_simple(
"ONNX tensor has no data".to_string(),
))
}
pub fn serialize_f32_tensor_onnx(tensor: &Tensor<f32>, name: Option<String>) -> Result<Vec<u8>> {
let proto = serialize_tensor_onnx(tensor, name)?;
#[cfg(feature = "serialize")]
{
serde_json::to_vec(&proto).map_err(|e| {
TensorError::serialization_error_simple(format!(
"ONNX JSON serialization failed: {}",
e
))
})
}
#[cfg(not(feature = "serialize"))]
{
let mut bytes = Vec::new();
bytes.extend_from_slice(&(proto.dims.len() as u32).to_le_bytes());
for dim in &proto.dims {
bytes.extend_from_slice(&dim.to_le_bytes());
}
bytes.extend_from_slice(&proto.data_type.to_le_bytes());
bytes.extend_from_slice(&(proto.raw_data.len() as u64).to_le_bytes());
bytes.extend_from_slice(&proto.raw_data);
Ok(bytes)
}
}
pub fn deserialize_f32_tensor_onnx(bytes: &[u8]) -> Result<Tensor<f32>> {
#[cfg(feature = "serialize")]
{
let proto: OnnxTensorProto = serde_json::from_slice(bytes).map_err(|e| {
TensorError::serialization_error_simple(format!(
"ONNX JSON deserialization failed: {}",
e
))
})?;
deserialize_tensor_onnx(&proto)
}
#[cfg(not(feature = "serialize"))]
{
let mut cursor = 0;
if bytes.len() < 4 {
return Err(TensorError::serialization_error_simple(
"ONNX data too small".to_string(),
));
}
let num_dims = u32::from_le_bytes([
bytes[cursor],
bytes[cursor + 1],
bytes[cursor + 2],
bytes[cursor + 3],
]) as usize;
cursor += 4;
let mut dims = Vec::with_capacity(num_dims);
for _ in 0..num_dims {
if bytes.len() < cursor + 8 {
return Err(TensorError::serialization_error_simple(
"ONNX data too small for dimensions".to_string(),
));
}
let dim = i64::from_le_bytes([
bytes[cursor],
bytes[cursor + 1],
bytes[cursor + 2],
bytes[cursor + 3],
bytes[cursor + 4],
bytes[cursor + 5],
bytes[cursor + 6],
bytes[cursor + 7],
]);
dims.push(dim);
cursor += 8;
}
if bytes.len() < cursor + 4 {
return Err(TensorError::serialization_error_simple(
"ONNX data too small for data type".to_string(),
));
}
let _data_type = i32::from_le_bytes([
bytes[cursor],
bytes[cursor + 1],
bytes[cursor + 2],
bytes[cursor + 3],
]);
cursor += 4;
if bytes.len() < cursor + 8 {
return Err(TensorError::serialization_error_simple(
"ONNX data too small for data length".to_string(),
));
}
let data_len = u64::from_le_bytes([
bytes[cursor],
bytes[cursor + 1],
bytes[cursor + 2],
bytes[cursor + 3],
bytes[cursor + 4],
bytes[cursor + 5],
bytes[cursor + 6],
bytes[cursor + 7],
]) as usize;
cursor += 8;
if bytes.len() < cursor + data_len {
return Err(TensorError::serialization_error_simple(
"ONNX data too small for raw data".to_string(),
));
}
let raw_data = &bytes[cursor..cursor + data_len];
let data_slice: &[f32] = bytemuck::cast_slice(raw_data);
let data_vec = data_slice.to_vec();
let shape_vec: Vec<usize> = dims.iter().map(|&d| d as usize).collect();
let shape = Shape::from_slice(&shape_vec);
use scirs2_core::ndarray::ArrayD;
let array = ArrayD::from_shape_vec(shape.dims(), data_vec).map_err(|e| {
TensorError::invalid_shape_simple(format!("Failed to create array from ONNX: {}", e))
})?;
Ok(Tensor::from_array(array))
}
}
pub fn onnx_strides(shape: &[usize]) -> Vec<i64> {
let mut strides = vec![1i64; shape.len()];
if shape.is_empty() {
return strides;
}
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * (shape[i + 1] as i64);
}
strides
}
pub fn is_onnx_compatible<T>(tensor: &Tensor<T>) -> bool
where
T: bytemuck::Pod + scirs2_core::num_traits::Float + Default + 'static,
{
true
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_onnx_dtype_conversion() {
assert_eq!(
OnnxDataType::from_dtype(DType::Float32).expect("test: from_dtype should succeed"),
OnnxDataType::Float
);
assert_eq!(
OnnxDataType::from_dtype(DType::Float64).expect("test: from_dtype should succeed"),
OnnxDataType::Double
);
assert_eq!(
OnnxDataType::from_dtype(DType::Int32).expect("test: from_dtype should succeed"),
OnnxDataType::Int32
);
assert_eq!(
OnnxDataType::Float
.to_dtype()
.expect("test: to_dtype should succeed"),
DType::Float32
);
assert_eq!(
OnnxDataType::Double
.to_dtype()
.expect("test: to_dtype should succeed"),
DType::Float64
);
}
#[test]
fn test_onnx_element_size() {
assert_eq!(OnnxDataType::Float.element_size(), 4);
assert_eq!(OnnxDataType::Double.element_size(), 8);
assert_eq!(OnnxDataType::Int32.element_size(), 4);
assert_eq!(OnnxDataType::Int8.element_size(), 1);
}
#[test]
fn test_onnx_tensor_proto_creation() {
let proto = OnnxTensorProto::new()
.with_name("test_tensor".to_string())
.with_doc("Test documentation".to_string());
assert_eq!(proto.name, "test_tensor");
assert_eq!(proto.doc_string, "Test documentation");
}
#[test]
fn test_serialize_onnx_f32() {
let data = array![[1.0f32, 2.0], [3.0, 4.0]];
let tensor = Tensor::from_array(data.into_dyn());
let proto = serialize_tensor_onnx(&tensor, Some("weights".to_string()))
.expect("test: operation should succeed");
assert_eq!(proto.name, "weights");
assert_eq!(proto.dims, vec![2, 2]);
assert_eq!(proto.data_type, OnnxDataType::Float as i32);
assert!(!proto.raw_data.is_empty());
assert_eq!(proto.raw_data.len(), 4 * 4); }
#[test]
fn test_deserialize_onnx_f32() {
let data = array![[1.0f32, 2.0], [3.0, 4.0]];
let tensor = Tensor::from_array(data.into_dyn());
let proto = serialize_tensor_onnx(&tensor, None)
.expect("test: serialize_tensor_onnx should succeed");
let deserialized =
deserialize_tensor_onnx::<f32>(&proto).expect("test: operation should succeed");
assert_eq!(tensor.shape(), deserialized.shape());
assert_eq!(tensor.data(), deserialized.data());
}
#[test]
fn test_onnx_strides() {
let shape = vec![2, 3, 4];
let strides = onnx_strides(&shape);
assert_eq!(strides, vec![12, 4, 1]);
let shape2 = vec![5];
let strides2 = onnx_strides(&shape2);
assert_eq!(strides2, vec![1]);
}
#[test]
fn test_serialize_deserialize_f32_onnx() {
let data = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
let tensor = Tensor::from_array(data.into_dyn());
let bytes = serialize_f32_tensor_onnx(&tensor, Some("test".to_string()))
.expect("test: operation should succeed");
let deserialized = deserialize_f32_tensor_onnx(&bytes)
.expect("test: deserialize_f32_tensor_onnx should succeed");
assert_eq!(tensor.shape(), deserialized.shape());
assert_eq!(tensor.data(), deserialized.data());
}
#[test]
fn test_onnx_compatible_check() {
let data = array![1.0f32, 2.0, 3.0, 4.0];
let tensor = Tensor::from_array(data.into_dyn());
assert!(is_onnx_compatible(&tensor));
}
#[test]
fn test_onnx_proto_num_elements() {
let mut proto = OnnxTensorProto::new();
proto.dims = vec![2, 3, 4];
assert_eq!(proto.num_elements(), 24);
proto.dims = vec![];
assert_eq!(proto.num_elements(), 0);
}
}