tract_nnef/ops/core/
submodel.rs1use tract_core::ops::submodel::SubmodelOp;
2
3use crate::internal::*;
4
5pub fn register(registry: &mut Registry) {
6 registry.register_dumper(ser_submodel);
7 registry.register_primitive(
8 "tract_core_submodel",
9 &[TypeName::Scalar.tensor().array().named("input"), TypeName::String.named("label")],
10 &[("outputs", TypeName::Any.tensor().array())],
11 de_submodel,
12 );
13}
14
15fn de_submodel(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
16 let wires: TVec<OutletId> = invocation.named_arg_as(builder, "input")?;
17 let label: String = invocation.named_arg_as(builder, "label")?;
18 let model: TypedModel = builder
19 .proto_model
20 .resources
21 .get(label.as_str())
22 .with_context(|| anyhow!("{} not found in model builder loaded resources", label.as_str()))?
23 .clone()
24 .downcast_arc::<TypedModelResource>()
25 .map_err(|_| anyhow!("Error while downcasting typed model resource"))
26 .map(|r| r.0.clone())
27 .with_context(|| anyhow!("Error while loading typed model resource"))?;
28
29 let op: Box<dyn TypedOp> = Box::new(SubmodelOp::new(Box::new(model), &label)?);
30
31 builder.model.wire_node(label, op, &wires).map(Value::from)
32}
33
34fn ser_submodel(
35 ast: &mut IntoAst,
36 node: &TypedNode,
37 op: &SubmodelOp,
38) -> TractResult<Option<Arc<RValue>>> {
39 let input = tvec![ast.mapping[&node.inputs[0]].clone()];
40 let invoke = invocation("tract_core_submodel", &input, &[("label", string(op.label()))]);
41 ast.resources.insert(op.label().to_string(), Arc::new(TypedModelResource(op.model().clone())));
42 Ok(Some(invoke))
43}