1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
use tract_hir::ops::element_wise::ElementWiseOp; use tract_nnef::internal::*; pub fn tract_nnef_onnx_registry() -> Registry { let mut registry: Registry = Registry::new("tract_onnx"); macro_rules! dumper { ($op:ty, $path: path) => { registry.register_dumper(TypeId::of::<$op>(), |ast, node| { $path(ast, node, node.op().downcast_ref::<$op>().unwrap()) }) }; }; dumper!(crate::ops::nn::lrn::Lrn, lrn_dump); registry.register_primitive("tract_onnx_lrn", &lrn_parameters(), lrn_load); registry.register_element_wise( "tract_onnx_isinf", TypeId::of::<crate::ops::math::IsInf>(), isinf_dump, isinf_parameters(), isinf_load, ); registry.register_unit_element_wise("tract_onnx_erf", &crate::ops::math::Erf {}); registry.register_unit_element_wise("tract_onnx_is_nan", &crate::ops::math::IsNan {}); registry } pub fn lrn_parameters() -> Vec<Parameter> { vec![ TypeName::Scalar.tensor().named("input"), TypeName::Scalar.named("alpha").default(0.0001), TypeName::Scalar.named("beta").default(0.75), TypeName::Scalar.named("bias").default(1.0), TypeName::Integer.named("size"), ] } pub fn lrn_dump( ast: &mut IntoAst, node: &TypedNode, lrn: &crate::ops::nn::lrn::Lrn, ) -> TractResult<Option<Arc<RValue>>> { let input = ast.mapping[&node.inputs[0]].clone(); Ok(Some(invocation( "tract_onnx_lrn", &[input], &[ ("alpha", numeric(lrn.alpha)), ("beta", numeric(lrn.beta)), ("bias", numeric(lrn.bias)), ("size", numeric(lrn.size)), ], ))) } pub fn lrn_load( builder: &mut ModelBuilder, invocation: &ResolvedInvocation, ) -> TractResult<TVec<OutletId>> { let input = invocation.named_arg_as(builder, "input")?; let alpha = invocation.named_arg_as(builder, "alpha")?; let beta = invocation.named_arg_as(builder, "beta")?; let bias = invocation.named_arg_as(builder, "bias")?; let size = invocation.named_arg_as(builder, "size")?; let op = crate::ops::nn::lrn::Lrn { alpha, beta, bias, size }; builder.wire(op, &[input]) } pub fn isinf_parameters() -> Vec<Parameter> { vec![ TypeName::Scalar.tensor().named("input"), TypeName::Logical.named("detect_positive").default(true), TypeName::Logical.named("detect_negative").default(true), ] } pub fn isinf_dump(ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> { let op = node.op_as::<ElementWiseOp>().unwrap().0.downcast_ref::<crate::ops::math::IsInf>().unwrap(); let input = ast.mapping[&node.inputs[0]].clone(); Ok(Some(invocation( "tract_onnx_isinf", &[input], &[ ("detect_negative", logical(op.detect_negative)), ("detect_positive", logical(op.detect_positive)), ], ))) } pub fn isinf_load( builder: &mut ModelBuilder, invocation: &ResolvedInvocation, ) -> TractResult<TVec<OutletId>> { let input = invocation.named_arg_as(builder, "input")?; let detect_positive = invocation.named_arg_as(builder, "detect_positive")?; let detect_negative = invocation.named_arg_as(builder, "detect_negative")?; let op = crate::ops::math::IsInf { detect_negative, detect_positive }; builder.wire(ElementWiseOp(Box::new(op)), &[input]) }