tract_linalg/frame/
pack.rs

1use std::alloc::Layout;
2use std::fmt::{Debug, Display};
3use std::marker::PhantomData;
4use std::ops::Range;
5use std::sync::Arc;
6use tract_data::internal::*;
7
8use crate::mmm::{EagerPackedInput, MMMInputFormat, MMMInputValue, PackedOpaqueFact};
9
10use crate::WeightType;
11
12#[derive(Clone, Eq, PartialEq, Hash)]
13pub struct PackedFormat {
14    pub dt: DatumType,
15    pub r: usize,
16    pub alignment_bytes: usize,
17    pub end_padding_record: usize,
18}
19
20impl MMMInputFormat for PackedFormat {
21    fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor> {
22        let packed = PackedFormat::pack_tensor(self, t, k_axis, mn_axis)?;
23        Ok(tensor0(Opaque(Arc::new(packed))))
24    }
25
26    fn prepare_one(
27        &self,
28        t: &Tensor,
29        k_axis: usize,
30        mn_axis: usize,
31    ) -> TractResult<Box<dyn MMMInputValue>> {
32        PackedFormat::pack_tensor(self, t, k_axis, mn_axis)
33    }
34
35    fn precursor(&self) -> WeightType {
36        WeightType::Plain(self.dt)
37    }
38
39    fn r(&self) -> usize {
40        self.r
41    }
42
43    fn k_alignment(&self) -> usize {
44        1
45    }
46
47    fn same_as(&self, other: &dyn MMMInputFormat) -> bool {
48        other.downcast_ref::<Self>().is_some_and(|other| self == other)
49    }
50
51    #[allow(clippy::collapsible_if)]
52    fn merge_with<'o, 'a: 'o, 'b: 'o>(
53        &'a self,
54        other: &'b dyn MMMInputFormat,
55    ) -> Option<&'o dyn MMMInputFormat> {
56        if let Some(other) = other.downcast_ref::<PackedFormat>() {
57            if self.r == other.r && self.dt == other.dt {
58                if self.alignment_bytes % other.alignment_bytes == 0
59                    && self.end_padding_record >= other.end_padding_record
60                {
61                    return Some(self);
62                }
63                if other.alignment_bytes % self.alignment_bytes == 0
64                    && other.end_padding_record >= self.end_padding_record
65                {
66                    return Some(other);
67                }
68            }
69        }
70        None
71    }
72
73    fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
74        self.len(k, mn) * self.dt.size_of()
75    }
76
77    fn extract_at_mn_f16(
78        &self,
79        data: &EagerPackedInput,
80        mn: usize,
81        slice: &mut [f16],
82    ) -> TractResult<()> {
83        ensure!(data.format().same_as(self));
84        ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
85        unsafe {
86            let ptr = data.packed.as_ptr().add(
87                (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
88            );
89            for (i, slot) in slice.iter_mut().enumerate() {
90                let ptr = ptr.add(i * self.dt.size_of() * self.r);
91                *slot = if self.dt == f16::datum_type() {
92                    *(ptr as *const f16)
93                } else if self.dt == f32::datum_type() {
94                    f16::from_f32(*(ptr as *const f32))
95                } else {
96                    bail!("Unexpected DT {:?}", self.dt)
97                }
98            }
99        }
100        Ok(())
101    }
102
103    fn extract_at_mn_f32(
104        &self,
105        data: &EagerPackedInput,
106        mn: usize,
107        slice: &mut [f32],
108    ) -> TractResult<()> {
109        ensure!(data.format().same_as(self));
110        ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
111        unsafe {
112            let ptr = data.packed.as_ptr().add(
113                (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
114            );
115            for (i, slot) in slice.iter_mut().enumerate() {
116                let ptr = ptr.add(i * self.dt.size_of() * self.r);
117                *slot = if self.dt == f16::datum_type() {
118                    (*(ptr as *const f16)).to_f32()
119                } else if self.dt == f32::datum_type() {
120                    *(ptr as *const f32)
121                } else {
122                    bail!("Unexpected DT {:?}", self.dt)
123                }
124            }
125        }
126        Ok(())
127    }
128}
129
130impl Display for PackedFormat {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        write!(f, "Packed{:?}[{}]", self.dt, self.r)
133    }
134}
135
136impl Debug for PackedFormat {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(
139            f,
140            "Packed{:?}[{}]@{}+{}",
141            self.dt, self.r, self.alignment_bytes, self.end_padding_record
142        )
143    }
144}
145
146impl PackedFormat {
147    pub const fn new(dt: DatumType, nr: usize, alignment_bytes: usize) -> PackedFormat {
148        PackedFormat { dt, r: nr, alignment_bytes, end_padding_record: 1 }
149    }
150
151    pub const fn with_end_padding_record(self, end_padding_record: usize) -> Self {
152        PackedFormat { end_padding_record, ..self }
153    }
154
155    #[inline]
156    pub fn align(self, alignment: usize) -> Self {
157        Self { alignment_bytes: alignment, ..self }
158    }
159
160    #[inline]
161    pub fn alignment(&self) -> usize {
162        self.alignment_bytes
163    }
164
165    #[inline]
166    pub fn panel_width(&self) -> usize {
167        self.r
168    }
169
170    #[inline]
171    pub fn len<D: DimLike>(&self, k: D, n: D) -> D {
172        n.divceil(self.r) * self.single_panel_len(k)
173    }
174
175    #[inline]
176    pub fn single_panel_len<D: DimLike>(&self, k: D) -> D {
177        ((k + self.end_padding_record) * self.r).divceil(self.alignment()) * self.alignment()
178    }
179
180    #[inline]
181    pub fn single_panel_layout(&self, k: usize, item_size: usize) -> Layout {
182        Layout::from_size_align(self.single_panel_len(k) * item_size, self.alignment()).unwrap()
183    }
184
185    pub fn pack_tensor(
186        &self,
187        t: &Tensor,
188        k_axis: usize,
189        mn_axis: usize,
190    ) -> TractResult<Box<dyn MMMInputValue>> {
191        ensure!(t.datum_type().is_copy());
192        ensure!(
193            t.datum_type().unquantized() == self.dt.unquantized(),
194            "Attempting to pack for {self} tensor {t:?}"
195        );
196        let k = t.shape()[k_axis];
197        let mn = t.shape()[mn_axis];
198        let packed_len = self.len(k, mn);
199        let panel_len = self.single_panel_len(k);
200        let panel_bytes = panel_len * t.datum_type().size_of();
201        let strides = t.strides();
202        unsafe {
203            let mut packed = Blob::new_for_size_and_align(
204                t.datum_type().size_of() * packed_len,
205                self.alignment_bytes,
206            );
207            if cfg!(debug_assertions) {
208                packed.as_bytes_mut().fill(0u8);
209            }
210            dispatch_copy!(Self::pack_t(t.datum_type())(
211                self,
212                packed.as_mut_ptr() as _,
213                t.as_ptr_unchecked(),
214                mn,
215                strides[k_axis],
216                strides[mn_axis],
217                0..k,
218                0..mn
219            ));
220            Ok(Box::new(EagerPackedInput {
221                fact: PackedOpaqueFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
222                packed: packed.into(),
223                panel_bytes,
224                mn,
225            }))
226        }
227    }
228
229    pub fn pack_tensor_view(
230        &self,
231        t: &TensorView,
232        k_axis: usize,
233        mn_axis: usize,
234    ) -> TractResult<Box<dyn MMMInputValue>> {
235        ensure!(
236            t.datum_type().unquantized() == self.dt.unquantized(),
237            "Attempting to pack for {self} tensor view {t:?}"
238        );
239        let k = t.shape()[k_axis];
240        let mn = t.shape()[mn_axis];
241        let packed_len = self.len(k, mn);
242        let panel_len = self.single_panel_len(k);
243        let panel_bytes = panel_len * t.datum_type().size_of();
244        let strides = t.strides();
245        unsafe {
246            let mut packed = Blob::new_for_size_and_align(
247                t.datum_type().size_of() * packed_len,
248                self.alignment_bytes,
249            );
250            if cfg!(debug_assertions) {
251                packed.as_bytes_mut().fill(0u8);
252            }
253            dispatch_copy!(Self::pack_t(t.datum_type())(
254                self,
255                packed.as_mut_ptr() as _,
256                t.as_ptr_unchecked(),
257                mn,
258                strides[k_axis],
259                strides[mn_axis],
260                0..k,
261                0..mn
262            ));
263            Ok(Box::new(EagerPackedInput {
264                fact: PackedOpaqueFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
265                packed: packed.into(),
266                panel_bytes,
267                mn,
268            }))
269        }
270    }
271
272    pub unsafe fn pack<'a, 'b>(
273        &self,
274        pb: impl std::borrow::BorrowMut<TensorView<'a>>,
275        b: impl std::borrow::Borrow<TensorView<'b>>,
276        k_axis: usize,
277        mn_axis: usize,
278    ) {
279        let k = b.borrow().shape()[k_axis];
280        let mn = b.borrow().shape()[mn_axis];
281        unsafe { self.pack_segment(pb, b, k_axis, mn_axis, 0..k, 0..mn) };
282    }
283
284
285    #[allow(clippy::too_many_arguments)]
286    #[rustfmt::skip]
287    pub unsafe fn pack_t<T: Datum + Copy>(
288        &self,
289        pb: *mut T,
290        b: *const T,
291        mn: usize,
292        k_stride: isize,
293        mn_stride: isize,
294        k_range: Range<usize>,
295        mn_range: Range<usize>,
296        ) { unsafe {
297        if k_range.len() == 0 || mn_range.len() == 0 {
298            return
299        }
300        if self.r == 1 && k_stride == 1 && mn == 1 {
301            pb.copy_from_nonoverlapping(b.add(k_range.start), k_range.len())
302        } else if mn_stride == 1 {
303            let size_of = T::datum_type().size_of();
304            let rbytes = self.r * size_of;
305            let mn_valid_end = mn_range.end.min(mn);
306            let mn_range_bytes = mn_range.start * size_of..mn_valid_end * size_of;
307            let k_stride_bytes = k_stride * size_of as isize;
308            let bb = b as *const u8;
309            let pbb = pb as *mut u8;
310            let panel_len = self.single_panel_len(k_range.len()) * size_of;
311            match rbytes {
312                16 => pack_mn_major::<[u8; 16]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
313                24 => pack_mn_major::<[u8; 24]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
314                32 => pack_mn_major::<[u8; 32]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
315                48 => pack_mn_major::<[u8; 48]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
316                64 => pack_mn_major::<[u8; 64]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
317                _ => {
318                    let mut packer = self.write_with_k_outer(pb, k_range.len(), mn_range.len());
319                    for k in k_range {
320                        for x in mn_range.start..mn_valid_end {
321                            packer.write(*b.offset(x as isize + k_stride * k as isize))
322                        }
323                        for _x in mn_valid_end..mn_range.end {
324                            packer.write(T::default())
325                        }
326                    }
327                }
328            }
329        } else if k_stride == 1 {
330            let mut packer = self.write_with_k_inner(pb, k_range.len(), mn);
331            let mn_valid_end = mn_range.end.min(mn);
332            for x in mn_range.start..mn_valid_end {
333                for k in k_range.clone() {
334                    packer.write(*b.offset(x as isize * mn_stride + k as isize))
335                }
336            }
337            // just ignore invalid mn_range
338        } else {
339            let mut packer = self.write_with_k_outer(pb, k_range.len(), mn);
340            let mn_valid_end = mn_range.end.min(mn);
341            for k in k_range {
342                for x in mn_range.start..mn_valid_end {
343                    packer.write(*b.offset(x as isize * mn_stride + k_stride * k as isize))
344                }
345                for _x in mn_valid_end..mn_range.end {
346                    packer.write(T::default())
347                }
348            }
349        }
350    }}
351
352    #[inline]
353    pub unsafe fn pack_segment<'a, 'b>(
354        &self,
355        mut pb: impl std::borrow::BorrowMut<TensorView<'a>>,
356        b: impl std::borrow::Borrow<TensorView<'b>>,
357        k_axis: usize,
358        mn_axis: usize,
359        k_range: Range<usize>,
360        mn_range: Range<usize>,
361    ) {
362        debug_assert!(pb.borrow().len() >= self.len(k_range.len(), mn_range.len()));
363        let pb = pb.borrow_mut();
364        let b = b.borrow();
365        let dt = pb.datum_type();
366        unsafe {
367            dispatch_copy!(Self::pack_t(dt)(
368                self,
369                pb.as_ptr_mut_unchecked(),
370                b.as_ptr_unchecked(),
371                b.shape()[mn_axis],
372                b.strides()[k_axis],
373                b.strides()[mn_axis],
374                k_range,
375                mn_range
376            ));
377        }
378    }
379
380    pub fn write_with_k_outer<'p, T: Copy + Debug>(
381        &self,
382        pb: *mut T,
383        k: usize,
384        mn: usize,
385    ) -> KOutWriter<'p, T> {
386        KOutWriter::new(pb, self.r, self.single_panel_len(k), mn, k)
387    }
388
389    pub fn write_single_panel_with_k_outer<'p, T: Copy + Debug>(
390        &self,
391        pb: *mut T,
392    ) -> KOutSinglePanelWriter<'p, T> {
393        KOutSinglePanelWriter::new(pb)
394    }
395
396    pub fn write_with_k_inner<'p, T: Copy + Debug>(
397        &self,
398        pb: *mut T,
399        k: usize,
400        mn: usize,
401    ) -> KInWriter<'p, T> {
402        let panel_len = self.single_panel_len(k);
403        KInWriter::new(pb, panel_len, self.r, mn, k)
404    }
405}
406
407pub trait PackingWriter<T: Copy> {
408    fn write(&mut self, t: T);
409}
410
411#[derive(Debug)]
412pub struct KOutSinglePanelWriter<'p, T>
413where
414    T: Copy + std::fmt::Debug,
415{
416    ptr: *mut T,
417    _phantom: PhantomData<&'p T>,
418}
419
420impl<'p, T> KOutSinglePanelWriter<'p, T>
421where
422    T: Copy + std::fmt::Debug,
423{
424    pub fn new(ptr: *mut T) -> KOutSinglePanelWriter<'p, T> {
425        KOutSinglePanelWriter { ptr, _phantom: PhantomData }
426    }
427}
428
429impl<T> PackingWriter<T> for KOutSinglePanelWriter<'_, T>
430where
431    T: Copy + std::fmt::Debug,
432{
433    #[inline(always)]
434    fn write(&mut self, t: T) {
435        unsafe {
436            *self.ptr = t;
437            self.ptr = self.ptr.offset(1);
438        }
439    }
440}
441
442#[derive(Debug)]
443pub struct KOutWriter<'p, T>
444where
445    T: Copy + std::fmt::Debug,
446{
447    ptr: *mut T,
448    panels: usize,
449    panel_width: usize,
450    last_panel_width: usize,
451    remain: usize,
452    current_panel: usize,
453    next_panel: isize,
454    next_lane: isize,
455    _phantom: PhantomData<&'p T>,
456}
457
458impl<'p, T> KOutWriter<'p, T>
459where
460    T: Copy + std::fmt::Debug,
461{
462    pub fn new(
463        ptr: *mut T,
464        panel_width: usize,
465        panel_len: usize,
466        mn: usize,
467        _k: usize,
468    ) -> KOutWriter<'p, T> {
469        let panels = mn.divceil(panel_width);
470        let last_panel_width = mn - (panels - 1) * panel_width;
471        KOutWriter {
472            ptr,
473            panels,
474            panel_width,
475            last_panel_width,
476            remain: if panels > 1 { panel_width } else { last_panel_width },
477            current_panel: 0,
478            next_panel: (panel_len - panel_width) as isize,
479            next_lane: (panel_width - last_panel_width) as isize
480                - (panel_len * (panels - 1)) as isize,
481            _phantom: PhantomData,
482        }
483    }
484}
485
486impl<T> PackingWriter<T> for KOutWriter<'_, T>
487where
488    T: Copy + std::fmt::Debug,
489{
490    #[inline(always)]
491    fn write(&mut self, t: T) {
492        unsafe {
493            *self.ptr = t;
494            self.remain -= 1;
495            self.ptr = self.ptr.offset(1);
496            if self.remain == 0 {
497                self.current_panel += 1;
498                if self.current_panel == self.panels {
499                    self.ptr = self.ptr.offset(self.next_lane);
500                    self.current_panel = 0;
501                } else {
502                    self.ptr = self.ptr.offset(self.next_panel);
503                }
504                if self.current_panel == self.panels - 1 {
505                    self.remain = self.last_panel_width;
506                } else {
507                    self.remain = self.panel_width;
508                }
509            }
510        }
511    }
512}
513
514#[derive(Debug)]
515pub struct KInWriter<'p, T>
516where
517    T: Copy + Debug,
518{
519    ptr: *mut T,
520    k: usize,
521    panels: usize,
522    panel_width: usize,
523    last_panel_width: usize,
524    remain_on_k: usize,
525    remain_on_mn: usize,
526    current_panel: usize,
527    next_mn_offset: isize,
528    next_panel_offset: isize,
529    _phantom: PhantomData<&'p T>,
530}
531
532impl<'p, T> KInWriter<'p, T>
533where
534    T: Copy + Debug,
535{
536    pub fn new(
537        ptr: *mut T,
538        panel_len: usize,
539        panel_width: usize,
540        mn: usize,
541        k: usize,
542    ) -> KInWriter<'p, T> {
543        let panels = mn.divceil(panel_width);
544        let last_panel_width = mn - (panels - 1) * panel_width;
545        KInWriter {
546            ptr,
547            k,
548            panels,
549            panel_width,
550            last_panel_width,
551            remain_on_k: k,
552            remain_on_mn: if panels == 1 { last_panel_width } else { panel_width },
553            current_panel: 0,
554            next_mn_offset: 1 - (k * panel_width) as isize,
555            next_panel_offset: panel_len as isize - (k * panel_width + panel_width - 1) as isize,
556            //                 ^ next panel     ^    ^ rewind left ^   ^ rewind up   ^
557            _phantom: PhantomData,
558        }
559    }
560}
561
562impl<T> PackingWriter<T> for KInWriter<'_, T>
563where
564    T: Copy + std::fmt::Debug,
565{
566    #[inline(always)]
567    fn write(&mut self, t: T) {
568        unsafe {
569            *self.ptr = t;
570            self.remain_on_k -= 1;
571            self.ptr = self.ptr.add(self.panel_width);
572            if self.remain_on_k == 0 {
573                self.remain_on_k = self.k;
574                self.remain_on_mn -= 1;
575                if self.remain_on_mn > 0 {
576                    self.ptr = self.ptr.offset(self.next_mn_offset);
577                } else {
578                    self.ptr = self.ptr.offset(self.next_panel_offset);
579                    self.current_panel += 1;
580                    if self.current_panel == self.panels - 1 {
581                        self.remain_on_mn = self.last_panel_width;
582                    } else {
583                        self.remain_on_mn = self.panel_width;
584                    }
585                }
586            }
587        }
588    }
589}
590
591#[inline(never)]
592unsafe fn pack_mn_major<Chunk: Copy>(
593    b: *const u8,
594    packed: *mut u8,
595    panel_len: usize,
596    k_stride_bytes: isize,
597    mn_range_bytes: Range<usize>,
598    k_range: Range<usize>,
599) {
600    unsafe {
601        let mnr = std::mem::size_of::<Chunk>();
602        let full_panes = mn_range_bytes.len() / mnr;
603        let partial_pane = mn_range_bytes.len() % mnr;
604        for k in 0..k_range.len() {
605            let mut p_row = packed.add(k * mnr);
606            let mut b_row = b.offset(
607                (k_range.start + k) as isize * k_stride_bytes + mn_range_bytes.start as isize,
608            );
609            for _ in 0..full_panes {
610                p_row.copy_from_nonoverlapping(b_row, mnr);
611                p_row = p_row.add(panel_len);
612                b_row = b_row.add(mnr);
613            }
614            if partial_pane > 0 {
615                p_row.copy_from_nonoverlapping(b_row, partial_pane);
616            }
617        }
618    }
619}
620
621pub trait Packing {
622    fn packing(r: usize) -> PackedFormat;
623}
624
625impl<D: Datum> Packing for D {
626    fn packing(r: usize) -> PackedFormat {
627        PackedFormat::new(Self::datum_type(), r, vector_size())
628    }
629}
630
631#[cfg(test)]
632mod test {
633    use std::ops::Range;
634
635    use proptest::prelude::*;
636    use tract_data::internal::num_integer::Integer;
637    use tract_data::internal::tract_ndarray::Zip;
638    use tract_data::internal::*;
639    use tract_ndarray::prelude::*;
640
641    #[derive(Debug)]
642    struct PackProblem {
643        k: usize,
644        mn: usize,
645        is_a: bool,
646        r: usize,
647        k_range: Range<usize>,
648        mn_range: Range<usize>,
649        align_panel: usize,
650    }
651
652    impl PackProblem {
653        fn input(&self) -> Array2<u32> {
654            let shape = if self.is_a { (self.mn, self.k) } else { (self.k, self.mn) };
655            let data = (0..(self.k * self.mn) as u32).collect();
656            Array2::from_shape_vec(shape, data).unwrap()
657        }
658
659        fn packer(&self) -> Array2<u32> {
660            let panels = self.mn_range.len().divceil(self.r);
661            let packer = super::PackedFormat::new(u32::datum_type(), self.r, self.align_panel)
662                .with_end_padding_record(0);
663            let input = self.input().into_tensor();
664            let panel_len = packer.single_panel_len(self.k_range.len());
665            let mut output =
666                Tensor::zero::<u32>(&[packer.len(self.k_range.len(), self.mn_range.len())])
667                    .unwrap();
668            unsafe {
669                packer.pack_segment(
670                    output.view_mut(),
671                    input.view(),
672                    self.is_a as usize,
673                    !self.is_a as usize,
674                    self.k_range.clone(),
675                    self.mn_range.clone(),
676                )
677            };
678            output.into_array::<u32>().unwrap().into_shape_with_order((panels, panel_len)).unwrap()
679        }
680
681        fn reference(&self) -> Array2<u32> {
682            let input = self.input();
683            let panels = self.mn_range.len().divceil(self.r);
684            let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
685            Array2::from_shape_fn([panels, len], |(panel, z)| {
686                let k = z / self.r;
687                let x = z % self.r;
688                let mn = panel * self.r + x + self.mn_range.start;
689                let k = k + self.k_range.start;
690                let coords = if self.is_a { (mn, k) } else { (k, mn) };
691                *input.get(coords).unwrap_or(&0)
692            })
693        }
694
695        fn valid(&self) -> Array2<bool> {
696            let panels = self.mn_range.len().divceil(self.r);
697            let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
698            Array2::from_shape_fn([panels, len], |(panel, z)| {
699                let k = z / self.r;
700                let x = z % self.r;
701                let k = k + self.k_range.start;
702                let mn = panel * self.r + x + self.mn_range.start;
703                k < self.k_range.end.min(self.k) && mn < self.mn_range.end.min(self.mn)
704            })
705        }
706
707        fn check(&self) {
708            let mut packer = self.packer();
709            let mut reference = self.reference();
710            let valid = self.valid();
711            Zip::from(&mut packer).and(&valid).for_each(|p, v| *p = if *v { *p } else { -1 as _ });
712            Zip::from(&mut reference)
713                .and(&valid)
714                .for_each(|p, v| *p = if *v { *p } else { -1 as _ });
715            assert_eq!(packer, reference);
716        }
717    }
718
719    impl Arbitrary for PackProblem {
720        type Parameters = ();
721        type Strategy = BoxedStrategy<PackProblem>;
722        fn arbitrary_with(_args: ()) -> Self::Strategy {
723            (any::<bool>(), 1usize..9, 1usize..20, 1usize..20)
724                .prop_flat_map(|(is_a, r, k, mn)| {
725                    (
726                        Just((is_a, r, k, mn)),
727                        sub_range_strat(0..k),
728                        sub_range_strat(0..mn),
729                        1usize..5,
730                    )
731                })
732                .prop_map(|((is_a, r, k, mn), k_range, mn_range, align_panel)| PackProblem {
733                    k,
734                    mn,
735                    is_a,
736                    r,
737                    k_range,
738                    mn_range,
739                    align_panel,
740                })
741                .boxed()
742        }
743    }
744
745    fn sub_range_strat(range: Range<usize>) -> BoxedStrategy<Range<usize>> {
746        (0..range.len())
747            .prop_flat_map(|cropped| (Just(cropped), 0..=cropped))
748            .prop_map(move |(cropped, left)| range.start + left..range.end - (cropped - left))
749            .boxed()
750    }
751
752    proptest::proptest! {
753        #[test]
754        fn prop(pb in any::<PackProblem>()) {
755            pb.check();
756        }
757
758        #[test]
759        fn subrange_prop(_range in sub_range_strat(0..20)) {
760        }
761
762    }
763
764    #[test]
765    fn simple_b_1() {
766        PackProblem {
767            k: 2,
768            mn: 1,
769            is_a: false,
770            r: 1,
771            k_range: 0..2,
772            mn_range: 0..1,
773            align_panel: 1,
774        }
775        .check();
776    }
777
778    #[test]
779    fn simple_b_2() {
780        PackProblem {
781            k: 2,
782            mn: 2,
783            is_a: false,
784            r: 1,
785            k_range: 0..2,
786            mn_range: 0..2,
787            align_panel: 1,
788        }
789        .check()
790    }
791
792    #[test]
793    fn simple_b_3() {
794        PackProblem {
795            k: 2,
796            mn: 1,
797            is_a: false,
798            r: 4,
799            k_range: 0..2,
800            mn_range: 0..1,
801            align_panel: 1,
802        }
803        .check();
804    }
805
806    #[test]
807    fn simple_b_4() {
808        PackProblem {
809            k: 1,
810            mn: 3,
811            is_a: false,
812            r: 2,
813            k_range: 0..1,
814            mn_range: 0..3,
815            align_panel: 1,
816        }
817        .check();
818    }
819
820    #[test]
821    fn simple_a_1() {
822        PackProblem {
823            k: 2,
824            mn: 2,
825            is_a: true,
826            r: 1,
827            k_range: 0..2,
828            mn_range: 0..2,
829            align_panel: 1,
830        }
831        .check();
832    }
833
834    #[test]
835    fn simple_a_2() {
836        PackProblem {
837            k: 2,
838            mn: 3,
839            is_a: true,
840            r: 2,
841            k_range: 0..2,
842            mn_range: 0..3,
843            align_panel: 1,
844        }
845        .check();
846    }
847
848    #[test]
849    fn range_k_0() {
850        PackProblem {
851            k: 2,
852            mn: 1,
853            is_a: false,
854            r: 1,
855            k_range: 1..2,
856            mn_range: 0..1,
857            align_panel: 1,
858        }
859        .check();
860    }
861
862    #[test]
863    fn range_k_1() {
864        PackProblem {
865            k: 2,
866            mn: 2,
867            is_a: false,
868            r: 1,
869            k_range: 0..2,
870            mn_range: 0..1,
871            align_panel: 1,
872        }
873        .check();
874    }
875
876    #[test]
877    fn range_k_2() {
878        PackProblem {
879            k: 2,
880            mn: 1,
881            is_a: false,
882            r: 6,
883            k_range: 1..2,
884            mn_range: 0..1,
885            align_panel: 1,
886        }
887        .check();
888    }
889
890    #[test]
891    fn range_mn_0() {
892        PackProblem {
893            k: 1,
894            mn: 2,
895            is_a: false,
896            r: 2,
897            k_range: 0..1,
898            mn_range: 0..1,
899            align_panel: 1,
900        }
901        .check();
902    }
903
904    #[test]
905    fn range_b_4() {
906        PackProblem {
907            k: 1,
908            mn: 2,
909            is_a: false,
910            r: 6,
911            k_range: 0..1,
912            mn_range: 1..2,
913            align_panel: 1,
914        }
915        .check();
916    }
917
918    #[test]
919    fn range_b_5() {
920        PackProblem {
921            k: 1,
922            mn: 7,
923            is_a: false,
924            r: 6,
925            k_range: 0..1,
926            mn_range: 1..7,
927            align_panel: 1,
928        }
929        .check();
930    }
931
932    #[test]
933    fn align_a_1() {
934        PackProblem {
935            k: 2,
936            mn: 2,
937            is_a: true,
938            r: 1,
939            k_range: 0..1,
940            mn_range: 0..2,
941            align_panel: 2,
942        }
943        .check();
944    }
945
946    #[test]
947    fn align_b_1() {
948        PackProblem {
949            k: 1,
950            mn: 1,
951            is_a: false,
952            r: 1,
953            k_range: 0..1,
954            mn_range: 0..1,
955            align_panel: 2,
956        }
957        .check();
958    }
959
960    #[test]
961    fn align_b_2() {
962        PackProblem {
963            k: 3,
964            mn: 1,
965            is_a: false,
966            r: 1,
967            k_range: 0..3,
968            mn_range: 0..1,
969            align_panel: 2,
970        }
971        .check();
972    }
973
974    #[test]
975    fn align_b_3() {
976        PackProblem {
977            k: 1,
978            mn: 1,
979            is_a: false,
980            r: 3,
981            k_range: 0..1,
982            mn_range: 0..1,
983            align_panel: 2,
984        }
985        .check();
986    }
987
988    #[test]
989    fn align_b_4() {
990        PackProblem {
991            k: 2,
992            mn: 1,
993            is_a: false,
994            r: 1,
995            k_range: 0..1,
996            mn_range: 0..1,
997            align_panel: 2,
998        }
999        .check();
1000    }
1001
1002    #[test]
1003    fn align_b_5() {
1004        PackProblem {
1005            k: 1,
1006            mn: 5,
1007            is_a: false,
1008            r: 4,
1009            k_range: 0..1,
1010            mn_range: 0..5,
1011            align_panel: 3,
1012        }
1013        .check();
1014    }
1015}