toolbox_rs/
partition.rs

1use core::cmp::max;
2use serde::{Deserialize, Serialize};
3use std::{
4    fmt::Display,
5    hash::Hash,
6    ops::{BitAnd, BitOr},
7};
8
9/// represents the hiearchical partition id scheme. The root id has ID 1 and
10/// children are shifted to the left by one and plus 0/1. The parent child
11/// relationship can thus be queried in constant time.
12#[derive(Serialize, Deserialize, Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
13pub struct PartitionID(u32);
14
15impl PartitionID {
16    /// Returns the root id
17    pub fn root() -> PartitionID {
18        PartitionID(1)
19    }
20
21    /// Returns the parent of a given ID.
22    /// Note that the parent of the root id is always 1
23    pub fn parent(&self) -> PartitionID {
24        let new_id = max(1, self.0 >> 1);
25        PartitionID::new(new_id)
26    }
27
28    pub fn parent_at_level(&self, level: u32) -> PartitionID {
29        let parent = self.0 & (0xffff_ffff ^ ((1 << level) - 1));
30        PartitionID::new(parent)
31    }
32
33    /// Returns a left-right ordered tuple of children for a given ID
34    pub fn children(&self) -> (PartitionID, PartitionID) {
35        let temp = self.0 << 1;
36        (PartitionID(temp), PartitionID(temp + 1))
37    }
38
39    /// Returns the left child of a ID
40    pub fn left_child(&self) -> PartitionID {
41        let temp = self.0 << 1;
42        PartitionID(temp)
43    }
44
45    /// Returns the right child of a ID
46    pub fn right_child(&self) -> PartitionID {
47        let temp = self.0 << 1;
48        PartitionID(temp + 1)
49    }
50
51    /// Transform ID to its left-most descendant k levels down
52    pub fn inplace_leftmost_descendant(&mut self, k: usize) {
53        self.0 <<= k;
54    }
55
56    /// Transform ID to its right-most descendant k levels down
57    pub fn inplace_rightmost_descendant(&mut self, k: usize) {
58        self.inplace_leftmost_descendant(k);
59        self.0 += (1 << k) - 1;
60    }
61
62    /// Transform the ID into its left child
63    pub fn inplace_left_child(&mut self) {
64        self.inplace_leftmost_descendant(1);
65    }
66
67    /// Transform the ID into its right child
68    pub fn inplace_right_child(&mut self) {
69        self.inplace_rightmost_descendant(1);
70    }
71
72    /// Returns a new PartitionID from an u32
73    pub fn new(id: u32) -> Self {
74        // the id scheme is designed in a way that the number of leading zeros is always odd
75        debug_assert!(id != 0);
76        PartitionID(id)
77    }
78
79    /// The level in this scheme is defined by the the number of leading zeroes.
80    pub fn level(&self) -> u8 {
81        // magic number 31 := 32 - 1, as 1 is the root's ID
82        (31 - self.0.leading_zeros()).try_into().unwrap()
83    }
84
85    /// Returns whether the ID id a left child
86    pub fn is_left_child(&self) -> bool {
87        self.0 % 2 == 0
88    }
89
90    /// Returns whether the ID id a right child
91    pub fn is_right_child(&self) -> bool {
92        self.0 % 2 == 1
93    }
94
95    // Returns the lowest common ancestor of this and the other ID
96    pub fn lowest_common_ancestor(&self, other: &PartitionID) -> PartitionID {
97        let mut left = *self;
98        let mut right = *other;
99
100        let left_level = left.level();
101        let right_level = right.level();
102
103        if left_level > right_level {
104            left.0 >>= left_level - right_level;
105        }
106        if right_level > left_level {
107            right.0 >>= right_level - left_level;
108        }
109
110        while left != right {
111            left = left.parent();
112            right = right.parent();
113        }
114        left
115    }
116
117    pub fn extract_bit(&self, index: usize) -> bool {
118        let mask = 1 << index;
119        mask & self.0 > 0
120    }
121}
122
123impl Display for PartitionID {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        write!(f, "{}", self.0)
126    }
127}
128
129impl From<PartitionID> for usize {
130    fn from(s: PartitionID) -> usize {
131        s.0.try_into().unwrap()
132    }
133}
134
135impl core::fmt::Binary for PartitionID {
136    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
137        let val = self.0;
138        core::fmt::Binary::fmt(&val, f) // delegate to u32's implementation
139    }
140}
141
142impl BitAnd for PartitionID {
143    type Output = Self;
144
145    fn bitand(self, rhs: Self) -> Self::Output {
146        Self(self.0 & rhs.0)
147    }
148}
149
150impl BitOr for PartitionID {
151    type Output = Self;
152
153    fn bitor(self, rhs: Self) -> Self::Output {
154        Self(self.0 | rhs.0)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160
161    use crate::partition::PartitionID;
162
163    #[test]
164    fn parent_id() {
165        let id = PartitionID::new(4);
166        assert_eq!(id.parent(), PartitionID::new(2));
167    }
168
169    #[test]
170    fn new_id() {
171        let id = PartitionID::new(1);
172        assert_eq!(id.parent(), PartitionID::root());
173    }
174
175    #[test]
176    fn children_ids() {
177        let id = PartitionID::new(0b0101_0101_0101_0101u32);
178        assert_eq!(id.level(), 14);
179        let (child0, child1) = id.children();
180        assert_eq!(child0, PartitionID::new(0b1010_1010_1010_1010u32));
181        assert_eq!(child1, PartitionID::new(0b1010_1010_1010_1011u32));
182    }
183
184    #[test]
185    fn level() {
186        let root = PartitionID::root();
187        assert_eq!(root.level(), 0);
188        let (child0, child1) = root.children();
189
190        assert_eq!(child0.level(), 1);
191        assert_eq!(child1.level(), 1);
192    }
193
194    #[test]
195    fn root_parent() {
196        let root = PartitionID::root();
197        let roots_parent = root.parent();
198        assert_eq!(root, roots_parent);
199    }
200
201    #[test]
202    fn left_right_childs() {
203        let id = PartitionID(12345);
204        let (left_child, right_child) = id.children();
205        assert_eq!(left_child, id.left_child());
206        assert_eq!(right_child, id.right_child());
207    }
208
209    #[test]
210    fn is_left_right_child() {
211        let id = PartitionID(12345);
212        let (left_child, right_child) = id.children();
213        assert_eq!(left_child, id.left_child());
214        assert_eq!(right_child, id.right_child());
215        assert!(left_child.is_left_child());
216        assert!(right_child.is_right_child());
217    }
218
219    #[test]
220    fn inplace_left_child() {
221        let mut id = PartitionID(12345);
222        let (left_child, _) = id.children();
223        id.inplace_left_child();
224        assert_eq!(left_child, id);
225    }
226
227    #[test]
228    fn inplace_right_child() {
229        let mut id = PartitionID(12345);
230        let (_, right_child) = id.children();
231        id.inplace_right_child();
232        assert_eq!(right_child, id);
233    }
234
235    #[test]
236    fn into_usize() {
237        let id = PartitionID(12345);
238        let id_usize = usize::from(id);
239        assert_eq!(12345, id_usize);
240    }
241
242    #[test]
243    fn inplace_leftmost_descendant() {
244        let id = PartitionID(1);
245        let mut current = id;
246        for i in 1..30 {
247            let mut id = id;
248            id.inplace_leftmost_descendant(i);
249            assert_eq!(current.left_child(), id);
250            current = current.left_child();
251        }
252    }
253
254    #[test]
255    fn inplace_rightmost_descendant() {
256        let id = PartitionID(1);
257        let mut current = id;
258        for i in 1..30 {
259            let mut id = id;
260            id.inplace_rightmost_descendant(i);
261            assert_eq!(current.right_child(), id);
262            current = current.right_child();
263        }
264    }
265
266    #[test]
267    fn display() {
268        for i in 0..100 {
269            let id = PartitionID(i);
270            let string = format!("{id}");
271            let recast_id = PartitionID(string.parse::<u32>().unwrap());
272            assert_eq!(id, recast_id);
273        }
274    }
275
276    #[test]
277    fn partial_eq() {
278        for i in 0..100 {
279            let id = PartitionID(i);
280            let string = format!("{id}");
281            let recast_id = PartitionID(string.parse::<u32>().unwrap());
282            assert_eq!(id, recast_id);
283        }
284    }
285
286    #[test]
287    fn parent_at_level() {
288        let id = PartitionID::new(0xffff_ffff);
289        let levels = vec![0, 3, 9, 15, 20];
290        let results = vec![
291            PartitionID::new(0b11111111111111111111111111111111),
292            PartitionID::new(0b11111111111111111111111111111000),
293            PartitionID::new(0b11111111111111111111111000000000),
294            PartitionID::new(0b11111111111111111000000000000000),
295            PartitionID::new(0b11111111111100000000000000000000),
296        ];
297        levels
298            .iter()
299            .zip(results.iter())
300            .for_each(|(level, expected)| {
301                assert_eq!(id.parent_at_level(*level), *expected);
302            });
303    }
304
305    #[test]
306    fn binary_trait() {
307        let id = PartitionID::new(0xffff_ffff);
308        let levels = vec![0, 3, 9, 15, 20];
309        let results = vec![
310            "0b11111111111111111111111111111111",
311            "0b11111111111111111111111111111000",
312            "0b11111111111111111111111000000000",
313            "0b11111111111111111000000000000000",
314            "0b11111111111100000000000000000000",
315        ];
316        levels
317            .iter()
318            .zip(results.iter())
319            .for_each(|(level, expected)| {
320                assert_eq!(format!("{:#032b}", id.parent_at_level(*level)), *expected);
321            });
322    }
323
324    #[test]
325    fn lowest_common_ancestor() {
326        let a = PartitionID(0b1000);
327        let b = PartitionID(0b1001);
328        assert_eq!(a.lowest_common_ancestor(&b), b.lowest_common_ancestor(&a));
329
330        let expected = PartitionID(0b100);
331        assert_eq!(a.lowest_common_ancestor(&b), expected);
332
333        let a = PartitionID(0b1001);
334        let b = PartitionID(0b1111);
335        assert_eq!(a.lowest_common_ancestor(&b), b.lowest_common_ancestor(&a));
336
337        assert_eq!(a.lowest_common_ancestor(&b), PartitionID::root());
338    }
339
340    #[test]
341    fn bitand() {
342        let a = PartitionID(0b1000);
343        let b = PartitionID(0b1001);
344        assert_eq!(PartitionID(0b1000), a & b);
345    }
346
347    #[test]
348    fn bitor() {
349        let a = PartitionID(0b1000);
350        let b = PartitionID(0b1001);
351        assert_eq!(PartitionID(0b1001), a | b);
352    }
353
354    #[test]
355    fn extract_bit() {
356        let a = PartitionID(0b1001);
357        assert!(a.extract_bit(0));
358        assert!(!a.extract_bit(1));
359        assert!(!a.extract_bit(2));
360        assert!(a.extract_bit(3));
361        assert!(!a.extract_bit(4));
362
363        let a = PartitionID(0b100000000100000001000);
364        // [0, 3, 7, 11, 15]
365        assert!(!a.extract_bit(0));
366        assert!(a.extract_bit(3));
367        assert!(!a.extract_bit(7));
368        assert!(a.extract_bit(11));
369        assert!(!a.extract_bit(15));
370    }
371}