use crate::session::Session;
use oxionnx_core::{OnnxError, Tensor};
use std::marker::PhantomData;
pub trait Shape {
fn dims() -> &'static [usize];
fn matches(shape: &[usize]) -> bool {
let expected = Self::dims();
if shape.len() != expected.len() {
return false;
}
expected.iter().zip(shape).all(|(&e, &a)| e == 0 || e == a)
}
}
pub struct TypedSession<I: Shape, O: Shape> {
inner: Session,
input_name: String,
output_name: String,
_phantom: PhantomData<(I, O)>,
}
impl<I: Shape, O: Shape> TypedSession<I, O> {
pub fn new(session: Session, input_name: &str, output_name: &str) -> Result<Self, OnnxError> {
if !session.input_names().contains(&input_name.to_string()) {
return Err(OnnxError::TensorNotFound(format!(
"TypedSession: input '{}' not found in model",
input_name
)));
}
if !session.output_names().contains(&output_name.to_string()) {
return Err(OnnxError::TensorNotFound(format!(
"TypedSession: output '{}' not found in model",
output_name
)));
}
Ok(Self {
inner: session,
input_name: input_name.to_string(),
output_name: output_name.to_string(),
_phantom: PhantomData,
})
}
pub fn run(&self, input: &Tensor) -> Result<Tensor, OnnxError> {
if !I::matches(&input.shape) {
return Err(OnnxError::ShapeMismatch(format!(
"TypedSession: input shape {:?} does not match expected {:?}",
input.shape,
I::dims()
)));
}
let outputs = self.inner.run_one(&self.input_name, input.clone())?;
let output = outputs.get(&self.output_name).ok_or_else(|| {
OnnxError::TensorNotFound(format!(
"TypedSession: output '{}' not produced",
self.output_name
))
})?;
if !O::matches(&output.shape) {
return Err(OnnxError::ShapeMismatch(format!(
"TypedSession: output shape {:?} does not match expected {:?}",
output.shape,
O::dims()
)));
}
Ok(output.clone())
}
pub fn inner(&self) -> &Session {
&self.inner
}
}
#[macro_export]
macro_rules! define_shape {
($name:ident, [$($dim:expr),*]) => {
pub struct $name;
impl $crate::typed_session::Shape for $name {
fn dims() -> &'static [usize] {
&[$($dim),*]
}
}
};
}
pub struct Scalar;
impl Shape for Scalar {
fn dims() -> &'static [usize] {
&[1]
}
}
pub struct Dynamic1D;
impl Shape for Dynamic1D {
fn dims() -> &'static [usize] {
&[0]
}
}
pub struct Dynamic2D;
impl Shape for Dynamic2D {
fn dims() -> &'static [usize] {
&[0, 0]
}
}
pub struct ImageNet224;
impl Shape for ImageNet224 {
fn dims() -> &'static [usize] {
&[0, 3, 224, 224]
}
}
pub struct BertInput;
impl Shape for BertInput {
fn dims() -> &'static [usize] {
&[0, 0]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shape_matches() {
assert!(ImageNet224::matches(&[1, 3, 224, 224]));
assert!(ImageNet224::matches(&[32, 3, 224, 224]));
assert!(!ImageNet224::matches(&[1, 4, 224, 224]));
assert!(!ImageNet224::matches(&[1, 3, 256, 256]));
assert!(!ImageNet224::matches(&[1, 3, 224]));
assert!(Dynamic1D::matches(&[42]));
assert!(Dynamic1D::matches(&[1]));
assert!(!Dynamic1D::matches(&[1, 2]));
assert!(Dynamic2D::matches(&[10, 20]));
assert!(!Dynamic2D::matches(&[10]));
assert!(Scalar::matches(&[1]));
assert!(!Scalar::matches(&[2]));
}
#[test]
fn test_imagenet_shape() {
let dims = ImageNet224::dims();
assert_eq!(dims, &[0, 3, 224, 224]);
assert_eq!(dims.len(), 4);
}
#[test]
fn test_bert_shape() {
let dims = BertInput::dims();
assert_eq!(dims, &[0, 0]);
assert_eq!(dims.len(), 2);
assert!(BertInput::matches(&[1, 128]));
assert!(BertInput::matches(&[16, 512]));
assert!(!BertInput::matches(&[1, 128, 768])); }
#[test]
fn test_define_shape_macro() {
crate::define_shape!(CustomShape, [0, 10, 20]);
assert_eq!(CustomShape::dims(), &[0, 10, 20]);
assert!(CustomShape::matches(&[5, 10, 20]));
assert!(!CustomShape::matches(&[5, 10, 30]));
assert!(!CustomShape::matches(&[5, 10]));
}
#[test]
fn test_typed_session_wrong_input_shape() {
let shape = &[1, 3, 256, 256];
assert!(
!ImageNet224::matches(shape),
"256x256 should not match 224x224"
);
let shape_ok = &[1, 3, 224, 224];
assert!(ImageNet224::matches(shape_ok));
}
#[test]
fn test_typed_session_dynamic() {
crate::define_shape!(DynBatch, [0, 768]);
assert!(DynBatch::matches(&[1, 768]));
assert!(DynBatch::matches(&[128, 768]));
assert!(!DynBatch::matches(&[1, 512])); assert!(!DynBatch::matches(&[1])); }
#[test]
fn test_typed_session_basic() {
assert!(Scalar::matches(&[1]));
assert!(!Scalar::matches(&[0]));
assert!(!Scalar::matches(&[]));
assert!(Dynamic1D::matches(&[100]));
assert!(Dynamic2D::matches(&[3, 4]));
assert!(ImageNet224::matches(&[8, 3, 224, 224]));
}
}