tract_pulse_opl/
concat.rs

1use std::ops::Range;
2use tract_nnef::internal::*;
3use tract_nnef::tract_core::trivial_op_state_freeeze;
4
5/// Concat with pulse along concat axis
6#[derive(Debug, Clone, Hash)]
7pub struct PulsedSameAxisConcat {
8    axis: usize,
9    pre_slice: Tensor,
10    post_slice: Tensor,
11    input_delay: usize,
12    input_len: TDim,
13}
14
15impl Op for PulsedSameAxisConcat {
16    fn name(&self) -> Cow<str> {
17        "PulsedSameAxisConcat".into()
18    }
19
20    op_as_typed_op!();
21}
22
23impl EvalOp for PulsedSameAxisConcat {
24    fn is_stateless(&self) -> bool {
25        true
26    }
27
28    fn state(
29        &self,
30        _session: &mut SessionState,
31        _node_id: usize,
32    ) -> TractResult<Option<Box<dyn OpState>>> {
33        Ok(Some(Box::<PulsedSameAxisConcatState>::default()))
34    }
35}
36
37impl TypedOp for PulsedSameAxisConcat {
38    as_op!();
39
40    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
41        Ok(tvec!(inputs[0].clone()))
42    }
43}
44
45#[derive(Clone, Debug, Default)]
46pub struct PulsedSameAxisConcatState {
47    current_pos: usize,
48}
49trivial_op_state_freeeze!(PulsedSameAxisConcatState);
50
51impl OpState for PulsedSameAxisConcatState {
52    fn eval(
53        &mut self,
54        session: &mut SessionState,
55        op: &dyn Op,
56        inputs: TVec<TValue>,
57    ) -> TractResult<TVec<TValue>> {
58        let op = op
59            .downcast_ref::<PulsedSameAxisConcat>()
60            .ok_or_else(|| format_err!("Wrong Op type"))?;
61        let input = args_1!(inputs);
62        let mut data = input.into_tensor();
63        let pulse = data.shape()[op.axis];
64        let current_pos = self.current_pos;
65        self.current_pos += pulse;
66
67        let pre_length = op.pre_slice.shape()[op.axis];
68        let pre_offset = op.input_delay - pre_length;
69        overwrite_part_of_pulse(op.axis, &mut data, current_pos, &op.pre_slice, pre_offset)?;
70        if let Ok(l) = op.input_len.eval(&session.resolved_symbols).to_usize() {
71            let post_offset = op.input_delay + l;
72            overwrite_part_of_pulse(op.axis, &mut data, current_pos, &op.post_slice, post_offset)?;
73        }
74
75        Ok(tvec!(data.into_tvalue()))
76    }
77}
78
79pub fn overwrite_part_of_pulse(
80    axis: usize,
81    pulse_data: &mut Tensor,
82    current_pos: usize,
83    const_data: &Tensor,
84    const_offset: usize,
85) -> TractResult<()> {
86    let pulse = pulse_data.shape()[axis];
87    let const_length = const_data.shape()[axis];
88    let const_range = const_offset..const_offset + const_length;
89    let pulse_range = current_pos..current_pos + pulse;
90
91    match range_in_range(&pulse_range, &const_range) {
92        RangeInRange::Before(_) | RangeInRange::After(_) => (),
93        RangeInRange::Begin(offset) => {
94            // ----[<----->HHH]HH----
95            pulse_data.assign_slice(offset..pulse, const_data, 0..pulse - offset, axis)?;
96        }
97        RangeInRange::Contain(offset) => {
98            // ----[<----->HHHHHHH-]---
99            pulse_data.assign_slice(
100                offset..offset + const_length,
101                const_data,
102                0..const_length,
103                axis,
104            )?;
105        }
106        RangeInRange::Inside(offset) => {
107            // ----------<H>[HH]HH----
108            pulse_data.assign_slice(0..pulse, const_data, offset..offset + pulse, axis)?;
109        }
110        RangeInRange::End(offset) => {
111            // --------<HHH>[HHHH-]---
112            pulse_data.assign_slice(
113                0..const_length - offset,
114                const_data,
115                offset..const_length,
116                axis,
117            )?;
118        }
119    }
120    Ok(())
121}
122
123#[derive(Copy, Clone, Debug)]
124#[allow(dead_code)]
125pub enum RangeInRange {
126    /// ----[--]<-->HHHH----
127    Before(usize),
128    /// ----[<----->HHH]HH----
129    Begin(usize),
130    /// ----[<----->HHHHHHH-]---
131    Contain(usize),
132    /// ----------<H>[HH]HH----
133    Inside(usize),
134    /// --------<HHH>[HHHH-]---
135    End(usize),
136    /// --------HHHHHHH<->[--]---
137    After(usize),
138}
139
140pub fn range_in_range(needle: &Range<usize>, haystack: &Range<usize>) -> RangeInRange {
141    if needle.end <= haystack.start {
142        RangeInRange::Before(haystack.start - needle.end)
143    } else if needle.start < haystack.start {
144        if needle.end < haystack.end {
145            RangeInRange::Begin(haystack.start - needle.start)
146        } else {
147            RangeInRange::Contain(haystack.start - needle.start)
148        }
149    } else if needle.start >= haystack.end {
150        RangeInRange::After(needle.start - haystack.end)
151    } else if needle.end > haystack.end {
152        RangeInRange::End(needle.start - haystack.start)
153    } else {
154        RangeInRange::Inside(needle.start - haystack.start)
155    }
156}