algebraeon_sets/combinatorics/
set_partitions.rs

1use indexmap::IndexMap;
2use itertools::Itertools;
3use std::collections::HashSet;
4
5#[derive(Clone, Debug)]
6pub struct Partition {
7    partition: Vec<HashSet<usize>>, // the partition
8    lookup: Vec<usize>,             // for each element, the index of its part
9}
10
11impl Partition {
12    #[cfg(any(debug_assertions, test))]
13    fn check_state(&self) -> Result<(), &'static str> {
14        use std::collections::HashMap;
15        let mut present = HashMap::new();
16        let n = self.lookup.len();
17        for (idx, part) in self.partition.iter().enumerate() {
18            if part.is_empty() {
19                return Err("Partition contains an empty part");
20            }
21            for &x in part {
22                if n <= x {
23                    return Err("Partition contains element which is too big");
24                }
25                if present.contains_key(&x) {
26                    return Err("Duplicate element in partition");
27                }
28                present.insert(x, idx);
29            }
30        }
31        for x in 0..n {
32            if !present.contains_key(&x) {
33                return Err("Missing element from partition");
34            }
35            if present.get(&x).unwrap() != &self.lookup[x] {
36                return Err("Incorrect entry in lookup");
37            }
38        }
39        Ok(())
40    }
41
42    pub fn new_unchecked(partition: Vec<HashSet<usize>>, lookup: Vec<usize>) -> Self {
43        let partition = Self { partition, lookup };
44        #[cfg(debug_assertions)]
45        partition.check_state().unwrap();
46        partition
47    }
48
49    pub fn new_from_function<T: Clone + Eq + std::hash::Hash>(
50        n: usize,
51        f: impl Fn(usize) -> T,
52    ) -> (Self, Vec<T>) {
53        let mut t_lookup = vec![];
54        for x in 0..n {
55            t_lookup.push(f(x));
56        }
57        let mut t_partition: IndexMap<_, Vec<usize>> = IndexMap::new();
58        #[allow(clippy::needless_range_loop)]
59        for x in 0..n {
60            let t = &t_lookup[x];
61            if t_partition.contains_key(&t) {
62                t_partition.get_mut(&t).unwrap().push(x);
63            } else {
64                t_partition.insert(t, vec![x]);
65            }
66        }
67
68        let lookup = (0..n)
69            .map(|x| t_partition.get_index_of(&t_lookup[x]).unwrap())
70            .collect();
71        let partition = t_partition
72            .iter()
73            .map(|(_t, part)| part.iter().copied().collect())
74            .collect();
75
76        let partition = Partition::new_unchecked(partition, lookup);
77        #[cfg(debug_assertions)]
78        partition.check_state().unwrap();
79        (
80            partition,
81            t_partition
82                .into_iter()
83                .map(|(t, _part)| t.clone())
84                .collect(),
85        )
86    }
87
88    pub fn project(&self, x: usize) -> usize {
89        self.lookup[x]
90    }
91
92    pub fn class_containing(&self, x: usize) -> &HashSet<usize> {
93        self.get_class(self.project(x))
94    }
95
96    pub fn get_class(&self, i: usize) -> &HashSet<usize> {
97        &self.partition[i]
98    }
99
100    pub fn num_elements(&self) -> usize {
101        self.lookup.len()
102    }
103
104    pub fn num_classes(&self) -> usize {
105        self.partition.len()
106    }
107
108    pub fn size(&self) -> usize {
109        self.partition.len()
110    }
111}
112
113#[derive(Debug, Clone)]
114pub struct Element {
115    x: usize,
116    cum_x: usize,
117    pivot: bool,
118}
119
120#[derive(Debug, Clone)]
121pub struct LexicographicPartitionsNumPartsInRange {
122    // how many elements in the set
123    n: usize,
124    // min and max number of parts in the partition
125    min_x: usize,
126    max_x: usize,
127    elements: Vec<Element>,
128    finished: bool,
129}
130
131impl LexicographicPartitionsNumPartsInRange {
132    #[allow(clippy::unnecessary_wraps)]
133    #[cfg(debug_assertions)]
134    fn check(&self) -> Result<(), ()> {
135        // check invariants
136        if !self.finished {
137            assert_eq!(self.elements.len(), self.n);
138            assert_eq!(self.elements[0].x, 0);
139            assert_eq!(self.elements[0].cum_x, 0);
140            assert!(self.elements[0].pivot);
141            let mut cum_max = 0;
142            for i in 1..self.n {
143                if self.elements[i].x <= cum_max {
144                    assert_eq!(self.elements[i].cum_x, cum_max);
145                    assert!(!self.elements[i].pivot);
146                } else if self.elements[i].x == cum_max + 1 {
147                    cum_max += 1;
148                    assert_eq!(self.elements[i].cum_x, cum_max);
149                    assert!(self.elements[i].pivot);
150                } else {
151                    panic!();
152                }
153            }
154            cum_max += 1;
155            assert!(self.min_x <= cum_max);
156            assert!(cum_max <= self.max_x);
157        }
158        Ok(())
159    }
160
161    pub fn new(n: usize, min_x: usize, max_x: usize) -> Self {
162        let mut elements = vec![];
163        for i in 0..n {
164            elements.push(Element {
165                x: 0,
166                cum_x: 0,
167                pivot: i == 0,
168            });
169        }
170        let mut s = Self {
171            n,
172            min_x,
173            max_x,
174            elements,
175            finished: false,
176        };
177        if (n == 0 && min_x > 0) || (n > 0 && max_x == 0) || (n < min_x) || (min_x > max_x) {
178            s.finished = true;
179        }
180        if n > 0 {
181            s.reset_tail(0);
182        }
183        s
184    }
185
186    fn reset_tail(&mut self, j: usize) {
187        let cum_max_j = self.elements[j].cum_x;
188        // if !self.elements[j].pivot {
189        //     self.elements[j].part = 0;
190        // }
191        for i in (j + 1)..self.n {
192            let rev_i = self.n - i;
193            let x = if rev_i <= self.min_x {
194                let x = self.min_x - rev_i;
195                if x > cum_max_j { x } else { 0 }
196            } else {
197                0
198            };
199            self.elements[i] = Element {
200                x,
201                cum_x: if x == 0 { cum_max_j } else { x },
202                pivot: x != 0,
203            };
204        }
205        #[cfg(debug_assertions)]
206        self.check().unwrap();
207    }
208}
209
210impl Iterator for LexicographicPartitionsNumPartsInRange {
211    type Item = Vec<usize>;
212
213    fn next(&mut self) -> Option<Self::Item> {
214        if self.finished {
215            None
216        } else {
217            let next = (0..self.n).map(|i| self.elements[i].x).collect();
218            'SEARCH: {
219                for i in (0..self.n).rev() {
220                    if !self.elements[i].pivot {
221                        let max = self.elements[i].cum_x;
222                        let x = &mut self.elements[i].x;
223                        if *x + 1 < self.max_x {
224                            #[allow(clippy::comparison_chain)]
225                            if *x < max {
226                                *x += 1;
227                                self.reset_tail(i);
228                                break 'SEARCH;
229                            } else if *x == max {
230                                *x += 1;
231                                self.elements[i].cum_x += 1;
232                                self.elements[i].pivot = true;
233                                self.reset_tail(i);
234                                break 'SEARCH;
235                            }
236                        }
237                    }
238                }
239                self.finished = true;
240            }
241            Some(next)
242        }
243    }
244}
245
246pub fn set_partitions_eq(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
247    LexicographicPartitionsNumPartsInRange::new(n, x, x)
248}
249
250pub fn set_partitions_le(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
251    LexicographicPartitionsNumPartsInRange::new(n, 0, x)
252}
253
254pub fn set_partitions_ge(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
255    LexicographicPartitionsNumPartsInRange::new(n, x, n)
256}
257
258pub fn set_partitions_range(
259    n: usize,
260    min_x: usize,
261    max_x: usize,
262) -> impl Iterator<Item = Vec<usize>> {
263    LexicographicPartitionsNumPartsInRange::new(n, min_x, max_x)
264}
265
266pub fn set_compositions_eq(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
267    (0..x).permutations(x).flat_map(move |perm| {
268        set_partitions_eq(n, x)
269            .map(move |partition| partition.into_iter().map(|i| perm[i]).collect())
270    })
271}
272
273#[cfg(test)]
274mod partition_tests {
275    use super::*;
276
277    #[allow(clippy::assertions_on_constants)]
278    #[test]
279    fn partition_check_bad_state() {
280        //not a covering set
281        let p = Partition {
282            partition: vec![
283                vec![0, 2].into_iter().collect(),
284                vec![3, 5].into_iter().collect(),
285            ],
286            lookup: vec![0, 0, 0, 1, 1, 1],
287        };
288        if let Ok(()) = p.check_state() {
289            assert!(false);
290        }
291
292        //not disjoint
293        let p = Partition {
294            partition: vec![
295                vec![0, 1, 2, 3].into_iter().collect(),
296                vec![2, 3, 4, 5].into_iter().collect(),
297            ],
298            lookup: vec![0, 0, 0, 0, 1, 1],
299        };
300        if let Ok(()) = p.check_state() {
301            assert!(false);
302        }
303
304        //lookup values too big
305        let p = Partition {
306            partition: vec![
307                vec![0, 1, 2].into_iter().collect(),
308                vec![3, 4, 5].into_iter().collect(),
309            ],
310            lookup: vec![0, 0, 0, 1, 1, 2],
311        };
312        if let Ok(()) = p.check_state() {
313            assert!(false);
314        }
315
316        //incorrect lookup values
317        let p = Partition {
318            partition: vec![
319                vec![0, 1, 2].into_iter().collect(),
320                vec![3, 4, 5].into_iter().collect(),
321            ],
322            lookup: vec![0, 0, 1, 1, 1, 1],
323        };
324        if let Ok(()) = p.check_state() {
325            assert!(false);
326        }
327    }
328
329    #[test]
330    fn from_function() {
331        let (p, _ts) = Partition::new_from_function(6, |x| x % 2);
332        println!("p = {:?}", p);
333        assert_eq!(p.num_elements(), 6);
334        assert_eq!(p.num_classes(), 2);
335    }
336
337    #[allow(clippy::too_many_lines)]
338    #[test]
339    fn generate_set_partitions() {
340        assert_eq!(
341            LexicographicPartitionsNumPartsInRange::new(0, 0, 0)
342                .collect::<Vec<_>>()
343                .len(),
344            1
345        );
346        assert_eq!(
347            LexicographicPartitionsNumPartsInRange::new(0, 1, 1)
348                .collect::<Vec<_>>()
349                .len(),
350            0
351        );
352        assert_eq!(
353            LexicographicPartitionsNumPartsInRange::new(0, 2, 2)
354                .collect::<Vec<_>>()
355                .len(),
356            0
357        );
358        assert_eq!(
359            LexicographicPartitionsNumPartsInRange::new(0, 3, 3)
360                .collect::<Vec<_>>()
361                .len(),
362            0
363        );
364
365        assert_eq!(
366            LexicographicPartitionsNumPartsInRange::new(1, 0, 0)
367                .collect::<Vec<_>>()
368                .len(),
369            0
370        );
371        assert_eq!(
372            LexicographicPartitionsNumPartsInRange::new(1, 1, 1)
373                .collect::<Vec<_>>()
374                .len(),
375            1
376        );
377        assert_eq!(
378            LexicographicPartitionsNumPartsInRange::new(1, 2, 2)
379                .collect::<Vec<_>>()
380                .len(),
381            0
382        );
383        assert_eq!(
384            LexicographicPartitionsNumPartsInRange::new(1, 3, 3)
385                .collect::<Vec<_>>()
386                .len(),
387            0
388        );
389
390        assert_eq!(
391            LexicographicPartitionsNumPartsInRange::new(2, 0, 0)
392                .collect::<Vec<_>>()
393                .len(),
394            0
395        );
396        assert_eq!(
397            LexicographicPartitionsNumPartsInRange::new(2, 1, 1)
398                .collect::<Vec<_>>()
399                .len(),
400            1
401        );
402        assert_eq!(
403            LexicographicPartitionsNumPartsInRange::new(2, 2, 2)
404                .collect::<Vec<_>>()
405                .len(),
406            1
407        );
408        assert_eq!(
409            LexicographicPartitionsNumPartsInRange::new(2, 3, 3)
410                .collect::<Vec<_>>()
411                .len(),
412            0
413        );
414
415        assert_eq!(
416            LexicographicPartitionsNumPartsInRange::new(3, 0, 0)
417                .collect::<Vec<_>>()
418                .len(),
419            0
420        );
421        assert_eq!(
422            LexicographicPartitionsNumPartsInRange::new(3, 1, 1)
423                .collect::<Vec<_>>()
424                .len(),
425            1
426        );
427        assert_eq!(
428            LexicographicPartitionsNumPartsInRange::new(3, 2, 2)
429                .collect::<Vec<_>>()
430                .len(),
431            3
432        );
433        assert_eq!(
434            LexicographicPartitionsNumPartsInRange::new(3, 3, 3)
435                .collect::<Vec<_>>()
436                .len(),
437            1
438        );
439
440        assert_eq!(
441            LexicographicPartitionsNumPartsInRange::new(4, 5, 3)
442                .collect::<Vec<_>>()
443                .len(),
444            0
445        );
446    }
447}