1use crate::internal::*;
2use tract_core::num_traits::Zero;
3use tract_core::ops::cnn::{MaxPool, PaddingSpec, PoolSpec, SumPool};
4
5register_all!(MaxPool: pulsify_max_pool, SumPool: pulsify_sum_pool);
6
7fn pulsify_max_pool(
8 op: &MaxPool,
9 source: &TypedModel,
10 node: &TypedNode,
11 target: &mut PulsedModel,
12 mapping: &HashMap<OutletId, OutletId>,
13 _symbol: &Symbol,
14 _pulse: &TDim,
15) -> TractResult<Option<TVec<OutletId>>> {
16 fn min_value<D: Datum + tract_core::num_traits::Bounded>() -> Tensor {
17 tensor0(D::min_value())
18 }
19 let fact = target.outlet_fact(mapping[&node.inputs[0]])?;
20 let min = dispatch_numbers!(min_value(fact.datum_type)());
21 if let Some((wire, pool_spec)) =
22 pulsify_pooled_input(&op.pool_spec, source, node, target, mapping, Some(min))?
23 {
24 Ok(Some(target.wire_node(&node.name, MaxPool { pool_spec, ..op.clone() }, &[wire])?))
25 } else {
26 Ok(None)
27 }
28}
29
30fn pulsify_sum_pool(
31 op: &SumPool,
32 source: &TypedModel,
33 node: &TypedNode,
34 target: &mut PulsedModel,
35 mapping: &HashMap<OutletId, OutletId>,
36 _symbol: &Symbol,
37 _pulse: &TDim,
38) -> TractResult<Option<TVec<OutletId>>> {
39 if let Some((wire, pool_spec)) =
40 pulsify_pooled_input(&op.pool_spec, source, node, target, mapping, None)?
41 {
42 Ok(Some(target.wire_node(&node.name, SumPool { pool_spec, ..op.clone() }, &[wire])?))
43 } else {
44 Ok(None)
45 }
46}
47
48impl PulsedOp for SumPool {
49 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
50 pulsed_output_facts(&self.pool_spec, inputs, inputs[0].datum_type)
51 }
52
53 as_op!();
54 pulsed_op_to_typed_op!();
55}
56
57impl PulsedOp for MaxPool {
58 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
59 let mut facts = pulsed_output_facts(&self.pool_spec, inputs, inputs[0].datum_type)?;
60 if let Some(idt) = self.with_index_outputs {
61 facts.push(facts[0].clone());
62 facts[1].datum_type = idt;
63 }
64 Ok(facts)
65 }
66
67 as_op!();
68 pulsed_op_to_typed_op!();
69}
70
71pub fn pulsed_output_facts(
72 spec: &PoolSpec,
73 inputs: &[&PulsedFact],
74 output_dt: DatumType,
75) -> TractResult<TVec<PulsedFact>> {
76 let ishape = spec.data_format.shape(&inputs[0].shape)?;
77 let computed = spec.padding.compute(
78 ishape.hw_dims(),
79 &spec.kernel_shape,
80 &spec.dilations(),
81 &spec.strides(),
82 );
83 let spatial_dims = computed.into_iter().map(|d| d.convoluted).collect::<TVec<TDim>>();
84 let oshape = spec.data_format.from_n_c_hw(
85 ishape.n().cloned().unwrap_or_else(|| 1.to_dim()),
86 spec.output_channels.into(),
87 spatial_dims,
88 )?;
89 let mut fact = inputs[0].clone();
90 let stream = fact.stream.as_mut().unwrap();
91 let input_shape = spec.data_format.shape(&*fact.shape)?;
92 let geo_axis = stream.axis - input_shape.h_axis();
93 let dilation = spec.dilations.as_ref().map(|d| d[geo_axis]).unwrap_or(1);
94 let kernel_len = (spec.kernel_shape[geo_axis] - 1) * dilation;
95 let stride = spec.strides.as_ref().and_then(|v| v.get(geo_axis).cloned()).unwrap_or(1);
96 stream.delay /= stride;
97 stream.dim = (stream.dim.clone() - kernel_len.to_dim()).div_ceil(stride as _);
98 fact.shape = oshape.shape.into();
99 fact.datum_type = output_dt;
100 Ok(tvec!(fact))
101}
102
103pub fn pulsify_pooled_input(
104 spec: &PoolSpec,
105 _source: &TypedModel,
106 node: &TypedNode,
107 target: &mut PulsedModel,
108 mapping: &HashMap<OutletId, OutletId>,
109 padding_value: Option<Tensor>,
110) -> TractResult<Option<(OutletId, PoolSpec)>> {
111 let mut wire = mapping[&node.inputs[0]];
112 let input_fact: PulsedFact = target.outlet_fact(wire)?.clone();
113 let input_stream = input_fact.stream.as_ref().unwrap();
114 let input_shape = spec.data_format.shape(input_fact.shape.clone())?;
115 if Some(input_stream.axis) == input_shape.n_axis() {
116 return Ok(None);
117 }
118 if input_stream.axis == input_shape.c_axis() {
119 bail!("Can not pulsify cnn pooling ops along the input channel axis");
120 }
121
122 let geo_axis = input_stream.axis - input_shape.h_axis();
123 let stride = spec.strides.as_ref().and_then(|v| v.get(geo_axis).cloned()).unwrap_or(1);
124 let pulse = input_fact.pulse().unwrap();
125 if !(pulse.to_owned() % (stride as i64)).is_zero() {
126 bail!("Pulsification requires pulse ({}) to be a stride ({}) multiple", pulse, stride)
127 }
128
129 let dilation = spec.dilations.as_ref().map(|d| d[geo_axis]).unwrap_or(1);
130 let kernel_len = (spec.kernel_shape[geo_axis] - 1) * dilation;
131 let overlap = (kernel_len + 1).saturating_sub(stride);
132
133 let computed_padding = spec.padding.compute_one(
134 geo_axis,
135 &input_stream.dim,
136 spec.kernel_shape[geo_axis],
137 spec.dilation(geo_axis),
138 spec.stride(geo_axis),
139 );
140
141 let before = computed_padding.pad_before.to_usize()?;
142 let early = input_stream.delay as isize + overlap as isize - before as isize;
143 let mut extra_delay = if early < 0 { (-early) as usize } else { 0 };
144 let delayed_input = input_stream.delay + overlap + extra_delay - before;
145 let misalignment = delayed_input % stride;
146 if misalignment > 0 {
147 extra_delay += stride - misalignment;
148 }
149
150 if overlap > 0 || extra_delay > 0 {
151 wire = target.wire_node(
152 format!("{}.delay", node.name),
153 tract_pulse_opl::ops::Delay::new_typed(
154 &(&input_fact).into(),
155 input_stream.axis,
156 extra_delay,
157 overlap,
158 ),
159 &[wire],
160 )?[0];
161 }
162
163 let has_padding =
164 !computed_padding.pad_before.is_zero() || !computed_padding.pad_after.is_zero();
165
166 if has_padding {
167 use tract_core::ops::array::PadMode;
168 let value = if let Some(tensor) = padding_value {
169 tensor.into_arc_tensor()
170 } else {
171 bail!("No padding value for streaming pool operation");
172 };
173 let op = tract_pulse_opl::ops::PulsePad {
174 axis: input_stream.axis,
175 before,
176 after: computed_padding.pad_after,
177 begin_input: input_stream.delay + extra_delay + overlap,
178 end_input: input_stream.dim.clone()
179 + input_stream.delay
180 + extra_delay
181 + overlap.to_dim(),
182 mode: PadMode::Constant(value),
183 overlap,
184 };
185 wire = target.wire_node(format!("{}.pulse-pad", node.name), op, &[wire])?[0];
186 }
187
188 if has_padding {
189 let mut bef = tvec!();
190 let mut aft = tvec!();
191 for ix in 0..input_shape.hw_rank() {
192 if ix == geo_axis {
193 bef.push(0);
194 aft.push(0);
195 } else {
196 let c = spec.padding.compute_one(
197 ix,
198 &input_shape.hw_dims()[ix],
199 spec.kernel_shape[ix],
200 spec.dilations()[ix],
201 spec.strides()[ix],
202 );
203 bef.push(c.pad_before.to_usize()?);
204 aft.push(c.pad_after.to_usize()?);
205 };
206 }
207 Ok(Some((
208 wire,
209 PoolSpec { padding: PaddingSpec::ExplicitOnnxPool(bef, aft, false), ..spec.clone() },
210 )))
211 } else {
212 Ok(Some((wire, spec.clone())))
213 }
214}
215
216#[cfg(test)]
217mod test {
218 use tract_pulse_opl::tract_core::ops::cnn::{Conv, PoolSpec};
219 use tract_pulse_opl::tract_nnef::internal::*;
220
221 use crate::model::{PulsedModel, PulsedModelExt};
222
223 #[test]
224 fn left_padded_conv_wo_delay() -> TractResult<()> {
225 let mut model = TypedModel::default();
226 let stream_sym = model.symbols.sym("S");
227 let stream_dim = stream_sym.to_dim();
228 let source = model.add_source("source", f32::fact(dims!(1, stream_dim)))?;
229 let kernel = model.add_const("kernel", rctensor3(&[[[1f32, 2f32]]]))?;
230 let bias = model.add_const("bias", rctensor0(0f32))?;
231 let conv = model.wire_node(
232 "conv",
233 Conv {
234 pool_spec: PoolSpec {
235 data_format: tract_core::ops::nn::DataFormat::CHW,
236 dilations: None,
237 strides: None,
238 kernel_shape: tvec![2],
239 padding: tract_core::ops::cnn::PaddingSpec::ExplicitOnnxPool(
240 tvec![1],
241 tvec![0],
242 false,
243 ),
244 input_channels: 1,
245 output_channels: 1,
246 },
247 kernel_fmt: tract_core::ops::cnn::KernelFormat::OIHW,
248 group: 1,
249 q_params: None,
250 },
251 &[source, kernel, bias],
252 )?;
253 model.set_output_outlets(&conv)?;
254 let pulsed = PulsedModel::new(&model, stream_sym, &1.to_dim())?;
255 let output_fact = pulsed.output_fact(0)?;
256 assert_eq!(output_fact.stream.as_ref().unwrap().delay, 0);
257 Ok(())
258 }
259}