kn_graph/onnx/
result.rs

1use std::error::Error;
2use std::fmt::{Display, Formatter};
3use std::io;
4use std::path::{Path, PathBuf};
5
6use crate::onnx::load::OnnxDimValue;
7use crate::onnx::proto::attribute_proto::AttributeType;
8use crate::onnx::proto::tensor_proto::DataType;
9use crate::onnx::typed_value::AsShapeError;
10use crate::shape::Shape;
11
12pub type OnnxResult<T> = Result<T, OnnxError>;
13
14#[derive(Debug, Copy, Clone)]
15pub struct Node<S = String> {
16    pub name: S,
17    pub op_type: S,
18}
19
20// TODO remove variants that are never constructed
21#[derive(Debug)]
22pub enum OnnxError {
23    IO(PathBuf, io::Error),
24
25    NonNormalExternalDataPath(PathBuf),
26    MustHaveParentPath(PathBuf),
27    FailedToShapeInput(Vec<OnnxDimValue>, String, usize),
28
29    MissingProtoField(&'static str),
30
31    LeftoverInputs(Node, Vec<usize>),
32    LeftoverAttributes(Node, Vec<String>),
33
34    InvalidOperationArgs(Node, String),
35    InputNodeDoesNotExist(Node, usize, String),
36    MissingInput(Node, usize, usize),
37    MissingAttribute(Node, String, AttributeType, Vec<String>),
38    UnexpectedAttributeType(Node, String, AttributeType, AttributeType),
39    InvalidAttributeBool(Node, String, i64),
40
41    UnsupportedOperation(Node),
42
43    UnsupportedNonFloatOutput(String),
44    UnsupportedType(String, DataType),
45
46    UnsupportedNdConvolution(Node, usize),
47
48    UnsupportedPartialShape(Node, String),
49    UnsupportedShape(Node, String),
50
51    UnsupportedElementWiseCombination(Node, String, String),
52
53    //TODO node/operand info
54    ExpectedNonBatchValue(String),
55    ExpectedSizeError(AsShapeError),
56
57    InvalidAutoPadValue(Node, String),
58    DifferentPadding(Node, usize, i64, i64),
59    NonDividingPooling(Node, Shape, Vec<i64>),
60}
61
62impl From<AsShapeError> for OnnxError {
63    fn from(e: AsShapeError) -> Self {
64        OnnxError::ExpectedSizeError(e)
65    }
66}
67
68pub trait ToOnnxLoadResult {
69    type T;
70    fn to_onnx_result(self, path: impl AsRef<Path>) -> OnnxResult<Self::T>;
71}
72
73impl<T> ToOnnxLoadResult for Result<T, io::Error> {
74    type T = T;
75    fn to_onnx_result(self, path: impl AsRef<Path>) -> OnnxResult<T> {
76        self.map_err(|e| OnnxError::IO(path.as_ref().to_owned(), e))
77    }
78}
79
80pub trait UnwrapProto {
81    type T;
82    fn unwrap_proto(self, field: &'static str) -> OnnxResult<Self::T>;
83}
84
85impl<T> UnwrapProto for Option<T> {
86    type T = T;
87    fn unwrap_proto(self, field: &'static str) -> OnnxResult<T> {
88        self.ok_or(OnnxError::MissingProtoField(field))
89    }
90}
91
92impl<S: AsRef<str>> Node<S> {
93    pub fn to_owned(self) -> Node<String> {
94        Node {
95            name: self.name.as_ref().to_owned(),
96            op_type: self.op_type.as_ref().to_owned(),
97        }
98    }
99}
100
101impl Display for OnnxError {
102    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103        write!(f, "{:?}", self)
104    }
105}
106
107impl Error for OnnxError {}