1use indexmap::IndexMap;
2use itertools::Itertools;
3use std::collections::HashSet;
4
5#[derive(Clone, Debug)]
6pub struct Partition {
7 partition: Vec<HashSet<usize>>, lookup: Vec<usize>, }
10
11impl Partition {
12 #[cfg(any(debug_assertions, test))]
13 fn check_state(&self) -> Result<(), &'static str> {
14 use std::collections::HashMap;
15 let mut present = HashMap::new();
16 let n = self.lookup.len();
17 for (idx, part) in self.partition.iter().enumerate() {
18 if part.is_empty() {
19 return Err("Partition contains an empty part");
20 }
21 for &x in part {
22 if n <= x {
23 return Err("Partition contains element which is too big");
24 }
25 if present.contains_key(&x) {
26 return Err("Duplicate element in partition");
27 }
28 present.insert(x, idx);
29 }
30 }
31 for x in 0..n {
32 if !present.contains_key(&x) {
33 return Err("Missing element from partition");
34 }
35 if present.get(&x).unwrap() != &self.lookup[x] {
36 return Err("Incorrect entry in lookup");
37 }
38 }
39 Ok(())
40 }
41
42 pub fn new_unchecked(partition: Vec<HashSet<usize>>, lookup: Vec<usize>) -> Self {
43 let partition = Self { partition, lookup };
44 #[cfg(debug_assertions)]
45 partition.check_state().unwrap();
46 partition
47 }
48
49 pub fn new_from_function<T: Clone + Eq + std::hash::Hash>(
50 n: usize,
51 f: impl Fn(usize) -> T,
52 ) -> (Self, Vec<T>) {
53 let mut t_lookup = vec![];
54 for x in 0..n {
55 t_lookup.push(f(x));
56 }
57 let mut t_partition: IndexMap<_, Vec<usize>> = IndexMap::new();
58 #[allow(clippy::needless_range_loop)]
59 for x in 0..n {
60 let t = &t_lookup[x];
61 if t_partition.contains_key(&t) {
62 t_partition.get_mut(&t).unwrap().push(x);
63 } else {
64 t_partition.insert(t, vec![x]);
65 }
66 }
67
68 let lookup = (0..n)
69 .map(|x| t_partition.get_index_of(&t_lookup[x]).unwrap())
70 .collect();
71 let partition = t_partition
72 .iter()
73 .map(|(_t, part)| part.iter().copied().collect())
74 .collect();
75
76 let partition = Partition::new_unchecked(partition, lookup);
77 #[cfg(debug_assertions)]
78 partition.check_state().unwrap();
79 (
80 partition,
81 t_partition
82 .into_iter()
83 .map(|(t, _part)| t.clone())
84 .collect(),
85 )
86 }
87
88 pub fn project(&self, x: usize) -> usize {
89 self.lookup[x]
90 }
91
92 pub fn class_containing(&self, x: usize) -> &HashSet<usize> {
93 self.get_class(self.project(x))
94 }
95
96 pub fn get_class(&self, i: usize) -> &HashSet<usize> {
97 &self.partition[i]
98 }
99
100 pub fn num_elements(&self) -> usize {
101 self.lookup.len()
102 }
103
104 pub fn num_classes(&self) -> usize {
105 self.partition.len()
106 }
107
108 pub fn size(&self) -> usize {
109 self.partition.len()
110 }
111}
112
113#[derive(Debug, Clone)]
114pub struct Element {
115 x: usize,
116 cum_x: usize,
117 pivot: bool,
118}
119
120#[derive(Debug, Clone)]
121pub struct LexicographicPartitionsNumPartsInRange {
122 n: usize,
124 min_x: usize,
126 max_x: usize,
127 elements: Vec<Element>,
128 finished: bool,
129}
130
131impl LexicographicPartitionsNumPartsInRange {
132 #[allow(clippy::unnecessary_wraps)]
133 #[cfg(debug_assertions)]
134 fn check(&self) -> Result<(), ()> {
135 if !self.finished {
137 assert_eq!(self.elements.len(), self.n);
138 assert_eq!(self.elements[0].x, 0);
139 assert_eq!(self.elements[0].cum_x, 0);
140 assert!(self.elements[0].pivot);
141 let mut cum_max = 0;
142 for i in 1..self.n {
143 if self.elements[i].x <= cum_max {
144 assert_eq!(self.elements[i].cum_x, cum_max);
145 assert!(!self.elements[i].pivot);
146 } else if self.elements[i].x == cum_max + 1 {
147 cum_max += 1;
148 assert_eq!(self.elements[i].cum_x, cum_max);
149 assert!(self.elements[i].pivot);
150 } else {
151 panic!();
152 }
153 }
154 cum_max += 1;
155 assert!(self.min_x <= cum_max);
156 assert!(cum_max <= self.max_x);
157 }
158 Ok(())
159 }
160
161 pub fn new(n: usize, min_x: usize, max_x: usize) -> Self {
162 let mut elements = vec![];
163 for i in 0..n {
164 elements.push(Element {
165 x: 0,
166 cum_x: 0,
167 pivot: i == 0,
168 });
169 }
170 let mut s = Self {
171 n,
172 min_x,
173 max_x,
174 elements,
175 finished: false,
176 };
177 if (n == 0 && min_x > 0) || (n > 0 && max_x == 0) || (n < min_x) || (min_x > max_x) {
178 s.finished = true;
179 }
180 if n > 0 {
181 s.reset_tail(0);
182 }
183 s
184 }
185
186 fn reset_tail(&mut self, j: usize) {
187 let cum_max_j = self.elements[j].cum_x;
188 for i in (j + 1)..self.n {
192 let rev_i = self.n - i;
193 let x = if rev_i <= self.min_x {
194 let x = self.min_x - rev_i;
195 if x > cum_max_j { x } else { 0 }
196 } else {
197 0
198 };
199 self.elements[i] = Element {
200 x,
201 cum_x: if x == 0 { cum_max_j } else { x },
202 pivot: x != 0,
203 };
204 }
205 #[cfg(debug_assertions)]
206 self.check().unwrap();
207 }
208}
209
210impl Iterator for LexicographicPartitionsNumPartsInRange {
211 type Item = Vec<usize>;
212
213 fn next(&mut self) -> Option<Self::Item> {
214 if self.finished {
215 None
216 } else {
217 let next = (0..self.n).map(|i| self.elements[i].x).collect();
218 'SEARCH: {
219 for i in (0..self.n).rev() {
220 if !self.elements[i].pivot {
221 let max = self.elements[i].cum_x;
222 let x = &mut self.elements[i].x;
223 if *x + 1 < self.max_x {
224 #[allow(clippy::comparison_chain)]
225 if *x < max {
226 *x += 1;
227 self.reset_tail(i);
228 break 'SEARCH;
229 } else if *x == max {
230 *x += 1;
231 self.elements[i].cum_x += 1;
232 self.elements[i].pivot = true;
233 self.reset_tail(i);
234 break 'SEARCH;
235 }
236 }
237 }
238 }
239 self.finished = true;
240 }
241 Some(next)
242 }
243 }
244}
245
246pub fn set_partitions_eq(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
247 LexicographicPartitionsNumPartsInRange::new(n, x, x)
248}
249
250pub fn set_partitions_le(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
251 LexicographicPartitionsNumPartsInRange::new(n, 0, x)
252}
253
254pub fn set_partitions_ge(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
255 LexicographicPartitionsNumPartsInRange::new(n, x, n)
256}
257
258pub fn set_partitions_range(
259 n: usize,
260 min_x: usize,
261 max_x: usize,
262) -> impl Iterator<Item = Vec<usize>> {
263 LexicographicPartitionsNumPartsInRange::new(n, min_x, max_x)
264}
265
266pub fn set_compositions_eq(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
267 (0..x).permutations(x).flat_map(move |perm| {
268 set_partitions_eq(n, x)
269 .map(move |partition| partition.into_iter().map(|i| perm[i]).collect())
270 })
271}
272
273#[cfg(test)]
274mod partition_tests {
275 use super::*;
276
277 #[allow(clippy::assertions_on_constants)]
278 #[test]
279 fn partition_check_bad_state() {
280 let p = Partition {
282 partition: vec![
283 vec![0, 2].into_iter().collect(),
284 vec![3, 5].into_iter().collect(),
285 ],
286 lookup: vec![0, 0, 0, 1, 1, 1],
287 };
288 if let Ok(()) = p.check_state() {
289 assert!(false);
290 }
291
292 let p = Partition {
294 partition: vec![
295 vec![0, 1, 2, 3].into_iter().collect(),
296 vec![2, 3, 4, 5].into_iter().collect(),
297 ],
298 lookup: vec![0, 0, 0, 0, 1, 1],
299 };
300 if let Ok(()) = p.check_state() {
301 assert!(false);
302 }
303
304 let p = Partition {
306 partition: vec![
307 vec![0, 1, 2].into_iter().collect(),
308 vec![3, 4, 5].into_iter().collect(),
309 ],
310 lookup: vec![0, 0, 0, 1, 1, 2],
311 };
312 if let Ok(()) = p.check_state() {
313 assert!(false);
314 }
315
316 let p = Partition {
318 partition: vec![
319 vec![0, 1, 2].into_iter().collect(),
320 vec![3, 4, 5].into_iter().collect(),
321 ],
322 lookup: vec![0, 0, 1, 1, 1, 1],
323 };
324 if let Ok(()) = p.check_state() {
325 assert!(false);
326 }
327 }
328
329 #[test]
330 fn from_function() {
331 let (p, _ts) = Partition::new_from_function(6, |x| x % 2);
332 println!("p = {:?}", p);
333 assert_eq!(p.num_elements(), 6);
334 assert_eq!(p.num_classes(), 2);
335 }
336
337 #[allow(clippy::too_many_lines)]
338 #[test]
339 fn generate_set_partitions() {
340 assert_eq!(
341 LexicographicPartitionsNumPartsInRange::new(0, 0, 0)
342 .collect::<Vec<_>>()
343 .len(),
344 1
345 );
346 assert_eq!(
347 LexicographicPartitionsNumPartsInRange::new(0, 1, 1)
348 .collect::<Vec<_>>()
349 .len(),
350 0
351 );
352 assert_eq!(
353 LexicographicPartitionsNumPartsInRange::new(0, 2, 2)
354 .collect::<Vec<_>>()
355 .len(),
356 0
357 );
358 assert_eq!(
359 LexicographicPartitionsNumPartsInRange::new(0, 3, 3)
360 .collect::<Vec<_>>()
361 .len(),
362 0
363 );
364
365 assert_eq!(
366 LexicographicPartitionsNumPartsInRange::new(1, 0, 0)
367 .collect::<Vec<_>>()
368 .len(),
369 0
370 );
371 assert_eq!(
372 LexicographicPartitionsNumPartsInRange::new(1, 1, 1)
373 .collect::<Vec<_>>()
374 .len(),
375 1
376 );
377 assert_eq!(
378 LexicographicPartitionsNumPartsInRange::new(1, 2, 2)
379 .collect::<Vec<_>>()
380 .len(),
381 0
382 );
383 assert_eq!(
384 LexicographicPartitionsNumPartsInRange::new(1, 3, 3)
385 .collect::<Vec<_>>()
386 .len(),
387 0
388 );
389
390 assert_eq!(
391 LexicographicPartitionsNumPartsInRange::new(2, 0, 0)
392 .collect::<Vec<_>>()
393 .len(),
394 0
395 );
396 assert_eq!(
397 LexicographicPartitionsNumPartsInRange::new(2, 1, 1)
398 .collect::<Vec<_>>()
399 .len(),
400 1
401 );
402 assert_eq!(
403 LexicographicPartitionsNumPartsInRange::new(2, 2, 2)
404 .collect::<Vec<_>>()
405 .len(),
406 1
407 );
408 assert_eq!(
409 LexicographicPartitionsNumPartsInRange::new(2, 3, 3)
410 .collect::<Vec<_>>()
411 .len(),
412 0
413 );
414
415 assert_eq!(
416 LexicographicPartitionsNumPartsInRange::new(3, 0, 0)
417 .collect::<Vec<_>>()
418 .len(),
419 0
420 );
421 assert_eq!(
422 LexicographicPartitionsNumPartsInRange::new(3, 1, 1)
423 .collect::<Vec<_>>()
424 .len(),
425 1
426 );
427 assert_eq!(
428 LexicographicPartitionsNumPartsInRange::new(3, 2, 2)
429 .collect::<Vec<_>>()
430 .len(),
431 3
432 );
433 assert_eq!(
434 LexicographicPartitionsNumPartsInRange::new(3, 3, 3)
435 .collect::<Vec<_>>()
436 .len(),
437 1
438 );
439
440 assert_eq!(
441 LexicographicPartitionsNumPartsInRange::new(4, 5, 3)
442 .collect::<Vec<_>>()
443 .len(),
444 0
445 );
446 }
447}