1use core::cmp::max;
2use serde::{Deserialize, Serialize};
3use std::{
4 fmt::Display,
5 hash::Hash,
6 ops::{BitAnd, BitOr},
7};
8
9#[derive(Serialize, Deserialize, Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
13pub struct PartitionID(u32);
14
15impl PartitionID {
16 pub fn root() -> PartitionID {
18 PartitionID(1)
19 }
20
21 pub fn parent(&self) -> PartitionID {
24 let new_id = max(1, self.0 >> 1);
25 PartitionID::new(new_id)
26 }
27
28 pub fn parent_at_level(&self, level: u32) -> PartitionID {
29 let parent = self.0 & (0xffff_ffff ^ ((1 << level) - 1));
30 PartitionID::new(parent)
31 }
32
33 pub fn children(&self) -> (PartitionID, PartitionID) {
35 let temp = self.0 << 1;
36 (PartitionID(temp), PartitionID(temp + 1))
37 }
38
39 pub fn left_child(&self) -> PartitionID {
41 let temp = self.0 << 1;
42 PartitionID(temp)
43 }
44
45 pub fn right_child(&self) -> PartitionID {
47 let temp = self.0 << 1;
48 PartitionID(temp + 1)
49 }
50
51 pub fn inplace_leftmost_descendant(&mut self, k: usize) {
53 self.0 <<= k;
54 }
55
56 pub fn inplace_rightmost_descendant(&mut self, k: usize) {
58 self.inplace_leftmost_descendant(k);
59 self.0 += (1 << k) - 1;
60 }
61
62 pub fn inplace_left_child(&mut self) {
64 self.inplace_leftmost_descendant(1);
65 }
66
67 pub fn inplace_right_child(&mut self) {
69 self.inplace_rightmost_descendant(1);
70 }
71
72 pub fn new(id: u32) -> Self {
74 debug_assert!(id != 0);
76 PartitionID(id)
77 }
78
79 pub fn level(&self) -> u8 {
81 (31 - self.0.leading_zeros()).try_into().unwrap()
83 }
84
85 pub fn is_left_child(&self) -> bool {
87 self.0 % 2 == 0
88 }
89
90 pub fn is_right_child(&self) -> bool {
92 self.0 % 2 == 1
93 }
94
95 pub fn lowest_common_ancestor(&self, other: &PartitionID) -> PartitionID {
97 let mut left = *self;
98 let mut right = *other;
99
100 let left_level = left.level();
101 let right_level = right.level();
102
103 if left_level > right_level {
104 left.0 >>= left_level - right_level;
105 }
106 if right_level > left_level {
107 right.0 >>= right_level - left_level;
108 }
109
110 while left != right {
111 left = left.parent();
112 right = right.parent();
113 }
114 left
115 }
116
117 pub fn extract_bit(&self, index: usize) -> bool {
118 let mask = 1 << index;
119 mask & self.0 > 0
120 }
121}
122
123impl Display for PartitionID {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 write!(f, "{}", self.0)
126 }
127}
128
129impl From<PartitionID> for usize {
130 fn from(s: PartitionID) -> usize {
131 s.0.try_into().unwrap()
132 }
133}
134
135impl core::fmt::Binary for PartitionID {
136 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
137 let val = self.0;
138 core::fmt::Binary::fmt(&val, f) }
140}
141
142impl BitAnd for PartitionID {
143 type Output = Self;
144
145 fn bitand(self, rhs: Self) -> Self::Output {
146 Self(self.0 & rhs.0)
147 }
148}
149
150impl BitOr for PartitionID {
151 type Output = Self;
152
153 fn bitor(self, rhs: Self) -> Self::Output {
154 Self(self.0 | rhs.0)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160
161 use crate::partition::PartitionID;
162
163 #[test]
164 fn parent_id() {
165 let id = PartitionID::new(4);
166 assert_eq!(id.parent(), PartitionID::new(2));
167 }
168
169 #[test]
170 fn new_id() {
171 let id = PartitionID::new(1);
172 assert_eq!(id.parent(), PartitionID::root());
173 }
174
175 #[test]
176 fn children_ids() {
177 let id = PartitionID::new(0b0101_0101_0101_0101u32);
178 assert_eq!(id.level(), 14);
179 let (child0, child1) = id.children();
180 assert_eq!(child0, PartitionID::new(0b1010_1010_1010_1010u32));
181 assert_eq!(child1, PartitionID::new(0b1010_1010_1010_1011u32));
182 }
183
184 #[test]
185 fn level() {
186 let root = PartitionID::root();
187 assert_eq!(root.level(), 0);
188 let (child0, child1) = root.children();
189
190 assert_eq!(child0.level(), 1);
191 assert_eq!(child1.level(), 1);
192 }
193
194 #[test]
195 fn root_parent() {
196 let root = PartitionID::root();
197 let roots_parent = root.parent();
198 assert_eq!(root, roots_parent);
199 }
200
201 #[test]
202 fn left_right_childs() {
203 let id = PartitionID(12345);
204 let (left_child, right_child) = id.children();
205 assert_eq!(left_child, id.left_child());
206 assert_eq!(right_child, id.right_child());
207 }
208
209 #[test]
210 fn is_left_right_child() {
211 let id = PartitionID(12345);
212 let (left_child, right_child) = id.children();
213 assert_eq!(left_child, id.left_child());
214 assert_eq!(right_child, id.right_child());
215 assert!(left_child.is_left_child());
216 assert!(right_child.is_right_child());
217 }
218
219 #[test]
220 fn inplace_left_child() {
221 let mut id = PartitionID(12345);
222 let (left_child, _) = id.children();
223 id.inplace_left_child();
224 assert_eq!(left_child, id);
225 }
226
227 #[test]
228 fn inplace_right_child() {
229 let mut id = PartitionID(12345);
230 let (_, right_child) = id.children();
231 id.inplace_right_child();
232 assert_eq!(right_child, id);
233 }
234
235 #[test]
236 fn into_usize() {
237 let id = PartitionID(12345);
238 let id_usize = usize::from(id);
239 assert_eq!(12345, id_usize);
240 }
241
242 #[test]
243 fn inplace_leftmost_descendant() {
244 let id = PartitionID(1);
245 let mut current = id;
246 for i in 1..30 {
247 let mut id = id;
248 id.inplace_leftmost_descendant(i);
249 assert_eq!(current.left_child(), id);
250 current = current.left_child();
251 }
252 }
253
254 #[test]
255 fn inplace_rightmost_descendant() {
256 let id = PartitionID(1);
257 let mut current = id;
258 for i in 1..30 {
259 let mut id = id;
260 id.inplace_rightmost_descendant(i);
261 assert_eq!(current.right_child(), id);
262 current = current.right_child();
263 }
264 }
265
266 #[test]
267 fn display() {
268 for i in 0..100 {
269 let id = PartitionID(i);
270 let string = format!("{id}");
271 let recast_id = PartitionID(string.parse::<u32>().unwrap());
272 assert_eq!(id, recast_id);
273 }
274 }
275
276 #[test]
277 fn partial_eq() {
278 for i in 0..100 {
279 let id = PartitionID(i);
280 let string = format!("{id}");
281 let recast_id = PartitionID(string.parse::<u32>().unwrap());
282 assert_eq!(id, recast_id);
283 }
284 }
285
286 #[test]
287 fn parent_at_level() {
288 let id = PartitionID::new(0xffff_ffff);
289 let levels = vec![0, 3, 9, 15, 20];
290 let results = vec![
291 PartitionID::new(0b11111111111111111111111111111111),
292 PartitionID::new(0b11111111111111111111111111111000),
293 PartitionID::new(0b11111111111111111111111000000000),
294 PartitionID::new(0b11111111111111111000000000000000),
295 PartitionID::new(0b11111111111100000000000000000000),
296 ];
297 levels
298 .iter()
299 .zip(results.iter())
300 .for_each(|(level, expected)| {
301 assert_eq!(id.parent_at_level(*level), *expected);
302 });
303 }
304
305 #[test]
306 fn binary_trait() {
307 let id = PartitionID::new(0xffff_ffff);
308 let levels = vec![0, 3, 9, 15, 20];
309 let results = vec![
310 "0b11111111111111111111111111111111",
311 "0b11111111111111111111111111111000",
312 "0b11111111111111111111111000000000",
313 "0b11111111111111111000000000000000",
314 "0b11111111111100000000000000000000",
315 ];
316 levels
317 .iter()
318 .zip(results.iter())
319 .for_each(|(level, expected)| {
320 assert_eq!(format!("{:#032b}", id.parent_at_level(*level)), *expected);
321 });
322 }
323
324 #[test]
325 fn lowest_common_ancestor() {
326 let a = PartitionID(0b1000);
327 let b = PartitionID(0b1001);
328 assert_eq!(a.lowest_common_ancestor(&b), b.lowest_common_ancestor(&a));
329
330 let expected = PartitionID(0b100);
331 assert_eq!(a.lowest_common_ancestor(&b), expected);
332
333 let a = PartitionID(0b1001);
334 let b = PartitionID(0b1111);
335 assert_eq!(a.lowest_common_ancestor(&b), b.lowest_common_ancestor(&a));
336
337 assert_eq!(a.lowest_common_ancestor(&b), PartitionID::root());
338 }
339
340 #[test]
341 fn bitand() {
342 let a = PartitionID(0b1000);
343 let b = PartitionID(0b1001);
344 assert_eq!(PartitionID(0b1000), a & b);
345 }
346
347 #[test]
348 fn bitor() {
349 let a = PartitionID(0b1000);
350 let b = PartitionID(0b1001);
351 assert_eq!(PartitionID(0b1001), a | b);
352 }
353
354 #[test]
355 fn extract_bit() {
356 let a = PartitionID(0b1001);
357 assert!(a.extract_bit(0));
358 assert!(!a.extract_bit(1));
359 assert!(!a.extract_bit(2));
360 assert!(a.extract_bit(3));
361 assert!(!a.extract_bit(4));
362
363 let a = PartitionID(0b100000000100000001000);
364 assert!(!a.extract_bit(0));
366 assert!(a.extract_bit(3));
367 assert!(!a.extract_bit(7));
368 assert!(a.extract_bit(11));
369 assert!(!a.extract_bit(15));
370 }
371}