1use crate::set::{
4 contiguous_integer_set::ContiguousIntegerSet,
5 ordered_integer_set::OrderedIntegerSet,
6};
7use num::{Integer, ToPrimitive};
8use rayon::iter::{
9 plumbing::{
10 bridge, Consumer, Producer, ProducerCallback, UnindexedConsumer,
11 },
12 IndexedParallelIterator, IntoParallelIterator, ParallelIterator,
13};
14use std::ops::Index;
15
16pub type Partition<T> = OrderedIntegerSet<T>;
17
18#[derive(Clone, Debug, Eq, PartialEq)]
20pub struct IntegerPartitions<T: Copy + Integer + ToPrimitive> {
21 partitions: Vec<Partition<T>>,
22}
23
24impl<T: Copy + Integer + ToPrimitive> IntegerPartitions<T> {
25 pub fn new(partitions: Vec<Partition<T>>) -> IntegerPartitions<T> {
26 IntegerPartitions {
27 partitions,
28 }
29 }
30
31 #[inline]
32 pub fn num_partitions(&self) -> usize {
33 self.partitions.len()
34 }
35
36 pub fn iter(&self) -> IntegerPartitionIter<T> {
38 IntegerPartitionIter {
39 partitions: self.partitions.clone(),
40 current_cursor: 0,
41 end_exclusive: self.partitions.len(),
42 }
43 }
44
45 pub fn union(&self) -> Partition<T> {
48 let intervals: Vec<ContiguousIntegerSet<T>> = self
49 .partitions
50 .iter()
51 .flat_map(|p| p.get_intervals_by_ref().clone())
52 .collect();
53 OrderedIntegerSet::from_contiguous_integer_sets(intervals)
54 }
55}
56
57impl<T: Copy + Integer + ToPrimitive> Index<usize> for IntegerPartitions<T> {
58 type Output = Partition<T>;
59
60 #[inline]
61 fn index(&self, index: usize) -> &Self::Output {
62 &self.partitions[index]
63 }
64}
65
66pub struct IntegerPartitionIter<T: Copy + Integer + ToPrimitive> {
67 partitions: Vec<Partition<T>>,
68 current_cursor: usize,
69 end_exclusive: usize,
70}
71
72impl<T: Copy + Integer + ToPrimitive> IntegerPartitionIter<T> {
73 pub fn clone_with_range(
74 &self,
75 start: usize,
76 end_exclusive: usize,
77 ) -> IntegerPartitionIter<T> {
78 assert!(
79 start <= end_exclusive,
80 "start ({}) has to be <= end_exclusive ({})",
81 start,
82 end_exclusive
83 );
84 IntegerPartitionIter {
85 partitions: self.partitions[start..end_exclusive].to_vec(),
86 current_cursor: 0,
87 end_exclusive: end_exclusive - start,
88 }
89 }
90}
91
92impl<T: Copy + Integer + ToPrimitive> Iterator for IntegerPartitionIter<T> {
93 type Item = Partition<T>;
94
95 fn next(&mut self) -> Option<Self::Item> {
96 if self.current_cursor >= self.end_exclusive {
97 None
98 } else {
99 self.current_cursor += 1;
100 Some(self.partitions[self.current_cursor - 1].clone())
101 }
102 }
103}
104
105impl<T: Copy + Integer + ToPrimitive> ExactSizeIterator
106 for IntegerPartitionIter<T>
107{
108 fn len(&self) -> usize {
109 if self.current_cursor >= self.end_exclusive {
110 0
111 } else {
112 self.end_exclusive - self.current_cursor
113 }
114 }
115}
116
117impl<T: Copy + Integer + ToPrimitive> DoubleEndedIterator
118 for IntegerPartitionIter<T>
119{
120 fn next_back(&mut self) -> Option<Self::Item> {
121 if self.current_cursor >= self.end_exclusive {
122 None
123 } else {
124 self.end_exclusive -= 1;
125 Some(self.partitions[self.end_exclusive].clone())
126 }
127 }
128}
129
130impl<'a, T: Copy + Integer + Send + ToPrimitive> IntoParallelIterator
131 for IntegerPartitionIter<T>
132{
133 type Item = <IntegerPartitionParallelIter<T> as ParallelIterator>::Item;
134 type Iter = IntegerPartitionParallelIter<T>;
135
136 fn into_par_iter(self) -> Self::Iter {
137 IntegerPartitionParallelIter {
138 iter: self,
139 }
140 }
141}
142
143pub struct IntegerPartitionParallelIter<T: Copy + Integer + ToPrimitive> {
144 iter: IntegerPartitionIter<T>,
145}
146
147impl<T: Copy + Integer + Send + ToPrimitive> ParallelIterator
148 for IntegerPartitionParallelIter<T>
149{
150 type Item = <IntegerPartitionIter<T> as Iterator>::Item;
151
152 fn drive_unindexed<C>(self, consumer: C) -> C::Result
153 where
154 C: UnindexedConsumer<Self::Item>, {
155 bridge(self, consumer)
156 }
157
158 fn opt_len(&self) -> Option<usize> {
159 Some(self.iter.len())
160 }
161}
162
163impl<T: Copy + Integer + Send + ToPrimitive> IndexedParallelIterator
164 for IntegerPartitionParallelIter<T>
165{
166 fn len(&self) -> usize {
167 self.iter.len()
168 }
169
170 fn drive<C>(self, consumer: C) -> C::Result
171 where
172 C: Consumer<Self::Item>, {
173 bridge(self, consumer)
174 }
175
176 fn with_producer<CB>(self, callback: CB) -> CB::Output
177 where
178 CB: ProducerCallback<Self::Item>, {
179 callback.callback(IntegerPartitionIterProducer {
180 iter: self.iter,
181 })
182 }
183}
184
185struct IntegerPartitionIterProducer<T: Copy + Integer + ToPrimitive> {
186 iter: IntegerPartitionIter<T>,
187}
188
189impl<T: Copy + Integer + Send + ToPrimitive> Producer
190 for IntegerPartitionIterProducer<T>
191{
192 type IntoIter = IntegerPartitionIter<T>;
193 type Item = <IntegerPartitionIter<T> as Iterator>::Item;
194
195 #[inline]
196 fn into_iter(self) -> Self::IntoIter {
197 self.iter
198 }
199
200 fn split_at(self, index: usize) -> (Self, Self) {
201 (
202 IntegerPartitionIterProducer {
203 iter: self.iter.clone_with_range(0, index),
204 },
205 IntegerPartitionIterProducer {
206 iter: self.iter.clone_with_range(index, self.iter.len()),
207 },
208 )
209 }
210}
211
212impl<T: Copy + Integer + ToPrimitive> IntoIterator
213 for IntegerPartitionIterProducer<T>
214{
215 type IntoIter = IntegerPartitionIter<T>;
216 type Item = <IntegerPartitionIter<T> as Iterator>::Item;
217
218 #[inline]
219 fn into_iter(self) -> Self::IntoIter {
220 self.iter
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use crate::{
227 partition::integer_partitions::{IntegerPartitions, Partition},
228 set::{ordered_integer_set::OrderedIntegerSet, traits::Finite},
229 };
230 use rayon::iter::{IntoParallelIterator, ParallelIterator};
231
232 #[test]
233 fn test_num_partitions() {
234 assert_eq!(IntegerPartitions::<usize>::new(vec![]).num_partitions(), 0);
235 assert_eq!(IntegerPartitions::<i32>::new(vec![]).num_partitions(), 0);
236 assert_eq!(IntegerPartitions::<i64>::new(vec![]).num_partitions(), 0);
237 assert_eq!(
238 IntegerPartitions::new(vec![
239 Partition::from_slice(&[[0i32, 2]]),
240 Partition::from_slice(&[[4, 8], [15, 21]]),
241 ])
242 .num_partitions(),
243 2
244 );
245 assert_eq!(
246 IntegerPartitions::new(vec![Partition::from_slice(&[
247 [2usize, 4],
248 [5, 6],
249 [10, 11]
250 ])])
251 .num_partitions(),
252 1
253 );
254 }
255
256 #[test]
257 fn test_partitions_union() {
258 let partitions = IntegerPartitions::<i32>::new(vec![
259 Partition::from_slice(&[[1, 3], [8, 9]]),
260 Partition::from_slice(&[[4, 5], [10, 14]]),
261 Partition::from_slice(&[[21, 24]]),
262 ]);
263 assert_eq!(
264 partitions.union(),
265 Partition::<i32>::from_slice(&[
266 [1, 3],
267 [4, 5],
268 [8, 9],
269 [10, 14],
270 [21, 24]
271 ])
272 );
273 }
274
275 #[test]
276 fn test_partitions_iter() {
277 macro_rules! test_with_type {
278 ($itype:ty) => {
279 let partition_list = vec![
280 Partition::<$itype>::from_slice(&[[0, 2], [9, 11]]),
281 Partition::from_slice(&[[4, 8], [15, 21]]),
282 ];
283 let partitions = IntegerPartitions::new(partition_list.clone());
284 for (actual, expected) in
285 partitions.iter().zip(partition_list.iter())
286 {
287 assert_eq!(&actual, expected);
288 }
289 };
290 }
291 test_with_type!(usize);
292 test_with_type!(i32);
293 test_with_type!(i64);
294 }
295
296 #[test]
297 fn test_partitions_next_back() {
298 let partitions = IntegerPartitions::new(vec![
299 OrderedIntegerSet::from_slice(&[[1, 3], [6, 9]]),
300 OrderedIntegerSet::from_slice(&[[4, 5], [10, 14]]),
301 OrderedIntegerSet::from_slice(&[[15, 20], [25, 26]]),
302 OrderedIntegerSet::from_slice(&[[21, 24]]),
303 ]);
304 assert_eq!(
305 partitions.iter().nth_back(2).unwrap(),
306 Partition::<i32>::from_slice(&[[4, 5], [10, 14]])
307 );
308 }
309
310 #[test]
311 fn test_integer_partition_exact_size_iter() {
312 assert_eq!(IntegerPartitions::<usize>::new(vec![]).iter().len(), 0);
313 assert_eq!(
314 IntegerPartitions::new(vec![OrderedIntegerSet::from_slice(&[
315 [-10, 20],
316 [30, 40]
317 ]),])
318 .iter()
319 .len(),
320 1
321 );
322
323 assert_eq!(
324 IntegerPartitions::new(vec![
325 OrderedIntegerSet::from_slice(&[[-1, 2], [6, 9]]),
326 OrderedIntegerSet::from_slice(&[[10, 14]]),
327 OrderedIntegerSet::from_slice(&[[15, 20], [25, 26]]),
328 OrderedIntegerSet::from_slice(&[[21, 24]]),
329 ])
330 .iter()
331 .len(),
332 4
333 );
334 }
335 #[test]
336 fn test_integer_partition_par_iter() {
337 let partitions = IntegerPartitions::new(vec![
338 OrderedIntegerSet::from_slice(&[[1, 3], [6, 9]]),
339 OrderedIntegerSet::from_slice(&[[4, 5], [10, 14]]),
340 OrderedIntegerSet::from_slice(&[[15, 20], [25, 26]]),
341 OrderedIntegerSet::from_slice(&[[21, 24]]),
342 ]);
343 let mut iter = partitions.iter();
344 assert_eq!(
345 iter.next(),
346 Some(OrderedIntegerSet::from_slice(&[[1, 3], [6, 9]]))
347 );
348 assert_eq!(
349 iter.next(),
350 Some(OrderedIntegerSet::from_slice(&[[4, 5], [10, 14]]))
351 );
352 assert_eq!(
353 iter.next(),
354 Some(OrderedIntegerSet::from_slice(&[[15, 20], [25, 26]]))
355 );
356 assert_eq!(
357 iter.next(),
358 Some(OrderedIntegerSet::from_slice(&[[21, 24]]))
359 );
360 assert_eq!(iter.next(), None);
361
362 let num_elements: usize =
363 partitions.iter().into_par_iter().map(|p| p.size()).sum();
364 assert_eq!(num_elements, 26);
365 }
366}