Skip to main content

diskann_quantization/
num.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Number types with limited dynamic range.
7
8use std::{fmt::Debug, num::NonZeroUsize};
9
10use thiserror::Error;
11
12/// A number type that must be greater than zero.
13#[derive(Debug, Clone, Copy, PartialEq)]
14#[repr(transparent)]
15pub struct Positive<T>(T)
16where
17    T: PartialOrd + Default + Debug;
18
19#[derive(Debug, Clone, Copy, Error)]
20#[error("value {:?} is not greater than {:?} (its default value)", .0, T::default())]
21pub struct NotPositiveError<T: Debug + Default>(T);
22
23impl<T> Positive<T>
24where
25    T: PartialOrd + Default + Debug,
26{
27    /// Create a new `Positive` if the given value is greater than 0 (`T::default()`);
28    pub fn new(value: T) -> Result<Self, NotPositiveError<T>> {
29        if value > T::default() {
30            Ok(Self(value))
31        } else {
32            Err(NotPositiveError(value))
33        }
34    }
35
36    /// Create a new `Positive` without checking whether the value is greater than 0.
37    ///
38    /// # Safety
39    ///
40    /// The value must be greater than `T::default()`.
41    pub const unsafe fn new_unchecked(value: T) -> Self {
42        Self(value)
43    }
44
45    /// Consume `self` and return the inner value.
46    pub fn into_inner(self) -> T {
47        self.0
48    }
49}
50
51// SAFETY: 1.0 is positive.
52pub(crate) const POSITIVE_ONE_F32: Positive<f32> = unsafe { Positive::new_unchecked(1.0) };
53
54#[derive(Debug, Clone, Copy, PartialEq)]
55#[repr(transparent)]
56pub struct PowerOfTwo(NonZeroUsize);
57
58#[derive(Debug, Clone, Copy, Error)]
59#[error("value {0} must be a power of two")]
60#[non_exhaustive]
61pub struct NotPowerOfTwo(usize);
62
63impl PowerOfTwo {
64    /// Create a new `PowerOfTwo` if the given value is greater a power of two.
65    pub const fn new(value: usize) -> Result<Self, NotPowerOfTwo> {
66        let v = match NonZeroUsize::new(value) {
67            Some(value) => value,
68            None => return Err(NotPowerOfTwo(value)),
69        };
70        if v.is_power_of_two() {
71            // Safety: We just checked.
72            Ok(unsafe { Self::new_unchecked(v) })
73        } else {
74            Err(NotPowerOfTwo(value))
75        }
76    }
77
78    /// Return the smallest power of two greater than or equal to `value`. If the next
79    /// power of two is greater than `usize::MAX`, `None` is returned.
80    pub const fn next(value: usize) -> Option<Self> {
81        // Note: use `match` instead of `Option::map` for `const`-compatibility.
82        match value.checked_next_power_of_two() {
83            // SAFETY: We trust the implementation of `usize::checked_next_power_of_two` since:
84            //
85            // * 0 can never be a power of two and thus cannot be returned.
86            // * If it succeeds, the returned value should be a power of two.
87            Some(v) => Some(unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(v)) }),
88            None => None,
89        }
90    }
91
92    /// Create a new `PowerOfTwo` without checking whether the value is a power of two.
93    ///
94    /// # Safety
95    ///
96    /// The value must be a power of two.
97    pub const unsafe fn new_unchecked(value: NonZeroUsize) -> Self {
98        Self(value)
99    }
100
101    /// Consume `self` and return the inner value.
102    pub const fn into_inner(self) -> NonZeroUsize {
103        self.0
104    }
105
106    /// Consume `self` and return the inner value as a `usize`.
107    pub const fn raw(self) -> usize {
108        self.0.get()
109    }
110
111    /// Construct `self` from the alignment in `layout`.
112    pub const fn from_align(layout: &std::alloc::Layout) -> Self {
113        // SAFETY: Alignment is guaranteed to be a power of two:
114        // - <https://doc.rust-lang.org/beta/std/alloc/struct.Layout.html#method.align>
115        unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(layout.align())) }
116    }
117
118    /// Return the alignment of `T` as a power of two.
119    pub const fn alignment_of<T>() -> Self {
120        // SAFETY: Alignment is guaranteed to be a power of two:
121        // - <https://doc.rust-lang.org/beta/std/alloc/struct.Layout.html#method.align>
122        unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(std::mem::align_of::<T>())) }
123    }
124
125    /// Compute the operation `lhs % self`.
126    ///
127    /// # Note
128    ///
129    /// The argument order of this function is reversed from the typical `align_offset`
130    /// method in the standard library.
131    pub const fn arg_mod(self, lhs: usize) -> usize {
132        lhs & (self.raw() - 1)
133    }
134
135    /// Compute the amount `x` that would have to be added to `lhs` so the quantity
136    /// `lhs + x` is a multiple of `self`.
137    ///
138    /// # Note
139    ///
140    /// The argument order of this function is reversed from the typical `align_offset`
141    /// method in the standard library.
142    pub const fn arg_align_offset(self, lhs: usize) -> usize {
143        let m = self.arg_mod(lhs);
144        if m == 0 {
145            0
146        } else {
147            self.raw() - m
148        }
149    }
150
151    /// Calculate the smallest value greater than or equal to `lhs` that is a multiple of
152    /// `self`. Return `None` if the operation would result in an overflow.
153    pub const fn arg_checked_next_multiple_of(self, lhs: usize) -> Option<usize> {
154        let offset = self.arg_align_offset(lhs);
155        lhs.checked_add(offset)
156    }
157}
158
159impl From<PowerOfTwo> for usize {
160    #[inline(always)]
161    fn from(value: PowerOfTwo) -> Self {
162        value.raw()
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    fn format_not_positive_error<T>(value: T) -> String
171    where
172        T: Debug + Default,
173    {
174        format!(
175            "value {:?} is not greater than {:?} (its default value)",
176            value,
177            T::default(),
178        )
179    }
180
181    #[test]
182    fn test_positive_f32() {
183        let x = Positive::<f32>::new(1.0);
184        assert!(x.is_ok());
185        let x = x.unwrap();
186        assert_eq!(x.into_inner(), 1.0);
187
188        // Using 0 should return an error.
189        let x = Positive::<f32>::new(0.0);
190        assert!(x.is_err());
191        assert_eq!(
192            x.unwrap_err().to_string(),
193            format_not_positive_error::<f32>(0.0)
194        );
195
196        // Using -1 should return an error.
197        let x = Positive::<f32>::new(-1.0);
198        assert!(x.is_err());
199        assert_eq!(
200            x.unwrap_err().to_string(),
201            format_not_positive_error::<f32>(-1.0)
202        );
203
204        // SAFETY: 1.0 is greater than zero.
205        let x = unsafe { Positive::<f32>::new_unchecked(1.0) };
206        assert_eq!(x.into_inner(), 1.0);
207    }
208
209    #[test]
210    fn test_positive_i64() {
211        let x = Positive::<i64>::new(1);
212        assert!(x.is_ok());
213        let x = x.unwrap();
214        assert_eq!(x.into_inner(), 1);
215
216        // Using 0 should return an error.
217        let x = Positive::<i64>::new(0);
218        assert!(x.is_err());
219        assert_eq!(
220            x.unwrap_err().to_string(),
221            format_not_positive_error::<i64>(0)
222        );
223
224        // Using -1 should return an error.
225        let x = Positive::<i64>::new(-1);
226        assert!(x.is_err());
227        assert_eq!(
228            x.unwrap_err().to_string(),
229            format_not_positive_error::<i64>(-1)
230        );
231
232        // SAFETY: 1 is greater than zero.
233        let x = unsafe { Positive::<i64>::new_unchecked(1) };
234        assert_eq!(x.into_inner(), 1);
235    }
236
237    #[test]
238    fn test_power_of_two() {
239        assert!(PowerOfTwo::new(0).is_err());
240        assert_eq!(PowerOfTwo::next(0).unwrap(), PowerOfTwo::new(1).unwrap());
241        for i in 0..63 {
242            let base = 2usize.pow(i);
243            let p = PowerOfTwo::new(base).unwrap();
244            assert_eq!(p.into_inner().get(), base);
245            assert_eq!(p.raw(), base);
246            assert_eq!(PowerOfTwo::new(base).unwrap().raw(), base);
247            assert_eq!(<_ as Into<usize>>::into(p), base);
248
249            if i != 1 {
250                assert!(PowerOfTwo::new(base - 1).is_err(), "failed for i = {}", i);
251                assert_eq!(PowerOfTwo::next(base - 1).unwrap().raw(), base);
252            }
253
254            if i != 0 {
255                assert!(PowerOfTwo::new(base + 1).is_err(), "failed for i = {}", i);
256            }
257
258            assert_eq!(p.arg_mod(0), 0);
259            assert_eq!(p.arg_mod(p.raw()), 0);
260
261            assert_eq!(p.arg_align_offset(0), 0);
262            assert_eq!(p.arg_align_offset(base), 0);
263
264            assert_eq!(p.arg_checked_next_multiple_of(0), Some(0));
265            assert_eq!(p.arg_checked_next_multiple_of(base), Some(base));
266
267            assert_eq!(p.arg_checked_next_multiple_of(1), Some(base));
268
269            if i > 1 {
270                assert_eq!(p.arg_mod(base + 1), 1);
271                assert_eq!(p.arg_mod(2 * base - 1), base - 1);
272
273                assert_eq!(p.arg_align_offset(base + 1), base - 1);
274                assert_eq!(p.arg_align_offset(2 * base - 1), 1);
275
276                assert_eq!(p.arg_checked_next_multiple_of(base + 1), Some(2 * base));
277                assert_eq!(p.arg_checked_next_multiple_of(2 * base - 1), Some(2 * base));
278            }
279        }
280
281        assert!(PowerOfTwo::next(2usize.pow(63) + 1).is_none());
282        assert!(PowerOfTwo::next(usize::MAX).is_none());
283    }
284}