1use std::error::Error;
2use std::fmt::{Display, Formatter};
3use std::io;
4use std::path::{Path, PathBuf};
5
6use crate::onnx::load::OnnxDimValue;
7use crate::onnx::proto::attribute_proto::AttributeType;
8use crate::onnx::proto::tensor_proto::DataType;
9use crate::onnx::typed_value::AsShapeError;
10use crate::shape::Shape;
11
12pub type OnnxResult<T> = Result<T, OnnxError>;
13
14#[derive(Debug, Copy, Clone)]
15pub struct Node<S = String> {
16 pub name: S,
17 pub op_type: S,
18}
19
20#[derive(Debug)]
22pub enum OnnxError {
23 IO(PathBuf, io::Error),
24
25 NonNormalExternalDataPath(PathBuf),
26 MustHaveParentPath(PathBuf),
27 FailedToShapeInput(Vec<OnnxDimValue>, String, usize),
28
29 MissingProtoField(&'static str),
30
31 LeftoverInputs(Node, Vec<usize>),
32 LeftoverAttributes(Node, Vec<String>),
33
34 InvalidOperationArgs(Node, String),
35 InputNodeDoesNotExist(Node, usize, String),
36 MissingInput(Node, usize, usize),
37 MissingAttribute(Node, String, AttributeType, Vec<String>),
38 UnexpectedAttributeType(Node, String, AttributeType, AttributeType),
39 InvalidAttributeBool(Node, String, i64),
40
41 UnsupportedOperation(Node),
42
43 UnsupportedNonFloatOutput(String),
44 UnsupportedType(String, DataType),
45
46 UnsupportedNdConvolution(Node, usize),
47
48 UnsupportedPartialShape(Node, String),
49 UnsupportedShape(Node, String),
50
51 UnsupportedElementWiseCombination(Node, String, String),
52
53 ExpectedNonBatchValue(String),
55 ExpectedSizeError(AsShapeError),
56
57 InvalidAutoPadValue(Node, String),
58 DifferentPadding(Node, usize, i64, i64),
59 NonDividingPooling(Node, Shape, Vec<i64>),
60}
61
62impl From<AsShapeError> for OnnxError {
63 fn from(e: AsShapeError) -> Self {
64 OnnxError::ExpectedSizeError(e)
65 }
66}
67
68pub trait ToOnnxLoadResult {
69 type T;
70 fn to_onnx_result(self, path: impl AsRef<Path>) -> OnnxResult<Self::T>;
71}
72
73impl<T> ToOnnxLoadResult for Result<T, io::Error> {
74 type T = T;
75 fn to_onnx_result(self, path: impl AsRef<Path>) -> OnnxResult<T> {
76 self.map_err(|e| OnnxError::IO(path.as_ref().to_owned(), e))
77 }
78}
79
80pub trait UnwrapProto {
81 type T;
82 fn unwrap_proto(self, field: &'static str) -> OnnxResult<Self::T>;
83}
84
85impl<T> UnwrapProto for Option<T> {
86 type T = T;
87 fn unwrap_proto(self, field: &'static str) -> OnnxResult<T> {
88 self.ok_or(OnnxError::MissingProtoField(field))
89 }
90}
91
92impl<S: AsRef<str>> Node<S> {
93 pub fn to_owned(self) -> Node<String> {
94 Node {
95 name: self.name.as_ref().to_owned(),
96 op_type: self.op_type.as_ref().to_owned(),
97 }
98 }
99}
100
101impl Display for OnnxError {
102 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103 write!(f, "{:?}", self)
104 }
105}
106
107impl Error for OnnxError {}