1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
//! ONNX node attribute helpers.

use onnx_pb::{attribute_proto::AttributeType, AttributeProto, GraphProto, TensorProto};

use crate::nodes::Axes;

/// Attribute constructor.
pub enum Attribute {
    Float(f32),
    Floats(Vec<f32>),
    Int(i64),
    Ints(Vec<i64>),
    Bytes(Vec<u8>),
    String(String),
    Strings(Vec<String>),
    Tensor(TensorProto),
    Tensors(Vec<TensorProto>),
    Graph(GraphProto),
    Graphs(Vec<GraphProto>),
}

macro_rules! attr_converter {
    ( $a:ident, $b:ty ) => {
        impl From<$b> for Attribute {
            fn from(v: $b) -> Self {
                Attribute::$a(v)
            }
        }
    };
}

impl From<bool> for Attribute {
    fn from(v: bool) -> Self {
        Attribute::Int(if v { 1 } else { 0 })
    }
}

impl From<Axes> for Attribute {
    fn from(axes: Axes) -> Self {
        Attribute::Ints(axes.0)
    }
}

attr_converter!(Float, f32);
attr_converter!(Floats, Vec<f32>);
attr_converter!(Int, i64);
attr_converter!(Bytes, Vec<u8>);
attr_converter!(String, String);
attr_converter!(Strings, Vec<String>);
attr_converter!(Ints, Vec<i64>);
attr_converter!(Tensor, TensorProto);
attr_converter!(Tensors, Vec<TensorProto>);
attr_converter!(Graph, GraphProto);
attr_converter!(Graphs, Vec<GraphProto>);

impl From<&str> for Attribute {
    fn from(v: &str) -> Self {
        v.to_owned().into()
    }
}

impl From<Vec<&str>> for Attribute {
    fn from(v: Vec<&str>) -> Self {
        v.into_iter()
            .map(|s| s.to_owned())
            .collect::<Vec<_>>()
            .into()
    }
}

/// Creates a new attribute struct.
pub(crate) fn make_attribute<S: Into<String>, A: Into<Attribute>>(
    name: S,
    attribute: A,
) -> AttributeProto {
    let mut attr_proto = AttributeProto {
        name: name.into(),
        ..AttributeProto::default()
    };
    match attribute.into() {
        Attribute::Float(val) => {
            attr_proto.f = val;
            attr_proto.r#type = AttributeType::Float as i32;
        }
        Attribute::Floats(vals) => {
            attr_proto.floats = vals;
            attr_proto.r#type = AttributeType::Floats as i32;
        }
        Attribute::Int(val) => {
            attr_proto.i = val;
            attr_proto.r#type = AttributeType::Int as i32;
        }
        Attribute::Ints(vals) => {
            attr_proto.ints = vals;
            attr_proto.r#type = AttributeType::Ints as i32;
        }
        Attribute::Bytes(val) => {
            attr_proto.s = val;
            attr_proto.r#type = AttributeType::String as i32;
        }
        Attribute::String(val) => {
            attr_proto.s = val.into();
            attr_proto.r#type = AttributeType::String as i32;
        }
        Attribute::Strings(vals) => {
            attr_proto.strings = vals.into_iter().map(Into::into).collect();
            attr_proto.r#type = AttributeType::Strings as i32;
        }
        Attribute::Graph(val) => {
            attr_proto.g = Some(val);
            attr_proto.r#type = AttributeType::Graph as i32;
        }
        Attribute::Graphs(vals) => {
            attr_proto.graphs = vals;
            attr_proto.r#type = AttributeType::Graphs as i32;
        }
        Attribute::Tensor(val) => {
            attr_proto.t = Some(val);
            attr_proto.r#type = AttributeType::Tensor as i32;
        }
        Attribute::Tensors(vals) => {
            attr_proto.tensors = vals;
            attr_proto.r#type = AttributeType::Tensors as i32;
        }
    };
    attr_proto
}