Skip to main content

tract_onnx/ops/array/
mod.rs

1mod compress;
2mod nonzero;
3mod one_hot;
4mod pad;
5mod shape;
6mod slice;
7mod split;
8mod squeeze;
9mod topk;
10mod trilu;
11mod unsqueeze;
12
13use tract_hir::internal::*;
14use tract_hir::ops::array;
15
16use crate::model::{OnnxOpRegister, ParsingContext};
17use crate::pb::*;
18
19pub fn register_all_ops(reg: &mut OnnxOpRegister) {
20    reg.insert("ArrayFeatureExtractor", array_feature_extractor);
21    reg.insert("Compress", compress::compress);
22    reg.insert("Concat", concat);
23    reg.insert("ConstantLike", constant_like);
24    reg.insert("ConstantOfShape", constant_of_shape);
25    reg.insert("Expand", |_, _| Ok((expand(array::MultiBroadcastTo), vec![])));
26    reg.insert("EyeLike", eye_like);
27    reg.insert("Flatten", flatten);
28    reg.insert("Gather", gather);
29    reg.insert("GatherElements", gather_elements);
30    reg.insert("GatherND", gather_nd);
31    reg.insert("NonZero", nonzero::non_zero);
32    reg.insert("OneHot", one_hot::one_hot);
33    reg.insert("Range", |_, _| Ok((expand(array::Range), vec![])));
34    reg.insert("Pad", pad::pad);
35    reg.insert("Reshape", |_, _| Ok((expand(array::Reshape::default()), vec![])));
36    reg.insert("Scatter", scatter_elements);
37    reg.insert("ScatterElements", scatter_elements);
38    reg.insert("ScatterND", |_, _| Ok((Box::new(array::ScatterNd), vec![])));
39    reg.insert("Shape", shape::shape);
40    reg.insert("Size", |_, _| Ok((expand(array::Size::new(DatumType::TDim)), vec![])));
41    reg.insert("Slice", slice::slice);
42    reg.insert("Split", split::split);
43    reg.insert("Squeeze", squeeze::squeeze);
44    reg.insert("Tile", |_, _| Ok((expand(array::Tile), vec![])));
45    reg.insert("TopK", topk::topk);
46    reg.insert("Transpose", transpose);
47    reg.insert("Trilu", trilu::trilu);
48    reg.insert("Unsqueeze", unsqueeze::unsqueeze);
49}
50
51pub fn array_feature_extractor(
52    _ctx: &ParsingContext,
53    _node: &NodeProto,
54) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
55    Ok((expand(array::ArrayFeatureExtractor), vec![]))
56}
57
58pub fn concat(
59    _ctx: &ParsingContext,
60    node: &NodeProto,
61) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
62    let axis = node.get_attr("axis")?;
63    Ok((expand(array::Concat::new(axis)), vec![]))
64}
65
66pub fn constant_like(
67    _ctx: &ParsingContext,
68    node: &NodeProto,
69) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
70    let value = node.get_attr_opt("value")?.unwrap_or(0.);
71    if node.input.len() == 0 {
72        let dt = node.get_attr_opt("dtype")?.unwrap_or(DatumType::F32);
73        let shape: Vec<usize> = node.get_attr_vec("shape")?;
74        let tensor =
75            tensor0(value).cast_to_dt(dt)?.broadcast_scalar_to_shape(&shape)?.into_arc_tensor();
76        Ok((Box::new(tract_hir::ops::konst::Const::new(tensor)?), vec![]))
77    } else {
78        Ok((Box::new(array::ConstantLike::new(value)), vec![]))
79    }
80}
81
82pub fn constant_of_shape(
83    ctx: &ParsingContext,
84    node: &NodeProto,
85) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
86    let mut value = match node.get_attr_opt("value")? {
87        Some(val) => ctx.load_tensor(val)?.into_arc_tensor(),
88        None => rctensor0(0.0),
89    };
90    if value.rank() > 0 {
91        if value.len() != 1 {
92            bail!("Expected scalar (or vector of length 1), got {:?}", value);
93        }
94        value = value.nth(0)?.into_arc_tensor();
95    }
96    Ok((expand(array::ConstantOfShape::new(value)), vec![]))
97}
98
99pub fn eye_like(
100    _ctx: &ParsingContext,
101    node: &NodeProto,
102) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
103    let dt = node.get_attr_opt("dtype")?;
104    let k = node.get_attr_opt("k")?.unwrap_or(0);
105    Ok((Box::new(array::EyeLike::new(dt, k)), vec![]))
106}
107
108pub fn flatten(
109    _ctx: &ParsingContext,
110    node: &NodeProto,
111) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
112    let axis: i64 = node.get_attr_opt("axis")?.unwrap_or(1);
113    Ok((expand(array::Flatten::new(axis)), vec![]))
114}
115
116pub fn gather(
117    _ctx: &ParsingContext,
118    node: &NodeProto,
119) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
120    let axis = node.get_attr_opt("axis")?.unwrap_or(0);
121    Ok((expand(array::Gather::new(axis)), vec![]))
122}
123
124pub fn gather_elements(
125    _ctx: &ParsingContext,
126    node: &NodeProto,
127) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
128    let axis = node.get_attr_opt("axis")?.unwrap_or(0);
129    Ok((expand(array::GatherElements::new(axis)), vec![]))
130}
131
132pub fn gather_nd(
133    _ctx: &ParsingContext,
134    node: &NodeProto,
135) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
136    let batch_dims = node.get_attr_opt("batch_dims")?.unwrap_or(0);
137    Ok((Box::new(array::GatherNd::new(batch_dims)), vec![]))
138}
139
140pub fn scatter_elements(
141    _ctx: &ParsingContext,
142    node: &NodeProto,
143) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
144    let axis = node.get_attr_opt("axis")?.unwrap_or(0);
145    Ok((expand(array::ScatterElements::new(axis)), vec![]))
146}
147
148pub fn transpose(
149    _ctx: &ParsingContext,
150    node: &NodeProto,
151) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
152    let perm = node.get_attr_opt_vec("perm")?;
153    Ok((expand(array::PermuteAxes::new(perm.map(|t| t.into()))), vec![]))
154}