tract_nnef/ops/core/
downsample.rs1use crate::internal::*;
2use crate::ser::*;
3use tract_core::ops::Downsample;
4
5pub fn register(registry: &mut Registry) {
6 registry.register_dumper(ser_downsample);
7 registry.register_primitive(
8 "tract_core_downsample",
9 &[
10 TypeName::Scalar.tensor().named("input"),
11 TypeName::Integer.named("axis"),
12 TypeName::Integer.named("stride"),
13 TypeName::Integer.named("modulo").default(0),
14 ],
15 &[("output", TypeName::Scalar.tensor())],
16 de_downsample,
17 );
18}
19
20fn ser_downsample(
21 ast: &mut IntoAst,
22 node: &TypedNode,
23 op: &Downsample,
24) -> TractResult<Option<Arc<RValue>>> {
25 let wire = ast.mapping[&node.inputs[0]].clone();
26 Ok(Some(invocation(
27 "tract_core_downsample",
28 &[wire],
29 &[
30 ("axis", numeric(op.axis)),
31 ("stride", numeric(op.stride)),
32 ("modulo", numeric(op.modulo)),
33 ],
34 )))
35}
36
37fn de_downsample(
38 builder: &mut ModelBuilder,
39 invocation: &ResolvedInvocation,
40) -> TractResult<Value> {
41 let wire = invocation.named_arg_as(builder, "input")?;
42 let axis = invocation.named_arg_as(builder, "axis")?;
43 let stride = invocation.named_arg_as::<i64>(builder, "stride")? as isize;
44 let modulo = invocation.named_arg_as(builder, "modulo")?;
45 builder.wire(Downsample { axis, stride, modulo }, &[wire])
46}