fenris_paradis/
lib.rs

1//! paradis
2//! =======
3//!
4//! Parallel processing of disjoint subsets.
5
6pub mod adapter;
7pub mod coloring;
8pub mod slice;
9
10use fenris_nested_vec::NestedVec;
11use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer};
12use rayon::iter::{IndexedParallelIterator, ParallelIterator};
13use serde::{Deserialize, Serialize};
14use std::cmp::max;
15use std::collections::HashSet;
16use std::fmt;
17use std::fmt::Debug;
18
19pub struct SubsetAccess<'data, Access> {
20    subset_label: usize,
21    global_indices: &'data [usize],
22    access: Access,
23}
24
25impl<'data, Access> SubsetAccess<'data, Access> {
26    pub fn global_indices(&self) -> &[usize] {
27        &self.global_indices
28    }
29
30    pub fn label(&self) -> usize {
31        self.subset_label
32    }
33
34    pub fn len(&self) -> usize {
35        self.global_indices().len()
36    }
37
38    pub fn get<'b>(&'b self, local_index: usize) -> <Access as ParallelIndexedAccess<'b>>::Record
39    where
40        'data: 'b,
41        Access: ParallelIndexedAccess<'b>,
42    {
43        let global_index = self.global_indices[local_index];
44        unsafe { self.access.get_unchecked(global_index) }
45    }
46
47    pub fn get_mut<'b>(&'b mut self, local_index: usize) -> <Access as ParallelIndexedAccess<'b>>::RecordMut
48    where
49        'data: 'b,
50        Access: ParallelIndexedAccess<'b>,
51    {
52        let global_index = self.global_indices[local_index];
53        unsafe { self.access.get_unchecked_mut(global_index) }
54    }
55}
56
57// TODO: Does this trait need to be unsafe, or does it suffice to have unsafe methods?
58/// Facilitates parallel access to (mutable) records stored in the collection.
59///
60/// The trait provides parallel access to (possibly mutable) *records*, defined by the
61/// associated types [`Record`][`ParallelIndexedAccess::Record`] and
62/// [`RecordMut`][`ParallelIndexedAccess::RecordMut`].
63///
64/// # Safety
65///
66/// An implementor must ensure that it is sound for multiple threads to access a single record
67/// *immutably*, provided that no thread accesses the same record mutably.
68///
69/// An implementor must furthermore ensure that it is sound for multiple threads to access
70/// *disjoint* records mutably.
71///
72/// It is the responsibility of the consumer that:
73///
74/// - If any thread accesses a record mutably, then no other thread must access the same record.
75/// - A mutable record must always be exclusive, even on a single thread.
76///   In particular, a single thread is not permitted to obtain two records associated with the
77///   same index in the collection if either record is mutable.
78///
79/// TODO: Make the invariants more precise
80///
81/// TODO: Can consider a slightly different API in which each thread must obtain its own access.
82pub unsafe trait ParallelIndexedAccess<'record>: Sync + Send + Clone {
83    type Record;
84    type RecordMut;
85
86    unsafe fn get_unchecked(&self, index: usize) -> Self::Record;
87    unsafe fn get_unchecked_mut(&self, index: usize) -> Self::RecordMut;
88}
89
90/// An indexed collection that exposes parallel indexed access to its contents.
91///
92/// The typical pattern for a generic parallel algorithm is to take any collection satisfying
93/// this trait as input, at which point we can guarantee that we are able to create the only
94/// parallel access to the collection.
95///
96/// # Examples
97///
98/// Let's consider a contrived example in which we want to multiply every number in a list by 2,
99/// and we want to multiply even and odd numbers on two separate threads. This is of course a
100/// *very bad* idea, for a number of reasons, but it does serve as a very simple example of
101/// a *sound* use of parallelism that nonetheless falls outside of what we can accomplish with
102/// safe Rust.
103///
104/// Although we'll have to use some `unsafe` code to accomplish this, the task sounds fairly
105/// trivial. However, we would soon discover that the only safe way of accessing different
106/// (scattered) elements of a slice mutably in parallel is through careful pointer manipulation.
107/// This is generally error prone and hard to get right.
108///
109/// The abstractions made available by `paradis` significantly simplifies this. Since slices
110/// implement the [`ParallelIndexedCollection`] trait, we can obtain *parallel access* to its
111/// elements, which lets us (unsafely) obtain access to any element in the slice without
112/// working with pointer arithmetic. We are of course wholly responsible for making sure that
113/// we never access the same element from two threads.
114///
115/// ```rust
116/// use fenris_paradis::{ParallelIndexedCollection, ParallelIndexedAccess};
117/// use crossbeam::scope;
118///
119/// fn par_double_all_numbers(numbers: &mut [i32]) {
120///     let n = numbers.len();
121///
122///     // Since creating an access takes a mutable reference to [i32], we know that we are the
123///     // only ones to hold a parallel access to the data throughout the program, so we can
124///     // soundly manipulate its data in parallel, provided we take some care
125///     let access = unsafe { numbers.create_access() };
126///
127///     // The standard library does not have a concept of scoped threads,
128///     // so we use crossbeam::scope for this purpose.
129///     // Otherwise we wouldn't be able to use our access across threads.
130///     scope(|s| {
131///         s.spawn(|_| {
132///             // Transform the even numbers
133///             for i in (0 .. n).step_by(2) {
134///                 unsafe { *access.get_unchecked_mut(i) *= 2; }
135///             }
136///         });
137///         s.spawn(|_| {
138///             // Transform the odd numbers
139///             for i in (1 .. n).step_by(2) {
140///                 unsafe { *access.get_unchecked_mut(i) *= 2; }
141///             }
142///         });
143///     }).expect("One of our threads panicked!");
144/// }
145///
146/// fn main() {
147///     let mut numbers = [0, 1, 2, 3, 4, 5, 6, 7];
148///     par_double_all_numbers(&mut numbers);
149///     assert_eq!(numbers, [0, 2, 4, 6, 8, 10, 12, 14]);
150/// }
151/// ```
152///
153/// # Safety
154///
155/// This trait is unsafe because the soundness of consuming code relies on the correctness of
156/// the implementation of [`ParallelIndexedCollection::len`].
157/// Consumers of this trait are permitted to access records
158/// (accessed through [`ParallelIndexedCollection::Access`]) with indices `[0, len)`. Therefore,
159/// an incorrect length may lead to unsoundness.
160pub unsafe trait ParallelIndexedCollection<'a> {
161    type Access;
162
163    unsafe fn create_access(&'a mut self) -> Self::Access;
164    fn len(&self) -> usize;
165}
166
167/// A set of subsets of indices, in which the intersection of indices between any two subsets is
168/// empty.
169///
170#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct DisjointSubsets {
172    // Store the max global index present in any of the subsets. We need this to
173    // ensure that none of the indices are out of bounds when accessing a storage.
174    max_index: Option<usize>,
175    // Each subset consists of a set of indices. Indices are allowed to overlap within a subset,
176    // but the intersection between the indices of any two subsets must be empty. In other words,
177    // no two subsets share a common index.
178    subsets: NestedVec<usize>,
179    // Store a label for each subset
180    labels: Vec<usize>,
181}
182
183#[derive(Copy, Clone, Debug, PartialEq, Eq)]
184pub struct SubsetsNotDisjointError;
185
186impl DisjointSubsets {
187    pub fn try_from_disjoint_subsets<Subsets: Into<NestedVec<usize>>>(
188        subsets: Subsets,
189        labels: Vec<usize>,
190    ) -> Result<Self, SubsetsNotDisjointError> {
191        let subsets = subsets.into();
192        assert_eq!(subsets.len(), labels.len(), "Must have exactly one label per subset.");
193
194        let mut max_index = None;
195        let mut global_index_set = HashSet::new();
196        // Subsets are allowed to contain duplicate entries, so we therefore build a local index
197        // set for each subset before checking against and adding them to the global index set.
198        let mut local_index_set = HashSet::new();
199
200        // Verify that subsets are disjoint
201        for subset in subsets.iter() {
202            local_index_set.clear();
203            for idx in subset {
204                if let Some(ref mut current_max) = max_index {
205                    *current_max = max(*current_max, *idx);
206                } else {
207                    max_index = Some(*idx);
208                }
209                local_index_set.insert(*idx);
210            }
211
212            for idx in &local_index_set {
213                let idx_already_present = !global_index_set.insert(*idx);
214                if idx_already_present {
215                    return Err(SubsetsNotDisjointError);
216                }
217            }
218        }
219
220        let disjoint_subsets = DisjointSubsets {
221            max_index,
222            subsets,
223            labels,
224        };
225
226        Ok(disjoint_subsets)
227    }
228
229    pub unsafe fn from_disjoint_subsets_unchecked<Subsets: Into<NestedVec<usize>>>(
230        subsets: Subsets,
231        labels: Vec<usize>,
232        max_index: Option<usize>,
233    ) -> Self {
234        let subsets = subsets.into();
235        assert_eq!(subsets.len(), labels.len(), "Must have exactly one label per subset.");
236        Self {
237            max_index,
238            subsets: subsets.into(),
239            labels,
240        }
241    }
242
243    pub fn subsets(&self) -> &NestedVec<usize> {
244        &self.subsets
245    }
246
247    pub fn into_subsets(self) -> NestedVec<usize> {
248        self.subsets
249    }
250
251    pub fn labels(&self) -> &[usize] {
252        &self.labels
253    }
254
255    /// Create a parallel iterator over the subsets, fetching data from the provided storage.
256    ///
257    /// Panics if any subset contains an index that exceeds the length reported by `storage`.
258    pub fn subsets_par_iter<'a, Storage>(
259        &'a self,
260        storage: &'a mut Storage,
261    ) -> DisjointSubsetsParIter<'a, Storage::Access>
262    where
263        Storage: ?Sized + ParallelIndexedCollection<'a>,
264    {
265        assert!(
266            self.max_index.is_none() || storage.len() > self.max_index.unwrap(),
267            "Subsets contain indices out of bounds."
268        );
269        // Sanity check: if we don't have a max index, then we also cannot have any subsets
270        debug_assert_eq!(self.max_index.is_none(), self.subsets.len() == 0);
271        let access = unsafe { storage.create_access() };
272
273        DisjointSubsetsParIter {
274            access,
275            subsets: &self.subsets,
276            labels: &self.labels,
277        }
278    }
279}
280
281pub struct DisjointSubsetsParIter<'a, Access> {
282    access: Access,
283    subsets: &'a NestedVec<usize>,
284    labels: &'a [usize],
285}
286
287impl<'a, Access: Send + Clone> ParallelIterator for DisjointSubsetsParIter<'a, Access> {
288    type Item = SubsetAccess<'a, Access>;
289
290    fn drive_unindexed<C>(self, consumer: C) -> C::Result
291    where
292        C: UnindexedConsumer<Self::Item>,
293    {
294        bridge(self, consumer)
295    }
296
297    fn opt_len(&self) -> Option<usize> {
298        Some(self.len())
299    }
300}
301
302impl<'a, Access: Send + Clone> IndexedParallelIterator for DisjointSubsetsParIter<'a, Access> {
303    fn len(&self) -> usize {
304        self.subsets.len()
305    }
306
307    fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> <C as Consumer<Self::Item>>::Result {
308        bridge(self, consumer)
309    }
310
311    fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
312        let num_subsets = self.subsets.len();
313        callback.callback(DisjointSubsetsProducer {
314            access: self.access,
315            subsets: &self.subsets,
316            labels: self.labels,
317            range_start_idx: 0,
318            range_len: num_subsets,
319        })
320    }
321}
322
323struct DisjointSubsetsProducer<'a, Access> {
324    access: Access,
325    subsets: &'a NestedVec<usize>,
326    labels: &'a [usize],
327    // Range start/len represents the range represented by this producer
328    range_start_idx: usize,
329    range_len: usize,
330}
331
332impl<'a, Access> Debug for DisjointSubsetsProducer<'a, Access> {
333    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334        f.debug_struct("DisjointSubsetsProducer")
335            // .field("access", &"<not debuggable>")
336            // .field("subsets", &self.subsets)
337            .field("range_start_idx", &self.range_start_idx)
338            .field("range_len", &self.range_len)
339            .finish()
340    }
341}
342
343impl<'a, Access: Send + Clone> Producer for DisjointSubsetsProducer<'a, Access> {
344    type Item = SubsetAccess<'a, Access>;
345    type IntoIter = DisjointSubsetsIter<'a, Access>;
346
347    fn into_iter(self) -> Self::IntoIter {
348        DisjointSubsetsIter {
349            access: self.access.clone(),
350            subsets: self.subsets,
351            labels: self.labels,
352            end: self.range_len + self.range_start_idx,
353            current_idx: self.range_start_idx,
354        }
355    }
356
357    fn split_at(self, index: usize) -> (Self, Self) {
358        let producer_len = self.range_len;
359        assert!(index < producer_len);
360        let global_subset_idx = self.range_start_idx + index;
361
362        let producer_left = DisjointSubsetsProducer {
363            access: self.access.clone(),
364            subsets: self.subsets,
365            labels: self.labels,
366            range_start_idx: self.range_start_idx,
367            range_len: index,
368        };
369
370        let producer_right = DisjointSubsetsProducer {
371            access: self.access,
372            subsets: self.subsets,
373            labels: self.labels,
374            range_start_idx: global_subset_idx,
375            range_len: producer_len - index,
376        };
377
378        (producer_left, producer_right)
379    }
380}
381
382struct DisjointSubsetsIter<'a, Access> {
383    access: Access,
384    subsets: &'a NestedVec<usize>,
385    labels: &'a [usize],
386    // end is an index one-past the end of the iterator
387    end: usize,
388    // The current index that the iterator is at
389    current_idx: usize,
390}
391
392impl<'a, Access> Debug for DisjointSubsetsIter<'a, Access> {
393    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394        f.debug_struct("DisjointSubsetsIter")
395            // .field("access", &"<not debuggable>")
396            // .field("subsets", &self.subsets)
397            .field("end", &self.end)
398            .field("current_idx", &self.current_idx)
399            .finish()
400    }
401}
402
403impl<'a, Access: Clone> Iterator for DisjointSubsetsIter<'a, Access> {
404    type Item = SubsetAccess<'a, Access>;
405
406    fn next(&mut self) -> Option<Self::Item> {
407        if self.current_idx < self.end {
408            let access = SubsetAccess {
409                subset_label: *self.labels.get(self.current_idx).unwrap(),
410                global_indices: self.subsets.get(self.current_idx).unwrap(),
411                access: self.access.clone(),
412            };
413            self.current_idx += 1;
414            Some(access)
415        } else {
416            None
417        }
418    }
419
420    fn size_hint(&self) -> (usize, Option<usize>) {
421        let len = self.end - self.current_idx;
422        (len, Some(len))
423    }
424}
425
426impl<'a, Access: Clone> ExactSizeIterator for DisjointSubsetsIter<'a, Access> {}
427
428impl<'a, Access: Clone> DoubleEndedIterator for DisjointSubsetsIter<'a, Access> {
429    fn next_back(&mut self) -> Option<Self::Item> {
430        if self.end > self.current_idx {
431            let subset_index = self.end - 1;
432            let access = SubsetAccess {
433                subset_label: *self.labels.get(subset_index).unwrap(),
434                global_indices: self.subsets.get(subset_index).unwrap(),
435                access: self.access.clone(),
436            };
437            self.end -= 1;
438            Some(access)
439        } else {
440            None
441        }
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::DisjointSubsets;
448    use super::DisjointSubsetsIter;
449    use super::ParallelIndexedCollection;
450    use fenris_nested_vec::NestedVec;
451    use proptest::collection::{btree_set, vec};
452    use proptest::prelude::*;
453    use rand::rngs::StdRng;
454    use rand::seq::SliceRandom;
455    use rand::SeedableRng;
456    use rayon::iter::{IndexedParallelIterator, ParallelIterator};
457
458    #[test]
459    fn test_disjoint_subsets_iter() {
460        let subsets_vec = vec![vec![4, 5], vec![1, 2, 3], vec![6, 0]];
461        let subset_labels = vec![0, 1, 2];
462        let subsets = NestedVec::from(&subsets_vec);
463
464        // Forward iteration only
465        {
466            // Range is over all subsets
467            let mut data = vec![10, 11, 12, 13, 14, 15, 16];
468            let data_slice = data.as_mut_slice();
469
470            let access = unsafe { data_slice.create_access() };
471
472            let mut iter = DisjointSubsetsIter {
473                access,
474                subsets: &subsets,
475                labels: &subset_labels,
476                end: subsets.len(),
477                current_idx: 0,
478            };
479
480            assert_eq!(iter.len(), 3);
481            let subset_access = iter.next().unwrap();
482            assert_eq!(subset_access.global_indices(), subsets_vec[0].as_slice());
483            assert_eq!(iter.len(), 2);
484            let subset_access = iter.next().unwrap();
485            assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
486            assert_eq!(iter.len(), 1);
487            let subset_access = iter.next().unwrap();
488            assert_eq!(subset_access.global_indices(), subsets_vec[2].as_slice());
489            assert_eq!(iter.len(), 0);
490            assert!(iter.next().is_none());
491        }
492
493        // Forward iteration only
494        {
495            // Range is over subset
496            let mut data = vec![10, 11, 12, 13, 14, 15, 16];
497            let data_slice = data.as_mut_slice();
498
499            let access = unsafe { data_slice.create_access() };
500
501            let mut iter = DisjointSubsetsIter {
502                access,
503                subsets: &subsets,
504                labels: &subset_labels,
505                end: subsets.len(),
506                current_idx: 1,
507            };
508
509            assert_eq!(iter.len(), 2);
510            let subset_access = iter.next().unwrap();
511            assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
512            assert_eq!(iter.len(), 1);
513            let subset_access = iter.next().unwrap();
514            assert_eq!(subset_access.global_indices(), subsets_vec[2].as_slice());
515            assert_eq!(iter.len(), 0);
516            assert!(iter.next().is_none());
517        }
518
519        // Backward iteration only
520        {
521            // Range is over subset
522            let mut data = vec![10, 11, 12, 13, 14, 15, 16];
523            let data_slice = data.as_mut_slice();
524
525            let access = unsafe { data_slice.create_access() };
526
527            let mut iter = DisjointSubsetsIter {
528                access,
529                subsets: &subsets,
530                labels: &subset_labels,
531                end: subsets.len(),
532                current_idx: 0,
533            };
534
535            assert_eq!(iter.len(), 3);
536            let subset_access = iter.next_back().unwrap();
537            assert_eq!(subset_access.global_indices(), subsets_vec[2].as_slice());
538            assert_eq!(iter.len(), 2);
539            let subset_access = iter.next_back().unwrap();
540            assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
541            assert_eq!(iter.len(), 1);
542            let subset_access = iter.next_back().unwrap();
543            assert_eq!(subset_access.global_indices(), subsets_vec[0].as_slice());
544            assert_eq!(iter.len(), 0);
545            assert!(iter.next().is_none());
546        }
547
548        // Backward iteration only
549        {
550            // Range is over subset
551            let mut data = vec![10, 11, 12, 13, 14, 15, 16];
552            let data_slice = data.as_mut_slice();
553
554            let access = unsafe { data_slice.create_access() };
555
556            let mut iter = DisjointSubsetsIter {
557                access,
558                subsets: &subsets,
559                labels: &subset_labels,
560                end: subsets.len() - 1,
561                current_idx: 0,
562            };
563
564            assert_eq!(iter.len(), 2);
565            let subset_access = iter.next_back().unwrap();
566            assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
567            assert_eq!(iter.len(), 1);
568            let subset_access = iter.next_back().unwrap();
569            assert_eq!(subset_access.global_indices(), subsets_vec[0].as_slice());
570            assert_eq!(iter.len(), 0);
571            assert!(iter.next().is_none());
572        }
573
574        // Forward and backward iteration
575        {
576            // Range is over subset
577            let mut data = vec![10, 11, 12, 13, 14, 15, 16];
578            let data_slice = data.as_mut_slice();
579
580            let access = unsafe { data_slice.create_access() };
581
582            let mut iter = DisjointSubsetsIter {
583                access,
584                subsets: &subsets,
585                labels: &subset_labels,
586                end: subsets.len(),
587                current_idx: 0,
588            };
589
590            assert_eq!(iter.len(), 3);
591            let subset_access = iter.next().unwrap();
592            assert_eq!(subset_access.global_indices(), subsets_vec[0].as_slice());
593            assert_eq!(iter.len(), 2);
594            let subset_access = iter.next_back().unwrap();
595            assert_eq!(subset_access.global_indices(), subsets_vec[2].as_slice());
596            assert_eq!(iter.len(), 1);
597            let subset_access = iter.next().unwrap();
598            assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
599            assert_eq!(iter.len(), 0);
600            assert!(iter.next_back().is_none());
601            assert!(iter.next().is_none());
602            assert!(iter.next().is_none());
603            assert!(iter.next_back().is_none());
604            assert!(iter.next_back().is_none());
605            assert!(iter.next().is_none());
606        }
607    }
608
609    #[test]
610    fn test_parallel() {
611        // TODO: Fixed seed
612        let mut rng = StdRng::seed_from_u64(458340234234);
613
614        let mut unique_indices: Vec<_> = (0..100000).collect();
615        unique_indices.shuffle(&mut rng);
616
617        let chunks: Vec<_> = unique_indices
618            .chunks(10)
619            .map(|chunk| chunk.to_vec())
620            .collect();
621
622        let labels = (0..chunks.len()).collect();
623
624        let disjoint_subsets = DisjointSubsets::try_from_disjoint_subsets(&chunks, labels).unwrap();
625
626        let mut output_par = vec![0; unique_indices.len()];
627        disjoint_subsets
628            .subsets_par_iter(output_par.as_mut_slice())
629            .zip_eq(&chunks)
630            // Try to ensure that rayon actually uses multiple threads, otherwise it might
631            // decide to run it all sequentially
632            .with_max_len(1)
633            .for_each(|(mut subset_access, chunk)| {
634                assert_eq!(subset_access.global_indices(), chunk.as_slice());
635                for i in 0..chunk.len() {
636                    *subset_access.get_mut(i) += 1;
637                }
638            });
639
640        let mut output_seq = vec![0; unique_indices.len()];
641        chunks.iter().for_each(|chunk| {
642            for i in 0..chunk.len() {
643                output_seq[chunk[i]] += 1;
644            }
645        });
646
647        let expected_output = vec![1; unique_indices.len()];
648        assert_eq!(output_seq, expected_output);
649        assert_eq!(output_par, expected_output);
650    }
651
652    // TODO: Test the strategy itself!
653    // TODO: Our current strategy also enforces that the subsets have no duplicate indices,
654    // which is explicitly allowed by our algorithms, so we should include this too
655    fn disjoint_subsets_strategy() -> impl Strategy<Value = NestedVec<usize>> {
656        let max_num_integers = 20usize;
657        (0..max_num_integers)
658            .prop_flat_map(|n| Just((0..n).collect::<Vec<_>>()))
659            .prop_shuffle()
660            .prop_flat_map(|integers| {
661                let n = integers.len();
662                let num_splits = 0..=n;
663                let split_indices = vec(0..n, num_splits);
664                (Just(integers), split_indices)
665            })
666            .prop_map(|(integers, mut split_indices)| {
667                let mut subsets = Vec::with_capacity(split_indices.len() + 1);
668                split_indices.push(0);
669                split_indices.push(integers.len());
670                split_indices.sort_unstable();
671                for window in split_indices.windows(2) {
672                    let idx = window[0];
673                    let idx_next = window[1];
674                    subsets.push(integers[idx..idx_next].to_vec());
675                }
676                NestedVec::from(&subsets)
677            })
678    }
679
680    fn overlapping_subsets_strategy() -> impl Strategy<Value = NestedVec<usize>> {
681        // Given a set of overlapping subsets, add the same index to multiple subsets,
682        // thereby ensuring that the subsets are no longer disjoint
683        let max_index = 20usize;
684        disjoint_subsets_strategy()
685            .prop_filter("Must have more than 1 subset", |subsets| subsets.len() > 1)
686            .prop_flat_map(move |subsets| {
687                let insertion_index = 0..max_index;
688                let subset_index_strategy = btree_set(0..subsets.len(), 2..=subsets.len());
689                (Just(subsets), subset_index_strategy, insertion_index)
690            })
691            .prop_map(|(subsets, subset_indices, insertion_index)| {
692                let mut subsets: Vec<Vec<_>> = subsets.into();
693                let num_subsets = subsets.len();
694                for subset_idx in subset_indices {
695                    subsets[subset_idx % num_subsets].push(insertion_index);
696                }
697                NestedVec::from(subsets)
698            })
699    }
700
701    proptest! {
702        #[test]
703        fn can_create_from_disjoint_subsets(
704            disjoint_subsets in disjoint_subsets_strategy()
705        ) {
706            let labels = (0 .. disjoint_subsets.len()).collect();
707            let disjoint = DisjointSubsets::try_from_disjoint_subsets(disjoint_subsets, labels);
708            dbg!(&disjoint);
709            prop_assert!(disjoint.is_ok());
710        }
711
712        #[test]
713        fn refuses_to_create_from_overlapping_subsets(
714            subsets in overlapping_subsets_strategy()
715        ) {
716            let labels = (0 .. subsets.len()).collect();
717            let disjoint = DisjointSubsets::try_from_disjoint_subsets(subsets, labels);
718            prop_assert!(disjoint.is_err());
719        }
720    }
721}