Skip to main content

tract_linalg/frame/
pack.rs

1use std::alloc::Layout;
2use std::fmt::{Debug, Display};
3use std::marker::PhantomData;
4use std::ops::Range;
5use tract_data::internal::*;
6
7use crate::mmm::{
8    EagerPackedInput, MMMInputFormat, MMMInputValue, PackedExoticFact, PackedMatrixStorage,
9};
10
11use crate::WeightType;
12
13#[derive(Clone, Eq, PartialEq, Hash)]
14pub struct PackedFormat {
15    pub dt: DatumType,
16    pub r: usize,
17    pub alignment_bytes: usize,
18    pub end_padding_record: usize,
19}
20
21impl MMMInputFormat for PackedFormat {
22    fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor> {
23        let packed = PackedFormat::pack_tensor(self, t, k_axis, mn_axis)?;
24        Ok(PackedMatrixStorage::new(packed).into_tensor(t.datum_type()))
25    }
26
27    fn prepare_one(
28        &self,
29        t: &Tensor,
30        k_axis: usize,
31        mn_axis: usize,
32    ) -> TractResult<Box<dyn MMMInputValue>> {
33        PackedFormat::pack_tensor(self, t, k_axis, mn_axis)
34    }
35
36    fn precursor(&self) -> WeightType {
37        WeightType::Plain(self.dt)
38    }
39
40    fn r(&self) -> usize {
41        self.r
42    }
43
44    fn k_alignment(&self) -> usize {
45        1
46    }
47
48    #[allow(clippy::collapsible_if)]
49    fn merge_with<'o, 'a: 'o, 'b: 'o>(
50        &'a self,
51        other: &'b dyn MMMInputFormat,
52    ) -> Option<&'o dyn MMMInputFormat> {
53        if let Some(other) = other.downcast_ref::<PackedFormat>() {
54            if self.r == other.r && self.dt == other.dt {
55                if self.alignment_bytes % other.alignment_bytes == 0
56                    && self.end_padding_record >= other.end_padding_record
57                {
58                    return Some(self);
59                }
60                if other.alignment_bytes % self.alignment_bytes == 0
61                    && other.end_padding_record >= self.end_padding_record
62                {
63                    return Some(other);
64                }
65            }
66        }
67        None
68    }
69
70    fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
71        self.len(k, mn) * self.dt.size_of()
72    }
73
74    fn extract_at_mn_f16(
75        &self,
76        data: &EagerPackedInput,
77        mn: usize,
78        slice: &mut [f16],
79    ) -> TractResult<()> {
80        ensure!(data.format().dyn_eq(self));
81        ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
82        unsafe {
83            let ptr = data.packed.as_ptr().add(
84                (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
85            );
86            for (i, slot) in slice.iter_mut().enumerate() {
87                let ptr = ptr.add(i * self.dt.size_of() * self.r);
88                *slot = if self.dt == f16::datum_type() {
89                    *(ptr as *const f16)
90                } else if self.dt == f32::datum_type() {
91                    f16::from_f32(*(ptr as *const f32))
92                } else {
93                    bail!("Unexpected DT {:?}", self.dt)
94                }
95            }
96        }
97        Ok(())
98    }
99
100    fn extract_at_mn_f32(
101        &self,
102        data: &EagerPackedInput,
103        mn: usize,
104        slice: &mut [f32],
105    ) -> TractResult<()> {
106        ensure!(data.format().dyn_eq(self));
107        ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
108        unsafe {
109            let ptr = data.packed.as_ptr().add(
110                (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
111            );
112            for (i, slot) in slice.iter_mut().enumerate() {
113                let ptr = ptr.add(i * self.dt.size_of() * self.r);
114                *slot = if self.dt == f16::datum_type() {
115                    (*(ptr as *const f16)).to_f32()
116                } else if self.dt == f32::datum_type() {
117                    *(ptr as *const f32)
118                } else {
119                    bail!("Unexpected DT {:?}", self.dt)
120                }
121            }
122        }
123        Ok(())
124    }
125}
126
127impl Display for PackedFormat {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        write!(f, "Packed{:?}[{}]", self.dt, self.r)
130    }
131}
132
133impl Debug for PackedFormat {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        write!(
136            f,
137            "Packed{:?}[{}]@{}+{}",
138            self.dt, self.r, self.alignment_bytes, self.end_padding_record
139        )
140    }
141}
142
143impl PackedFormat {
144    pub const fn new(dt: DatumType, nr: usize, alignment_bytes: usize) -> PackedFormat {
145        PackedFormat { dt, r: nr, alignment_bytes, end_padding_record: 1 }
146    }
147
148    pub const fn with_end_padding_record(self, end_padding_record: usize) -> Self {
149        PackedFormat { end_padding_record, ..self }
150    }
151
152    #[inline]
153    pub fn align(self, alignment: usize) -> Self {
154        Self { alignment_bytes: alignment, ..self }
155    }
156
157    #[inline]
158    pub fn alignment(&self) -> usize {
159        self.alignment_bytes
160    }
161
162    #[inline]
163    pub fn panel_width(&self) -> usize {
164        self.r
165    }
166
167    #[inline]
168    pub fn len<D: DimLike>(&self, k: D, n: D) -> D {
169        n.divceil(self.r) * self.single_panel_len(k)
170    }
171
172    #[inline]
173    pub fn single_panel_len<D: DimLike>(&self, k: D) -> D {
174        ((k + self.end_padding_record) * self.r).divceil(self.alignment()) * self.alignment()
175    }
176
177    #[inline]
178    pub fn single_panel_layout(&self, k: usize, item_size: usize) -> Layout {
179        Layout::from_size_align(self.single_panel_len(k) * item_size, self.alignment()).unwrap()
180    }
181
182    pub fn pack_tensor(
183        &self,
184        t: &Tensor,
185        k_axis: usize,
186        mn_axis: usize,
187    ) -> TractResult<Box<dyn MMMInputValue>> {
188        ensure!(t.datum_type().is_copy());
189        ensure!(
190            t.datum_type().unquantized() == self.dt.unquantized(),
191            "Attempting to pack for {self} tensor {t:?}"
192        );
193        let k = t.shape()[k_axis];
194        let mn = t.shape()[mn_axis];
195        let packed_len = self.len(k, mn);
196        let panel_len = self.single_panel_len(k);
197        let panel_bytes = panel_len * t.datum_type().size_of();
198        let strides = t.strides();
199        unsafe {
200            let mut packed = Blob::new_for_size_and_align(
201                t.datum_type().size_of() * packed_len,
202                self.alignment_bytes,
203            );
204            if cfg!(debug_assertions) {
205                packed.as_bytes_mut().fill(0u8);
206            }
207            dispatch_copy!(Self::pack_t(t.datum_type())(
208                self,
209                packed.as_mut_ptr() as _,
210                t.as_ptr_unchecked(),
211                mn,
212                strides[k_axis],
213                strides[mn_axis],
214                0..k,
215                0..mn
216            ));
217            Ok(Box::new(EagerPackedInput {
218                fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
219                packed: packed.into(),
220                panel_bytes,
221                mn,
222            }))
223        }
224    }
225
226    pub fn pack_tensor_view(
227        &self,
228        t: &TensorView,
229        k_axis: usize,
230        mn_axis: usize,
231    ) -> TractResult<Box<dyn MMMInputValue>> {
232        ensure!(
233            t.datum_type().unquantized() == self.dt.unquantized(),
234            "Attempting to pack for {self} tensor view {t:?}"
235        );
236        let k = t.shape()[k_axis];
237        let mn = t.shape()[mn_axis];
238        let packed_len = self.len(k, mn);
239        let panel_len = self.single_panel_len(k);
240        let panel_bytes = panel_len * t.datum_type().size_of();
241        let strides = t.strides();
242        unsafe {
243            let mut packed = Blob::new_for_size_and_align(
244                t.datum_type().size_of() * packed_len,
245                self.alignment_bytes,
246            );
247            if cfg!(debug_assertions) {
248                packed.as_bytes_mut().fill(0u8);
249            }
250            dispatch_copy!(Self::pack_t(t.datum_type())(
251                self,
252                packed.as_mut_ptr() as _,
253                t.as_ptr_unchecked(),
254                mn,
255                strides[k_axis],
256                strides[mn_axis],
257                0..k,
258                0..mn
259            ));
260            Ok(Box::new(EagerPackedInput {
261                fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
262                packed: packed.into(),
263                panel_bytes,
264                mn,
265            }))
266        }
267    }
268
269    pub unsafe fn pack<'a, 'b>(
270        &self,
271        pb: impl std::borrow::BorrowMut<TensorView<'a>>,
272        b: impl std::borrow::Borrow<TensorView<'b>>,
273        k_axis: usize,
274        mn_axis: usize,
275    ) {
276        let k = b.borrow().shape()[k_axis];
277        let mn = b.borrow().shape()[mn_axis];
278        unsafe { self.pack_segment(pb, b, k_axis, mn_axis, 0..k, 0..mn) };
279    }
280
281
282    #[allow(clippy::too_many_arguments)]
283    #[rustfmt::skip]
284    pub unsafe fn pack_t<T: Datum + Copy>(
285        &self,
286        pb: *mut T,
287        b: *const T,
288        mn: usize,
289        k_stride: isize,
290        mn_stride: isize,
291        k_range: Range<usize>,
292        mn_range: Range<usize>,
293        ) { unsafe {
294        if k_range.len() == 0 || mn_range.len() == 0 {
295            return
296        }
297        if self.r == 1 && k_stride == 1 && mn == 1 {
298            pb.copy_from_nonoverlapping(b.add(k_range.start), k_range.len())
299        } else if mn_stride == 1 {
300            let size_of = T::datum_type().size_of();
301            let rbytes = self.r * size_of;
302            let mn_valid_end = mn_range.end.min(mn);
303            let mn_range_bytes = mn_range.start * size_of..mn_valid_end * size_of;
304            let k_stride_bytes = k_stride * size_of as isize;
305            let bb = b as *const u8;
306            let pbb = pb as *mut u8;
307            let panel_len = self.single_panel_len(k_range.len()) * size_of;
308            match rbytes {
309                16 => pack_mn_major::<[u8; 16]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
310                24 => pack_mn_major::<[u8; 24]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
311                32 => pack_mn_major::<[u8; 32]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
312                48 => pack_mn_major::<[u8; 48]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
313                64 => pack_mn_major::<[u8; 64]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
314                _ => {
315                    let mut packer = self.write_with_k_outer(pb, k_range.len(), mn_range.len());
316                    for k in k_range {
317                        for x in mn_range.start..mn_valid_end {
318                            packer.write(*b.offset(x as isize + k_stride * k as isize))
319                        }
320                        for _x in mn_valid_end..mn_range.end {
321                            packer.write(T::default())
322                        }
323                    }
324                }
325            }
326        } else if k_stride == 1 {
327            let mut packer = self.write_with_k_inner(pb, k_range.len(), mn);
328            let mn_valid_end = mn_range.end.min(mn);
329            for x in mn_range.start..mn_valid_end {
330                for k in k_range.clone() {
331                    packer.write(*b.offset(x as isize * mn_stride + k as isize))
332                }
333            }
334            // just ignore invalid mn_range
335        } else {
336            let mut packer = self.write_with_k_outer(pb, k_range.len(), mn);
337            let mn_valid_end = mn_range.end.min(mn);
338            for k in k_range {
339                for x in mn_range.start..mn_valid_end {
340                    packer.write(*b.offset(x as isize * mn_stride + k_stride * k as isize))
341                }
342                for _x in mn_valid_end..mn_range.end {
343                    packer.write(T::default())
344                }
345            }
346        }
347    }}
348
349    #[inline]
350    pub unsafe fn pack_segment<'a, 'b>(
351        &self,
352        mut pb: impl std::borrow::BorrowMut<TensorView<'a>>,
353        b: impl std::borrow::Borrow<TensorView<'b>>,
354        k_axis: usize,
355        mn_axis: usize,
356        k_range: Range<usize>,
357        mn_range: Range<usize>,
358    ) {
359        debug_assert!(pb.borrow().len() >= self.len(k_range.len(), mn_range.len()));
360        let pb = pb.borrow_mut();
361        let b = b.borrow();
362        let dt = pb.datum_type();
363        unsafe {
364            dispatch_copy!(Self::pack_t(dt)(
365                self,
366                pb.as_ptr_mut_unchecked(),
367                b.as_ptr_unchecked(),
368                b.shape()[mn_axis],
369                b.strides()[k_axis],
370                b.strides()[mn_axis],
371                k_range,
372                mn_range
373            ));
374        }
375    }
376
377    pub fn write_with_k_outer<'p, T: Copy + Debug>(
378        &self,
379        pb: *mut T,
380        k: usize,
381        mn: usize,
382    ) -> KOutWriter<'p, T> {
383        KOutWriter::new(pb, self.r, self.single_panel_len(k), mn, k)
384    }
385
386    pub fn write_single_panel_with_k_outer<'p, T: Copy + Debug>(
387        &self,
388        pb: *mut T,
389    ) -> KOutSinglePanelWriter<'p, T> {
390        KOutSinglePanelWriter::new(pb)
391    }
392
393    pub fn write_with_k_inner<'p, T: Copy + Debug>(
394        &self,
395        pb: *mut T,
396        k: usize,
397        mn: usize,
398    ) -> KInWriter<'p, T> {
399        let panel_len = self.single_panel_len(k);
400        KInWriter::new(pb, panel_len, self.r, mn, k)
401    }
402}
403
404pub trait PackingWriter<T: Copy> {
405    fn write(&mut self, t: T);
406}
407
408#[derive(Debug)]
409pub struct KOutSinglePanelWriter<'p, T>
410where
411    T: Copy + std::fmt::Debug,
412{
413    ptr: *mut T,
414    _phantom: PhantomData<&'p T>,
415}
416
417impl<'p, T> KOutSinglePanelWriter<'p, T>
418where
419    T: Copy + std::fmt::Debug,
420{
421    pub fn new(ptr: *mut T) -> KOutSinglePanelWriter<'p, T> {
422        KOutSinglePanelWriter { ptr, _phantom: PhantomData }
423    }
424}
425
426impl<T> PackingWriter<T> for KOutSinglePanelWriter<'_, T>
427where
428    T: Copy + std::fmt::Debug,
429{
430    #[inline(always)]
431    fn write(&mut self, t: T) {
432        unsafe {
433            *self.ptr = t;
434            self.ptr = self.ptr.offset(1);
435        }
436    }
437}
438
439#[derive(Debug)]
440pub struct KOutWriter<'p, T>
441where
442    T: Copy + std::fmt::Debug,
443{
444    ptr: *mut T,
445    panels: usize,
446    panel_width: usize,
447    last_panel_width: usize,
448    remain: usize,
449    current_panel: usize,
450    next_panel: isize,
451    next_lane: isize,
452    _phantom: PhantomData<&'p T>,
453}
454
455impl<'p, T> KOutWriter<'p, T>
456where
457    T: Copy + std::fmt::Debug,
458{
459    pub fn new(
460        ptr: *mut T,
461        panel_width: usize,
462        panel_len: usize,
463        mn: usize,
464        _k: usize,
465    ) -> KOutWriter<'p, T> {
466        let panels = mn.divceil(panel_width);
467        let last_panel_width = mn - (panels - 1) * panel_width;
468        KOutWriter {
469            ptr,
470            panels,
471            panel_width,
472            last_panel_width,
473            remain: if panels > 1 { panel_width } else { last_panel_width },
474            current_panel: 0,
475            next_panel: (panel_len - panel_width) as isize,
476            next_lane: (panel_width - last_panel_width) as isize
477                - (panel_len * (panels - 1)) as isize,
478            _phantom: PhantomData,
479        }
480    }
481}
482
483impl<T> PackingWriter<T> for KOutWriter<'_, T>
484where
485    T: Copy + std::fmt::Debug,
486{
487    #[inline(always)]
488    fn write(&mut self, t: T) {
489        unsafe {
490            *self.ptr = t;
491            self.remain -= 1;
492            self.ptr = self.ptr.offset(1);
493            if self.remain == 0 {
494                self.current_panel += 1;
495                if self.current_panel == self.panels {
496                    self.ptr = self.ptr.offset(self.next_lane);
497                    self.current_panel = 0;
498                } else {
499                    self.ptr = self.ptr.offset(self.next_panel);
500                }
501                if self.current_panel == self.panels - 1 {
502                    self.remain = self.last_panel_width;
503                } else {
504                    self.remain = self.panel_width;
505                }
506            }
507        }
508    }
509}
510
511#[derive(Debug)]
512pub struct KInWriter<'p, T>
513where
514    T: Copy + Debug,
515{
516    ptr: *mut T,
517    k: usize,
518    panels: usize,
519    panel_width: usize,
520    last_panel_width: usize,
521    remain_on_k: usize,
522    remain_on_mn: usize,
523    current_panel: usize,
524    next_mn_offset: isize,
525    next_panel_offset: isize,
526    _phantom: PhantomData<&'p T>,
527}
528
529impl<'p, T> KInWriter<'p, T>
530where
531    T: Copy + Debug,
532{
533    pub fn new(
534        ptr: *mut T,
535        panel_len: usize,
536        panel_width: usize,
537        mn: usize,
538        k: usize,
539    ) -> KInWriter<'p, T> {
540        let panels = mn.divceil(panel_width);
541        let last_panel_width = mn - (panels - 1) * panel_width;
542        KInWriter {
543            ptr,
544            k,
545            panels,
546            panel_width,
547            last_panel_width,
548            remain_on_k: k,
549            remain_on_mn: if panels == 1 { last_panel_width } else { panel_width },
550            current_panel: 0,
551            next_mn_offset: 1 - (k * panel_width) as isize,
552            next_panel_offset: panel_len as isize - (k * panel_width + panel_width - 1) as isize,
553            //                 ^ next panel     ^    ^ rewind left ^   ^ rewind up   ^
554            _phantom: PhantomData,
555        }
556    }
557}
558
559impl<T> PackingWriter<T> for KInWriter<'_, T>
560where
561    T: Copy + std::fmt::Debug,
562{
563    #[inline(always)]
564    fn write(&mut self, t: T) {
565        unsafe {
566            *self.ptr = t;
567            self.remain_on_k -= 1;
568            self.ptr = self.ptr.add(self.panel_width);
569            if self.remain_on_k == 0 {
570                self.remain_on_k = self.k;
571                self.remain_on_mn -= 1;
572                if self.remain_on_mn > 0 {
573                    self.ptr = self.ptr.offset(self.next_mn_offset);
574                } else {
575                    self.ptr = self.ptr.offset(self.next_panel_offset);
576                    self.current_panel += 1;
577                    if self.current_panel == self.panels - 1 {
578                        self.remain_on_mn = self.last_panel_width;
579                    } else {
580                        self.remain_on_mn = self.panel_width;
581                    }
582                }
583            }
584        }
585    }
586}
587
588#[inline(never)]
589unsafe fn pack_mn_major<Chunk: Copy>(
590    b: *const u8,
591    packed: *mut u8,
592    panel_len: usize,
593    k_stride_bytes: isize,
594    mn_range_bytes: Range<usize>,
595    k_range: Range<usize>,
596) {
597    unsafe {
598        let mnr = std::mem::size_of::<Chunk>();
599        let full_panes = mn_range_bytes.len() / mnr;
600        let partial_pane = mn_range_bytes.len() % mnr;
601        for k in 0..k_range.len() {
602            let mut p_row = packed.add(k * mnr);
603            let mut b_row = b.offset(
604                (k_range.start + k) as isize * k_stride_bytes + mn_range_bytes.start as isize,
605            );
606            for _ in 0..full_panes {
607                p_row.copy_from_nonoverlapping(b_row, mnr);
608                p_row = p_row.add(panel_len);
609                b_row = b_row.add(mnr);
610            }
611            if partial_pane > 0 {
612                p_row.copy_from_nonoverlapping(b_row, partial_pane);
613            }
614        }
615    }
616}
617
618pub trait Packing {
619    fn packing(r: usize) -> PackedFormat;
620}
621
622impl<D: Datum> Packing for D {
623    fn packing(r: usize) -> PackedFormat {
624        PackedFormat::new(Self::datum_type(), r, vector_size())
625    }
626}
627
628#[cfg(test)]
629mod test {
630    use std::ops::Range;
631
632    use proptest::prelude::*;
633    use tract_data::internal::num_integer::Integer;
634    use tract_data::internal::tract_ndarray::Zip;
635    use tract_data::internal::*;
636    use tract_ndarray::prelude::*;
637
638    #[derive(Debug)]
639    struct PackProblem {
640        k: usize,
641        mn: usize,
642        is_a: bool,
643        r: usize,
644        k_range: Range<usize>,
645        mn_range: Range<usize>,
646        align_panel: usize,
647    }
648
649    impl PackProblem {
650        fn input(&self) -> Array2<u32> {
651            let shape = if self.is_a { (self.mn, self.k) } else { (self.k, self.mn) };
652            let data = (0..(self.k * self.mn) as u32).collect();
653            Array2::from_shape_vec(shape, data).unwrap()
654        }
655
656        fn packer(&self) -> Array2<u32> {
657            let panels = self.mn_range.len().divceil(self.r);
658            let packer = super::PackedFormat::new(u32::datum_type(), self.r, self.align_panel)
659                .with_end_padding_record(0);
660            let input = self.input().into_tensor();
661            let panel_len = packer.single_panel_len(self.k_range.len());
662            let mut output =
663                Tensor::zero::<u32>(&[packer.len(self.k_range.len(), self.mn_range.len())])
664                    .unwrap();
665            unsafe {
666                packer.pack_segment(
667                    output.view_mut(),
668                    input.view(),
669                    self.is_a as usize,
670                    !self.is_a as usize,
671                    self.k_range.clone(),
672                    self.mn_range.clone(),
673                )
674            };
675            output
676                .into_plain_array::<u32>()
677                .unwrap()
678                .into_shape_with_order((panels, panel_len))
679                .unwrap()
680        }
681
682        fn reference(&self) -> Array2<u32> {
683            let input = self.input();
684            let panels = self.mn_range.len().divceil(self.r);
685            let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
686            Array2::from_shape_fn([panels, len], |(panel, z)| {
687                let k = z / self.r;
688                let x = z % self.r;
689                let mn = panel * self.r + x + self.mn_range.start;
690                let k = k + self.k_range.start;
691                let coords = if self.is_a { (mn, k) } else { (k, mn) };
692                *input.get(coords).unwrap_or(&0)
693            })
694        }
695
696        fn valid(&self) -> Array2<bool> {
697            let panels = self.mn_range.len().divceil(self.r);
698            let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
699            Array2::from_shape_fn([panels, len], |(panel, z)| {
700                let k = z / self.r;
701                let x = z % self.r;
702                let k = k + self.k_range.start;
703                let mn = panel * self.r + x + self.mn_range.start;
704                k < self.k_range.end.min(self.k) && mn < self.mn_range.end.min(self.mn)
705            })
706        }
707
708        fn check(&self) {
709            let mut packer = self.packer();
710            let mut reference = self.reference();
711            let valid = self.valid();
712            Zip::from(&mut packer).and(&valid).for_each(|p, v| *p = if *v { *p } else { -1 as _ });
713            Zip::from(&mut reference)
714                .and(&valid)
715                .for_each(|p, v| *p = if *v { *p } else { -1 as _ });
716            assert_eq!(packer, reference);
717        }
718    }
719
720    impl Arbitrary for PackProblem {
721        type Parameters = ();
722        type Strategy = BoxedStrategy<PackProblem>;
723        fn arbitrary_with(_args: ()) -> Self::Strategy {
724            (any::<bool>(), 1usize..9, 1usize..20, 1usize..20)
725                .prop_flat_map(|(is_a, r, k, mn)| {
726                    (
727                        Just((is_a, r, k, mn)),
728                        sub_range_strat(0..k),
729                        sub_range_strat(0..mn),
730                        1usize..5,
731                    )
732                })
733                .prop_map(|((is_a, r, k, mn), k_range, mn_range, align_panel)| PackProblem {
734                    k,
735                    mn,
736                    is_a,
737                    r,
738                    k_range,
739                    mn_range,
740                    align_panel,
741                })
742                .boxed()
743        }
744    }
745
746    fn sub_range_strat(range: Range<usize>) -> BoxedStrategy<Range<usize>> {
747        (0..range.len())
748            .prop_flat_map(|cropped| (Just(cropped), 0..=cropped))
749            .prop_map(move |(cropped, left)| range.start + left..range.end - (cropped - left))
750            .boxed()
751    }
752
753    proptest::proptest! {
754        #[test]
755        fn prop(pb in any::<PackProblem>()) {
756            pb.check();
757        }
758
759        #[test]
760        fn subrange_prop(_range in sub_range_strat(0..20)) {
761        }
762
763    }
764
765    #[test]
766    fn simple_b_1() {
767        PackProblem {
768            k: 2,
769            mn: 1,
770            is_a: false,
771            r: 1,
772            k_range: 0..2,
773            mn_range: 0..1,
774            align_panel: 1,
775        }
776        .check();
777    }
778
779    #[test]
780    fn simple_b_2() {
781        PackProblem {
782            k: 2,
783            mn: 2,
784            is_a: false,
785            r: 1,
786            k_range: 0..2,
787            mn_range: 0..2,
788            align_panel: 1,
789        }
790        .check()
791    }
792
793    #[test]
794    fn simple_b_3() {
795        PackProblem {
796            k: 2,
797            mn: 1,
798            is_a: false,
799            r: 4,
800            k_range: 0..2,
801            mn_range: 0..1,
802            align_panel: 1,
803        }
804        .check();
805    }
806
807    #[test]
808    fn simple_b_4() {
809        PackProblem {
810            k: 1,
811            mn: 3,
812            is_a: false,
813            r: 2,
814            k_range: 0..1,
815            mn_range: 0..3,
816            align_panel: 1,
817        }
818        .check();
819    }
820
821    #[test]
822    fn simple_a_1() {
823        PackProblem {
824            k: 2,
825            mn: 2,
826            is_a: true,
827            r: 1,
828            k_range: 0..2,
829            mn_range: 0..2,
830            align_panel: 1,
831        }
832        .check();
833    }
834
835    #[test]
836    fn simple_a_2() {
837        PackProblem {
838            k: 2,
839            mn: 3,
840            is_a: true,
841            r: 2,
842            k_range: 0..2,
843            mn_range: 0..3,
844            align_panel: 1,
845        }
846        .check();
847    }
848
849    #[test]
850    fn range_k_0() {
851        PackProblem {
852            k: 2,
853            mn: 1,
854            is_a: false,
855            r: 1,
856            k_range: 1..2,
857            mn_range: 0..1,
858            align_panel: 1,
859        }
860        .check();
861    }
862
863    #[test]
864    fn range_k_1() {
865        PackProblem {
866            k: 2,
867            mn: 2,
868            is_a: false,
869            r: 1,
870            k_range: 0..2,
871            mn_range: 0..1,
872            align_panel: 1,
873        }
874        .check();
875    }
876
877    #[test]
878    fn range_k_2() {
879        PackProblem {
880            k: 2,
881            mn: 1,
882            is_a: false,
883            r: 6,
884            k_range: 1..2,
885            mn_range: 0..1,
886            align_panel: 1,
887        }
888        .check();
889    }
890
891    #[test]
892    fn range_mn_0() {
893        PackProblem {
894            k: 1,
895            mn: 2,
896            is_a: false,
897            r: 2,
898            k_range: 0..1,
899            mn_range: 0..1,
900            align_panel: 1,
901        }
902        .check();
903    }
904
905    #[test]
906    fn range_b_4() {
907        PackProblem {
908            k: 1,
909            mn: 2,
910            is_a: false,
911            r: 6,
912            k_range: 0..1,
913            mn_range: 1..2,
914            align_panel: 1,
915        }
916        .check();
917    }
918
919    #[test]
920    fn range_b_5() {
921        PackProblem {
922            k: 1,
923            mn: 7,
924            is_a: false,
925            r: 6,
926            k_range: 0..1,
927            mn_range: 1..7,
928            align_panel: 1,
929        }
930        .check();
931    }
932
933    #[test]
934    fn align_a_1() {
935        PackProblem {
936            k: 2,
937            mn: 2,
938            is_a: true,
939            r: 1,
940            k_range: 0..1,
941            mn_range: 0..2,
942            align_panel: 2,
943        }
944        .check();
945    }
946
947    #[test]
948    fn align_b_1() {
949        PackProblem {
950            k: 1,
951            mn: 1,
952            is_a: false,
953            r: 1,
954            k_range: 0..1,
955            mn_range: 0..1,
956            align_panel: 2,
957        }
958        .check();
959    }
960
961    #[test]
962    fn align_b_2() {
963        PackProblem {
964            k: 3,
965            mn: 1,
966            is_a: false,
967            r: 1,
968            k_range: 0..3,
969            mn_range: 0..1,
970            align_panel: 2,
971        }
972        .check();
973    }
974
975    #[test]
976    fn align_b_3() {
977        PackProblem {
978            k: 1,
979            mn: 1,
980            is_a: false,
981            r: 3,
982            k_range: 0..1,
983            mn_range: 0..1,
984            align_panel: 2,
985        }
986        .check();
987    }
988
989    #[test]
990    fn align_b_4() {
991        PackProblem {
992            k: 2,
993            mn: 1,
994            is_a: false,
995            r: 1,
996            k_range: 0..1,
997            mn_range: 0..1,
998            align_panel: 2,
999        }
1000        .check();
1001    }
1002
1003    #[test]
1004    fn align_b_5() {
1005        PackProblem {
1006            k: 1,
1007            mn: 5,
1008            is_a: false,
1009            r: 4,
1010            k_range: 0..1,
1011            mn_range: 0..5,
1012            align_panel: 3,
1013        }
1014        .check();
1015    }
1016}