Skip to main content

diskann_quantization/
num.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Number types with limited dynamic range.
7
8use std::{fmt::Debug, num::NonZeroUsize};
9
10use thiserror::Error;
11
12/// A number type that must be greater than zero.
13#[derive(Debug, Clone, Copy, PartialEq)]
14#[repr(transparent)]
15pub struct Positive<T>(T)
16where
17    T: PartialOrd + Default + Debug;
18
19#[derive(Debug, Clone, Copy, Error)]
20#[error("value {:?} is not greater than {:?} (its default value)", .0, T::default())]
21pub struct NotPositiveError<T: Debug + Default>(T);
22
23impl<T> Positive<T>
24where
25    T: PartialOrd + Default + Debug,
26{
27    /// Create a new `Positive` if the given value is greater than 0 (`T::default()`);
28    pub fn new(value: T) -> Result<Self, NotPositiveError<T>> {
29        if value > T::default() {
30            Ok(Self(value))
31        } else {
32            Err(NotPositiveError(value))
33        }
34    }
35
36    /// Create a new `Positive` without checking whether the value is greater than 0.
37    ///
38    /// # Safety
39    ///
40    /// The value must be greater than `T::default()`.
41    pub const unsafe fn new_unchecked(value: T) -> Self {
42        Self(value)
43    }
44
45    /// Consume `self` and return the inner value.
46    pub fn into_inner(self) -> T {
47        self.0
48    }
49}
50
51// SAFETY: 1.0 is positive.
52pub(crate) const POSITIVE_ONE_F32: Positive<f32> = unsafe { Positive::new_unchecked(1.0) };
53
54#[derive(Debug, Clone, Copy, PartialEq)]
55#[repr(transparent)]
56pub struct PowerOfTwo(NonZeroUsize);
57
58#[derive(Debug, Clone, Copy, Error)]
59#[error("value {0} must be a power of two")]
60#[non_exhaustive]
61pub struct NotPowerOfTwo(usize);
62
63impl PowerOfTwo {
64    /// Create a new `PowerOfTwo` if the given value is greater a power of two.
65    pub const fn new(value: usize) -> Result<Self, NotPowerOfTwo> {
66        let v = match NonZeroUsize::new(value) {
67            Some(value) => value,
68            None => return Err(NotPowerOfTwo(value)),
69        };
70        if v.is_power_of_two() {
71            // Safety: We just checked.
72            Ok(unsafe { Self::new_unchecked(v) })
73        } else {
74            Err(NotPowerOfTwo(value))
75        }
76    }
77
78    /// Return the smallest power of two greater than or equal to `value`. If the next
79    /// power of two is greater than `usize::MAX`, `None` is returned.
80    pub const fn next(value: usize) -> Option<Self> {
81        // Note: use `match` instead of `Option::map` for `const`-compatibility.
82        match value.checked_next_power_of_two() {
83            // SAFETY: We trust the implementation of `usize::checked_next_power_of_two` since:
84            //
85            // * 0 can never be a power of two and thus cannot be returned.
86            // * If it succeeds, the returned value should be a power of two.
87            Some(v) => Some(unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(v)) }),
88            None => None,
89        }
90    }
91
92    /// Create a new `PowerOfTwo` without checking whether the value is a power of two.
93    ///
94    /// # Safety
95    ///
96    /// The value must be a power of two.
97    pub const unsafe fn new_unchecked(value: NonZeroUsize) -> Self {
98        Self(value)
99    }
100
101    /// Consume `self` and return the inner value.
102    pub const fn into_inner(self) -> NonZeroUsize {
103        self.0
104    }
105
106    /// Consume `self` and return the inner value as a `usize`.
107    pub const fn raw(self) -> usize {
108        self.0.get()
109    }
110
111    /// Construct `self` from the alignment in `layout`.
112    pub const fn from_align(layout: &std::alloc::Layout) -> Self {
113        // SAFETY: Alignment is guaranteed to be a power of two:
114        // - <https://doc.rust-lang.org/beta/std/alloc/struct.Layout.html#method.align>
115        unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(layout.align())) }
116    }
117
118    /// Return the alignment of `T` as a power of two.
119    pub const fn alignment_of<T>() -> Self {
120        // SAFETY: Alignment is guaranteed to be a power of two:
121        // - <https://doc.rust-lang.org/beta/std/alloc/struct.Layout.html#method.align>
122        unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(std::mem::align_of::<T>())) }
123    }
124
125    /// Compute the operation `lhs % self`.
126    ///
127    /// # Note
128    ///
129    /// The argument order of this function is reversed from the typical `align_offset`
130    /// method in the standard library.
131    pub const fn arg_mod(self, lhs: usize) -> usize {
132        lhs & (self.raw() - 1)
133    }
134
135    /// Compute the amount `x` that would have to be added to `lhs` so the quantity
136    /// `lhs + x` is a multiple of `self`.
137    ///
138    /// # Note
139    ///
140    /// The argument order of this function is reversed from the typical `align_offset`
141    /// method in the standard library.
142    pub const fn arg_align_offset(self, lhs: usize) -> usize {
143        let m = self.arg_mod(lhs);
144        if m == 0 { 0 } else { self.raw() - m }
145    }
146
147    /// Calculate the smallest value greater than or equal to `lhs` that is a multiple of
148    /// `self`. Return `None` if the operation would result in an overflow.
149    pub const fn arg_checked_next_multiple_of(self, lhs: usize) -> Option<usize> {
150        let offset = self.arg_align_offset(lhs);
151        lhs.checked_add(offset)
152    }
153}
154
155impl From<PowerOfTwo> for usize {
156    #[inline(always)]
157    fn from(value: PowerOfTwo) -> Self {
158        value.raw()
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    fn format_not_positive_error<T>(value: T) -> String
167    where
168        T: Debug + Default,
169    {
170        format!(
171            "value {:?} is not greater than {:?} (its default value)",
172            value,
173            T::default(),
174        )
175    }
176
177    #[test]
178    fn test_positive_f32() {
179        let x = Positive::<f32>::new(1.0);
180        assert!(x.is_ok());
181        let x = x.unwrap();
182        assert_eq!(x.into_inner(), 1.0);
183
184        // Using 0 should return an error.
185        let x = Positive::<f32>::new(0.0);
186        assert!(x.is_err());
187        assert_eq!(
188            x.unwrap_err().to_string(),
189            format_not_positive_error::<f32>(0.0)
190        );
191
192        // Using -1 should return an error.
193        let x = Positive::<f32>::new(-1.0);
194        assert!(x.is_err());
195        assert_eq!(
196            x.unwrap_err().to_string(),
197            format_not_positive_error::<f32>(-1.0)
198        );
199
200        // SAFETY: 1.0 is greater than zero.
201        let x = unsafe { Positive::<f32>::new_unchecked(1.0) };
202        assert_eq!(x.into_inner(), 1.0);
203    }
204
205    #[test]
206    fn test_positive_i64() {
207        let x = Positive::<i64>::new(1);
208        assert!(x.is_ok());
209        let x = x.unwrap();
210        assert_eq!(x.into_inner(), 1);
211
212        // Using 0 should return an error.
213        let x = Positive::<i64>::new(0);
214        assert!(x.is_err());
215        assert_eq!(
216            x.unwrap_err().to_string(),
217            format_not_positive_error::<i64>(0)
218        );
219
220        // Using -1 should return an error.
221        let x = Positive::<i64>::new(-1);
222        assert!(x.is_err());
223        assert_eq!(
224            x.unwrap_err().to_string(),
225            format_not_positive_error::<i64>(-1)
226        );
227
228        // SAFETY: 1 is greater than zero.
229        let x = unsafe { Positive::<i64>::new_unchecked(1) };
230        assert_eq!(x.into_inner(), 1);
231    }
232
233    #[test]
234    fn test_power_of_two() {
235        assert!(PowerOfTwo::new(0).is_err());
236        assert_eq!(PowerOfTwo::next(0).unwrap(), PowerOfTwo::new(1).unwrap());
237        for i in 0..63 {
238            let base = 2usize.pow(i);
239            let p = PowerOfTwo::new(base).unwrap();
240            assert_eq!(p.into_inner().get(), base);
241            assert_eq!(p.raw(), base);
242            assert_eq!(PowerOfTwo::new(base).unwrap().raw(), base);
243            assert_eq!(<_ as Into<usize>>::into(p), base);
244
245            if i != 1 {
246                assert!(PowerOfTwo::new(base - 1).is_err(), "failed for i = {}", i);
247                assert_eq!(PowerOfTwo::next(base - 1).unwrap().raw(), base);
248            }
249
250            if i != 0 {
251                assert!(PowerOfTwo::new(base + 1).is_err(), "failed for i = {}", i);
252            }
253
254            assert_eq!(p.arg_mod(0), 0);
255            assert_eq!(p.arg_mod(p.raw()), 0);
256
257            assert_eq!(p.arg_align_offset(0), 0);
258            assert_eq!(p.arg_align_offset(base), 0);
259
260            assert_eq!(p.arg_checked_next_multiple_of(0), Some(0));
261            assert_eq!(p.arg_checked_next_multiple_of(base), Some(base));
262
263            assert_eq!(p.arg_checked_next_multiple_of(1), Some(base));
264
265            if i > 1 {
266                assert_eq!(p.arg_mod(base + 1), 1);
267                assert_eq!(p.arg_mod(2 * base - 1), base - 1);
268
269                assert_eq!(p.arg_align_offset(base + 1), base - 1);
270                assert_eq!(p.arg_align_offset(2 * base - 1), 1);
271
272                assert_eq!(p.arg_checked_next_multiple_of(base + 1), Some(2 * base));
273                assert_eq!(p.arg_checked_next_multiple_of(2 * base - 1), Some(2 * base));
274            }
275        }
276
277        assert!(PowerOfTwo::next(2usize.pow(63) + 1).is_none());
278        assert!(PowerOfTwo::next(usize::MAX).is_none());
279    }
280}