1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use tract_hir::ops::element_wise::ElementWiseOp;
use tract_nnef::internal::*;

pub fn tract_nnef_onnx_registry() -> Registry {
    let mut registry: Registry = Registry::new("tract_onnx");
    macro_rules! dumper {
        ($op:ty, $path: path) => {
            registry.register_dumper(TypeId::of::<$op>(), |ast, node| {
                $path(ast, node, node.op().downcast_ref::<$op>().unwrap())
            })
        };
    };
    dumper!(crate::ops::nn::lrn::Lrn, lrn_dump);
    registry.register_primitive("tract_onnx_lrn", &lrn_parameters(), lrn_load);
    registry.register_element_wise(
        "tract_onnx_isinf",
        TypeId::of::<crate::ops::math::IsInf>(),
        isinf_dump,
        isinf_parameters(),
        isinf_load,
    );
    registry.register_unit_element_wise("tract_onnx_erf", &crate::ops::math::Erf {});
    registry.register_unit_element_wise("tract_onnx_is_nan", &crate::ops::math::IsNan {});
    registry
}

pub fn lrn_parameters() -> Vec<Parameter> {
    vec![
        TypeName::Scalar.tensor().named("input"),
        TypeName::Scalar.named("alpha").default(0.0001),
        TypeName::Scalar.named("beta").default(0.75),
        TypeName::Scalar.named("bias").default(1.0),
        TypeName::Integer.named("size"),
    ]
}

pub fn lrn_dump(
    ast: &mut IntoAst,
    node: &TypedNode,
    lrn: &crate::ops::nn::lrn::Lrn,
) -> TractResult<Option<Arc<RValue>>> {
    let input = ast.mapping[&node.inputs[0]].clone();
    Ok(Some(invocation(
        "tract_onnx_lrn",
        &[input],
        &[
            ("alpha", numeric(lrn.alpha)),
            ("beta", numeric(lrn.beta)),
            ("bias", numeric(lrn.bias)),
            ("size", numeric(lrn.size)),
        ],
    )))
}

pub fn lrn_load(
    builder: &mut ModelBuilder,
    invocation: &ResolvedInvocation,
) -> TractResult<TVec<OutletId>> {
    let input = invocation.named_arg_as(builder, "input")?;
    let alpha = invocation.named_arg_as(builder, "alpha")?;
    let beta = invocation.named_arg_as(builder, "beta")?;
    let bias = invocation.named_arg_as(builder, "bias")?;
    let size = invocation.named_arg_as(builder, "size")?;
    let op = crate::ops::nn::lrn::Lrn { alpha, beta, bias, size };
    builder.wire(op, &[input])
}

pub fn isinf_parameters() -> Vec<Parameter> {
    vec![
        TypeName::Scalar.tensor().named("input"),
        TypeName::Logical.named("detect_positive").default(true),
        TypeName::Logical.named("detect_negative").default(true),
    ]
}

pub fn isinf_dump(ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> {
    let op =
        node.op_as::<ElementWiseOp>().unwrap().0.downcast_ref::<crate::ops::math::IsInf>().unwrap();
    let input = ast.mapping[&node.inputs[0]].clone();
    Ok(Some(invocation(
        "tract_onnx_isinf",
        &[input],
        &[
            ("detect_negative", logical(op.detect_negative)),
            ("detect_positive", logical(op.detect_positive)),
        ],
    )))
}

pub fn isinf_load(
    builder: &mut ModelBuilder,
    invocation: &ResolvedInvocation,
) -> TractResult<TVec<OutletId>> {
    let input = invocation.named_arg_as(builder, "input")?;
    let detect_positive = invocation.named_arg_as(builder, "detect_positive")?;
    let detect_negative = invocation.named_arg_as(builder, "detect_negative")?;
    let op = crate::ops::math::IsInf { detect_negative, detect_positive };
    builder.wire(ElementWiseOp(Box::new(op)), &[input])
}