1use indexmap::IndexMap;
2use std::collections::HashSet;
3
4#[derive(Clone, Debug)]
5pub struct Partition {
6 partition: Vec<HashSet<usize>>, lookup: Vec<usize>, }
9
10impl Partition {
11 #[cfg(any(debug_assertions, test))]
12 fn check_state(&self) -> Result<(), &'static str> {
13 use std::collections::HashMap;
14 let mut present = HashMap::new();
15 let n = self.lookup.len();
16 for (idx, part) in self.partition.iter().enumerate() {
17 if part.len() == 0 {
18 return Err("Partition contains an empty part");
19 }
20 for &x in part {
21 if n <= x {
22 return Err("Partition contains element which is too big");
23 }
24 if present.contains_key(&x) {
25 return Err("Duplicate element in partition");
26 }
27 present.insert(x, idx);
28 }
29 }
30 for x in 0..n {
31 if !present.contains_key(&x) {
32 return Err("Missing element from partition");
33 }
34 if present.get(&x).unwrap() != &self.lookup[x] {
35 return Err("Incorrect entry in lookup");
36 }
37 }
38 Ok(())
39 }
40
41 pub fn new_unchecked(partition: Vec<HashSet<usize>>, lookup: Vec<usize>) -> Self {
42 let partition = Self { partition, lookup };
43 #[cfg(debug_assertions)]
44 partition.check_state().unwrap();
45 partition
46 }
47
48 pub fn new_from_function<T: Clone + Eq + std::hash::Hash>(
49 n: usize,
50 f: impl Fn(usize) -> T,
51 ) -> (Self, Vec<T>) {
52 let mut t_lookup = vec![];
53 for x in 0..n {
54 t_lookup.push(f(x));
55 }
56 let mut t_partition = IndexMap::new();
57 for x in 0..n {
58 let t = &t_lookup[x];
59 if !t_partition.contains_key(&t) {
60 t_partition.insert(t, vec![x]);
61 } else {
62 t_partition.get_mut(&t).unwrap().push(x)
63 }
64 }
65
66 let lookup = (0..n)
67 .map(|x| t_partition.get_index_of(&t_lookup[x]).unwrap())
68 .collect();
69 let partition = t_partition
70 .iter()
71 .map(|(_t, part)| part.iter().cloned().collect())
72 .collect();
73
74 let partition = Partition::new_unchecked(partition, lookup);
75 #[cfg(debug_assertions)]
76 partition.check_state().unwrap();
77 (
78 partition,
79 t_partition
80 .into_iter()
81 .map(|(t, _part)| t.clone())
82 .collect(),
83 )
84 }
85
86 pub fn project(&self, x: usize) -> usize {
87 self.lookup[x]
88 }
89
90 pub fn class_containing(&self, x: usize) -> &HashSet<usize> {
91 self.get_class(self.project(x))
92 }
93
94 pub fn get_class(&self, i: usize) -> &HashSet<usize> {
95 &self.partition[i]
96 }
97
98 pub fn num_elements(&self) -> usize {
99 self.lookup.len()
100 }
101
102 pub fn num_classes(&self) -> usize {
103 self.partition.len()
104 }
105
106 pub fn size(&self) -> usize {
107 self.partition.len()
108 }
109}
110
111#[derive(Debug, Clone)]
112pub struct Element {
113 x: usize,
114 cum_x: usize,
115 pivot: bool,
116}
117
118#[derive(Debug, Clone)]
119pub struct LexographicPartitionsNumPartsInRange {
120 n: usize,
122 min_x: usize,
124 max_x: usize,
125 elements: Vec<Element>,
126 finished: bool,
127}
128
129impl LexographicPartitionsNumPartsInRange {
130 #[cfg(debug_assertions)]
131 fn check(&self) -> Result<(), ()> {
132 if !self.finished {
134 assert_eq!(self.elements.len(), self.n);
135 assert_eq!(self.elements[0].x, 0);
136 assert_eq!(self.elements[0].cum_x, 0);
137 assert_eq!(self.elements[0].pivot, true);
138 let mut cum_max = 0;
139 for i in 1..self.n {
140 if self.elements[i].x <= cum_max {
141 assert_eq!(self.elements[i].cum_x, cum_max);
142 assert_eq!(self.elements[i].pivot, false);
143 } else if self.elements[i].x == cum_max + 1 {
144 cum_max += 1;
145 assert_eq!(self.elements[i].cum_x, cum_max);
146 assert_eq!(self.elements[i].pivot, true);
147 } else {
148 panic!();
149 }
150 }
151 cum_max += 1;
152 assert!(self.min_x <= cum_max);
153 assert!(cum_max <= self.max_x);
154 }
155 Ok(())
156 }
157
158 pub fn new(n: usize, min_x: usize, max_x: usize) -> Self {
159 let mut elements = vec![];
160 for i in 0..n {
161 elements.push(Element {
162 x: 0,
163 cum_x: 0,
164 pivot: i == 0,
165 })
166 }
167 let mut s = Self {
168 n,
169 min_x,
170 max_x,
171 elements,
172 finished: false,
173 };
174 if (n == 0 && min_x > 0) || (n > 0 && max_x == 0) || (n < min_x) || (min_x > max_x) {
175 s.finished = true;
176 }
177 if n > 0 {
178 s.reset_tail(0);
179 }
180 s
181 }
182
183 fn reset_tail(&mut self, j: usize) {
184 let cum_max_j = self.elements[j].cum_x;
185 for i in (j + 1)..self.n {
189 let rev_i = self.n - i;
190 let x = if rev_i <= self.min_x {
191 let x = self.min_x - rev_i;
192 if x > cum_max_j { x } else { 0 }
193 } else {
194 0
195 };
196 self.elements[i] = Element {
197 x,
198 cum_x: if x == 0 { cum_max_j } else { x },
199 pivot: x != 0,
200 };
201 }
202 #[cfg(debug_assertions)]
203 self.check().unwrap();
204 }
205}
206
207impl Iterator for LexographicPartitionsNumPartsInRange {
208 type Item = Vec<usize>;
209
210 fn next(&mut self) -> Option<Self::Item> {
211 if self.finished {
212 None
213 } else {
214 let next = (0..self.n).map(|i| self.elements[i].x).collect();
215 'SEARCH: {
216 for i in (0..self.n).rev() {
217 if !self.elements[i].pivot {
218 let max = self.elements[i].cum_x;
219 let x = &mut self.elements[i].x;
220 if *x + 1 < self.max_x {
221 if *x < max {
222 *x += 1;
223 self.reset_tail(i);
224 break 'SEARCH;
225 } else if *x == max {
226 *x += 1;
227 self.elements[i].cum_x += 1;
228 self.elements[i].pivot = true;
229 self.reset_tail(i);
230 break 'SEARCH;
231 }
232 }
233 }
234 }
235 self.finished = true;
236 }
237 Some(next)
238 }
239 }
240}
241
242pub fn set_partitions_eq(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
243 LexographicPartitionsNumPartsInRange::new(n, x, x)
244}
245
246pub fn set_partitions_le(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
247 LexographicPartitionsNumPartsInRange::new(n, 0, x)
248}
249
250pub fn set_partitions_ge(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
251 LexographicPartitionsNumPartsInRange::new(n, x, n)
252}
253
254pub fn set_partitions_range(
255 n: usize,
256 min_x: usize,
257 max_x: usize,
258) -> impl Iterator<Item = Vec<usize>> {
259 LexographicPartitionsNumPartsInRange::new(n, min_x, max_x)
260}
261
262#[cfg(test)]
263mod partition_tests {
264 use super::*;
265
266 #[test]
267 fn partition_check_bad_state() {
268 let p = Partition {
270 partition: vec![
271 vec![0, 2].into_iter().collect(),
272 vec![3, 5].into_iter().collect(),
273 ],
274 lookup: vec![0, 0, 0, 1, 1, 1],
275 };
276 match p.check_state() {
277 Ok(()) => assert!(false),
278 Err(_) => {}
279 }
280
281 let p = Partition {
283 partition: vec![
284 vec![0, 1, 2, 3].into_iter().collect(),
285 vec![2, 3, 4, 5].into_iter().collect(),
286 ],
287 lookup: vec![0, 0, 0, 0, 1, 1],
288 };
289 match p.check_state() {
290 Ok(()) => assert!(false),
291 Err(_) => {}
292 }
293
294 let p = Partition {
296 partition: vec![
297 vec![0, 1, 2].into_iter().collect(),
298 vec![3, 4, 5].into_iter().collect(),
299 ],
300 lookup: vec![0, 0, 0, 1, 1, 2],
301 };
302 match p.check_state() {
303 Ok(()) => assert!(false),
304 Err(_) => {}
305 }
306
307 let p = Partition {
309 partition: vec![
310 vec![0, 1, 2].into_iter().collect(),
311 vec![3, 4, 5].into_iter().collect(),
312 ],
313 lookup: vec![0, 0, 1, 1, 1, 1],
314 };
315 match p.check_state() {
316 Ok(()) => assert!(false),
317 Err(_) => {}
318 }
319 }
320
321 #[test]
322 fn from_function() {
323 let (p, _ts) = Partition::new_from_function(6, |x| x % 2);
324 println!("p = {:?}", p);
325 assert_eq!(p.num_elements(), 6);
326 assert_eq!(p.num_classes(), 2);
327 }
328
329 #[test]
330 fn generate_set_partitions() {
331 assert_eq!(
332 LexographicPartitionsNumPartsInRange::new(0, 0, 0)
333 .collect::<Vec<_>>()
334 .len(),
335 1
336 );
337 assert_eq!(
338 LexographicPartitionsNumPartsInRange::new(0, 1, 1)
339 .collect::<Vec<_>>()
340 .len(),
341 0
342 );
343 assert_eq!(
344 LexographicPartitionsNumPartsInRange::new(0, 2, 2)
345 .collect::<Vec<_>>()
346 .len(),
347 0
348 );
349 assert_eq!(
350 LexographicPartitionsNumPartsInRange::new(0, 3, 3)
351 .collect::<Vec<_>>()
352 .len(),
353 0
354 );
355
356 assert_eq!(
357 LexographicPartitionsNumPartsInRange::new(1, 0, 0)
358 .collect::<Vec<_>>()
359 .len(),
360 0
361 );
362 assert_eq!(
363 LexographicPartitionsNumPartsInRange::new(1, 1, 1)
364 .collect::<Vec<_>>()
365 .len(),
366 1
367 );
368 assert_eq!(
369 LexographicPartitionsNumPartsInRange::new(1, 2, 2)
370 .collect::<Vec<_>>()
371 .len(),
372 0
373 );
374 assert_eq!(
375 LexographicPartitionsNumPartsInRange::new(1, 3, 3)
376 .collect::<Vec<_>>()
377 .len(),
378 0
379 );
380
381 assert_eq!(
382 LexographicPartitionsNumPartsInRange::new(2, 0, 0)
383 .collect::<Vec<_>>()
384 .len(),
385 0
386 );
387 assert_eq!(
388 LexographicPartitionsNumPartsInRange::new(2, 1, 1)
389 .collect::<Vec<_>>()
390 .len(),
391 1
392 );
393 assert_eq!(
394 LexographicPartitionsNumPartsInRange::new(2, 2, 2)
395 .collect::<Vec<_>>()
396 .len(),
397 1
398 );
399 assert_eq!(
400 LexographicPartitionsNumPartsInRange::new(2, 3, 3)
401 .collect::<Vec<_>>()
402 .len(),
403 0
404 );
405
406 assert_eq!(
407 LexographicPartitionsNumPartsInRange::new(3, 0, 0)
408 .collect::<Vec<_>>()
409 .len(),
410 0
411 );
412 assert_eq!(
413 LexographicPartitionsNumPartsInRange::new(3, 1, 1)
414 .collect::<Vec<_>>()
415 .len(),
416 1
417 );
418 assert_eq!(
419 LexographicPartitionsNumPartsInRange::new(3, 2, 2)
420 .collect::<Vec<_>>()
421 .len(),
422 3
423 );
424 assert_eq!(
425 LexographicPartitionsNumPartsInRange::new(3, 3, 3)
426 .collect::<Vec<_>>()
427 .len(),
428 1
429 );
430
431 assert_eq!(
432 LexographicPartitionsNumPartsInRange::new(4, 5, 3)
433 .collect::<Vec<_>>()
434 .len(),
435 0
436 );
437 }
438}