tract_pulse_opl/
concat.rs1use std::ops::Range;
2use tract_nnef::internal::*;
3use tract_nnef::tract_core::trivial_op_state_freeeze;
4
5#[derive(Debug, Clone, Hash)]
7pub struct PulsedSameAxisConcat {
8 pub axis: usize,
9 pub input_delay: usize,
10 pub input_len: TDim,
11}
12
13impl Op for PulsedSameAxisConcat {
14 fn name(&self) -> StaticName {
15 "PulsedSameAxisConcat".into()
16 }
17
18 op_as_typed_op!();
19}
20
21impl EvalOp for PulsedSameAxisConcat {
22 fn is_stateless(&self) -> bool {
23 false
24 }
25
26 fn state(
27 &self,
28 _session: &mut SessionState,
29 _node_id: usize,
30 ) -> TractResult<Option<Box<dyn OpState>>> {
31 Ok(Some(Box::<PulsedSameAxisConcatState>::default()))
32 }
33}
34
35impl TypedOp for PulsedSameAxisConcat {
36 as_op!();
37
38 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
39 Ok(tvec!(inputs[0].clone()))
40 }
41}
42
43#[derive(Clone, Debug, Default)]
44pub struct PulsedSameAxisConcatState {
45 current_pos: usize,
46}
47trivial_op_state_freeeze!(PulsedSameAxisConcatState);
48
49impl OpState for PulsedSameAxisConcatState {
50 fn eval(
51 &mut self,
52 session: &mut SessionState,
53 op: &dyn Op,
54 inputs: TVec<TValue>,
55 ) -> TractResult<TVec<TValue>> {
56 let op = op
57 .downcast_ref::<PulsedSameAxisConcat>()
58 .ok_or_else(|| format_err!("Wrong Op type"))?;
59 let (pre, input, post) = args_3!(inputs);
60 let mut data = input.into_tensor();
61 let pulse = data.shape()[op.axis];
62 let current_pos = self.current_pos;
63 self.current_pos += pulse;
64
65 let pre_length = pre.shape()[op.axis];
66 let pre_offset = op.input_delay - pre_length;
67 overwrite_part_of_pulse(op.axis, &mut data, current_pos, &pre, pre_offset)?;
68 if let Ok(l) = op.input_len.eval(&session.resolved_symbols).to_usize() {
69 let post_offset = op.input_delay + l;
70 overwrite_part_of_pulse(op.axis, &mut data, current_pos, &post, post_offset)?;
71 }
72
73 Ok(tvec!(data.into_tvalue()))
74 }
75}
76
77pub fn overwrite_part_of_pulse(
78 axis: usize,
79 pulse_data: &mut Tensor,
80 current_pos: usize,
81 const_data: &Tensor,
82 const_offset: usize,
83) -> TractResult<()> {
84 let pulse = pulse_data.shape()[axis];
85 let const_length = const_data.shape()[axis];
86 let const_range = const_offset..const_offset + const_length;
87 let pulse_range = current_pos..current_pos + pulse;
88
89 match range_in_range(&pulse_range, &const_range) {
90 RangeInRange::Before(_) | RangeInRange::After(_) => (),
91 RangeInRange::Begin(offset) => {
92 pulse_data.assign_slice(offset..pulse, const_data, 0..pulse - offset, axis)?;
94 }
95 RangeInRange::Contain(offset) => {
96 pulse_data.assign_slice(
98 offset..offset + const_length,
99 const_data,
100 0..const_length,
101 axis,
102 )?;
103 }
104 RangeInRange::Inside(offset) => {
105 pulse_data.assign_slice(0..pulse, const_data, offset..offset + pulse, axis)?;
107 }
108 RangeInRange::End(offset) => {
109 pulse_data.assign_slice(
111 0..const_length - offset,
112 const_data,
113 offset..const_length,
114 axis,
115 )?;
116 }
117 }
118 Ok(())
119}
120
121#[derive(Copy, Clone, Debug)]
122#[allow(dead_code)]
123pub enum RangeInRange {
124 Before(usize),
126 Begin(usize),
128 Contain(usize),
130 Inside(usize),
132 End(usize),
134 After(usize),
136}
137
138pub fn range_in_range(needle: &Range<usize>, haystack: &Range<usize>) -> RangeInRange {
139 if needle.end <= haystack.start {
140 RangeInRange::Before(haystack.start - needle.end)
141 } else if needle.start < haystack.start {
142 if needle.end < haystack.end {
143 RangeInRange::Begin(haystack.start - needle.start)
144 } else {
145 RangeInRange::Contain(haystack.start - needle.start)
146 }
147 } else if needle.start >= haystack.end {
148 RangeInRange::After(needle.start - haystack.end)
149 } else if needle.end > haystack.end {
150 RangeInRange::End(needle.start - haystack.start)
151 } else {
152 RangeInRange::Inside(needle.start - haystack.start)
153 }
154}