algebraeon_sets/combinatorics/
set_partitions.rs

1use indexmap::IndexMap;
2use std::collections::HashSet;
3
4#[derive(Clone, Debug)]
5pub struct Partition {
6    partition: Vec<HashSet<usize>>, // the partition
7    lookup: Vec<usize>,             // for each element, the index of its part
8}
9
10impl Partition {
11    #[cfg(any(debug_assertions, test))]
12    fn check_state(&self) -> Result<(), &'static str> {
13        use std::collections::HashMap;
14        let mut present = HashMap::new();
15        let n = self.lookup.len();
16        for (idx, part) in self.partition.iter().enumerate() {
17            if part.len() == 0 {
18                return Err("Partition contains an empty part");
19            }
20            for &x in part {
21                if n <= x {
22                    return Err("Partition contains element which is too big");
23                }
24                if present.contains_key(&x) {
25                    return Err("Duplicate element in partition");
26                }
27                present.insert(x, idx);
28            }
29        }
30        for x in 0..n {
31            if !present.contains_key(&x) {
32                return Err("Missing element from partition");
33            }
34            if present.get(&x).unwrap() != &self.lookup[x] {
35                return Err("Incorrect entry in lookup");
36            }
37        }
38        Ok(())
39    }
40
41    pub fn new_unchecked(partition: Vec<HashSet<usize>>, lookup: Vec<usize>) -> Self {
42        let partition = Self { partition, lookup };
43        #[cfg(debug_assertions)]
44        partition.check_state().unwrap();
45        partition
46    }
47
48    pub fn new_from_function<T: Clone + Eq + std::hash::Hash>(
49        n: usize,
50        f: impl Fn(usize) -> T,
51    ) -> (Self, Vec<T>) {
52        let mut t_lookup = vec![];
53        for x in 0..n {
54            t_lookup.push(f(x));
55        }
56        let mut t_partition = IndexMap::new();
57        for x in 0..n {
58            let t = &t_lookup[x];
59            if !t_partition.contains_key(&t) {
60                t_partition.insert(t, vec![x]);
61            } else {
62                t_partition.get_mut(&t).unwrap().push(x)
63            }
64        }
65
66        let lookup = (0..n)
67            .map(|x| t_partition.get_index_of(&t_lookup[x]).unwrap())
68            .collect();
69        let partition = t_partition
70            .iter()
71            .map(|(_t, part)| part.iter().cloned().collect())
72            .collect();
73
74        let partition = Partition::new_unchecked(partition, lookup);
75        #[cfg(debug_assertions)]
76        partition.check_state().unwrap();
77        (
78            partition,
79            t_partition
80                .into_iter()
81                .map(|(t, _part)| t.clone())
82                .collect(),
83        )
84    }
85
86    pub fn project(&self, x: usize) -> usize {
87        self.lookup[x]
88    }
89
90    pub fn class_containing(&self, x: usize) -> &HashSet<usize> {
91        self.get_class(self.project(x))
92    }
93
94    pub fn get_class(&self, i: usize) -> &HashSet<usize> {
95        &self.partition[i]
96    }
97
98    pub fn num_elements(&self) -> usize {
99        self.lookup.len()
100    }
101
102    pub fn num_classes(&self) -> usize {
103        self.partition.len()
104    }
105
106    pub fn size(&self) -> usize {
107        self.partition.len()
108    }
109}
110
111#[derive(Debug, Clone)]
112pub struct Element {
113    x: usize,
114    cum_x: usize,
115    pivot: bool,
116}
117
118#[derive(Debug, Clone)]
119pub struct LexographicPartitionsNumPartsInRange {
120    // how many elements in the set
121    n: usize,
122    // min and max number of parts in the partition
123    min_x: usize,
124    max_x: usize,
125    elements: Vec<Element>,
126    finished: bool,
127}
128
129impl LexographicPartitionsNumPartsInRange {
130    #[cfg(debug_assertions)]
131    fn check(&self) -> Result<(), ()> {
132        // check invariants
133        if !self.finished {
134            assert_eq!(self.elements.len(), self.n);
135            assert_eq!(self.elements[0].x, 0);
136            assert_eq!(self.elements[0].cum_x, 0);
137            assert_eq!(self.elements[0].pivot, true);
138            let mut cum_max = 0;
139            for i in 1..self.n {
140                if self.elements[i].x <= cum_max {
141                    assert_eq!(self.elements[i].cum_x, cum_max);
142                    assert_eq!(self.elements[i].pivot, false);
143                } else if self.elements[i].x == cum_max + 1 {
144                    cum_max += 1;
145                    assert_eq!(self.elements[i].cum_x, cum_max);
146                    assert_eq!(self.elements[i].pivot, true);
147                } else {
148                    panic!();
149                }
150            }
151            cum_max += 1;
152            assert!(self.min_x <= cum_max);
153            assert!(cum_max <= self.max_x);
154        }
155        Ok(())
156    }
157
158    pub fn new(n: usize, min_x: usize, max_x: usize) -> Self {
159        let mut elements = vec![];
160        for i in 0..n {
161            elements.push(Element {
162                x: 0,
163                cum_x: 0,
164                pivot: i == 0,
165            })
166        }
167        let mut s = Self {
168            n,
169            min_x,
170            max_x,
171            elements,
172            finished: false,
173        };
174        if (n == 0 && min_x > 0) || (n > 0 && max_x == 0) || (n < min_x) || (min_x > max_x) {
175            s.finished = true;
176        }
177        if n > 0 {
178            s.reset_tail(0);
179        }
180        s
181    }
182
183    fn reset_tail(&mut self, j: usize) {
184        let cum_max_j = self.elements[j].cum_x;
185        // if !self.elements[j].pivot {
186        //     self.elements[j].part = 0;
187        // }
188        for i in (j + 1)..self.n {
189            let rev_i = self.n - i;
190            let x = if rev_i <= self.min_x {
191                let x = self.min_x - rev_i;
192                if x > cum_max_j { x } else { 0 }
193            } else {
194                0
195            };
196            self.elements[i] = Element {
197                x,
198                cum_x: if x == 0 { cum_max_j } else { x },
199                pivot: x != 0,
200            };
201        }
202        #[cfg(debug_assertions)]
203        self.check().unwrap();
204    }
205}
206
207impl Iterator for LexographicPartitionsNumPartsInRange {
208    type Item = Vec<usize>;
209
210    fn next(&mut self) -> Option<Self::Item> {
211        if self.finished {
212            None
213        } else {
214            let next = (0..self.n).map(|i| self.elements[i].x).collect();
215            'SEARCH: {
216                for i in (0..self.n).rev() {
217                    if !self.elements[i].pivot {
218                        let max = self.elements[i].cum_x;
219                        let x = &mut self.elements[i].x;
220                        if *x + 1 < self.max_x {
221                            if *x < max {
222                                *x += 1;
223                                self.reset_tail(i);
224                                break 'SEARCH;
225                            } else if *x == max {
226                                *x += 1;
227                                self.elements[i].cum_x += 1;
228                                self.elements[i].pivot = true;
229                                self.reset_tail(i);
230                                break 'SEARCH;
231                            }
232                        }
233                    }
234                }
235                self.finished = true;
236            }
237            Some(next)
238        }
239    }
240}
241
242pub fn set_partitions_eq(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
243    LexographicPartitionsNumPartsInRange::new(n, x, x)
244}
245
246pub fn set_partitions_le(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
247    LexographicPartitionsNumPartsInRange::new(n, 0, x)
248}
249
250pub fn set_partitions_ge(n: usize, x: usize) -> impl Iterator<Item = Vec<usize>> {
251    LexographicPartitionsNumPartsInRange::new(n, x, n)
252}
253
254pub fn set_partitions_range(
255    n: usize,
256    min_x: usize,
257    max_x: usize,
258) -> impl Iterator<Item = Vec<usize>> {
259    LexographicPartitionsNumPartsInRange::new(n, min_x, max_x)
260}
261
262#[cfg(test)]
263mod partition_tests {
264    use super::*;
265
266    #[test]
267    fn partition_check_bad_state() {
268        //not a covering set
269        let p = Partition {
270            partition: vec![
271                vec![0, 2].into_iter().collect(),
272                vec![3, 5].into_iter().collect(),
273            ],
274            lookup: vec![0, 0, 0, 1, 1, 1],
275        };
276        match p.check_state() {
277            Ok(()) => assert!(false),
278            Err(_) => {}
279        }
280
281        //not disjoint
282        let p = Partition {
283            partition: vec![
284                vec![0, 1, 2, 3].into_iter().collect(),
285                vec![2, 3, 4, 5].into_iter().collect(),
286            ],
287            lookup: vec![0, 0, 0, 0, 1, 1],
288        };
289        match p.check_state() {
290            Ok(()) => assert!(false),
291            Err(_) => {}
292        }
293
294        //lookup values too big
295        let p = Partition {
296            partition: vec![
297                vec![0, 1, 2].into_iter().collect(),
298                vec![3, 4, 5].into_iter().collect(),
299            ],
300            lookup: vec![0, 0, 0, 1, 1, 2],
301        };
302        match p.check_state() {
303            Ok(()) => assert!(false),
304            Err(_) => {}
305        }
306
307        //incorrect lookup values
308        let p = Partition {
309            partition: vec![
310                vec![0, 1, 2].into_iter().collect(),
311                vec![3, 4, 5].into_iter().collect(),
312            ],
313            lookup: vec![0, 0, 1, 1, 1, 1],
314        };
315        match p.check_state() {
316            Ok(()) => assert!(false),
317            Err(_) => {}
318        }
319    }
320
321    #[test]
322    fn from_function() {
323        let (p, _ts) = Partition::new_from_function(6, |x| x % 2);
324        println!("p = {:?}", p);
325        assert_eq!(p.num_elements(), 6);
326        assert_eq!(p.num_classes(), 2);
327    }
328
329    #[test]
330    fn generate_set_partitions() {
331        assert_eq!(
332            LexographicPartitionsNumPartsInRange::new(0, 0, 0)
333                .collect::<Vec<_>>()
334                .len(),
335            1
336        );
337        assert_eq!(
338            LexographicPartitionsNumPartsInRange::new(0, 1, 1)
339                .collect::<Vec<_>>()
340                .len(),
341            0
342        );
343        assert_eq!(
344            LexographicPartitionsNumPartsInRange::new(0, 2, 2)
345                .collect::<Vec<_>>()
346                .len(),
347            0
348        );
349        assert_eq!(
350            LexographicPartitionsNumPartsInRange::new(0, 3, 3)
351                .collect::<Vec<_>>()
352                .len(),
353            0
354        );
355
356        assert_eq!(
357            LexographicPartitionsNumPartsInRange::new(1, 0, 0)
358                .collect::<Vec<_>>()
359                .len(),
360            0
361        );
362        assert_eq!(
363            LexographicPartitionsNumPartsInRange::new(1, 1, 1)
364                .collect::<Vec<_>>()
365                .len(),
366            1
367        );
368        assert_eq!(
369            LexographicPartitionsNumPartsInRange::new(1, 2, 2)
370                .collect::<Vec<_>>()
371                .len(),
372            0
373        );
374        assert_eq!(
375            LexographicPartitionsNumPartsInRange::new(1, 3, 3)
376                .collect::<Vec<_>>()
377                .len(),
378            0
379        );
380
381        assert_eq!(
382            LexographicPartitionsNumPartsInRange::new(2, 0, 0)
383                .collect::<Vec<_>>()
384                .len(),
385            0
386        );
387        assert_eq!(
388            LexographicPartitionsNumPartsInRange::new(2, 1, 1)
389                .collect::<Vec<_>>()
390                .len(),
391            1
392        );
393        assert_eq!(
394            LexographicPartitionsNumPartsInRange::new(2, 2, 2)
395                .collect::<Vec<_>>()
396                .len(),
397            1
398        );
399        assert_eq!(
400            LexographicPartitionsNumPartsInRange::new(2, 3, 3)
401                .collect::<Vec<_>>()
402                .len(),
403            0
404        );
405
406        assert_eq!(
407            LexographicPartitionsNumPartsInRange::new(3, 0, 0)
408                .collect::<Vec<_>>()
409                .len(),
410            0
411        );
412        assert_eq!(
413            LexographicPartitionsNumPartsInRange::new(3, 1, 1)
414                .collect::<Vec<_>>()
415                .len(),
416            1
417        );
418        assert_eq!(
419            LexographicPartitionsNumPartsInRange::new(3, 2, 2)
420                .collect::<Vec<_>>()
421                .len(),
422            3
423        );
424        assert_eq!(
425            LexographicPartitionsNumPartsInRange::new(3, 3, 3)
426                .collect::<Vec<_>>()
427                .len(),
428            1
429        );
430
431        assert_eq!(
432            LexographicPartitionsNumPartsInRange::new(4, 5, 3)
433                .collect::<Vec<_>>()
434                .len(),
435            0
436        );
437    }
438}