1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
use onnx_pb::ModelProto;
use prost::Message;
use crate::shape_inference_proto;
#[derive(Debug)]
pub enum Error {
Decode(prost::DecodeError),
Encode(prost::EncodeError),
}
pub fn shape_inference(model: &ModelProto) -> Result<ModelProto, Error> {
let mut body = Vec::new();
model.encode(&mut body).map_err(|e| Error::Encode(e))?;
let inferred = shape_inference_proto(body.as_slice());
ModelProto::decode(inferred.as_slice()).map_err(|e| Error::Decode(e))
}