use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum OnnxError {
InvalidGraph(String),
UnsupportedOp(String),
ShapeMismatch(String),
TypeError(String),
InvalidAttribute(String),
ExecutionError(String),
InvalidData(String),
}
impl fmt::Display for OnnxError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidGraph(msg) => write!(f, "invalid graph: {msg}"),
Self::UnsupportedOp(msg) => write!(f, "unsupported op: {msg}"),
Self::ShapeMismatch(msg) => write!(f, "shape mismatch: {msg}"),
Self::TypeError(msg) => write!(f, "type error: {msg}"),
Self::InvalidAttribute(msg) => write!(f, "invalid attribute: {msg}"),
Self::ExecutionError(msg) => write!(f, "execution error: {msg}"),
Self::InvalidData(msg) => write!(f, "invalid data: {msg}"),
}
}
}
impl std::error::Error for OnnxError {}
pub type OnnxResult<T> = Result<T, OnnxError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DataType {
Float32,
Float64,
Float16,
Int32,
Int64,
Int8,
Uint8,
Bool,
}
impl DataType {
pub fn size_bytes(self) -> usize {
match self {
Self::Float32 => 4,
Self::Float64 => 8,
Self::Float16 => 2,
Self::Int32 => 4,
Self::Int64 => 8,
Self::Int8 | Self::Uint8 | Self::Bool => 1,
}
}
}
impl fmt::Display for DataType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Float32 => write!(f, "float32"),
Self::Float64 => write!(f, "float64"),
Self::Float16 => write!(f, "float16"),
Self::Int32 => write!(f, "int32"),
Self::Int64 => write!(f, "int64"),
Self::Int8 => write!(f, "int8"),
Self::Uint8 => write!(f, "uint8"),
Self::Bool => write!(f, "bool"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TensorShape {
pub dims: Vec<Option<usize>>,
}
impl TensorShape {
pub fn fixed(dims: Vec<usize>) -> Self {
Self {
dims: dims.into_iter().map(Some).collect(),
}
}
pub fn new(dims: Vec<Option<usize>>) -> Self {
Self { dims }
}
pub fn is_fully_known(&self) -> bool {
self.dims.iter().all(Option::is_some)
}
pub fn rank(&self) -> usize {
self.dims.len()
}
pub fn to_concrete(&self) -> OnnxResult<Vec<usize>> {
self.dims
.iter()
.enumerate()
.map(|(i, d)| {
d.ok_or_else(|| OnnxError::ShapeMismatch(format!("dimension {i} is dynamic")))
})
.collect()
}
pub fn element_count(&self) -> Option<usize> {
self.dims
.iter()
.copied()
.try_fold(1usize, |acc, d| d.and_then(|s| acc.checked_mul(s)))
}
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: String,
pub dtype: DataType,
pub shape: TensorShape,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AttributeValue {
Int(i64),
Float(f64),
String(String),
Tensor(OnnxTensor),
Ints(Vec<i64>),
Floats(Vec<f64>),
Strings(Vec<String>),
}
impl AttributeValue {
pub fn as_int(&self) -> OnnxResult<i64> {
match self {
Self::Int(v) => Ok(*v),
other => Err(OnnxError::InvalidAttribute(format!(
"expected Int, got {other:?}"
))),
}
}
pub fn as_float(&self) -> OnnxResult<f64> {
match self {
Self::Float(v) => Ok(*v),
other => Err(OnnxError::InvalidAttribute(format!(
"expected Float, got {other:?}"
))),
}
}
pub fn as_string(&self) -> OnnxResult<&str> {
match self {
Self::String(v) => Ok(v.as_str()),
other => Err(OnnxError::InvalidAttribute(format!(
"expected String, got {other:?}"
))),
}
}
pub fn as_ints(&self) -> OnnxResult<&[i64]> {
match self {
Self::Ints(v) => Ok(v.as_slice()),
other => Err(OnnxError::InvalidAttribute(format!(
"expected Ints, got {other:?}"
))),
}
}
pub fn as_floats(&self) -> OnnxResult<&[f64]> {
match self {
Self::Floats(v) => Ok(v.as_slice()),
other => Err(OnnxError::InvalidAttribute(format!(
"expected Floats, got {other:?}"
))),
}
}
}
#[derive(Debug, Clone)]
pub struct Node {
pub op_type: String,
pub name: String,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
pub attributes: HashMap<String, AttributeValue>,
}
#[derive(Debug, Clone)]
pub struct Graph {
pub name: String,
pub nodes: Vec<Node>,
pub inputs: Vec<TensorInfo>,
pub outputs: Vec<TensorInfo>,
pub initializers: HashMap<String, OnnxTensor>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct OnnxTensor {
pub data: Vec<u8>,
pub dtype: DataType,
pub shape: Vec<usize>,
}
impl OnnxTensor {
pub fn element_count(&self) -> usize {
if self.shape.is_empty() {
1 } else {
self.shape.iter().product()
}
}
pub fn size_bytes(&self) -> usize {
self.element_count() * self.dtype.size_bytes()
}
pub fn zeros(shape: Vec<usize>, dtype: DataType) -> Self {
let count: usize = if shape.is_empty() {
1
} else {
shape.iter().product()
};
let data = vec![0u8; count * dtype.size_bytes()];
Self { data, dtype, shape }
}
pub fn from_f32(values: &[f32], shape: Vec<usize>) -> Self {
let data: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
Self {
data,
dtype: DataType::Float32,
shape,
}
}
pub fn from_f64(values: &[f64], shape: Vec<usize>) -> Self {
let data: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
Self {
data,
dtype: DataType::Float64,
shape,
}
}
pub fn from_i32(values: &[i32], shape: Vec<usize>) -> Self {
let data: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
Self {
data,
dtype: DataType::Int32,
shape,
}
}
pub fn from_i64(values: &[i64], shape: Vec<usize>) -> Self {
let data: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
Self {
data,
dtype: DataType::Int64,
shape,
}
}
pub fn from_bool(values: &[bool], shape: Vec<usize>) -> Self {
let data: Vec<u8> = values.iter().map(|&v| u8::from(v)).collect();
Self {
data,
dtype: DataType::Bool,
shape,
}
}
pub fn scalar_f32(value: f32) -> Self {
Self::from_f32(&[value], vec![])
}
pub fn scalar_i64(value: i64) -> Self {
Self::from_i64(&[value], vec![])
}
pub fn as_f32(&self) -> OnnxResult<Vec<f32>> {
if self.dtype != DataType::Float32 {
return Err(OnnxError::TypeError(format!(
"expected Float32, got {:?}",
self.dtype
)));
}
Ok(self
.data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect())
}
pub fn as_f64(&self) -> OnnxResult<Vec<f64>> {
if self.dtype != DataType::Float64 {
return Err(OnnxError::TypeError(format!(
"expected Float64, got {:?}",
self.dtype
)));
}
Ok(self
.data
.chunks_exact(8)
.map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
.collect())
}
pub fn as_i32(&self) -> OnnxResult<Vec<i32>> {
if self.dtype != DataType::Int32 {
return Err(OnnxError::TypeError(format!(
"expected Int32, got {:?}",
self.dtype
)));
}
Ok(self
.data
.chunks_exact(4)
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect())
}
pub fn as_i64(&self) -> OnnxResult<Vec<i64>> {
if self.dtype != DataType::Int64 {
return Err(OnnxError::TypeError(format!(
"expected Int64, got {:?}",
self.dtype
)));
}
Ok(self
.data
.chunks_exact(8)
.map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
.collect())
}
pub fn as_bool(&self) -> OnnxResult<Vec<bool>> {
if self.dtype != DataType::Bool {
return Err(OnnxError::TypeError(format!(
"expected Bool, got {:?}",
self.dtype
)));
}
Ok(self.data.iter().map(|&b| b != 0).collect())
}
}
pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> OnnxResult<Vec<usize>> {
let max_rank = a.len().max(b.len());
let mut result = Vec::with_capacity(max_rank);
for i in 0..max_rank {
let da = if i < max_rank - a.len() {
1
} else {
a[i + a.len() - max_rank]
};
let db = if i < max_rank - b.len() {
1
} else {
b[i + b.len() - max_rank]
};
if da == db {
result.push(da);
} else if da == 1 {
result.push(db);
} else if db == 1 {
result.push(da);
} else {
return Err(OnnxError::ShapeMismatch(format!(
"cannot broadcast shapes {a:?} and {b:?} at dimension {i}"
)));
}
}
Ok(result)
}
pub fn flat_to_multi(flat: usize, shape: &[usize]) -> Vec<usize> {
if shape.is_empty() {
return vec![];
}
let mut indices = vec![0usize; shape.len()];
let mut remaining = flat;
for i in (0..shape.len()).rev() {
if shape[i] > 0 {
indices[i] = remaining % shape[i];
remaining /= shape[i];
}
}
indices
}
pub fn multi_to_flat(indices: &[usize], shape: &[usize]) -> usize {
let mut flat = 0usize;
let mut stride = 1usize;
for i in (0..shape.len()).rev() {
flat += indices[i] * stride;
stride *= shape[i];
}
flat
}
pub fn broadcast_index(out_indices: &[usize], in_shape: &[usize], out_shape: &[usize]) -> usize {
if in_shape.is_empty() {
return 0; }
let offset = out_shape.len() - in_shape.len();
let mut mapped = vec![0usize; in_shape.len()];
for i in 0..in_shape.len() {
mapped[i] = if in_shape[i] == 1 {
0
} else {
out_indices[i + offset]
};
}
multi_to_flat(&mapped, in_shape)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_type_sizes() {
assert_eq!(DataType::Float32.size_bytes(), 4);
assert_eq!(DataType::Float64.size_bytes(), 8);
assert_eq!(DataType::Float16.size_bytes(), 2);
assert_eq!(DataType::Int32.size_bytes(), 4);
assert_eq!(DataType::Int64.size_bytes(), 8);
assert_eq!(DataType::Int8.size_bytes(), 1);
assert_eq!(DataType::Uint8.size_bytes(), 1);
assert_eq!(DataType::Bool.size_bytes(), 1);
}
#[test]
fn test_tensor_shape_fixed() {
let shape = TensorShape::fixed(vec![2, 3, 4]);
assert!(shape.is_fully_known());
assert_eq!(shape.rank(), 3);
assert_eq!(shape.element_count(), Some(24));
assert_eq!(shape.to_concrete().ok(), Some(vec![2, 3, 4]));
}
#[test]
fn test_tensor_shape_dynamic() {
let shape = TensorShape::new(vec![Some(2), None, Some(4)]);
assert!(!shape.is_fully_known());
assert_eq!(shape.element_count(), None);
assert!(shape.to_concrete().is_err());
}
#[test]
fn test_tensor_f32_roundtrip() {
let values = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let t = OnnxTensor::from_f32(&values, vec![2, 3]);
assert_eq!(t.element_count(), 6);
assert_eq!(t.dtype, DataType::Float32);
assert_eq!(t.as_f32().ok(), Some(values));
}
#[test]
fn test_tensor_i64_roundtrip() {
let values = vec![10i64, 20, 30];
let t = OnnxTensor::from_i64(&values, vec![3]);
assert_eq!(t.as_i64().ok(), Some(values));
}
#[test]
fn test_tensor_bool_roundtrip() {
let values = vec![true, false, true, false];
let t = OnnxTensor::from_bool(&values, vec![4]);
assert_eq!(t.as_bool().ok(), Some(values));
}
#[test]
fn test_tensor_type_mismatch() {
let t = OnnxTensor::from_f32(&[1.0], vec![1]);
assert!(t.as_i64().is_err());
assert!(t.as_bool().is_err());
}
#[test]
fn test_scalar_tensor() {
let t = OnnxTensor::scalar_f32(7.125);
assert_eq!(t.shape, Vec::<usize>::new());
assert_eq!(t.element_count(), 1);
assert_eq!(t.as_f32().ok(), Some(vec![7.125]));
}
#[test]
fn test_zeros() {
let t = OnnxTensor::zeros(vec![2, 3], DataType::Float32);
let data = t.as_f32().ok();
assert_eq!(data, Some(vec![0.0; 6]));
}
#[test]
fn test_attribute_accessors() {
let a = AttributeValue::Int(42);
assert_eq!(a.as_int().ok(), Some(42));
assert!(a.as_float().is_err());
let b = AttributeValue::Float(7.125);
assert!((b.as_float().ok().unwrap_or(0.0) - 7.125).abs() < 1e-10);
let c = AttributeValue::String("relu".into());
assert_eq!(c.as_string().ok(), Some("relu"));
let d = AttributeValue::Ints(vec![1, 2, 3]);
assert_eq!(d.as_ints().ok(), Some(&[1i64, 2, 3][..]));
}
#[test]
fn test_broadcast_shapes() {
assert_eq!(broadcast_shapes(&[3, 4], &[3, 4]).ok(), Some(vec![3, 4]));
assert_eq!(broadcast_shapes(&[1, 4], &[3, 1]).ok(), Some(vec![3, 4]));
assert_eq!(broadcast_shapes(&[4], &[3, 4]).ok(), Some(vec![3, 4]));
assert_eq!(
broadcast_shapes(&[2, 1, 4], &[3, 1]).ok(),
Some(vec![2, 3, 4])
);
assert!(broadcast_shapes(&[3, 4], &[3, 5]).is_err());
}
#[test]
fn test_flat_to_multi() {
assert_eq!(flat_to_multi(0, &[2, 3]), vec![0, 0]);
assert_eq!(flat_to_multi(1, &[2, 3]), vec![0, 1]);
assert_eq!(flat_to_multi(3, &[2, 3]), vec![1, 0]);
assert_eq!(flat_to_multi(5, &[2, 3]), vec![1, 2]);
}
#[test]
fn test_multi_to_flat() {
assert_eq!(multi_to_flat(&[0, 0], &[2, 3]), 0);
assert_eq!(multi_to_flat(&[0, 1], &[2, 3]), 1);
assert_eq!(multi_to_flat(&[1, 0], &[2, 3]), 3);
assert_eq!(multi_to_flat(&[1, 2], &[2, 3]), 5);
}
#[test]
fn test_broadcast_index_scalar() {
assert_eq!(broadcast_index(&[0, 0], &[], &[2, 3]), 0);
}
#[test]
fn test_graph_construction() {
let graph = Graph {
name: "test".into(),
nodes: vec![Node {
op_type: "Relu".into(),
name: "relu0".into(),
inputs: vec!["x".into()],
outputs: vec!["y".into()],
attributes: HashMap::new(),
}],
inputs: vec![TensorInfo {
name: "x".into(),
dtype: DataType::Float32,
shape: TensorShape::fixed(vec![1, 3]),
}],
outputs: vec![TensorInfo {
name: "y".into(),
dtype: DataType::Float32,
shape: TensorShape::fixed(vec![1, 3]),
}],
initializers: HashMap::new(),
};
assert_eq!(graph.nodes.len(), 1);
assert_eq!(graph.nodes[0].op_type, "Relu");
}
}