1use crate::internal::*;
2use crate::ops::cnn::PaddingSpec;
3use crate::ops::nn::{DataFormat, DataShape};
4use ndarray::prelude::*;
5
6use super::PatchAxis;
7
8use std::fmt::Debug;
9use std::ops::Range;
10
11use tract_itertools::{izip, Itertools};
12
13#[derive(Clone, PartialEq, Eq, Hash)]
14pub struct PatchSpec {
15    pub input_shape: TVec<usize>,
16    pub input_inner_stride: usize,
17    pub output_inner_stride: usize,
18    pub kernel_shape: TVec<usize>,
19    pub strides: TVec<usize>,
20    pub dilations: TVec<usize>,
21    pub padding: PaddingSpec,
22}
23
24impl Debug for PatchSpec {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(
27            f,
28            "input: {} kernel: {} strides: {} dil: {} pad: {:?}",
29            self.input_shape.iter().join(","),
30            self.kernel_shape.iter().join(","),
31            self.strides.iter().join(","),
32            self.dilations.iter().join(","),
33            self.padding
34        )
35    }
36}
37
38impl PatchSpec {
39    pub fn for_full_shape(
40        data_format: DataFormat,
41        input_full_shape: &[usize],
42    ) -> TractResult<PatchSpec> {
43        let shape = data_format.shape(input_full_shape.into())?;
44        Ok(Self::for_data_shape(shape))
45    }
46
47    pub fn for_data_shape(data_shape: DataShape) -> PatchSpec {
48        let input_shape: TVec<usize> = data_shape.hw_dims().into();
49        PatchSpec {
50            kernel_shape: tvec!(1; input_shape.len()),
51            input_inner_stride: *data_shape.w_stride(),
52            output_inner_stride: 1,
53            strides: tvec!(1; input_shape.len()),
54            dilations: tvec!(1; input_shape.len()),
55            padding: PaddingSpec::Valid,
56            input_shape,
57        }
58    }
59
60    pub fn with_kernel_shape(self, kernel_shape: TVec<usize>) -> PatchSpec {
61        PatchSpec { kernel_shape, ..self }
62    }
63
64    pub fn with_dilations(self, dilations: TVec<usize>) -> PatchSpec {
65        PatchSpec { dilations, ..self }
66    }
67
68    pub fn with_strides(self, strides: TVec<usize>) -> PatchSpec {
69        PatchSpec { strides, ..self }
70    }
71
72    pub fn with_padding(self, padding: PaddingSpec) -> PatchSpec {
73        PatchSpec { padding, ..self }
74    }
75
76    pub fn with_output_inner_stride(self, output_inner_stride: usize) -> PatchSpec {
77        PatchSpec { output_inner_stride, ..self }
78    }
79
80    pub fn into_patch(self) -> Patch {
81        let dims = self.padding.compute(
82            &self.input_shape,
83            &self.kernel_shape,
84            &self.dilations,
85            &self.strides,
86        );
87        let output: TVec<usize> = dims.iter().map(|d| d.convoluted).collect();
88        let pad_before: TVec<usize> = dims.iter().map(|d| d.pad_before).collect();
89        let pad_after: TVec<usize> = dims.iter().map(|d| d.pad_after).collect();
90
91        let data_field: Vec<isize> = ::ndarray::indices(&*self.kernel_shape)
92            .into_iter()
93            .flat_map(|coords| {
94                #[allow(clippy::unnecessary_to_owned)] coords
96                    .slice()
97                    .to_vec()
98                    .into_iter()
99                    .enumerate()
100                    .map(|(ix, c)| (c * self.dilations[ix]) as isize - pad_before[ix] as isize)
101            })
102            .collect();
103        let data_field = Array2::from_shape_vec(
104            (self.kernel_shape.iter().cloned().product(), self.kernel_shape.len()),
105            data_field,
106        )
107        .unwrap();
108        let data_field_min_max: TVec<_> = data_field
109            .columns()
110            .into_iter()
111            .map(|col| (col.iter().min().cloned().unwrap(), col.iter().max().cloned().unwrap()))
112            .collect();
113
114        fn strides(shape: &[usize], inner: usize) -> TVec<isize> {
115            let mut strides: TVec<isize> = tvec![inner as isize];
116            for dim in shape.iter().skip(1).rev() {
117                let previous = *strides.last().unwrap();
118                strides.push(*dim as isize * previous);
119            }
120            strides.reverse();
121            strides
122        }
123
124        let input_storage_strides = strides(&self.input_shape, self.input_inner_stride);
125        let output_storage_strides = strides(&output, self.output_inner_stride);
126
127        let standard_layout_data_field: Vec<isize> = data_field
128            .outer_iter()
129            .map(|coords| izip!(coords, &input_storage_strides).map(|(a, b)| a * b).sum::<isize>())
130            .collect();
131
132        let regions: TVec<TVec<_>> = dims
134            .iter()
135            .enumerate()
136            .map(|(ix, d)| {
137                PatchAxis {
138                    input_dim: self.input_shape[ix],
139                    kernel_dim: self.kernel_shape[ix],
140                    pad_before: d.pad_before,
141                    pad_after: d.pad_after,
142                    output_dim: d.convoluted,
143                    stride: self.strides[ix],
144                    dilation: self.dilations[ix],
145                }
146                .regions()
147            })
148            .collect::<TVec<_>>();
149
150        let zone_strides = strides(®ions.iter().map(|d| d.len()).collect::<TVec<_>>(), 1);
151
152        let zones: Vec<Zone> = regions
153            .iter()
154            .multi_cartesian_product()
155            .map(|regions| Zone {
156                input_zone_offset: 0,
157                output_ranges: regions.iter().map(|reg| reg.range.clone()).collect(),
158                output_shape: regions.iter().map(|reg| reg.range.end - reg.range.start).collect(),
159                output_zone_offset: izip!(®ions, &output_storage_strides)
160                    .map(|(reg, &stride)| reg.range.start as isize * stride)
161                    .sum::<isize>(),
162                valid: regions.iter().all(|reg| reg.mask.is_none()),
163                values_offsets: izip!(
164                    0..,
165                    ndarray::indices(&*self.kernel_shape),
166                    &standard_layout_data_field
167                )
168                .filter(|(_ix, coords, _offset)| {
169                    izip!(coords.slice(), ®ions)
170                        .all(|(&x, axis)| !axis.mask.as_ref().map(|mask| mask[x]).unwrap_or(false))
171                })
172                .map(|(ix, _coords, &window_offset)| (ix, window_offset))
173                .collect(),
174            })
175            .collect();
176
177        let valid_zone = zones.iter().position(|z| z.valid);
178
179        let mut valid_output_zone = tvec!();
180        let mut invalid_output_zones = tvec!();
181        for ix in 0..self.input_shape.len() {
182            let min_max = data_field_min_max[ix];
183            let min = (-min_max.0 as usize).divceil(self.strides[ix]);
184            let max =
185                (self.input_shape[ix].saturating_sub(min_max.1 as usize)).divceil(self.strides[ix]);
186            if min != 0 {
187                let mut invalid = valid_output_zone.clone();
188                invalid.push(0..min);
189                while invalid.len() < output.len() {
190                    invalid.push(0..output[invalid.len()])
191                }
192                invalid_output_zones.push(invalid);
193            }
194            if max < output[ix] {
195                let mut invalid = valid_output_zone.clone();
196                invalid.push(max..output[ix]);
197                while invalid.len() < output.len() {
198                    invalid.push(0..output[invalid.len()])
199                }
200                invalid_output_zones.push(invalid);
201            }
202            valid_output_zone.push(min..max)
203        }
204
205        let op_strides_times_input_storage_strides =
206            izip!(&self.strides, &input_storage_strides).map(|(a, b)| *a as isize * b).collect();
207
208        Patch {
209            spec: self,
210            padded: pad_before.iter().any(|&p| p != 0) || pad_after.iter().any(|&p| p != 0),
211            pad_before,
212            pad_after,
213            output_shape: output,
214            data_field,
215            data_field_min_max,
216            standard_layout_data_field,
217            input_storage_strides,
218            output_storage_strides,
219            op_strides_times_input_storage_strides,
220            valid_output_zone,
221            invalid_output_zones,
222            zones,
223            valid_zone_id: valid_zone,
224            zone_strides,
225        }
226    }
227}
228
229#[derive(Clone, PartialEq, Eq, Hash)]
230pub struct Patch {
231    pub spec: PatchSpec,
232    pub pad_before: TVec<usize>,
233    pub pad_after: TVec<usize>,
234    pub padded: bool,
235    pub output_shape: TVec<usize>,
236    pub data_field: Array2<isize>,
237    pub data_field_min_max: TVec<(isize, isize)>,
238    pub standard_layout_data_field: Vec<isize>,
239    pub op_strides_times_input_storage_strides: TVec<isize>,
240    pub valid_output_zone: TVec<Range<usize>>,
241    pub invalid_output_zones: TVec<TVec<Range<usize>>>,
242    pub zones: Vec<Zone>,
243    pub valid_zone_id: Option<usize>,
244    pub zone_strides: TVec<isize>,
245    pub input_storage_strides: TVec<isize>,
246    pub output_storage_strides: TVec<isize>,
247}
248
249impl Debug for Patch {
250    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251        write!(f, "{:?}", self.spec)
252    }
253}
254
255impl Patch {
256    #[inline]
257    pub fn rank(&self) -> usize {
258        self.spec.input_shape.len()
259    }
260
261    unsafe fn is_valid(&self, coords: &[usize]) -> bool {
262        unsafe {
263            for ix in 0..self.rank() {
264                let c = *coords.get_unchecked(ix) as isize;
265                let strides = *self.spec.strides.get_unchecked(ix) as isize;
266                let pos = c * strides;
267                let min_max = self.data_field_min_max.get_unchecked(ix);
268                if pos + min_max.0 < 0
269                    || pos + min_max.1 >= *self.spec.input_shape.get_unchecked(ix) as isize
270                {
271                    return false;
272                }
273            }
274            true
275        }
276    }
277
278    pub fn valid_zone(&self) -> Option<&Zone> {
279        self.valid_zone_id.map(|id| &self.zones[id])
280    }
281
282    #[inline]
283    pub fn visit_output(&self, mut acceptor: impl FnMut(&Scanner)) {
284        if self.zones.len() == 0 {
285            return;
286        }
287        let mut scanner = Scanner::new(self);
288        while !scanner.done() {
289            acceptor(&scanner);
290            scanner.next();
291        }
292    }
293
294    pub fn centers_offsets(&self) -> Vec<isize> {
295        if self.zones.len() == 0 {
296            return vec![];
297        }
298        let mut scanner = Scanner::new(self);
299        let len = self.output_shape.iter().cloned().product();
300        let mut v = Vec::with_capacity(len);
301        for _ in 0..len {
302            v.push(scanner.input_center_offset);
303            scanner.next()
304        }
305        v
306    }
307
308    pub fn at<'p>(&'p self, coords: &[usize]) -> PatchIterator<'p> {
309        self.at_hint(coords, None)
310    }
311
312    pub fn at_hint<'p>(&'p self, coords: &[usize], hint: Option<bool>) -> PatchIterator<'p> {
313        unsafe {
314            assert_eq!(coords.len(), self.spec.kernel_shape.len());
315            let mut center = 0;
316            for i in 0..self.op_strides_times_input_storage_strides.len() {
317                center += *self.op_strides_times_input_storage_strides.get_unchecked(i)
318                    * *coords.get_unchecked(i) as isize;
319            }
320            let valid = hint.unwrap_or_else(|| !self.padded || self.is_valid(coords));
321            if valid {
322                PatchIterator::Fast(FastPatchIterator { patch: self, center, item: 0 })
323            } else {
324                let mut input_patch_center: TVec<_> = coords.into();
325                input_patch_center
326                    .iter_mut()
327                    .zip(self.spec.strides.iter())
328                    .for_each(|(a, &b)| *a *= b);
329                PatchIterator::Safe(SafePatchIterator {
330                    patch: self,
331                    item: 0,
332                    input_patch_center,
333                    center,
334                })
335            }
336        }
337    }
338
339    pub fn global_offset_for(&self, coords: &[usize], patch_index: usize) -> usize {
340        assert_eq!(coords.len(), self.spec.kernel_shape.len());
341        let center = izip!(coords, &self.op_strides_times_input_storage_strides)
342            .map(|(a, b)| *a as isize * *b)
343            .sum::<isize>();
344        (center + self.standard_layout_data_field[patch_index]) as usize
345    }
346}
347
348#[derive(Clone, Debug, PartialEq, Eq, Hash)]
349pub struct Zone {
350    pub valid: bool,
351    pub input_zone_offset: isize,
352    pub output_zone_offset: isize,
353    pub output_ranges: Box<[Range<usize>]>,
354    pub output_shape: Box<[usize]>,
355    pub values_offsets: Box<[(usize, isize)]>,
357}
358
359impl Zone {
360    pub fn contains_output(&self, coords: &[usize]) -> bool {
361        self.output_ranges.iter().zip(coords).all(|(range, &x)| x >= range.start && x < range.end)
362    }
363
364    #[inline]
365    pub fn visit_output(&self, patch: &Patch, mut acceptor: impl FnMut(&ZoneScanner)) {
366        let mut scanner = ZoneScanner::new(self, patch);
367        while !scanner.done() {
368            acceptor(&scanner);
369            scanner.next();
370        }
371    }
372}
373
374#[derive(Clone, Debug, PartialEq, Eq)]
375pub struct ZoneScanner<'p> {
376    pub patch: &'p Patch,
377    pub zone: &'p Zone,
378    pub output_offset: isize,
379    pub output_coords: Box<[usize]>,
380    pub input_center_offset: isize,
381    pub inner_loop_axis: usize,
382    pub inner_loop_len: usize,
383    pub inner_loop_output_range: Range<usize>,
384    pub inner_loop_output_stride: isize,
385    pub inner_loop_input_full_stride: isize,
386    pub done: bool,
387}
388
389impl<'p> ZoneScanner<'p> {
390    pub fn new(zone: &'p Zone, patch: &'p Patch) -> ZoneScanner<'p> {
391        let inner_loop_axis =
392            zone.output_shape.iter().enumerate().max_by_key(|(_, dim)| *dim).unwrap().0;
393        let inner_loop_output_range = zone.output_ranges[inner_loop_axis].clone();
394        let inner_loop_output_stride = patch.output_storage_strides[inner_loop_axis];
395        let inner_loop_input_full_stride =
396            patch.op_strides_times_input_storage_strides[inner_loop_axis];
397        let mut scan = ZoneScanner {
398            patch,
399            zone,
400            output_offset: 0,
401            input_center_offset: 0,
402            inner_loop_axis,
403            inner_loop_len: inner_loop_output_range.len(),
404            inner_loop_output_range,
405            inner_loop_output_stride,
406            inner_loop_input_full_stride,
407            output_coords: zone.output_ranges.iter().map(|r| r.start).collect(),
408            done: false,
409        };
410        scan.refresh_dependent();
411        scan
412    }
413
414    #[inline]
415    pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
416        self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
417    }
418
419    pub unsafe fn next_non_inner_axis(&mut self) {
420        unsafe {
421            let rank = self.patch.rank();
422            let inner_loop_axis = self.inner_loop_axis;
423            for axis in (0..rank).rev() {
424                if axis == inner_loop_axis {
425                    continue;
426                }
427                *self.output_coords.get_unchecked_mut(axis) += 1;
428                if *self.output_coords.get_unchecked_mut(axis)
429                    < self.zone.output_ranges.get_unchecked(axis).end
430                {
431                    self.refresh_dependent();
432                    return;
433                }
434                *self.output_coords.get_unchecked_mut(axis) =
435                    self.zone.output_ranges.get_unchecked(axis).start;
436            }
437            self.done = true;
438        }
439    }
440
441    pub unsafe fn reset(&mut self) {
442        unsafe {
443            self.output_offset = 0;
444            self.input_center_offset = 0;
445            for ix in 0..self.output_coords.len() {
446                *self.output_coords.get_unchecked_mut(ix) =
447                    self.zone.output_ranges.get_unchecked(ix).start;
448            }
449            self.done = false;
450            self.refresh_dependent()
451        }
452    }
453
454    #[inline(never)]
455    fn refresh_dependent(&mut self) {
456        self.input_center_offset = self
457            .patch
458            .op_strides_times_input_storage_strides
459            .iter()
460            .zip(self.output_coords.iter())
461            .map(|(a, b)| *a * *b as isize)
462            .sum();
463        self.output_offset = self
464            .patch
465            .output_storage_strides
466            .iter()
467            .zip(self.output_coords.iter())
468            .map(|(a, b)| a * *b as isize)
469            .sum();
470    }
471
472    #[inline]
473    pub fn next(&mut self) {
474        let inner_loop_axis = self.inner_loop_axis;
475        unsafe {
476            *self.output_coords.get_unchecked_mut(inner_loop_axis) += 1;
477            if *self.output_coords.get_unchecked(inner_loop_axis) < self.inner_loop_output_range.end
478            {
479                self.input_center_offset += self.inner_loop_input_full_stride;
480                self.output_offset += self.inner_loop_output_stride;
481            } else {
482                *self.output_coords.get_unchecked_mut(inner_loop_axis) =
483                    self.inner_loop_output_range.start;
484                self.next_non_inner_axis();
485            }
486        }
487    }
488
489    pub fn done(&self) -> bool {
490        self.done
491    }
492}
493
494#[derive(Clone, Debug, PartialEq, Eq)]
495pub struct Scanner<'p> {
496    pub patch: &'p Patch,
497    pub zone_id: usize,
498    pub zone_coords: TVec<usize>,
499    pub zone: &'p Zone,
500    pub output_offset: isize,
501    pub output_coords: TVec<usize>,
502    pub input_coords: TVec<usize>,
503    pub input_center_offset: isize,
504    done: bool,
505}
506
507impl<'p> Scanner<'p> {
508    fn new(patch: &'p Patch) -> Scanner<'p> {
509        let rank = patch.rank();
510        let zone = &patch.zones[0];
511        Scanner {
512            patch,
513            zone_coords: tvec!(0; rank),
514            zone,
515            zone_id: 0,
516            output_offset: 0,
517            input_center_offset: 0,
518            input_coords: tvec!(0; rank),
519            output_coords: tvec!(0; rank),
520            done: false,
521        }
522    }
523
524    #[inline]
525    pub fn valid_count(&self) -> usize {
526        self.zone.values_offsets.len()
527    }
528
529    #[inline]
530    pub fn valid_offsets(&self) -> impl Iterator<Item = isize> + '_ {
531        self.zone.values_offsets.iter().map(move |pair| pair.1 + self.input_center_offset)
532    }
533
534    #[inline]
535    pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
536        self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
537    }
538
539    #[inline]
540    pub fn next(&mut self) {
541        let rank = self.patch.rank();
542        let inner_dim = rank - 1;
543        unsafe {
544            *self.output_coords.get_unchecked_mut(inner_dim) += 1;
545            *self.input_coords.get_unchecked_mut(inner_dim) +=
546                *self.patch.spec.strides.get_unchecked(inner_dim);
547            self.output_offset += self.patch.spec.output_inner_stride as isize;
548            self.input_center_offset +=
549                self.patch.op_strides_times_input_storage_strides.get_unchecked(inner_dim);
550            if *self.output_coords.get_unchecked(inner_dim)
551                < self.zone.output_ranges.get_unchecked(inner_dim).end
552            {
553                return;
554            }
555            if self.output_coords.get_unchecked(inner_dim)
556                < self.patch.output_shape.get_unchecked(inner_dim)
557            {
558                self.zone_id += 1;
559                *self.zone_coords.get_unchecked_mut(inner_dim) += 1;
560                self.zone = self.patch.zones.get_unchecked(self.zone_id);
561            } else {
562                for axis in (0..rank - 1).rev() {
563                    *self.output_coords.get_unchecked_mut(axis + 1) = 0;
564                    *self.input_coords.get_unchecked_mut(axis + 1) = 0;
565                    *self.output_coords.get_unchecked_mut(axis) += 1;
566                    *self.input_coords.get_unchecked_mut(axis) +=
567                        self.patch.spec.strides.get_unchecked(axis);
568                    *self.zone_coords.get_unchecked_mut(axis + 1) = 0;
569                    if *self.output_coords.get_unchecked(axis)
570                        == self.zone.output_ranges.get_unchecked(axis).end
571                    {
572                        *self.zone_coords.get_unchecked_mut(axis) += 1;
573                    }
574                    if *self.output_coords.get_unchecked(axis)
575                        < *self.patch.output_shape.get_unchecked(axis)
576                    {
577                        break;
578                    }
579                }
580                if self.output_coords.get_unchecked(0) == self.patch.output_shape.get_unchecked(0) {
581                    self.done = true;
582                    return;
583                }
584                self.zone_id = 0;
585                self.input_center_offset = 0;
586                for i in 0..rank {
587                    self.zone_id += *self.zone_coords.get_unchecked(i)
588                        * *self.patch.zone_strides.get_unchecked(i) as usize;
589                    self.input_center_offset += *self.input_coords.get_unchecked(i) as isize
590                        * *self.patch.input_storage_strides.get_unchecked(i);
591                }
592                self.zone = self.patch.zones.get_unchecked(self.zone_id);
593            }
594        }
595    }
596
597    pub fn done(&self) -> bool {
598        self.done
599    }
600}
601
602#[derive(Debug)]
603pub enum PatchIterator<'p> {
604    Fast(FastPatchIterator<'p>),
605    Safe(SafePatchIterator<'p>),
606}
607
608impl Iterator for PatchIterator<'_> {
609    type Item = Option<isize>;
610    #[inline(always)]
611    fn next(&mut self) -> Option<Option<isize>> {
612        match self {
613            PatchIterator::Fast(it) => it.next(),
614            PatchIterator::Safe(it) => it.next(),
615        }
616    }
617}
618
619#[derive(Debug)]
620pub struct FastPatchIterator<'p> {
621    patch: &'p Patch,
622    center: isize,
623    item: usize,
624}
625
626impl Iterator for FastPatchIterator<'_> {
627    type Item = Option<isize>;
628    #[inline(always)]
629    fn next(&mut self) -> Option<Option<isize>> {
630        if self.item == self.patch.standard_layout_data_field.len() {
631            return None;
632        }
633        unsafe {
634            let position =
635                self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
636            self.item += 1;
637            Some(Some(position))
638        }
639    }
640}
641
642#[derive(Debug)]
643pub struct SafePatchIterator<'p> {
644    patch: &'p Patch,
645    item: usize,
646    input_patch_center: TVec<usize>,
647    center: isize,
648}
649
650impl Iterator for SafePatchIterator<'_> {
651    type Item = Option<isize>;
652    fn next(&mut self) -> Option<Option<isize>> {
653        unsafe {
654            if self.item == self.patch.standard_layout_data_field.len() {
655                return None;
656            }
657            let input_shape = &self.patch.spec.input_shape;
658            let img_offset = self.patch.data_field.as_ptr().add(self.item * input_shape.len());
659
660            for ix in 0..input_shape.len() {
661                let pos = *self.input_patch_center.get_unchecked(ix) as isize + *img_offset.add(ix);
662                if pos < 0 || pos as usize >= *input_shape.get_unchecked(ix) {
663                    self.item += 1;
664                    return Some(None);
665                }
666            }
667            let position =
668                self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
669            self.item += 1;
670            Some(Some(position))
671        }
672    }
673}
674
675#[cfg(test)]
676pub mod test {
677    use super::*;
678    use crate::ops::nn::DataFormat::*;
679    use proptest::prelude::*;
680    use proptest::*;
681
682    fn compute_output_spatial_dim(
683        input: usize,
684        dilation: usize,
685        kdim: usize,
686        pad_before: usize,
687        bad_after: usize,
688        stride: usize,
689    ) -> usize {
690        let patch = PatchSpec::for_full_shape(NCHW, &[1, 1, input])
691            .unwrap()
692            .with_dilations(tvec!(dilation))
693            .with_kernel_shape(tvec!(kdim))
694            .with_padding(PaddingSpec::ExplicitOnnxPool(tvec![pad_before], tvec![bad_after], true))
695            .with_strides(tvec![stride])
696            .into_patch();
697        patch.output_shape[0]
698    }
699
700    #[test]
701    fn basic() {
702        assert_eq!(compute_output_spatial_dim(5, 1, 3, 0, 0, 1), 3);
703    }
704
705    #[test]
706    fn strides() {
707        assert_eq!(compute_output_spatial_dim(7, 1, 3, 0, 0, 2), 3);
708    }
709
710    #[test]
711    fn padding() {
712        assert_eq!(compute_output_spatial_dim(5, 1, 3, 1, 1, 1), 5);
713    }
714
715    #[test]
716    fn strides_and_padding() {
717        assert_eq!(compute_output_spatial_dim(7, 1, 3, 1, 1, 2), 4);
718    }
719
720    fn field(kdim: &[usize], dilations: &[usize]) -> Array2<isize> {
721        let patch =
722            PatchSpec::for_data_shape(NCHW.from_n_c_hw(1, 1, tvec![10; kdim.len()]).unwrap())
723                .with_dilations(dilations.into())
724                .with_kernel_shape(kdim.into())
725                .with_strides(tvec![1; kdim.len()])
726                .into_patch();
727        patch.data_field
728    }
729
730    #[test]
731    fn test_field() {
732        assert_eq!(field(&[3], &[1]), arr2(&[[0], [1], [2]]));
733        assert_eq!(field(&[3], &[2]), arr2(&[[0], [2], [4]]));
734        assert_eq!(field(&[2, 2], &[1, 1]), arr2(&[[0, 0], [0, 1], [1, 0], [1, 1]]));
735        assert_eq!(field(&[2, 2], &[2, 1]), arr2(&[[0, 0], [0, 1], [2, 0], [2, 1]]));
736    }
737
738    pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
739        let len = shape.iter().product::<usize>();
740        let shape = shape.to_vec();
741        proptest::collection::vec(any::<i8>().prop_map(|i| i as f32), len..=len)
742            .prop_map(move |vec| ArrayD::from_shape_vec(shape.clone(), vec).unwrap().into_tensor())
743            .boxed()
744    }
745
746    #[derive(Debug)]
747    struct Problem {
748        patch: Patch,
749        input: Tensor,
750        data_format: DataFormat,
751    }
752
753    impl Arbitrary for Problem {
754        type Parameters = ();
755        type Strategy = BoxedStrategy<Problem>;
756        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
757            (
758                prop_oneof!(Just(NCHW), Just(NHWC)),
759                (1usize..3, 1usize..3),
760                1usize..3,
761                (1usize..3, 1usize..3),
762                prop_oneof![
763                    Just(PaddingSpec::Valid),
764                    Just(PaddingSpec::SameLower),
765                    Just(PaddingSpec::SameUpper)
766                ],
767                (1usize..4, 1usize..4),
768            )
769                .prop_flat_map(|p| {
770                    let dil = p.1;
771                    let ks = p.3;
772                    let strides = p.5;
773                    let min_size: (usize, usize) = (1 + (ks.0 - 1) * dil.0, 1 + (ks.1 - 1) * dil.1);
774                    (
775                        Just(p),
776                        (min_size.0..min_size.0 + strides.0 * 3),
777                        (min_size.1..min_size.1 + strides.1 * 3),
778                    )
779                })
780                .prop_flat_map(|(p, h, w)| {
781                    let input_shape = p.0.from_n_c_hw(1, p.2, [h, w]).unwrap();
782                    let input = tensor(&input_shape.shape);
783                    (Just(p), input)
784                })
785                .prop_map(|((fmt, dil, c, ks, pad, strides), input)| {
786                    let output_inner_stride = if fmt.c_is_last() { c } else { 1 };
787                    Problem {
788                        patch: PatchSpec::for_full_shape(fmt, input.shape())
789                            .unwrap()
790                            .with_dilations(tvec!(dil.0, dil.1))
791                            .with_kernel_shape(tvec!(ks.0, ks.1))
792                            .with_padding(pad)
793                            .with_strides(tvec![strides.0, strides.1])
794                            .with_output_inner_stride(output_inner_stride)
795                            .into_patch(),
796                        input,
797                        data_format: fmt,
798                    }
799                })
800                .boxed()
801        }
802    }
803
804    impl Problem {
805        fn input_shape(&self) -> DataShape {
806            self.data_format.shape(self.input.shape().into()).unwrap()
807        }
808
809        fn output_shape(&self) -> DataShape {
810            self.data_format
811                .from_n_c_hw(
812                    self.input_shape().n().cloned().unwrap_or(1),
813                    *self.input_shape().c(),
814                    &*self.patch.output_shape,
815                )
816                .unwrap()
817        }
818
819        fn reference_sumpool(&self) -> Tensor {
820            let input_shape = self.input_shape();
821            let output_shape = self.output_shape();
822            let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
823            for geo_out in tract_ndarray::indices(output_shape.hw_dims()) {
824                for geo_ker in tract_ndarray::indices(&*self.patch.spec.kernel_shape) {
825                    let geo_in: TVec<isize> = izip!(
826                        geo_out.slice(),
827                        geo_ker.slice(),
828                        &self.patch.spec.strides,
829                        &self.patch.spec.dilations,
830                        &self.patch.pad_before
831                    )
832                    .map(|(o, k, s, d, p)| (o * s + k * d) as isize - *p as isize)
833                    .collect();
834                    if izip!(&geo_in, input_shape.hw_dims())
835                        .any(|(g, i)| *g >= *i as isize || *g < 0)
836                    {
837                        continue;
838                    }
839                    let geo_in: TVec<usize> = geo_in.into_iter().map(|x| x as usize).collect();
840                    for c in 0..*output_shape.c() {
841                        let ocoords = self.data_format.from_n_c_hw(0, c, geo_out.slice()).unwrap();
842                        let icoords = self.data_format.from_n_c_hw(0, c, &geo_in).unwrap();
843                        output.to_array_view_mut::<f32>().unwrap()[&*ocoords.shape] +=
844                            self.input.to_array_view::<f32>().unwrap()[&*icoords.shape];
845                    }
846                }
847            }
848            output
849        }
850
851        fn check_visitor(&self) {
852            let input_shape = self.input_shape();
853            let output_shape = self.output_shape();
854            let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
855            self.patch.visit_output(|visitor| {
856                for (_k, offset_in) in visitor.valid_offsets_ker_in() {
857                    for c in 0..*output_shape.c() {
858                        output.as_slice_mut::<f32>().unwrap()
859                            [visitor.output_offset as usize + c * output_shape.c_stride()] +=
860                            self.input.as_slice::<f32>().unwrap()
861                                [offset_in as usize + c * input_shape.c_stride()];
862                    }
863                }
864            });
865            assert_eq!(output, self.reference_sumpool());
866        }
867
868        fn check_zone_visitor(&self) {
869            let input_shape = self.input_shape();
870            let output_shape = self.output_shape();
871            let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
872            for zone in &self.patch.zones {
873                zone.visit_output(&self.patch, |visitor| {
874                    for (_k, offset_in) in visitor.valid_offsets_ker_in() {
875                        for c in 0..*output_shape.c() {
876                            output.as_slice_mut::<f32>().unwrap()
877                                [visitor.output_offset as usize + c * output_shape.c_stride()] +=
878                                self.input.as_slice::<f32>().unwrap()
879                                    [offset_in as usize + c * input_shape.c_stride()];
880                        }
881                    }
882                });
883            }
884            assert_eq!(output, self.reference_sumpool());
885        }
886
887        fn check_zoning(&self) {
888            fn in_zone(full_coords: &[usize], h_axis: usize, zone: &[Range<usize>]) -> bool {
889                for a in 0..zone.len() {
890                    if full_coords[h_axis + a] < zone[a].start
891                        || full_coords[h_axis + a] >= zone[a].end
892                    {
893                        return false;
894                    }
895                }
896                true
897            }
898
899            let valid_zone = &self.patch.valid_output_zone;
900            let invalid_zones = &self.patch.invalid_output_zones;
901            let output_full_shape = self.output_shape();
902            let h_axis = self.input_shape().h_axis();
903            for coords in ndarray::indices(&*output_full_shape.shape) {
904                let inside_valid = in_zone(coords.slice(), h_axis, valid_zone);
905                let invalid_count =
906                    invalid_zones.iter().filter(|z| in_zone(coords.slice(), h_axis, z)).count();
907                unsafe {
908                    assert_eq!(
909                        inside_valid,
910                        self.patch.is_valid(&coords.slice()[self.input_shape().hw_axes()]),
911                        "coords {:?}, valid_zone: {:?} inside_valid: {:?}",
912                        coords.slice(),
913                        valid_zone,
914                        inside_valid
915                    );
916                }
917                if inside_valid {
918                    assert_eq!(invalid_count, 0);
919                } else {
920                    assert_eq!(
921                        invalid_count,
922                        1,
923                        "coords {:?}, valid_zone: {:?} inside_valid: {:?} invalid_zones: {:?}",
924                        coords.slice(),
925                        valid_zone,
926                        inside_valid,
927                        invalid_zones
928                    );
929                }
930            }
931        }
932    }
933
934    proptest! {
935        #[test]
936        fn test_visitor(pb in any::<Problem>()) {
937            pb.check_visitor();
938        }
939
940        #[test]
941        fn test_zone_visitor(pb in any::<Problem>()) {
942            pb.check_zone_visitor();
943        }
944
945        #[test]
946        fn test_zoning(pb in any::<Problem>()) {
947            pb.check_zoning();
948        }
949    }
950
951    #[test]
952    fn test_visitor_1() {
953        let input_shape = NCHW.from_n_c_hw(1, 1, [2, 2]).unwrap();
954        let input = Tensor::zero::<f32>(&input_shape.shape).unwrap();
955        let patch = PatchSpec::for_data_shape(input_shape.clone())
956            .with_kernel_shape(tvec![2, 1])
957            .with_padding(PaddingSpec::SameLower)
958            .with_strides(tvec![1, 2])
959            .into_patch();
960        Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
961    }
962
963    #[test]
964    fn test_visitor_2() {
965        let input_shape = NCHW.from_n_c_hw(1, 2, [1, 1]).unwrap();
966        let input = tensor4(&[[[[0.]], [[1f32]]]]);
967        assert_eq!(input.shape(), &*input_shape.shape);
968        let patch =
969            PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
970        Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
971    }
972
973    #[test]
974    fn test_visitor_3() {
975        let input_shape = NHWC.from_n_c_hw(1, 2, [2, 1]).unwrap();
976        let input = tensor4(&[[[[0., 0.]], [[1., 0f32]]]]);
977        assert_eq!(input.shape(), &*input_shape.shape);
978        let patch =
979            PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
980        Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
981    }
982
983    #[test]
984    fn test_visitor_4() {
985        let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
986        let input = tensor4(&[[[[0., 1f32]]]]);
987        assert_eq!(input.shape(), &*input_shape.shape);
988        let patch = PatchSpec::for_data_shape(input_shape.clone())
989            .with_kernel_shape(tvec!(1, 2))
990            .with_output_inner_stride(1)
991            .with_padding(PaddingSpec::SameLower)
992            .into_patch();
993        Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
994    }
995
996    #[test]
997    fn test_zone_visitor_1() {
998        let input_shape = NCHW.from_n_c_hw(1, 1, [2, 1]).unwrap();
999        let input = tensor4(&[[[[0.], [1f32]]]]);
1000        assert_eq!(input.shape(), &*input_shape.shape);
1001        let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
1002        Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
1003    }
1004
1005    #[test]
1006    fn test_zone_visitor_2() {
1007        let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
1008        let input = tensor4(&[[[[0., 1f32]]]]);
1009        assert_eq!(input.shape(), &*input_shape.shape);
1010        let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
1011        Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
1012    }
1013}