tract_pulse_opl/
concat.rs

1use std::ops::Range;
2use tract_nnef::internal::*;
3use tract_nnef::tract_core::trivial_op_state_freeeze;
4
5/// Concat with pulse along concat axis
6#[derive(Debug, Clone, Hash)]
7pub struct PulsedSameAxisConcat {
8    pub axis: usize,
9    pub input_delay: usize,
10    pub input_len: TDim,
11}
12
13impl Op for PulsedSameAxisConcat {
14    fn name(&self) -> StaticName {
15        "PulsedSameAxisConcat".into()
16    }
17
18    op_as_typed_op!();
19}
20
21impl EvalOp for PulsedSameAxisConcat {
22    fn is_stateless(&self) -> bool {
23        false
24    }
25
26    fn state(
27        &self,
28        _session: &mut SessionState,
29        _node_id: usize,
30    ) -> TractResult<Option<Box<dyn OpState>>> {
31        Ok(Some(Box::<PulsedSameAxisConcatState>::default()))
32    }
33}
34
35impl TypedOp for PulsedSameAxisConcat {
36    as_op!();
37
38    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
39        Ok(tvec!(inputs[0].clone()))
40    }
41}
42
43#[derive(Clone, Debug, Default)]
44pub struct PulsedSameAxisConcatState {
45    current_pos: usize,
46}
47trivial_op_state_freeeze!(PulsedSameAxisConcatState);
48
49impl OpState for PulsedSameAxisConcatState {
50    fn eval(
51        &mut self,
52        session: &mut SessionState,
53        op: &dyn Op,
54        inputs: TVec<TValue>,
55    ) -> TractResult<TVec<TValue>> {
56        let op = op
57            .downcast_ref::<PulsedSameAxisConcat>()
58            .ok_or_else(|| format_err!("Wrong Op type"))?;
59        let (pre, input, post) = args_3!(inputs);
60        let mut data = input.into_tensor();
61        let pulse = data.shape()[op.axis];
62        let current_pos = self.current_pos;
63        self.current_pos += pulse;
64
65        let pre_length = pre.shape()[op.axis];
66        let pre_offset = op.input_delay - pre_length;
67        overwrite_part_of_pulse(op.axis, &mut data, current_pos, &pre, pre_offset)?;
68        if let Ok(l) = op.input_len.eval(&session.resolved_symbols).to_usize() {
69            let post_offset = op.input_delay + l;
70            overwrite_part_of_pulse(op.axis, &mut data, current_pos, &post, post_offset)?;
71        }
72
73        Ok(tvec!(data.into_tvalue()))
74    }
75}
76
77pub fn overwrite_part_of_pulse(
78    axis: usize,
79    pulse_data: &mut Tensor,
80    current_pos: usize,
81    const_data: &Tensor,
82    const_offset: usize,
83) -> TractResult<()> {
84    let pulse = pulse_data.shape()[axis];
85    let const_length = const_data.shape()[axis];
86    let const_range = const_offset..const_offset + const_length;
87    let pulse_range = current_pos..current_pos + pulse;
88
89    match range_in_range(&pulse_range, &const_range) {
90        RangeInRange::Before(_) | RangeInRange::After(_) => (),
91        RangeInRange::Begin(offset) => {
92            // ----[<----->HHH]HH----
93            pulse_data.assign_slice(offset..pulse, const_data, 0..pulse - offset, axis)?;
94        }
95        RangeInRange::Contain(offset) => {
96            // ----[<----->HHHHHHH-]---
97            pulse_data.assign_slice(
98                offset..offset + const_length,
99                const_data,
100                0..const_length,
101                axis,
102            )?;
103        }
104        RangeInRange::Inside(offset) => {
105            // ----------<H>[HH]HH----
106            pulse_data.assign_slice(0..pulse, const_data, offset..offset + pulse, axis)?;
107        }
108        RangeInRange::End(offset) => {
109            // --------<HHH>[HHHH-]---
110            pulse_data.assign_slice(
111                0..const_length - offset,
112                const_data,
113                offset..const_length,
114                axis,
115            )?;
116        }
117    }
118    Ok(())
119}
120
121#[derive(Copy, Clone, Debug)]
122#[allow(dead_code)]
123pub enum RangeInRange {
124    /// ----[--]<-->HHHH----
125    Before(usize),
126    /// ----[<----->HHH]HH----
127    Begin(usize),
128    /// ----[<----->HHHHHHH-]---
129    Contain(usize),
130    /// ----------<H>[HH]HH----
131    Inside(usize),
132    /// --------<HHH>[HHHH-]---
133    End(usize),
134    /// --------HHHHHHH<->[--]---
135    After(usize),
136}
137
138pub fn range_in_range(needle: &Range<usize>, haystack: &Range<usize>) -> RangeInRange {
139    if needle.end <= haystack.start {
140        RangeInRange::Before(haystack.start - needle.end)
141    } else if needle.start < haystack.start {
142        if needle.end < haystack.end {
143            RangeInRange::Begin(haystack.start - needle.start)
144        } else {
145            RangeInRange::Contain(haystack.start - needle.start)
146        }
147    } else if needle.start >= haystack.end {
148        RangeInRange::After(needle.start - haystack.end)
149    } else if needle.end > haystack.end {
150        RangeInRange::End(needle.start - haystack.start)
151    } else {
152        RangeInRange::Inside(needle.start - haystack.start)
153    }
154}