tract_pulse_opl/
concat.rs1use std::ops::Range;
2use tract_nnef::internal::*;
3use tract_nnef::tract_core::trivial_op_state_freeeze;
4
5#[derive(Debug, Clone, Hash)]
7pub struct PulsedSameAxisConcat {
8 axis: usize,
9 pre_slice: Tensor,
10 post_slice: Tensor,
11 input_delay: usize,
12 input_len: TDim,
13}
14
15impl Op for PulsedSameAxisConcat {
16 fn name(&self) -> Cow<str> {
17 "PulsedSameAxisConcat".into()
18 }
19
20 op_as_typed_op!();
21}
22
23impl EvalOp for PulsedSameAxisConcat {
24 fn is_stateless(&self) -> bool {
25 true
26 }
27
28 fn state(
29 &self,
30 _session: &mut SessionState,
31 _node_id: usize,
32 ) -> TractResult<Option<Box<dyn OpState>>> {
33 Ok(Some(Box::<PulsedSameAxisConcatState>::default()))
34 }
35}
36
37impl TypedOp for PulsedSameAxisConcat {
38 as_op!();
39
40 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
41 Ok(tvec!(inputs[0].clone()))
42 }
43}
44
45#[derive(Clone, Debug, Default)]
46pub struct PulsedSameAxisConcatState {
47 current_pos: usize,
48}
49trivial_op_state_freeeze!(PulsedSameAxisConcatState);
50
51impl OpState for PulsedSameAxisConcatState {
52 fn eval(
53 &mut self,
54 session: &mut SessionState,
55 op: &dyn Op,
56 inputs: TVec<TValue>,
57 ) -> TractResult<TVec<TValue>> {
58 let op = op
59 .downcast_ref::<PulsedSameAxisConcat>()
60 .ok_or_else(|| format_err!("Wrong Op type"))?;
61 let input = args_1!(inputs);
62 let mut data = input.into_tensor();
63 let pulse = data.shape()[op.axis];
64 let current_pos = self.current_pos;
65 self.current_pos += pulse;
66
67 let pre_length = op.pre_slice.shape()[op.axis];
68 let pre_offset = op.input_delay - pre_length;
69 overwrite_part_of_pulse(op.axis, &mut data, current_pos, &op.pre_slice, pre_offset)?;
70 if let Ok(l) = op.input_len.eval(&session.resolved_symbols).to_usize() {
71 let post_offset = op.input_delay + l;
72 overwrite_part_of_pulse(op.axis, &mut data, current_pos, &op.post_slice, post_offset)?;
73 }
74
75 Ok(tvec!(data.into_tvalue()))
76 }
77}
78
79pub fn overwrite_part_of_pulse(
80 axis: usize,
81 pulse_data: &mut Tensor,
82 current_pos: usize,
83 const_data: &Tensor,
84 const_offset: usize,
85) -> TractResult<()> {
86 let pulse = pulse_data.shape()[axis];
87 let const_length = const_data.shape()[axis];
88 let const_range = const_offset..const_offset + const_length;
89 let pulse_range = current_pos..current_pos + pulse;
90
91 match range_in_range(&pulse_range, &const_range) {
92 RangeInRange::Before(_) | RangeInRange::After(_) => (),
93 RangeInRange::Begin(offset) => {
94 pulse_data.assign_slice(offset..pulse, const_data, 0..pulse - offset, axis)?;
96 }
97 RangeInRange::Contain(offset) => {
98 pulse_data.assign_slice(
100 offset..offset + const_length,
101 const_data,
102 0..const_length,
103 axis,
104 )?;
105 }
106 RangeInRange::Inside(offset) => {
107 pulse_data.assign_slice(0..pulse, const_data, offset..offset + pulse, axis)?;
109 }
110 RangeInRange::End(offset) => {
111 pulse_data.assign_slice(
113 0..const_length - offset,
114 const_data,
115 offset..const_length,
116 axis,
117 )?;
118 }
119 }
120 Ok(())
121}
122
123#[derive(Copy, Clone, Debug)]
124#[allow(dead_code)]
125pub enum RangeInRange {
126 Before(usize),
128 Begin(usize),
130 Contain(usize),
132 Inside(usize),
134 End(usize),
136 After(usize),
138}
139
140pub fn range_in_range(needle: &Range<usize>, haystack: &Range<usize>) -> RangeInRange {
141 if needle.end <= haystack.start {
142 RangeInRange::Before(haystack.start - needle.end)
143 } else if needle.start < haystack.start {
144 if needle.end < haystack.end {
145 RangeInRange::Begin(haystack.start - needle.start)
146 } else {
147 RangeInRange::Contain(haystack.start - needle.start)
148 }
149 } else if needle.start >= haystack.end {
150 RangeInRange::After(needle.start - haystack.end)
151 } else if needle.end > haystack.end {
152 RangeInRange::End(needle.start - haystack.start)
153 } else {
154 RangeInRange::Inside(needle.start - haystack.start)
155 }
156}