use crate::proto::onnx::TensorProto;
pub trait Scalar: Copy + Send + Sync + 'static {
fn to_f32(&self) -> f32;
}
impl Scalar for f32 {
fn to_f32(&self) -> f32 {
*self
}
}
impl Scalar for f64 {
fn to_f32(&self) -> f32 {
*self as f32
}
}
pub trait Tensor:
Clone
+ std::fmt::Debug
+ std::fmt::Display
+ Send
+ Sync
+ 'static
+ serde::Serialize
+ serde::de::DeserializeOwned
{
type Scalar: Scalar;
fn dims(&self) -> &[i64];
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn to_proto(&self) -> TensorProto;
fn from_proto(proto: TensorProto) -> Result<Self, TensorSerializationError>;
}
#[derive(Debug)]
pub enum TensorSerializationError {
ElementTypeMismatch {
expected: i32,
found: i32,
},
ShapeError(String),
Custom(String),
}
impl std::fmt::Display for TensorSerializationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ElementTypeMismatch { expected, found } => {
write!(
f,
"tensor elem_type mismatch: expected {expected}, found {found}"
)
}
Self::ShapeError(m) => write!(f, "tensor shape error: {m}"),
Self::Custom(m) => write!(f, "tensor serialization failure: {m}"),
}
}
}
impl std::error::Error for TensorSerializationError {}