tract_pulse/ops/cnn/
pools.rs

1use crate::internal::*;
2use tract_core::num_traits::Zero;
3use tract_core::ops::cnn::{MaxPool, PaddingSpec, PoolSpec, SumPool};
4
5register_all!(MaxPool: pulsify_max_pool, SumPool: pulsify_sum_pool);
6
7fn pulsify_max_pool(
8    op: &MaxPool,
9    source: &TypedModel,
10    node: &TypedNode,
11    target: &mut PulsedModel,
12    mapping: &HashMap<OutletId, OutletId>,
13    _symbol: &Symbol,
14    _pulse: &TDim,
15) -> TractResult<Option<TVec<OutletId>>> {
16    fn min_value<D: Datum + tract_core::num_traits::Bounded>() -> Tensor {
17        tensor0(D::min_value())
18    }
19    let fact = target.outlet_fact(mapping[&node.inputs[0]])?;
20    let min = dispatch_numbers!(min_value(fact.datum_type)());
21    if let Some((wire, pool_spec)) =
22        pulsify_pooled_input(&op.pool_spec, source, node, target, mapping, Some(min))?
23    {
24        Ok(Some(target.wire_node(&node.name, MaxPool { pool_spec, ..op.clone() }, &[wire])?))
25    } else {
26        Ok(None)
27    }
28}
29
30fn pulsify_sum_pool(
31    op: &SumPool,
32    source: &TypedModel,
33    node: &TypedNode,
34    target: &mut PulsedModel,
35    mapping: &HashMap<OutletId, OutletId>,
36    _symbol: &Symbol,
37    _pulse: &TDim,
38) -> TractResult<Option<TVec<OutletId>>> {
39    if let Some((wire, pool_spec)) =
40        pulsify_pooled_input(&op.pool_spec, source, node, target, mapping, None)?
41    {
42        Ok(Some(target.wire_node(&node.name, SumPool { pool_spec, ..op.clone() }, &[wire])?))
43    } else {
44        Ok(None)
45    }
46}
47
48impl PulsedOp for SumPool {
49    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
50        pulsed_output_facts(&self.pool_spec, inputs, inputs[0].datum_type)
51    }
52
53    as_op!();
54    pulsed_op_to_typed_op!();
55}
56
57impl PulsedOp for MaxPool {
58    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
59        let mut facts = pulsed_output_facts(&self.pool_spec, inputs, inputs[0].datum_type)?;
60        if let Some(idt) = self.with_index_outputs {
61            facts.push(facts[0].clone());
62            facts[1].datum_type = idt;
63        }
64        Ok(facts)
65    }
66
67    as_op!();
68    pulsed_op_to_typed_op!();
69}
70
71pub fn pulsed_output_facts(
72    spec: &PoolSpec,
73    inputs: &[&PulsedFact],
74    output_dt: DatumType,
75) -> TractResult<TVec<PulsedFact>> {
76    let ishape = spec.data_format.shape(&inputs[0].shape)?;
77    let computed = spec.padding.compute(
78        ishape.hw_dims(),
79        &spec.kernel_shape,
80        &spec.dilations(),
81        &spec.strides(),
82    );
83    let spatial_dims = computed.into_iter().map(|d| d.convoluted).collect::<TVec<TDim>>();
84    let oshape = spec.data_format.from_n_c_hw(
85        ishape.n().cloned().unwrap_or_else(|| 1.to_dim()),
86        spec.output_channels.into(),
87        spatial_dims,
88    )?;
89    let mut fact = inputs[0].clone();
90    let stream = fact.stream.as_mut().unwrap();
91    let input_shape = spec.data_format.shape(&*fact.shape)?;
92    let geo_axis = stream.axis - input_shape.h_axis();
93    let dilation = spec.dilations.as_ref().map(|d| d[geo_axis]).unwrap_or(1);
94    let kernel_len = (spec.kernel_shape[geo_axis] - 1) * dilation;
95    let stride = spec.strides.as_ref().and_then(|v| v.get(geo_axis).cloned()).unwrap_or(1);
96    stream.delay /= stride;
97    stream.dim = (stream.dim.clone() - kernel_len.to_dim()).div_ceil(stride as _);
98    fact.shape = oshape.shape.into();
99    fact.datum_type = output_dt;
100    Ok(tvec!(fact))
101}
102
103pub fn pulsify_pooled_input(
104    spec: &PoolSpec,
105    _source: &TypedModel,
106    node: &TypedNode,
107    target: &mut PulsedModel,
108    mapping: &HashMap<OutletId, OutletId>,
109    padding_value: Option<Tensor>,
110) -> TractResult<Option<(OutletId, PoolSpec)>> {
111    let mut wire = mapping[&node.inputs[0]];
112    let input_fact: PulsedFact = target.outlet_fact(wire)?.clone();
113    let input_stream = input_fact.stream.as_ref().unwrap();
114    let input_shape = spec.data_format.shape(input_fact.shape.clone())?;
115    if Some(input_stream.axis) == input_shape.n_axis() {
116        return Ok(None);
117    }
118    if input_stream.axis == input_shape.c_axis() {
119        bail!("Can not pulsify cnn pooling ops along the input channel axis");
120    }
121
122    let geo_axis = input_stream.axis - input_shape.h_axis();
123    let stride = spec.strides.as_ref().and_then(|v| v.get(geo_axis).cloned()).unwrap_or(1);
124    let pulse = input_fact.pulse().unwrap();
125    if !(pulse.to_owned() % (stride as i64)).is_zero() {
126        bail!("Pulsification requires pulse ({}) to be a stride ({}) multiple", pulse, stride)
127    }
128
129    let dilation = spec.dilations.as_ref().map(|d| d[geo_axis]).unwrap_or(1);
130    let kernel_len = (spec.kernel_shape[geo_axis] - 1) * dilation;
131    let overlap = (kernel_len + 1).saturating_sub(stride);
132
133    let computed_padding = spec.padding.compute_one(
134        geo_axis,
135        &input_stream.dim,
136        spec.kernel_shape[geo_axis],
137        spec.dilation(geo_axis),
138        spec.stride(geo_axis),
139    );
140
141    let before = computed_padding.pad_before.to_usize()?;
142    let early = input_stream.delay as isize + overlap as isize - before as isize;
143    let mut extra_delay = if early < 0 { (-early) as usize } else { 0 };
144    let delayed_input = input_stream.delay + overlap + extra_delay - before;
145    let misalignment = delayed_input % stride;
146    if misalignment > 0 {
147        extra_delay += stride - misalignment;
148    }
149
150    if overlap > 0 || extra_delay > 0 {
151        wire = target.wire_node(
152            format!("{}.delay", node.name),
153            tract_pulse_opl::ops::Delay::new_typed(
154                &(&input_fact).into(),
155                input_stream.axis,
156                extra_delay,
157                overlap,
158            ),
159            &[wire],
160        )?[0];
161    }
162
163    let has_padding =
164        !computed_padding.pad_before.is_zero() || !computed_padding.pad_after.is_zero();
165
166    if has_padding {
167        use tract_core::ops::array::PadMode;
168        let value = if let Some(tensor) = padding_value {
169            tensor.into_arc_tensor()
170        } else {
171            bail!("No padding value for streaming pool operation");
172        };
173        let op = tract_pulse_opl::ops::PulsePad {
174            axis: input_stream.axis,
175            before,
176            after: computed_padding.pad_after,
177            begin_input: input_stream.delay + extra_delay + overlap,
178            end_input: input_stream.dim.clone()
179                + input_stream.delay
180                + extra_delay
181                + overlap.to_dim(),
182            mode: PadMode::Constant(value),
183            overlap,
184        };
185        wire = target.wire_node(format!("{}.pulse-pad", node.name), op, &[wire])?[0];
186    }
187
188    if has_padding {
189        let mut bef = tvec!();
190        let mut aft = tvec!();
191        for ix in 0..input_shape.hw_rank() {
192            if ix == geo_axis {
193                bef.push(0);
194                aft.push(0);
195            } else {
196                let c = spec.padding.compute_one(
197                    ix,
198                    &input_shape.hw_dims()[ix],
199                    spec.kernel_shape[ix],
200                    spec.dilations()[ix],
201                    spec.strides()[ix],
202                );
203                bef.push(c.pad_before.to_usize()?);
204                aft.push(c.pad_after.to_usize()?);
205            };
206        }
207        Ok(Some((
208            wire,
209            PoolSpec { padding: PaddingSpec::ExplicitOnnxPool(bef, aft, false), ..spec.clone() },
210        )))
211    } else {
212        Ok(Some((wire, spec.clone())))
213    }
214}
215
216#[cfg(test)]
217mod test {
218    use tract_pulse_opl::tract_core::ops::cnn::{Conv, PoolSpec};
219    use tract_pulse_opl::tract_nnef::internal::*;
220
221    use crate::model::{PulsedModel, PulsedModelExt};
222
223    #[test]
224    fn left_padded_conv_wo_delay() -> TractResult<()> {
225        let mut model = TypedModel::default();
226        let stream_sym = model.symbols.sym("S");
227        let stream_dim = stream_sym.to_dim();
228        let source = model.add_source("source", f32::fact(dims!(1, stream_dim)))?;
229        let kernel = model.add_const("kernel", rctensor3(&[[[1f32, 2f32]]]))?;
230        let bias = model.add_const("bias", rctensor0(0f32))?;
231        let conv = model.wire_node(
232            "conv",
233            Conv {
234                pool_spec: PoolSpec {
235                    data_format: tract_core::ops::nn::DataFormat::CHW,
236                    dilations: None,
237                    strides: None,
238                    kernel_shape: tvec![2],
239                    padding: tract_core::ops::cnn::PaddingSpec::ExplicitOnnxPool(
240                        tvec![1],
241                        tvec![0],
242                        false,
243                    ),
244                    input_channels: 1,
245                    output_channels: 1,
246                },
247                kernel_fmt: tract_core::ops::cnn::KernelFormat::OIHW,
248                group: 1,
249                q_params: None,
250            },
251            &[source, kernel, bias],
252        )?;
253        model.set_output_outlets(&conv)?;
254        let pulsed = PulsedModel::new(&model, stream_sym, &1.to_dim())?;
255        let output_fact = pulsed.output_fact(0)?;
256        assert_eq!(output_fact.stream.as_ref().unwrap().delay, 0);
257        Ok(())
258    }
259}