onnx_shape_inference/
proto.rs1use onnx_pb::ModelProto;
2use prost::Message;
3
4use crate::shape_inference_proto;
5
6#[derive(Debug)]
8pub enum Error {
9 Decode(prost::DecodeError),
11
12 Encode(prost::EncodeError),
14}
15
16pub fn shape_inference(model: &ModelProto) -> Result<ModelProto, Error> {
18 let mut body = Vec::new();
19 model.encode(&mut body).map_err(|e| Error::Encode(e))?;
20 let inferred = shape_inference_proto(body.as_slice());
21 ModelProto::decode(inferred.as_slice()).map_err(|e| Error::Decode(e))
22}