onnx_shape_inference/
proto.rs

1use onnx_pb::ModelProto;
2use prost::Message;
3
4use crate::shape_inference_proto;
5
6/// Error type.
7#[derive(Debug)]
8pub enum Error {
9    /// Decode error.
10    Decode(prost::DecodeError),
11
12    /// Encode error.
13    Encode(prost::EncodeError),
14}
15
16/// Infers model shapes.
17pub fn shape_inference(model: &ModelProto) -> Result<ModelProto, Error> {
18    let mut body = Vec::new();
19    model.encode(&mut body).map_err(|e| Error::Encode(e))?;
20    let inferred = shape_inference_proto(body.as_slice());
21    ModelProto::decode(inferred.as_slice()).map_err(|e| Error::Decode(e))
22}