logo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
use super::SizeValue;
use core::num::NonZeroU64;

/// Helper type for alignment calculations
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AlignmentValue(NonZeroU64);

impl AlignmentValue {
    pub const fn new(val: u64) -> Self {
        if !val.is_power_of_two() {
            panic!("Alignment must be a power of 2!");
        }
        // SAFETY: This is safe since 0 is not a power of 2
        Self(unsafe { NonZeroU64::new_unchecked(val) })
    }

    /// Returns an alignment that is the smallest power of two greater than the passed in `size`
    pub const fn from_next_power_of_two_size(size: SizeValue) -> Self {
        match size.get().checked_next_power_of_two() {
            None => panic!("Overflow occured while getting the next power of 2!"),
            Some(val) => {
                // SAFETY: This is safe since we got the next_power_of_two
                Self(unsafe { NonZeroU64::new_unchecked(val) })
            }
        }
    }

    pub const fn get(&self) -> u64 {
        self.0.get()
    }

    /// Returns the max alignment from an array of alignments
    pub const fn max<const N: usize>(input: [AlignmentValue; N]) -> AlignmentValue {
        let mut max = input[0];
        let mut i = 1;

        while i < N {
            if input[i].get() > max.get() {
                max = input[i];
            }

            i += 1;
        }

        max
    }

    /// Returns true if `n` is a multiple of this alignment
    pub const fn is_aligned(&self, n: u64) -> bool {
        n % self.get() == 0
    }

    /// Returns the amount of padding needed so that `n + padding` will be a multiple of this alignment
    pub const fn padding_needed_for(&self, n: u64) -> u64 {
        let r = n % self.get();
        if r > 0 {
            self.get() - r
        } else {
            0
        }
    }

    /// Will round up the given `n` so that the returned value will be a multiple of this alignment
    pub const fn round_up(&self, n: u64) -> u64 {
        n + self.padding_needed_for(n)
    }

    /// Will round up the given `n` so that the returned value will be a multiple of this alignment
    pub const fn round_up_size(&self, n: SizeValue) -> SizeValue {
        SizeValue::new(self.round_up(n.get()))
    }
}

#[cfg(test)]
mod test {
    use super::AlignmentValue;

    #[test]
    fn new() {
        assert_eq!(4, AlignmentValue::new(4).get());
    }

    #[test]
    #[should_panic]
    fn new_panic() {
        AlignmentValue::new(3);
    }

    #[test]
    fn from_next_power_of_two_size() {
        assert_eq!(
            AlignmentValue::new(8),
            AlignmentValue::from_next_power_of_two_size(super::SizeValue::new(7))
        );
    }

    #[test]
    #[should_panic]
    fn from_next_power_of_two_size_panic() {
        AlignmentValue::from_next_power_of_two_size(super::SizeValue::new(u64::MAX));
    }

    #[test]
    fn max() {
        assert_eq!(
            AlignmentValue::new(32),
            AlignmentValue::max([
                AlignmentValue::new(2),
                AlignmentValue::new(8),
                AlignmentValue::new(32)
            ])
        );
    }

    #[test]
    fn is_aligned() {
        assert!(AlignmentValue::new(8).is_aligned(32));
        assert!(!AlignmentValue::new(8).is_aligned(9));
    }

    #[test]
    fn padding_needed_for() {
        assert_eq!(1, AlignmentValue::new(8).padding_needed_for(7));
        assert_eq!(16 - 9, AlignmentValue::new(8).padding_needed_for(9));
    }

    #[test]
    fn round_up() {
        assert_eq!(24, AlignmentValue::new(8).round_up(20));
        assert_eq!(
            super::SizeValue::new(16),
            AlignmentValue::new(16).round_up_size(super::SizeValue::new(7))
        );
    }

    #[test]
    fn derived_traits() {
        let alignment = AlignmentValue::new(8);
        #[allow(clippy::clone_on_copy)]
        let alignment_clone = alignment.clone();

        assert!(alignment == alignment_clone);

        assert_eq!(format!("{:?}", alignment), "AlignmentValue(8)");
    }
}