Skip to main content

tract_nnef/ops/core/
downsample.rs

1use crate::internal::*;
2use crate::ser::*;
3use tract_core::ops::Downsample;
4
5pub fn register(registry: &mut Registry) {
6    registry.register_dumper(ser_downsample);
7    registry.register_primitive(
8        "tract_core_downsample",
9        &[
10            TypeName::Scalar.tensor().named("input"),
11            TypeName::Integer.named("axis"),
12            TypeName::Integer.named("stride"),
13            TypeName::Integer.named("modulo").default(0),
14        ],
15        &[("output", TypeName::Scalar.tensor())],
16        de_downsample,
17    );
18}
19
20fn ser_downsample(
21    ast: &mut IntoAst,
22    node: &TypedNode,
23    op: &Downsample,
24) -> TractResult<Option<Arc<RValue>>> {
25    let wire = ast.mapping[&node.inputs[0]].clone();
26    Ok(Some(invocation(
27        "tract_core_downsample",
28        &[wire],
29        &[
30            ("axis", numeric(op.axis)),
31            ("stride", numeric(op.stride)),
32            ("modulo", numeric(op.modulo)),
33        ],
34    )))
35}
36
37fn de_downsample(
38    builder: &mut ModelBuilder,
39    invocation: &ResolvedInvocation,
40) -> TractResult<Value> {
41    let wire = invocation.named_arg_as(builder, "input")?;
42    let axis = invocation.named_arg_as(builder, "axis")?;
43    let stride = invocation.named_arg_as::<i64>(builder, "stride")? as isize;
44    let modulo = invocation.named_arg_as(builder, "modulo")?;
45    builder.wire(Downsample { axis, stride, modulo }, &[wire])
46}