1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
use crate::internal::*; use crate::ser::*; use tract_core::ops; pub fn register(registry: &mut Registry) { registry.register_dumper(TypeId::of::<ops::Downsample>(), ser_downsample); registry.register_primitive( "tract_core_downsample", &[ TypeName::Scalar.tensor().named("input"), TypeName::Integer.named("axis"), TypeName::Integer.named("stride"), TypeName::Integer.named("modulo").default(0), ], de_downsample, ); } fn ser_downsample(ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> { let op = node.op().downcast_ref::<ops::Downsample>().unwrap(); let wire = ast.mapping[&node.inputs[0]].clone(); Ok(Some(invocation( "tract_core_downsample", &[wire], &[ ("axis", numeric(op.axis)), ("stride", numeric(op.stride)), ("modulo", numeric(op.modulo)), ], ))) } fn de_downsample( builder: &mut ModelBuilder, invocation: &ResolvedInvocation, ) -> TractResult<TVec<OutletId>> { let wire = invocation.named_arg_as(builder, "input")?; let axis = invocation.named_arg_as(builder, "axis")?; let stride = invocation.named_arg_as::<i64>(builder, "stride")? as isize; let modulo = invocation.named_arg_as(builder, "modulo")?; builder.wire(ops::Downsample { axis, stride, modulo }, &[wire]) }