1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use onnx_pb::{attribute_proto::AttributeType, AttributeProto, GraphProto, TensorProto};
use crate::nodes::Axes;
pub enum Attribute {
Float(f32),
Floats(Vec<f32>),
Int(i64),
Ints(Vec<i64>),
Bytes(Vec<u8>),
String(String),
Strings(Vec<String>),
Tensor(TensorProto),
Tensors(Vec<TensorProto>),
Graph(GraphProto),
Graphs(Vec<GraphProto>),
}
macro_rules! attr_converter {
( $a:ident, $b:ty ) => {
impl From<$b> for Attribute {
fn from(v: $b) -> Self {
Attribute::$a(v)
}
}
};
}
impl From<bool> for Attribute {
fn from(v: bool) -> Self {
Attribute::Int(if v { 1 } else { 0 })
}
}
impl From<Axes> for Attribute {
fn from(axes: Axes) -> Self {
Attribute::Ints(axes.0)
}
}
attr_converter!(Float, f32);
attr_converter!(Floats, Vec<f32>);
attr_converter!(Int, i64);
attr_converter!(Bytes, Vec<u8>);
attr_converter!(String, String);
attr_converter!(Strings, Vec<String>);
attr_converter!(Ints, Vec<i64>);
attr_converter!(Tensor, TensorProto);
attr_converter!(Tensors, Vec<TensorProto>);
attr_converter!(Graph, GraphProto);
attr_converter!(Graphs, Vec<GraphProto>);
impl From<&str> for Attribute {
fn from(v: &str) -> Self {
v.to_owned().into()
}
}
impl From<Vec<&str>> for Attribute {
fn from(v: Vec<&str>) -> Self {
v.into_iter()
.map(|s| s.to_owned())
.collect::<Vec<_>>()
.into()
}
}
pub(crate) fn make_attribute<S: Into<String>, A: Into<Attribute>>(
name: S,
attribute: A,
) -> AttributeProto {
let mut attr_proto = AttributeProto {
name: name.into(),
..AttributeProto::default()
};
match attribute.into() {
Attribute::Float(val) => {
attr_proto.f = val;
attr_proto.r#type = AttributeType::Float as i32;
}
Attribute::Floats(vals) => {
attr_proto.floats = vals;
attr_proto.r#type = AttributeType::Floats as i32;
}
Attribute::Int(val) => {
attr_proto.i = val;
attr_proto.r#type = AttributeType::Int as i32;
}
Attribute::Ints(vals) => {
attr_proto.ints = vals;
attr_proto.r#type = AttributeType::Ints as i32;
}
Attribute::Bytes(val) => {
attr_proto.s = val;
attr_proto.r#type = AttributeType::String as i32;
}
Attribute::String(val) => {
attr_proto.s = val.into();
attr_proto.r#type = AttributeType::String as i32;
}
Attribute::Strings(vals) => {
attr_proto.strings = vals.into_iter().map(Into::into).collect();
attr_proto.r#type = AttributeType::Strings as i32;
}
Attribute::Graph(val) => {
attr_proto.g = Some(val);
attr_proto.r#type = AttributeType::Graph as i32;
}
Attribute::Graphs(vals) => {
attr_proto.graphs = vals;
attr_proto.r#type = AttributeType::Graphs as i32;
}
Attribute::Tensor(val) => {
attr_proto.t = Some(val);
attr_proto.r#type = AttributeType::Tensor as i32;
}
Attribute::Tensors(vals) => {
attr_proto.tensors = vals;
attr_proto.r#type = AttributeType::Tensors as i32;
}
};
attr_proto
}