onnx_helpers/builder/
value.rs

1//! Value info builder.
2
3use onnx_pb::{
4    tensor_proto::DataType,
5    tensor_shape_proto::Dimension,
6    type_proto::{self, Tensor},
7    TensorShapeProto, TypeProto, ValueInfoProto,
8};
9
10use crate::{
11    builder::{Bag, Marker, Node},
12    nodes,
13};
14
15/// Value info builder.
16#[derive(Default, Clone)]
17pub struct Value {
18    name: String,
19    elem_type: DataType,
20    shape: Vec<Dimension>,
21    doc_string: Option<String>,
22    pub(crate) bag: Option<Bag>,
23    pub(crate) marker: Option<Marker>,
24}
25
26impl Value {
27    /// Creates a new builder.
28    #[inline]
29    pub fn new<S: Into<String>>(name: S) -> Self {
30        Value {
31            name: name.into(),
32            ..Value::default()
33        }
34    }
35
36    /// Sets value name.
37    #[inline]
38    pub fn name<S: Into<String>>(mut self, name: S) -> Self {
39        self.name = name.into();
40        self
41    }
42
43    /// Sets value element type.
44    #[inline]
45    pub fn typed<T: Into<DataType>>(mut self, elem_type: T) -> Self {
46        self.elem_type = elem_type.into();
47        self
48    }
49
50    /// Sets value shape.
51    #[inline]
52    pub fn shape<D: Into<Dimension>>(mut self, shape: Vec<D>) -> Self {
53        self.shape = shape.into_iter().map(|dim| dim.into()).collect();
54        self
55    }
56
57    /// Inserts value dimension.
58    #[inline]
59    pub fn dim<D: Into<Dimension>>(mut self, dim: D) -> Self {
60        self.shape.push(dim.into());
61        self
62    }
63
64    /// Creates node for input.
65    /// Requires builder to be bagged.
66    #[inline]
67    pub fn node(self) -> nodes::Node {
68        let mut node = Node::named(self.name.clone()).build();
69        node.bag = self.bag.clone();
70        let marker = self.marker.as_ref().unwrap().clone();
71        let mut bag: Bag = self.bag.as_ref().unwrap().clone();
72        let value = self.build();
73        bag.value(value, marker);
74        node
75    }
76
77    /// Builds the value info.
78    #[inline]
79    pub fn build(self) -> ValueInfoProto {
80        ValueInfoProto {
81            name: self.name,
82            r#type: Some(TypeProto {
83                denotation: String::default(),
84                value: Some(type_proto::Value::TensorType(Tensor {
85                    shape: Some(TensorShapeProto { dim: self.shape }),
86                    elem_type: self.elem_type as i32,
87                })),
88            }),
89            doc_string: self.doc_string.unwrap_or_default(),
90        }
91    }
92}
93
94impl Into<ValueInfoProto> for Value {
95    fn into(self) -> ValueInfoProto {
96        self.build()
97    }
98}