tract_pulse_opl/
pad.rs

1use tract_core::ndarray::*;
2use tract_core::ops::array::PadMode;
3use tract_nnef::internal::*;
4use tract_nnef::ser::tdim;
5use tract_nnef::tract_core::ops::OpStateFreeze;
6
7pub fn register(registry: &mut Registry) {
8    registry.register_primitive(
9        "tract_pulse_pulse_pad",
10        &[
11            TypeName::Scalar.tensor().named("input"),
12            TypeName::Integer.named("axis"),
13            TypeName::Integer.named("before"),
14            TypeName::Integer.named("after"),
15            TypeName::Integer.named("begin_input"),
16            TypeName::Integer.named("end_input"),
17            TypeName::String.named("border"),
18            TypeName::Scalar.named("value"),
19            TypeName::Integer.named("overlap"),
20        ],
21        &[("output", TypeName::Scalar.tensor())],
22        deser,
23    );
24    registry.register_dumper(ser)
25}
26
27fn ser(ast: &mut IntoAst, node: &TypedNode, op: &PulsePad) -> TractResult<Option<Arc<RValue>>> {
28    let wire = ast.mapping[&node.inputs[0]].clone();
29    let dt = ast.model.outlet_fact(node.inputs[0])?.datum_type;
30    let (border, value) = tract_nnef::ops::nnef::ser::pad_mode(&op.mode, dt)?;
31    let mut params = vec![
32        ("axis", numeric(op.axis)),
33        ("before", numeric(op.before)),
34        ("begin_input", numeric(op.begin_input)),
35        ("overlap", numeric(op.overlap)),
36        ("after", tdim(&op.after)),
37        ("end_input", tdim(&op.end_input)),
38    ];
39    params.push(("border", string(border)));
40    if let Some(value) = value {
41        params.push(("value", value));
42    }
43    Ok(Some(invocation("tract_pulse_pulse_pad", &[wire], &params)))
44}
45
46fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
47    let wire = invocation.named_arg_as(builder, "input")?;
48    let axis = invocation.named_arg_as(builder, "axis")?;
49    let before = invocation.named_arg_as(builder, "before")?;
50    let begin_input = invocation.named_arg_as(builder, "begin_input")?;
51    let overlap = invocation.named_arg_as(builder, "overlap")?;
52    let border = invocation.named_arg_as::<String>(builder, "border")?;
53    let value: Tensor = tensor0(invocation.named_arg_as::<f32>(builder, "value")?);
54    let (after, end_input) = builder.allowing_new_symbols(|builder| {
55        TractResult::Ok((
56            invocation.named_arg_as(builder, "after")?,
57            invocation.named_arg_as(builder, "end_input")?,
58        ))
59    })?;
60
61    let mode = tract_nnef::ops::nnef::deser::pad_mode(&border, value)?;
62    let op = PulsePad { axis, before, after, begin_input, end_input, mode, overlap };
63    builder.wire(op, &[wire])
64}
65
66pub(crate) unsafe fn fill_slice_constant<T: Datum + Copy>(
67    data: &mut Tensor,
68    constant: &Tensor,
69    axis: usize,
70    range: std::ops::Range<usize>,
71) {
72    unsafe {
73        let c = constant.to_scalar_unchecked::<T>();
74        data.to_array_view_mut_unchecked::<T>().slice_axis_mut(Axis(axis), range.into()).fill(*c);
75    }
76}
77
78unsafe fn fill_slice_with_frame<T: Datum + Copy>(
79    data: &mut Tensor,
80    axis: usize,
81    valid: &Tensor,
82    range: std::ops::Range<usize>,
83) {
84    unsafe {
85        let mut data = data.to_array_view_mut_unchecked::<T>();
86        let valid = valid.to_array_view_unchecked::<T>();
87        for i in range {
88            data.slice_axis_mut(Axis(axis), (i..i + 1).into()).assign(&valid);
89        }
90    }
91}
92
93#[derive(Debug, Clone, Default, Hash)]
94struct PulsePadOpState {
95    current_pos: usize,
96    last_valid_frame: Option<Tensor>,
97}
98
99impl OpState for PulsePadOpState {
100    fn eval(
101        &mut self,
102        session: &mut SessionState,
103        op: &dyn Op,
104        inputs: TVec<TValue>,
105    ) -> TractResult<TVec<TValue>> {
106        let input = args_1!(inputs).into_tensor();
107        let op = op.downcast_ref::<PulsePad>().ok_or_else(|| format_err!("Wrong Op type"))?;
108        let tensor = self.pad(session, op, input)?;
109        Ok(tvec!(tensor.into_tvalue()))
110    }
111}
112
113impl PulsePadOpState {
114    unsafe fn save_frame<T: Datum + Copy>(&mut self, op: &PulsePad, input: &Tensor, frame: usize) {
115        let data = unsafe { input.to_array_view_unchecked::<T>() };
116        self.last_valid_frame =
117            Some(data.index_axis(Axis(op.axis), frame).to_owned().into_tensor());
118    }
119
120    fn pad(
121        &mut self,
122        session: &SessionState,
123        op: &PulsePad,
124        mut input: Tensor,
125    ) -> TractResult<Tensor> {
126        let pulse = input.shape()[op.axis];
127        let pulse_begin = self.current_pos;
128        let pulse_end = self.current_pos + pulse;
129        self.current_pos += pulse - op.overlap;
130        let end_input =
131            op.end_input.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
132        let after = op.after.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
133
134        if let PadMode::Edge = op.mode {
135            if after != 0 && pulse_begin < end_input {
136                let latest_valid_frame = (end_input - pulse_begin).min(pulse) - 1;
137                unsafe {
138                    dispatch_copy_by_size!(Self::save_frame(input.datum_type())(
139                        self,
140                        op,
141                        &input,
142                        latest_valid_frame
143                    ))
144                }
145            }
146        }
147
148        // pulse is entirely in valid input, just forward
149        if pulse_begin >= op.begin_input && pulse_end <= end_input {
150            return Ok(input);
151        }
152        // pulse is entirely before or after output is valid, just forward
153        if pulse_end <= op.begin_input - op.before || pulse_begin >= end_input.saturating_add(after)
154        {
155            return Ok(input);
156        }
157
158        if pulse_begin < op.begin_input {
159            let fill_up_to = (op.begin_input - pulse_begin).min(pulse);
160            match &op.mode {
161                PadMode::Constant(c) => unsafe {
162                    dispatch_copy_by_size!(fill_slice_constant(input.datum_type())(
163                        &mut input,
164                        c,
165                        op.axis,
166                        0..fill_up_to
167                    ))
168                },
169                PadMode::Edge => {
170                    let frame = input.slice(op.axis, fill_up_to, fill_up_to + 1)?;
171                    unsafe {
172                        dispatch_copy_by_size!(fill_slice_with_frame(input.datum_type())(
173                            &mut input,
174                            op.axis,
175                            &frame,
176                            0..fill_up_to
177                        ))
178                    }
179                }
180                _ => unimplemented!(),
181            }
182        }
183        if pulse_end > end_input && after > 0 {
184            let fill_from = pulse - (pulse_end - end_input).min(pulse);
185            match &op.mode {
186                PadMode::Constant(c) => unsafe {
187                    dispatch_copy_by_size!(fill_slice_constant(input.datum_type())(
188                        &mut input,
189                        c,
190                        op.axis,
191                        fill_from..pulse
192                    ))
193                },
194                PadMode::Edge => {
195                    let last_frame = self.last_valid_frame.as_ref().unwrap();
196                    unsafe {
197                        dispatch_copy_by_size!(fill_slice_with_frame(input.datum_type())(
198                            &mut input,
199                            op.axis,
200                            last_frame,
201                            fill_from..pulse
202                        ))
203                    }
204                }
205                _ => unimplemented!(),
206            }
207        }
208
209        Ok(input)
210    }
211}
212
213#[derive(Debug, Clone, Default, Hash)]
214pub struct PulsePad {
215    pub axis: usize,
216    pub before: usize,
217    pub after: TDim,
218    pub begin_input: usize,
219    pub end_input: TDim,
220    pub mode: PadMode,
221    pub overlap: usize,
222}
223
224impl Op for PulsePad {
225    fn name(&self) -> StaticName {
226        "PulsePad".into()
227    }
228
229    fn info(&self) -> TractResult<Vec<String>> {
230        Ok(vec![format!(
231            "Mode: {:?}, axis: {} before: {} after: {}",
232            self.mode, self.axis, self.before, self.after,
233        )])
234    }
235
236    op_as_typed_op!();
237}
238
239impl EvalOp for PulsePad {
240    fn is_stateless(&self) -> bool {
241        false
242    }
243
244    fn state(
245        &self,
246        _session: &mut SessionState,
247        _node_id: usize,
248    ) -> TractResult<Option<Box<dyn OpState>>> {
249        Ok(Some(Box::<PulsePadOpState>::default()))
250    }
251}
252
253impl TypedOp for PulsePad {
254    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
255        Ok(tvec!(inputs[0].clone()))
256    }
257
258    as_op!();
259}
260
261#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
262struct FrozenPulsePadOpState {
263    current_pos: usize,
264    last_valid_frame: Option<Arc<Tensor>>,
265}
266
267impl OpStateFreeze for PulsePadOpState {
268    fn freeze(&self) -> Box<dyn FrozenOpState> {
269        Box::new(FrozenPulsePadOpState {
270            current_pos: self.current_pos,
271            last_valid_frame: self.last_valid_frame.as_ref().map(|t| t.clone().into_arc_tensor()),
272        })
273    }
274}
275
276impl FrozenOpState for FrozenPulsePadOpState {
277    fn unfreeze(&self) -> Box<dyn OpState> {
278        Box::new(PulsePadOpState {
279            current_pos: self.current_pos,
280            last_valid_frame: self.last_valid_frame.as_ref().map(|t| t.clone().into_tensor()),
281        })
282    }
283}