1use std::alloc::Layout;
2use std::fmt::{Debug, Display};
3use std::marker::PhantomData;
4use std::ops::Range;
5use tract_data::internal::*;
6
7use crate::mmm::{
8 EagerPackedInput, MMMInputFormat, MMMInputValue, PackedExoticFact, PackedMatrixStorage,
9};
10
11use crate::WeightType;
12
13#[derive(Clone, Eq, PartialEq, Hash)]
14pub struct PackedFormat {
15 pub dt: DatumType,
16 pub r: usize,
17 pub alignment_bytes: usize,
18 pub end_padding_record: usize,
19}
20
21impl MMMInputFormat for PackedFormat {
22 fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor> {
23 let packed = PackedFormat::pack_tensor(self, t, k_axis, mn_axis)?;
24 Ok(PackedMatrixStorage::new(packed).into_tensor(t.datum_type()))
25 }
26
27 fn prepare_one(
28 &self,
29 t: &Tensor,
30 k_axis: usize,
31 mn_axis: usize,
32 ) -> TractResult<Box<dyn MMMInputValue>> {
33 PackedFormat::pack_tensor(self, t, k_axis, mn_axis)
34 }
35
36 fn precursor(&self) -> WeightType {
37 WeightType::Plain(self.dt)
38 }
39
40 fn r(&self) -> usize {
41 self.r
42 }
43
44 fn k_alignment(&self) -> usize {
45 1
46 }
47
48 #[allow(clippy::collapsible_if)]
49 fn merge_with<'o, 'a: 'o, 'b: 'o>(
50 &'a self,
51 other: &'b dyn MMMInputFormat,
52 ) -> Option<&'o dyn MMMInputFormat> {
53 if let Some(other) = other.downcast_ref::<PackedFormat>() {
54 if self.r == other.r && self.dt == other.dt {
55 if self.alignment_bytes % other.alignment_bytes == 0
56 && self.end_padding_record >= other.end_padding_record
57 {
58 return Some(self);
59 }
60 if other.alignment_bytes % self.alignment_bytes == 0
61 && other.end_padding_record >= self.end_padding_record
62 {
63 return Some(other);
64 }
65 }
66 }
67 None
68 }
69
70 fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
71 self.len(k, mn) * self.dt.size_of()
72 }
73
74 fn extract_at_mn_f16(
75 &self,
76 data: &EagerPackedInput,
77 mn: usize,
78 slice: &mut [f16],
79 ) -> TractResult<()> {
80 ensure!(data.format().dyn_eq(self));
81 ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
82 unsafe {
83 let ptr = data.packed.as_ptr().add(
84 (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
85 );
86 for (i, slot) in slice.iter_mut().enumerate() {
87 let ptr = ptr.add(i * self.dt.size_of() * self.r);
88 *slot = if self.dt == f16::datum_type() {
89 *(ptr as *const f16)
90 } else if self.dt == f32::datum_type() {
91 f16::from_f32(*(ptr as *const f32))
92 } else {
93 bail!("Unexpected DT {:?}", self.dt)
94 }
95 }
96 }
97 Ok(())
98 }
99
100 fn extract_at_mn_f32(
101 &self,
102 data: &EagerPackedInput,
103 mn: usize,
104 slice: &mut [f32],
105 ) -> TractResult<()> {
106 ensure!(data.format().dyn_eq(self));
107 ensure!(self.len(data.k(), data.mn()) * self.dt.size_of() == data.packed.len());
108 unsafe {
109 let ptr = data.packed.as_ptr().add(
110 (self.single_panel_len(data.k()) * (mn / self.r) + mn % self.r) * self.dt.size_of(),
111 );
112 for (i, slot) in slice.iter_mut().enumerate() {
113 let ptr = ptr.add(i * self.dt.size_of() * self.r);
114 *slot = if self.dt == f16::datum_type() {
115 (*(ptr as *const f16)).to_f32()
116 } else if self.dt == f32::datum_type() {
117 *(ptr as *const f32)
118 } else {
119 bail!("Unexpected DT {:?}", self.dt)
120 }
121 }
122 }
123 Ok(())
124 }
125}
126
127impl Display for PackedFormat {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 write!(f, "Packed{:?}[{}]", self.dt, self.r)
130 }
131}
132
133impl Debug for PackedFormat {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 write!(
136 f,
137 "Packed{:?}[{}]@{}+{}",
138 self.dt, self.r, self.alignment_bytes, self.end_padding_record
139 )
140 }
141}
142
143impl PackedFormat {
144 pub const fn new(dt: DatumType, nr: usize, alignment_bytes: usize) -> PackedFormat {
145 PackedFormat { dt, r: nr, alignment_bytes, end_padding_record: 1 }
146 }
147
148 pub const fn with_end_padding_record(self, end_padding_record: usize) -> Self {
149 PackedFormat { end_padding_record, ..self }
150 }
151
152 #[inline]
153 pub fn align(self, alignment: usize) -> Self {
154 Self { alignment_bytes: alignment, ..self }
155 }
156
157 #[inline]
158 pub fn alignment(&self) -> usize {
159 self.alignment_bytes
160 }
161
162 #[inline]
163 pub fn panel_width(&self) -> usize {
164 self.r
165 }
166
167 #[inline]
168 pub fn len<D: DimLike>(&self, k: D, n: D) -> D {
169 n.divceil(self.r) * self.single_panel_len(k)
170 }
171
172 #[inline]
173 pub fn single_panel_len<D: DimLike>(&self, k: D) -> D {
174 ((k + self.end_padding_record) * self.r).divceil(self.alignment()) * self.alignment()
175 }
176
177 #[inline]
178 pub fn single_panel_layout(&self, k: usize, item_size: usize) -> Layout {
179 Layout::from_size_align(self.single_panel_len(k) * item_size, self.alignment()).unwrap()
180 }
181
182 pub fn pack_tensor(
183 &self,
184 t: &Tensor,
185 k_axis: usize,
186 mn_axis: usize,
187 ) -> TractResult<Box<dyn MMMInputValue>> {
188 ensure!(t.datum_type().is_copy());
189 ensure!(
190 t.datum_type().unquantized() == self.dt.unquantized(),
191 "Attempting to pack for {self} tensor {t:?}"
192 );
193 let k = t.shape()[k_axis];
194 let mn = t.shape()[mn_axis];
195 let packed_len = self.len(k, mn);
196 let panel_len = self.single_panel_len(k);
197 let panel_bytes = panel_len * t.datum_type().size_of();
198 let strides = t.strides();
199 unsafe {
200 let mut packed = Blob::new_for_size_and_align(
201 t.datum_type().size_of() * packed_len,
202 self.alignment_bytes,
203 );
204 if cfg!(debug_assertions) {
205 packed.as_bytes_mut().fill(0u8);
206 }
207 dispatch_copy!(Self::pack_t(t.datum_type())(
208 self,
209 packed.as_mut_ptr() as _,
210 t.as_ptr_unchecked(),
211 mn,
212 strides[k_axis],
213 strides[mn_axis],
214 0..k,
215 0..mn
216 ));
217 Ok(Box::new(EagerPackedInput {
218 fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
219 packed: packed.into(),
220 panel_bytes,
221 mn,
222 }))
223 }
224 }
225
226 pub fn pack_tensor_view(
227 &self,
228 t: &TensorView,
229 k_axis: usize,
230 mn_axis: usize,
231 ) -> TractResult<Box<dyn MMMInputValue>> {
232 ensure!(
233 t.datum_type().unquantized() == self.dt.unquantized(),
234 "Attempting to pack for {self} tensor view {t:?}"
235 );
236 let k = t.shape()[k_axis];
237 let mn = t.shape()[mn_axis];
238 let packed_len = self.len(k, mn);
239 let panel_len = self.single_panel_len(k);
240 let panel_bytes = panel_len * t.datum_type().size_of();
241 let strides = t.strides();
242 unsafe {
243 let mut packed = Blob::new_for_size_and_align(
244 t.datum_type().size_of() * packed_len,
245 self.alignment_bytes,
246 );
247 if cfg!(debug_assertions) {
248 packed.as_bytes_mut().fill(0u8);
249 }
250 dispatch_copy!(Self::pack_t(t.datum_type())(
251 self,
252 packed.as_mut_ptr() as _,
253 t.as_ptr_unchecked(),
254 mn,
255 strides[k_axis],
256 strides[mn_axis],
257 0..k,
258 0..mn
259 ));
260 Ok(Box::new(EagerPackedInput {
261 fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k },
262 packed: packed.into(),
263 panel_bytes,
264 mn,
265 }))
266 }
267 }
268
269 pub unsafe fn pack<'a, 'b>(
270 &self,
271 pb: impl std::borrow::BorrowMut<TensorView<'a>>,
272 b: impl std::borrow::Borrow<TensorView<'b>>,
273 k_axis: usize,
274 mn_axis: usize,
275 ) {
276 let k = b.borrow().shape()[k_axis];
277 let mn = b.borrow().shape()[mn_axis];
278 unsafe { self.pack_segment(pb, b, k_axis, mn_axis, 0..k, 0..mn) };
279 }
280
281
282 #[allow(clippy::too_many_arguments)]
283 #[rustfmt::skip]
284 pub unsafe fn pack_t<T: Datum + Copy>(
285 &self,
286 pb: *mut T,
287 b: *const T,
288 mn: usize,
289 k_stride: isize,
290 mn_stride: isize,
291 k_range: Range<usize>,
292 mn_range: Range<usize>,
293 ) { unsafe {
294 if k_range.len() == 0 || mn_range.len() == 0 {
295 return
296 }
297 if self.r == 1 && k_stride == 1 && mn == 1 {
298 pb.copy_from_nonoverlapping(b.add(k_range.start), k_range.len())
299 } else if mn_stride == 1 {
300 let size_of = T::datum_type().size_of();
301 let rbytes = self.r * size_of;
302 let mn_valid_end = mn_range.end.min(mn);
303 let mn_range_bytes = mn_range.start * size_of..mn_valid_end * size_of;
304 let k_stride_bytes = k_stride * size_of as isize;
305 let bb = b as *const u8;
306 let pbb = pb as *mut u8;
307 let panel_len = self.single_panel_len(k_range.len()) * size_of;
308 match rbytes {
309 16 => pack_mn_major::<[u8; 16]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
310 24 => pack_mn_major::<[u8; 24]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
311 32 => pack_mn_major::<[u8; 32]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
312 48 => pack_mn_major::<[u8; 48]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
313 64 => pack_mn_major::<[u8; 64]>(bb, pbb, panel_len, k_stride_bytes, mn_range_bytes, k_range),
314 _ => {
315 let mut packer = self.write_with_k_outer(pb, k_range.len(), mn_range.len());
316 for k in k_range {
317 for x in mn_range.start..mn_valid_end {
318 packer.write(*b.offset(x as isize + k_stride * k as isize))
319 }
320 for _x in mn_valid_end..mn_range.end {
321 packer.write(T::default())
322 }
323 }
324 }
325 }
326 } else if k_stride == 1 {
327 let mut packer = self.write_with_k_inner(pb, k_range.len(), mn);
328 let mn_valid_end = mn_range.end.min(mn);
329 for x in mn_range.start..mn_valid_end {
330 for k in k_range.clone() {
331 packer.write(*b.offset(x as isize * mn_stride + k as isize))
332 }
333 }
334 } else {
336 let mut packer = self.write_with_k_outer(pb, k_range.len(), mn);
337 let mn_valid_end = mn_range.end.min(mn);
338 for k in k_range {
339 for x in mn_range.start..mn_valid_end {
340 packer.write(*b.offset(x as isize * mn_stride + k_stride * k as isize))
341 }
342 for _x in mn_valid_end..mn_range.end {
343 packer.write(T::default())
344 }
345 }
346 }
347 }}
348
349 #[inline]
350 pub unsafe fn pack_segment<'a, 'b>(
351 &self,
352 mut pb: impl std::borrow::BorrowMut<TensorView<'a>>,
353 b: impl std::borrow::Borrow<TensorView<'b>>,
354 k_axis: usize,
355 mn_axis: usize,
356 k_range: Range<usize>,
357 mn_range: Range<usize>,
358 ) {
359 debug_assert!(pb.borrow().len() >= self.len(k_range.len(), mn_range.len()));
360 let pb = pb.borrow_mut();
361 let b = b.borrow();
362 let dt = pb.datum_type();
363 unsafe {
364 dispatch_copy!(Self::pack_t(dt)(
365 self,
366 pb.as_ptr_mut_unchecked(),
367 b.as_ptr_unchecked(),
368 b.shape()[mn_axis],
369 b.strides()[k_axis],
370 b.strides()[mn_axis],
371 k_range,
372 mn_range
373 ));
374 }
375 }
376
377 pub fn write_with_k_outer<'p, T: Copy + Debug>(
378 &self,
379 pb: *mut T,
380 k: usize,
381 mn: usize,
382 ) -> KOutWriter<'p, T> {
383 KOutWriter::new(pb, self.r, self.single_panel_len(k), mn, k)
384 }
385
386 pub fn write_single_panel_with_k_outer<'p, T: Copy + Debug>(
387 &self,
388 pb: *mut T,
389 ) -> KOutSinglePanelWriter<'p, T> {
390 KOutSinglePanelWriter::new(pb)
391 }
392
393 pub fn write_with_k_inner<'p, T: Copy + Debug>(
394 &self,
395 pb: *mut T,
396 k: usize,
397 mn: usize,
398 ) -> KInWriter<'p, T> {
399 let panel_len = self.single_panel_len(k);
400 KInWriter::new(pb, panel_len, self.r, mn, k)
401 }
402}
403
404pub trait PackingWriter<T: Copy> {
405 fn write(&mut self, t: T);
406}
407
408#[derive(Debug)]
409pub struct KOutSinglePanelWriter<'p, T>
410where
411 T: Copy + std::fmt::Debug,
412{
413 ptr: *mut T,
414 _phantom: PhantomData<&'p T>,
415}
416
417impl<'p, T> KOutSinglePanelWriter<'p, T>
418where
419 T: Copy + std::fmt::Debug,
420{
421 pub fn new(ptr: *mut T) -> KOutSinglePanelWriter<'p, T> {
422 KOutSinglePanelWriter { ptr, _phantom: PhantomData }
423 }
424}
425
426impl<T> PackingWriter<T> for KOutSinglePanelWriter<'_, T>
427where
428 T: Copy + std::fmt::Debug,
429{
430 #[inline(always)]
431 fn write(&mut self, t: T) {
432 unsafe {
433 *self.ptr = t;
434 self.ptr = self.ptr.offset(1);
435 }
436 }
437}
438
439#[derive(Debug)]
440pub struct KOutWriter<'p, T>
441where
442 T: Copy + std::fmt::Debug,
443{
444 ptr: *mut T,
445 panels: usize,
446 panel_width: usize,
447 last_panel_width: usize,
448 remain: usize,
449 current_panel: usize,
450 next_panel: isize,
451 next_lane: isize,
452 _phantom: PhantomData<&'p T>,
453}
454
455impl<'p, T> KOutWriter<'p, T>
456where
457 T: Copy + std::fmt::Debug,
458{
459 pub fn new(
460 ptr: *mut T,
461 panel_width: usize,
462 panel_len: usize,
463 mn: usize,
464 _k: usize,
465 ) -> KOutWriter<'p, T> {
466 let panels = mn.divceil(panel_width);
467 let last_panel_width = mn - (panels - 1) * panel_width;
468 KOutWriter {
469 ptr,
470 panels,
471 panel_width,
472 last_panel_width,
473 remain: if panels > 1 { panel_width } else { last_panel_width },
474 current_panel: 0,
475 next_panel: (panel_len - panel_width) as isize,
476 next_lane: (panel_width - last_panel_width) as isize
477 - (panel_len * (panels - 1)) as isize,
478 _phantom: PhantomData,
479 }
480 }
481}
482
483impl<T> PackingWriter<T> for KOutWriter<'_, T>
484where
485 T: Copy + std::fmt::Debug,
486{
487 #[inline(always)]
488 fn write(&mut self, t: T) {
489 unsafe {
490 *self.ptr = t;
491 self.remain -= 1;
492 self.ptr = self.ptr.offset(1);
493 if self.remain == 0 {
494 self.current_panel += 1;
495 if self.current_panel == self.panels {
496 self.ptr = self.ptr.offset(self.next_lane);
497 self.current_panel = 0;
498 } else {
499 self.ptr = self.ptr.offset(self.next_panel);
500 }
501 if self.current_panel == self.panels - 1 {
502 self.remain = self.last_panel_width;
503 } else {
504 self.remain = self.panel_width;
505 }
506 }
507 }
508 }
509}
510
511#[derive(Debug)]
512pub struct KInWriter<'p, T>
513where
514 T: Copy + Debug,
515{
516 ptr: *mut T,
517 k: usize,
518 panels: usize,
519 panel_width: usize,
520 last_panel_width: usize,
521 remain_on_k: usize,
522 remain_on_mn: usize,
523 current_panel: usize,
524 next_mn_offset: isize,
525 next_panel_offset: isize,
526 _phantom: PhantomData<&'p T>,
527}
528
529impl<'p, T> KInWriter<'p, T>
530where
531 T: Copy + Debug,
532{
533 pub fn new(
534 ptr: *mut T,
535 panel_len: usize,
536 panel_width: usize,
537 mn: usize,
538 k: usize,
539 ) -> KInWriter<'p, T> {
540 let panels = mn.divceil(panel_width);
541 let last_panel_width = mn - (panels - 1) * panel_width;
542 KInWriter {
543 ptr,
544 k,
545 panels,
546 panel_width,
547 last_panel_width,
548 remain_on_k: k,
549 remain_on_mn: if panels == 1 { last_panel_width } else { panel_width },
550 current_panel: 0,
551 next_mn_offset: 1 - (k * panel_width) as isize,
552 next_panel_offset: panel_len as isize - (k * panel_width + panel_width - 1) as isize,
553 _phantom: PhantomData,
555 }
556 }
557}
558
559impl<T> PackingWriter<T> for KInWriter<'_, T>
560where
561 T: Copy + std::fmt::Debug,
562{
563 #[inline(always)]
564 fn write(&mut self, t: T) {
565 unsafe {
566 *self.ptr = t;
567 self.remain_on_k -= 1;
568 self.ptr = self.ptr.add(self.panel_width);
569 if self.remain_on_k == 0 {
570 self.remain_on_k = self.k;
571 self.remain_on_mn -= 1;
572 if self.remain_on_mn > 0 {
573 self.ptr = self.ptr.offset(self.next_mn_offset);
574 } else {
575 self.ptr = self.ptr.offset(self.next_panel_offset);
576 self.current_panel += 1;
577 if self.current_panel == self.panels - 1 {
578 self.remain_on_mn = self.last_panel_width;
579 } else {
580 self.remain_on_mn = self.panel_width;
581 }
582 }
583 }
584 }
585 }
586}
587
588#[inline(never)]
589unsafe fn pack_mn_major<Chunk: Copy>(
590 b: *const u8,
591 packed: *mut u8,
592 panel_len: usize,
593 k_stride_bytes: isize,
594 mn_range_bytes: Range<usize>,
595 k_range: Range<usize>,
596) {
597 unsafe {
598 let mnr = std::mem::size_of::<Chunk>();
599 let full_panes = mn_range_bytes.len() / mnr;
600 let partial_pane = mn_range_bytes.len() % mnr;
601 for k in 0..k_range.len() {
602 let mut p_row = packed.add(k * mnr);
603 let mut b_row = b.offset(
604 (k_range.start + k) as isize * k_stride_bytes + mn_range_bytes.start as isize,
605 );
606 for _ in 0..full_panes {
607 p_row.copy_from_nonoverlapping(b_row, mnr);
608 p_row = p_row.add(panel_len);
609 b_row = b_row.add(mnr);
610 }
611 if partial_pane > 0 {
612 p_row.copy_from_nonoverlapping(b_row, partial_pane);
613 }
614 }
615 }
616}
617
618pub trait Packing {
619 fn packing(r: usize) -> PackedFormat;
620}
621
622impl<D: Datum> Packing for D {
623 fn packing(r: usize) -> PackedFormat {
624 PackedFormat::new(Self::datum_type(), r, vector_size())
625 }
626}
627
628#[cfg(test)]
629mod test {
630 use std::ops::Range;
631
632 use proptest::prelude::*;
633 use tract_data::internal::num_integer::Integer;
634 use tract_data::internal::tract_ndarray::Zip;
635 use tract_data::internal::*;
636 use tract_ndarray::prelude::*;
637
638 #[derive(Debug)]
639 struct PackProblem {
640 k: usize,
641 mn: usize,
642 is_a: bool,
643 r: usize,
644 k_range: Range<usize>,
645 mn_range: Range<usize>,
646 align_panel: usize,
647 }
648
649 impl PackProblem {
650 fn input(&self) -> Array2<u32> {
651 let shape = if self.is_a { (self.mn, self.k) } else { (self.k, self.mn) };
652 let data = (0..(self.k * self.mn) as u32).collect();
653 Array2::from_shape_vec(shape, data).unwrap()
654 }
655
656 fn packer(&self) -> Array2<u32> {
657 let panels = self.mn_range.len().divceil(self.r);
658 let packer = super::PackedFormat::new(u32::datum_type(), self.r, self.align_panel)
659 .with_end_padding_record(0);
660 let input = self.input().into_tensor();
661 let panel_len = packer.single_panel_len(self.k_range.len());
662 let mut output =
663 Tensor::zero::<u32>(&[packer.len(self.k_range.len(), self.mn_range.len())])
664 .unwrap();
665 unsafe {
666 packer.pack_segment(
667 output.view_mut(),
668 input.view(),
669 self.is_a as usize,
670 !self.is_a as usize,
671 self.k_range.clone(),
672 self.mn_range.clone(),
673 )
674 };
675 output
676 .into_plain_array::<u32>()
677 .unwrap()
678 .into_shape_with_order((panels, panel_len))
679 .unwrap()
680 }
681
682 fn reference(&self) -> Array2<u32> {
683 let input = self.input();
684 let panels = self.mn_range.len().divceil(self.r);
685 let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
686 Array2::from_shape_fn([panels, len], |(panel, z)| {
687 let k = z / self.r;
688 let x = z % self.r;
689 let mn = panel * self.r + x + self.mn_range.start;
690 let k = k + self.k_range.start;
691 let coords = if self.is_a { (mn, k) } else { (k, mn) };
692 *input.get(coords).unwrap_or(&0)
693 })
694 }
695
696 fn valid(&self) -> Array2<bool> {
697 let panels = self.mn_range.len().divceil(self.r);
698 let len = Integer::next_multiple_of(&(self.k_range.len() * self.r), &self.align_panel);
699 Array2::from_shape_fn([panels, len], |(panel, z)| {
700 let k = z / self.r;
701 let x = z % self.r;
702 let k = k + self.k_range.start;
703 let mn = panel * self.r + x + self.mn_range.start;
704 k < self.k_range.end.min(self.k) && mn < self.mn_range.end.min(self.mn)
705 })
706 }
707
708 fn check(&self) {
709 let mut packer = self.packer();
710 let mut reference = self.reference();
711 let valid = self.valid();
712 Zip::from(&mut packer).and(&valid).for_each(|p, v| *p = if *v { *p } else { -1 as _ });
713 Zip::from(&mut reference)
714 .and(&valid)
715 .for_each(|p, v| *p = if *v { *p } else { -1 as _ });
716 assert_eq!(packer, reference);
717 }
718 }
719
720 impl Arbitrary for PackProblem {
721 type Parameters = ();
722 type Strategy = BoxedStrategy<PackProblem>;
723 fn arbitrary_with(_args: ()) -> Self::Strategy {
724 (any::<bool>(), 1usize..9, 1usize..20, 1usize..20)
725 .prop_flat_map(|(is_a, r, k, mn)| {
726 (
727 Just((is_a, r, k, mn)),
728 sub_range_strat(0..k),
729 sub_range_strat(0..mn),
730 1usize..5,
731 )
732 })
733 .prop_map(|((is_a, r, k, mn), k_range, mn_range, align_panel)| PackProblem {
734 k,
735 mn,
736 is_a,
737 r,
738 k_range,
739 mn_range,
740 align_panel,
741 })
742 .boxed()
743 }
744 }
745
746 fn sub_range_strat(range: Range<usize>) -> BoxedStrategy<Range<usize>> {
747 (0..range.len())
748 .prop_flat_map(|cropped| (Just(cropped), 0..=cropped))
749 .prop_map(move |(cropped, left)| range.start + left..range.end - (cropped - left))
750 .boxed()
751 }
752
753 proptest::proptest! {
754 #[test]
755 fn prop(pb in any::<PackProblem>()) {
756 pb.check();
757 }
758
759 #[test]
760 fn subrange_prop(_range in sub_range_strat(0..20)) {
761 }
762
763 }
764
765 #[test]
766 fn simple_b_1() {
767 PackProblem {
768 k: 2,
769 mn: 1,
770 is_a: false,
771 r: 1,
772 k_range: 0..2,
773 mn_range: 0..1,
774 align_panel: 1,
775 }
776 .check();
777 }
778
779 #[test]
780 fn simple_b_2() {
781 PackProblem {
782 k: 2,
783 mn: 2,
784 is_a: false,
785 r: 1,
786 k_range: 0..2,
787 mn_range: 0..2,
788 align_panel: 1,
789 }
790 .check()
791 }
792
793 #[test]
794 fn simple_b_3() {
795 PackProblem {
796 k: 2,
797 mn: 1,
798 is_a: false,
799 r: 4,
800 k_range: 0..2,
801 mn_range: 0..1,
802 align_panel: 1,
803 }
804 .check();
805 }
806
807 #[test]
808 fn simple_b_4() {
809 PackProblem {
810 k: 1,
811 mn: 3,
812 is_a: false,
813 r: 2,
814 k_range: 0..1,
815 mn_range: 0..3,
816 align_panel: 1,
817 }
818 .check();
819 }
820
821 #[test]
822 fn simple_a_1() {
823 PackProblem {
824 k: 2,
825 mn: 2,
826 is_a: true,
827 r: 1,
828 k_range: 0..2,
829 mn_range: 0..2,
830 align_panel: 1,
831 }
832 .check();
833 }
834
835 #[test]
836 fn simple_a_2() {
837 PackProblem {
838 k: 2,
839 mn: 3,
840 is_a: true,
841 r: 2,
842 k_range: 0..2,
843 mn_range: 0..3,
844 align_panel: 1,
845 }
846 .check();
847 }
848
849 #[test]
850 fn range_k_0() {
851 PackProblem {
852 k: 2,
853 mn: 1,
854 is_a: false,
855 r: 1,
856 k_range: 1..2,
857 mn_range: 0..1,
858 align_panel: 1,
859 }
860 .check();
861 }
862
863 #[test]
864 fn range_k_1() {
865 PackProblem {
866 k: 2,
867 mn: 2,
868 is_a: false,
869 r: 1,
870 k_range: 0..2,
871 mn_range: 0..1,
872 align_panel: 1,
873 }
874 .check();
875 }
876
877 #[test]
878 fn range_k_2() {
879 PackProblem {
880 k: 2,
881 mn: 1,
882 is_a: false,
883 r: 6,
884 k_range: 1..2,
885 mn_range: 0..1,
886 align_panel: 1,
887 }
888 .check();
889 }
890
891 #[test]
892 fn range_mn_0() {
893 PackProblem {
894 k: 1,
895 mn: 2,
896 is_a: false,
897 r: 2,
898 k_range: 0..1,
899 mn_range: 0..1,
900 align_panel: 1,
901 }
902 .check();
903 }
904
905 #[test]
906 fn range_b_4() {
907 PackProblem {
908 k: 1,
909 mn: 2,
910 is_a: false,
911 r: 6,
912 k_range: 0..1,
913 mn_range: 1..2,
914 align_panel: 1,
915 }
916 .check();
917 }
918
919 #[test]
920 fn range_b_5() {
921 PackProblem {
922 k: 1,
923 mn: 7,
924 is_a: false,
925 r: 6,
926 k_range: 0..1,
927 mn_range: 1..7,
928 align_panel: 1,
929 }
930 .check();
931 }
932
933 #[test]
934 fn align_a_1() {
935 PackProblem {
936 k: 2,
937 mn: 2,
938 is_a: true,
939 r: 1,
940 k_range: 0..1,
941 mn_range: 0..2,
942 align_panel: 2,
943 }
944 .check();
945 }
946
947 #[test]
948 fn align_b_1() {
949 PackProblem {
950 k: 1,
951 mn: 1,
952 is_a: false,
953 r: 1,
954 k_range: 0..1,
955 mn_range: 0..1,
956 align_panel: 2,
957 }
958 .check();
959 }
960
961 #[test]
962 fn align_b_2() {
963 PackProblem {
964 k: 3,
965 mn: 1,
966 is_a: false,
967 r: 1,
968 k_range: 0..3,
969 mn_range: 0..1,
970 align_panel: 2,
971 }
972 .check();
973 }
974
975 #[test]
976 fn align_b_3() {
977 PackProblem {
978 k: 1,
979 mn: 1,
980 is_a: false,
981 r: 3,
982 k_range: 0..1,
983 mn_range: 0..1,
984 align_panel: 2,
985 }
986 .check();
987 }
988
989 #[test]
990 fn align_b_4() {
991 PackProblem {
992 k: 2,
993 mn: 1,
994 is_a: false,
995 r: 1,
996 k_range: 0..1,
997 mn_range: 0..1,
998 align_panel: 2,
999 }
1000 .check();
1001 }
1002
1003 #[test]
1004 fn align_b_5() {
1005 PackProblem {
1006 k: 1,
1007 mn: 5,
1008 is_a: false,
1009 r: 4,
1010 k_range: 0..1,
1011 mn_range: 0..5,
1012 align_panel: 3,
1013 }
1014 .check();
1015 }
1016}