1pub mod adapter;
7pub mod coloring;
8pub mod slice;
9
10use fenris_nested_vec::NestedVec;
11use rayon::iter::plumbing::{bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer};
12use rayon::iter::{IndexedParallelIterator, ParallelIterator};
13use serde::{Deserialize, Serialize};
14use std::cmp::max;
15use std::collections::HashSet;
16use std::fmt;
17use std::fmt::Debug;
18
19pub struct SubsetAccess<'data, Access> {
20 subset_label: usize,
21 global_indices: &'data [usize],
22 access: Access,
23}
24
25impl<'data, Access> SubsetAccess<'data, Access> {
26 pub fn global_indices(&self) -> &[usize] {
27 &self.global_indices
28 }
29
30 pub fn label(&self) -> usize {
31 self.subset_label
32 }
33
34 pub fn len(&self) -> usize {
35 self.global_indices().len()
36 }
37
38 pub fn get<'b>(&'b self, local_index: usize) -> <Access as ParallelIndexedAccess<'b>>::Record
39 where
40 'data: 'b,
41 Access: ParallelIndexedAccess<'b>,
42 {
43 let global_index = self.global_indices[local_index];
44 unsafe { self.access.get_unchecked(global_index) }
45 }
46
47 pub fn get_mut<'b>(&'b mut self, local_index: usize) -> <Access as ParallelIndexedAccess<'b>>::RecordMut
48 where
49 'data: 'b,
50 Access: ParallelIndexedAccess<'b>,
51 {
52 let global_index = self.global_indices[local_index];
53 unsafe { self.access.get_unchecked_mut(global_index) }
54 }
55}
56
57pub unsafe trait ParallelIndexedAccess<'record>: Sync + Send + Clone {
83 type Record;
84 type RecordMut;
85
86 unsafe fn get_unchecked(&self, index: usize) -> Self::Record;
87 unsafe fn get_unchecked_mut(&self, index: usize) -> Self::RecordMut;
88}
89
90pub unsafe trait ParallelIndexedCollection<'a> {
161 type Access;
162
163 unsafe fn create_access(&'a mut self) -> Self::Access;
164 fn len(&self) -> usize;
165}
166
167#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct DisjointSubsets {
172 max_index: Option<usize>,
175 subsets: NestedVec<usize>,
179 labels: Vec<usize>,
181}
182
183#[derive(Copy, Clone, Debug, PartialEq, Eq)]
184pub struct SubsetsNotDisjointError;
185
186impl DisjointSubsets {
187 pub fn try_from_disjoint_subsets<Subsets: Into<NestedVec<usize>>>(
188 subsets: Subsets,
189 labels: Vec<usize>,
190 ) -> Result<Self, SubsetsNotDisjointError> {
191 let subsets = subsets.into();
192 assert_eq!(subsets.len(), labels.len(), "Must have exactly one label per subset.");
193
194 let mut max_index = None;
195 let mut global_index_set = HashSet::new();
196 let mut local_index_set = HashSet::new();
199
200 for subset in subsets.iter() {
202 local_index_set.clear();
203 for idx in subset {
204 if let Some(ref mut current_max) = max_index {
205 *current_max = max(*current_max, *idx);
206 } else {
207 max_index = Some(*idx);
208 }
209 local_index_set.insert(*idx);
210 }
211
212 for idx in &local_index_set {
213 let idx_already_present = !global_index_set.insert(*idx);
214 if idx_already_present {
215 return Err(SubsetsNotDisjointError);
216 }
217 }
218 }
219
220 let disjoint_subsets = DisjointSubsets {
221 max_index,
222 subsets,
223 labels,
224 };
225
226 Ok(disjoint_subsets)
227 }
228
229 pub unsafe fn from_disjoint_subsets_unchecked<Subsets: Into<NestedVec<usize>>>(
230 subsets: Subsets,
231 labels: Vec<usize>,
232 max_index: Option<usize>,
233 ) -> Self {
234 let subsets = subsets.into();
235 assert_eq!(subsets.len(), labels.len(), "Must have exactly one label per subset.");
236 Self {
237 max_index,
238 subsets: subsets.into(),
239 labels,
240 }
241 }
242
243 pub fn subsets(&self) -> &NestedVec<usize> {
244 &self.subsets
245 }
246
247 pub fn into_subsets(self) -> NestedVec<usize> {
248 self.subsets
249 }
250
251 pub fn labels(&self) -> &[usize] {
252 &self.labels
253 }
254
255 pub fn subsets_par_iter<'a, Storage>(
259 &'a self,
260 storage: &'a mut Storage,
261 ) -> DisjointSubsetsParIter<'a, Storage::Access>
262 where
263 Storage: ?Sized + ParallelIndexedCollection<'a>,
264 {
265 assert!(
266 self.max_index.is_none() || storage.len() > self.max_index.unwrap(),
267 "Subsets contain indices out of bounds."
268 );
269 debug_assert_eq!(self.max_index.is_none(), self.subsets.len() == 0);
271 let access = unsafe { storage.create_access() };
272
273 DisjointSubsetsParIter {
274 access,
275 subsets: &self.subsets,
276 labels: &self.labels,
277 }
278 }
279}
280
281pub struct DisjointSubsetsParIter<'a, Access> {
282 access: Access,
283 subsets: &'a NestedVec<usize>,
284 labels: &'a [usize],
285}
286
287impl<'a, Access: Send + Clone> ParallelIterator for DisjointSubsetsParIter<'a, Access> {
288 type Item = SubsetAccess<'a, Access>;
289
290 fn drive_unindexed<C>(self, consumer: C) -> C::Result
291 where
292 C: UnindexedConsumer<Self::Item>,
293 {
294 bridge(self, consumer)
295 }
296
297 fn opt_len(&self) -> Option<usize> {
298 Some(self.len())
299 }
300}
301
302impl<'a, Access: Send + Clone> IndexedParallelIterator for DisjointSubsetsParIter<'a, Access> {
303 fn len(&self) -> usize {
304 self.subsets.len()
305 }
306
307 fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> <C as Consumer<Self::Item>>::Result {
308 bridge(self, consumer)
309 }
310
311 fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
312 let num_subsets = self.subsets.len();
313 callback.callback(DisjointSubsetsProducer {
314 access: self.access,
315 subsets: &self.subsets,
316 labels: self.labels,
317 range_start_idx: 0,
318 range_len: num_subsets,
319 })
320 }
321}
322
323struct DisjointSubsetsProducer<'a, Access> {
324 access: Access,
325 subsets: &'a NestedVec<usize>,
326 labels: &'a [usize],
327 range_start_idx: usize,
329 range_len: usize,
330}
331
332impl<'a, Access> Debug for DisjointSubsetsProducer<'a, Access> {
333 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 f.debug_struct("DisjointSubsetsProducer")
335 .field("range_start_idx", &self.range_start_idx)
338 .field("range_len", &self.range_len)
339 .finish()
340 }
341}
342
343impl<'a, Access: Send + Clone> Producer for DisjointSubsetsProducer<'a, Access> {
344 type Item = SubsetAccess<'a, Access>;
345 type IntoIter = DisjointSubsetsIter<'a, Access>;
346
347 fn into_iter(self) -> Self::IntoIter {
348 DisjointSubsetsIter {
349 access: self.access.clone(),
350 subsets: self.subsets,
351 labels: self.labels,
352 end: self.range_len + self.range_start_idx,
353 current_idx: self.range_start_idx,
354 }
355 }
356
357 fn split_at(self, index: usize) -> (Self, Self) {
358 let producer_len = self.range_len;
359 assert!(index < producer_len);
360 let global_subset_idx = self.range_start_idx + index;
361
362 let producer_left = DisjointSubsetsProducer {
363 access: self.access.clone(),
364 subsets: self.subsets,
365 labels: self.labels,
366 range_start_idx: self.range_start_idx,
367 range_len: index,
368 };
369
370 let producer_right = DisjointSubsetsProducer {
371 access: self.access,
372 subsets: self.subsets,
373 labels: self.labels,
374 range_start_idx: global_subset_idx,
375 range_len: producer_len - index,
376 };
377
378 (producer_left, producer_right)
379 }
380}
381
382struct DisjointSubsetsIter<'a, Access> {
383 access: Access,
384 subsets: &'a NestedVec<usize>,
385 labels: &'a [usize],
386 end: usize,
388 current_idx: usize,
390}
391
392impl<'a, Access> Debug for DisjointSubsetsIter<'a, Access> {
393 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394 f.debug_struct("DisjointSubsetsIter")
395 .field("end", &self.end)
398 .field("current_idx", &self.current_idx)
399 .finish()
400 }
401}
402
403impl<'a, Access: Clone> Iterator for DisjointSubsetsIter<'a, Access> {
404 type Item = SubsetAccess<'a, Access>;
405
406 fn next(&mut self) -> Option<Self::Item> {
407 if self.current_idx < self.end {
408 let access = SubsetAccess {
409 subset_label: *self.labels.get(self.current_idx).unwrap(),
410 global_indices: self.subsets.get(self.current_idx).unwrap(),
411 access: self.access.clone(),
412 };
413 self.current_idx += 1;
414 Some(access)
415 } else {
416 None
417 }
418 }
419
420 fn size_hint(&self) -> (usize, Option<usize>) {
421 let len = self.end - self.current_idx;
422 (len, Some(len))
423 }
424}
425
426impl<'a, Access: Clone> ExactSizeIterator for DisjointSubsetsIter<'a, Access> {}
427
428impl<'a, Access: Clone> DoubleEndedIterator for DisjointSubsetsIter<'a, Access> {
429 fn next_back(&mut self) -> Option<Self::Item> {
430 if self.end > self.current_idx {
431 let subset_index = self.end - 1;
432 let access = SubsetAccess {
433 subset_label: *self.labels.get(subset_index).unwrap(),
434 global_indices: self.subsets.get(subset_index).unwrap(),
435 access: self.access.clone(),
436 };
437 self.end -= 1;
438 Some(access)
439 } else {
440 None
441 }
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::DisjointSubsets;
448 use super::DisjointSubsetsIter;
449 use super::ParallelIndexedCollection;
450 use fenris_nested_vec::NestedVec;
451 use proptest::collection::{btree_set, vec};
452 use proptest::prelude::*;
453 use rand::rngs::StdRng;
454 use rand::seq::SliceRandom;
455 use rand::SeedableRng;
456 use rayon::iter::{IndexedParallelIterator, ParallelIterator};
457
458 #[test]
459 fn test_disjoint_subsets_iter() {
460 let subsets_vec = vec![vec![4, 5], vec![1, 2, 3], vec![6, 0]];
461 let subset_labels = vec![0, 1, 2];
462 let subsets = NestedVec::from(&subsets_vec);
463
464 {
466 let mut data = vec![10, 11, 12, 13, 14, 15, 16];
468 let data_slice = data.as_mut_slice();
469
470 let access = unsafe { data_slice.create_access() };
471
472 let mut iter = DisjointSubsetsIter {
473 access,
474 subsets: &subsets,
475 labels: &subset_labels,
476 end: subsets.len(),
477 current_idx: 0,
478 };
479
480 assert_eq!(iter.len(), 3);
481 let subset_access = iter.next().unwrap();
482 assert_eq!(subset_access.global_indices(), subsets_vec[0].as_slice());
483 assert_eq!(iter.len(), 2);
484 let subset_access = iter.next().unwrap();
485 assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
486 assert_eq!(iter.len(), 1);
487 let subset_access = iter.next().unwrap();
488 assert_eq!(subset_access.global_indices(), subsets_vec[2].as_slice());
489 assert_eq!(iter.len(), 0);
490 assert!(iter.next().is_none());
491 }
492
493 {
495 let mut data = vec![10, 11, 12, 13, 14, 15, 16];
497 let data_slice = data.as_mut_slice();
498
499 let access = unsafe { data_slice.create_access() };
500
501 let mut iter = DisjointSubsetsIter {
502 access,
503 subsets: &subsets,
504 labels: &subset_labels,
505 end: subsets.len(),
506 current_idx: 1,
507 };
508
509 assert_eq!(iter.len(), 2);
510 let subset_access = iter.next().unwrap();
511 assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
512 assert_eq!(iter.len(), 1);
513 let subset_access = iter.next().unwrap();
514 assert_eq!(subset_access.global_indices(), subsets_vec[2].as_slice());
515 assert_eq!(iter.len(), 0);
516 assert!(iter.next().is_none());
517 }
518
519 {
521 let mut data = vec![10, 11, 12, 13, 14, 15, 16];
523 let data_slice = data.as_mut_slice();
524
525 let access = unsafe { data_slice.create_access() };
526
527 let mut iter = DisjointSubsetsIter {
528 access,
529 subsets: &subsets,
530 labels: &subset_labels,
531 end: subsets.len(),
532 current_idx: 0,
533 };
534
535 assert_eq!(iter.len(), 3);
536 let subset_access = iter.next_back().unwrap();
537 assert_eq!(subset_access.global_indices(), subsets_vec[2].as_slice());
538 assert_eq!(iter.len(), 2);
539 let subset_access = iter.next_back().unwrap();
540 assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
541 assert_eq!(iter.len(), 1);
542 let subset_access = iter.next_back().unwrap();
543 assert_eq!(subset_access.global_indices(), subsets_vec[0].as_slice());
544 assert_eq!(iter.len(), 0);
545 assert!(iter.next().is_none());
546 }
547
548 {
550 let mut data = vec![10, 11, 12, 13, 14, 15, 16];
552 let data_slice = data.as_mut_slice();
553
554 let access = unsafe { data_slice.create_access() };
555
556 let mut iter = DisjointSubsetsIter {
557 access,
558 subsets: &subsets,
559 labels: &subset_labels,
560 end: subsets.len() - 1,
561 current_idx: 0,
562 };
563
564 assert_eq!(iter.len(), 2);
565 let subset_access = iter.next_back().unwrap();
566 assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
567 assert_eq!(iter.len(), 1);
568 let subset_access = iter.next_back().unwrap();
569 assert_eq!(subset_access.global_indices(), subsets_vec[0].as_slice());
570 assert_eq!(iter.len(), 0);
571 assert!(iter.next().is_none());
572 }
573
574 {
576 let mut data = vec![10, 11, 12, 13, 14, 15, 16];
578 let data_slice = data.as_mut_slice();
579
580 let access = unsafe { data_slice.create_access() };
581
582 let mut iter = DisjointSubsetsIter {
583 access,
584 subsets: &subsets,
585 labels: &subset_labels,
586 end: subsets.len(),
587 current_idx: 0,
588 };
589
590 assert_eq!(iter.len(), 3);
591 let subset_access = iter.next().unwrap();
592 assert_eq!(subset_access.global_indices(), subsets_vec[0].as_slice());
593 assert_eq!(iter.len(), 2);
594 let subset_access = iter.next_back().unwrap();
595 assert_eq!(subset_access.global_indices(), subsets_vec[2].as_slice());
596 assert_eq!(iter.len(), 1);
597 let subset_access = iter.next().unwrap();
598 assert_eq!(subset_access.global_indices(), subsets_vec[1].as_slice());
599 assert_eq!(iter.len(), 0);
600 assert!(iter.next_back().is_none());
601 assert!(iter.next().is_none());
602 assert!(iter.next().is_none());
603 assert!(iter.next_back().is_none());
604 assert!(iter.next_back().is_none());
605 assert!(iter.next().is_none());
606 }
607 }
608
609 #[test]
610 fn test_parallel() {
611 let mut rng = StdRng::seed_from_u64(458340234234);
613
614 let mut unique_indices: Vec<_> = (0..100000).collect();
615 unique_indices.shuffle(&mut rng);
616
617 let chunks: Vec<_> = unique_indices
618 .chunks(10)
619 .map(|chunk| chunk.to_vec())
620 .collect();
621
622 let labels = (0..chunks.len()).collect();
623
624 let disjoint_subsets = DisjointSubsets::try_from_disjoint_subsets(&chunks, labels).unwrap();
625
626 let mut output_par = vec![0; unique_indices.len()];
627 disjoint_subsets
628 .subsets_par_iter(output_par.as_mut_slice())
629 .zip_eq(&chunks)
630 .with_max_len(1)
633 .for_each(|(mut subset_access, chunk)| {
634 assert_eq!(subset_access.global_indices(), chunk.as_slice());
635 for i in 0..chunk.len() {
636 *subset_access.get_mut(i) += 1;
637 }
638 });
639
640 let mut output_seq = vec![0; unique_indices.len()];
641 chunks.iter().for_each(|chunk| {
642 for i in 0..chunk.len() {
643 output_seq[chunk[i]] += 1;
644 }
645 });
646
647 let expected_output = vec![1; unique_indices.len()];
648 assert_eq!(output_seq, expected_output);
649 assert_eq!(output_par, expected_output);
650 }
651
652 fn disjoint_subsets_strategy() -> impl Strategy<Value = NestedVec<usize>> {
656 let max_num_integers = 20usize;
657 (0..max_num_integers)
658 .prop_flat_map(|n| Just((0..n).collect::<Vec<_>>()))
659 .prop_shuffle()
660 .prop_flat_map(|integers| {
661 let n = integers.len();
662 let num_splits = 0..=n;
663 let split_indices = vec(0..n, num_splits);
664 (Just(integers), split_indices)
665 })
666 .prop_map(|(integers, mut split_indices)| {
667 let mut subsets = Vec::with_capacity(split_indices.len() + 1);
668 split_indices.push(0);
669 split_indices.push(integers.len());
670 split_indices.sort_unstable();
671 for window in split_indices.windows(2) {
672 let idx = window[0];
673 let idx_next = window[1];
674 subsets.push(integers[idx..idx_next].to_vec());
675 }
676 NestedVec::from(&subsets)
677 })
678 }
679
680 fn overlapping_subsets_strategy() -> impl Strategy<Value = NestedVec<usize>> {
681 let max_index = 20usize;
684 disjoint_subsets_strategy()
685 .prop_filter("Must have more than 1 subset", |subsets| subsets.len() > 1)
686 .prop_flat_map(move |subsets| {
687 let insertion_index = 0..max_index;
688 let subset_index_strategy = btree_set(0..subsets.len(), 2..=subsets.len());
689 (Just(subsets), subset_index_strategy, insertion_index)
690 })
691 .prop_map(|(subsets, subset_indices, insertion_index)| {
692 let mut subsets: Vec<Vec<_>> = subsets.into();
693 let num_subsets = subsets.len();
694 for subset_idx in subset_indices {
695 subsets[subset_idx % num_subsets].push(insertion_index);
696 }
697 NestedVec::from(subsets)
698 })
699 }
700
701 proptest! {
702 #[test]
703 fn can_create_from_disjoint_subsets(
704 disjoint_subsets in disjoint_subsets_strategy()
705 ) {
706 let labels = (0 .. disjoint_subsets.len()).collect();
707 let disjoint = DisjointSubsets::try_from_disjoint_subsets(disjoint_subsets, labels);
708 dbg!(&disjoint);
709 prop_assert!(disjoint.is_ok());
710 }
711
712 #[test]
713 fn refuses_to_create_from_overlapping_subsets(
714 subsets in overlapping_subsets_strategy()
715 ) {
716 let labels = (0 .. subsets.len()).collect();
717 let disjoint = DisjointSubsets::try_from_disjoint_subsets(subsets, labels);
718 prop_assert!(disjoint.is_err());
719 }
720 }
721}