1use crate::model::OnnxOpRegister;
2use crate::model::ParsingContext;
3use crate::pb::*;
4use tract_hir::internal::*;
5use tract_hir::ops;
6use tract_hir::ops::binary::Nary;
7
8mod clip;
9mod gemm;
10mod mat_mul_integer;
11mod pow;
12mod rem;
13
14pub fn register_all_ops(reg: &mut OnnxOpRegister) {
15 reg.insert("Add", |_, _| Ok((ops::math::Add.into_hir(), vec![])));
16 reg.insert("Sub", |_, _| Ok((ops::math::Sub.into_hir(), vec![])));
17 reg.insert("Mul", |_, _| Ok((ops::math::Mul.into_hir(), vec![])));
18 reg.insert("Div", |_, _| Ok((ops::math::Div.into_hir(), vec![])));
19 reg.insert("Mod", rem::rem);
20
21 reg.insert("BitShift", bitshift);
22 reg.insert("BitwiseAnd", |_, _| Ok((ops::logic::BitAnd.into_hir(), vec![])));
23 reg.insert("BitwiseOr", |_, _| Ok((ops::logic::BitOr.into_hir(), vec![])));
24 reg.insert("BitwiseXor", |_, _| Ok((ops::logic::BitXor.into_hir(), vec![])));
25 reg.insert("BitwiseNot", |_, _| Ok((ops::logic::bitnot().into_hir(), vec![])));
26
27 reg.insert("Sum", |_, _| Ok((Box::new(Nary(Box::new(ops::math::Add), false)), vec![])));
28 reg.insert("Max", |_, _| Ok((Box::new(Nary(Box::new(ops::math::Max), false)), vec![])));
29 reg.insert("Min", |_, _| Ok((Box::new(Nary(Box::new(ops::math::Min), false)), vec![])));
30 reg.insert("Mean", |_, _| Ok((Box::new(Nary(Box::new(ops::math::Add), true)), vec![])));
31
32 reg.insert("Abs", |_, _| Ok((ops::math::abs().into_hir(), vec![])));
33 reg.insert("Ceil", |_, _| Ok((ops::math::ceil().into_hir(), vec![])));
34 reg.insert("Floor", |_, _| Ok((ops::math::floor().into_hir(), vec![])));
35 reg.insert("Round", |_, _| Ok((ops::math::round_half_to_even().into_hir(), vec![])));
36 reg.insert("Clip", clip::clip);
37
38 reg.insert("Cos", |_, _| Ok((ops::math::cos().into_hir(), vec![])));
39 reg.insert("Sin", |_, _| Ok((ops::math::sin().into_hir(), vec![])));
40 reg.insert("Tan", |_, _| Ok((ops::math::tan().into_hir(), vec![])));
41 reg.insert("Acos", |_, _| Ok((ops::math::acos().into_hir(), vec![])));
42 reg.insert("Asin", |_, _| Ok((ops::math::asin().into_hir(), vec![])));
43 reg.insert("Atan", |_, _| Ok((ops::math::atan().into_hir(), vec![])));
44
45 reg.insert("Cosh", |_, _| Ok((ops::math::cosh().into_hir(), vec![])));
46 reg.insert("Sinh", |_, _| Ok((ops::math::sinh().into_hir(), vec![])));
47 reg.insert("Tanh", |_, _| Ok((ops::math::tanh().into_hir(), vec![])));
48 reg.insert("Acosh", |_, _| Ok((ops::math::acosh().into_hir(), vec![])));
49 reg.insert("Asinh", |_, _| Ok((ops::math::asinh().into_hir(), vec![])));
50 reg.insert("Atanh", |_, _| Ok((ops::math::atanh().into_hir(), vec![])));
51
52 reg.insert("Erf", |_, _| Ok((ops::math::erf().into_hir(), vec![])));
53 reg.insert("Exp", |_, _| Ok((ops::math::exp().into_hir(), vec![])));
54 reg.insert("Log", |_, _| Ok((ops::math::ln().into_hir(), vec![])));
55 reg.insert("Sqrt", |_, _| Ok((ops::math::sqrt().into_hir(), vec![])));
56 reg.insert("Rsqrt", |_, _| Ok((ops::math::rsqrt().into_hir(), vec![])));
57
58 reg.insert("IsNaN", |_, _| Ok((tract_onnx_opl::is_nan::is_nan().into_hir(), vec![])));
59 reg.insert("IsInf", isinf);
60 reg.insert("Neg", |_, _| Ok((ops::math::neg().into_hir(), vec![])));
61 reg.insert("Sign", |_, _| Ok((ops::math::sign().into_hir(), vec![])));
62 reg.insert("Reciprocal", |_, _| Ok((ops::math::recip().into_hir(), vec![])));
63
64 reg.insert("Pow", pow::pow);
65
66 reg.insert("MatMul", |_, _| Ok((expand(ops::matmul::MatMulInference::default()), vec![])));
67 reg.insert("MatMulInteger", mat_mul_integer::mat_mul_integer);
68 reg.insert("QLinearMatMul", mat_mul_integer::q_linear_mat_mul);
69 reg.insert("Gemm", gemm::gemm);
70}
71
72fn isinf(
73 _ctx: &ParsingContext,
74 node: &NodeProto,
75) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
76 let detect_positive = node.get_attr_opt("detect_positive")?.unwrap_or(1) != 0;
77 let detect_negative = node.get_attr_opt("detect_negative")?.unwrap_or(1) != 0;
78 Ok((tract_onnx_opl::is_inf::is_inf(detect_positive, detect_negative).into_hir(), vec![]))
79}
80
81fn bitshift(
82 _ctx: &ParsingContext,
83 node: &NodeProto,
84) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
85 let op: Box<dyn InferenceOp> = if node.get_attr_opt("direction")?.unwrap_or("LEFT") == "RIGHT" {
86 ops::math::ShiftRight.into_hir()
87 } else {
88 ops::math::ShiftLeft.into_hir()
89 };
90 Ok((op, vec![]))
91}