Skip to main content

diskann_quantization/
num.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Number types with limited dynamic range.
7
8use std::{fmt::Debug, num::NonZeroUsize};
9
10use thiserror::Error;
11
12/// A number type that must be greater than zero.
13#[derive(Debug, Clone, Copy, PartialEq)]
14#[repr(transparent)]
15pub struct Positive<T>(T)
16where
17    T: PartialOrd + Default + Debug;
18
19#[derive(Debug, Clone, Copy, Error)]
20#[error("value {:?} is not greater than {:?} (its default value)", .0, T::default())]
21pub struct NotPositiveError<T: Debug + Default>(T);
22
23impl<T> Positive<T>
24where
25    T: PartialOrd + Default + Debug,
26{
27    /// Create a new `Positive` if the given value is greater than 0 (`T::default()`);
28    pub fn new(value: T) -> Result<Self, NotPositiveError<T>> {
29        if value > T::default() {
30            Ok(Self(value))
31        } else {
32            Err(NotPositiveError(value))
33        }
34    }
35
36    /// Create a new `Positive` without checking whether the value is greater than 0.
37    ///
38    /// # Safety
39    ///
40    /// The value must be greater than `T::default()`.
41    pub const unsafe fn new_unchecked(value: T) -> Self {
42        Self(value)
43    }
44
45    /// Consume `self` and return the inner value.
46    pub fn into_inner(self) -> T {
47        self.0
48    }
49}
50
51// SAFETY: 1.0 is positive.
52pub(crate) const POSITIVE_ONE_F32: Positive<f32> = unsafe { Positive::new_unchecked(1.0) };
53
54#[derive(Debug, Clone, Copy, PartialEq)]
55#[repr(transparent)]
56pub struct PowerOfTwo(NonZeroUsize);
57
58#[derive(Debug, Clone, Copy, Error)]
59#[error("value {0} must be a power of two")]
60#[non_exhaustive]
61pub struct NotPowerOfTwo(usize);
62
63macro_rules! constants {
64    (($pow:ident, $value:ident) => $shift:literal) => {
65        pub const $pow: Self = match Self::new(1 << $shift) {
66            Ok(v) => v,
67            Err(_) => panic!("not a power of two"),
68        };
69
70        pub const $value: Self = Self::$pow;
71    };
72    ($(($pow:ident, $value:ident) => $shift:literal),+ $(,)?) => {
73        $(constants!(($pow, $value) => $shift);)+
74    };
75}
76
77impl PowerOfTwo {
78    constants! {
79        (P0, V1) => 0,
80        (P1, V2) => 1,
81        (P2, V4) => 2,
82        (P3, V8) => 3,
83        (P4, V16) => 4,
84        (P5, V32) => 5,
85        (P6, V64) => 6,
86        (P7, V128) => 7,
87        (P8, V256) => 8,
88        (P9, V512) => 9,
89        (P10, V1024) => 10,
90        (P11, V2048) => 11,
91        (P12, V4096) => 12,
92        (P13, V8192) => 13,
93        (P14, V16384) => 14,
94        (P15, V32768) => 15,
95        (P16, V65536) => 16,
96        (P17, V131072) => 17,
97        (P18, V262144) => 18,
98        (P19, V524288) => 19,
99        (P20, V1048576) => 20,
100        (P21, V2097152) => 21,
101        (P22, V4194304) => 22,
102        (P23, V8388608) => 23,
103        (P24, V16777216) => 24,
104        (P25, V33554432) => 25,
105        (P26, V67108864) => 26,
106        (P27, V134217728) => 27,
107        (P28, V268435456) => 28,
108        (P29, V536870912) => 29,
109        (P30, V1073741824) => 30,
110        (P31, V2147483648) => 31,
111    }
112}
113
114#[cfg(target_pointer_width = "64")]
115impl PowerOfTwo {
116    constants! {
117        (P32, V4294967296) => 32,
118        (P33, V8589934592) => 33,
119        (P34, V17179869184) => 34,
120        (P35, V34359738368) => 35,
121        (P36, V68719476736) => 36,
122        (P37, V137438953472) => 37,
123        (P38, V274877906944) => 38,
124        (P39, V549755813888) => 39,
125        (P40, V1099511627776) => 40,
126        (P41, V2199023255552) => 41,
127        (P42, V4398046511104) => 42,
128        (P43, V8796093022208) => 43,
129        (P44, V17592186044416) => 44,
130        (P45, V35184372088832) => 45,
131        (P46, V70368744177664) => 46,
132        (P47, V140737488355328) => 47,
133        (P48, V281474976710656) => 48,
134        (P49, V562949953421312) => 49,
135        (P50, V1125899906842624) => 50,
136        (P51, V2251799813685248) => 51,
137        (P52, V4503599627370496) => 52,
138        (P53, V9007199254740992) => 53,
139        (P54, V18014398509481984) => 54,
140        (P55, V36028797018963968) => 55,
141        (P56, V72057594037927936) => 56,
142        (P57, V144115188075855872) => 57,
143        (P58, V288230376151711744) => 58,
144        (P59, V576460752303423488) => 59,
145        (P60, V1152921504606846976) => 60,
146        (P61, V2305843009213693952) => 61,
147        (P62, V4611686018427387904) => 62,
148        (P63, V9223372036854775808) => 63,
149    }
150}
151
152impl PowerOfTwo {
153    /// Create a new `PowerOfTwo` if the given value is greater a power of two.
154    pub const fn new(value: usize) -> Result<Self, NotPowerOfTwo> {
155        let v = match NonZeroUsize::new(value) {
156            Some(value) => value,
157            None => return Err(NotPowerOfTwo(value)),
158        };
159        if v.is_power_of_two() {
160            // Safety: We just checked.
161            Ok(unsafe { Self::new_unchecked(v) })
162        } else {
163            Err(NotPowerOfTwo(value))
164        }
165    }
166
167    /// Return the smallest power of two greater than or equal to `value`. If the next
168    /// power of two is greater than `usize::MAX`, `None` is returned.
169    pub const fn next(value: usize) -> Option<Self> {
170        // Note: use `match` instead of `Option::map` for `const`-compatibility.
171        match value.checked_next_power_of_two() {
172            // SAFETY: We trust the implementation of `usize::checked_next_power_of_two` since:
173            //
174            // * 0 can never be a power of two and thus cannot be returned.
175            // * If it succeeds, the returned value should be a power of two.
176            Some(v) => Some(unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(v)) }),
177            None => None,
178        }
179    }
180
181    /// Create a new `PowerOfTwo` without checking whether the value is a power of two.
182    ///
183    /// # Safety
184    ///
185    /// The value must be a power of two.
186    pub const unsafe fn new_unchecked(value: NonZeroUsize) -> Self {
187        Self(value)
188    }
189
190    /// Consume `self` and return the inner value.
191    pub const fn into_inner(self) -> NonZeroUsize {
192        self.0
193    }
194
195    /// Consume `self` and return the inner value as a `usize`.
196    pub const fn raw(self) -> usize {
197        self.0.get()
198    }
199
200    /// Construct `self` from the alignment in `layout`.
201    pub const fn from_align(layout: &std::alloc::Layout) -> Self {
202        // SAFETY: Alignment is guaranteed to be a power of two:
203        // - <https://doc.rust-lang.org/beta/std/alloc/struct.Layout.html#method.align>
204        unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(layout.align())) }
205    }
206
207    /// Return the alignment of `T` as a power of two.
208    pub const fn alignment_of<T>() -> Self {
209        // SAFETY: Alignment is guaranteed to be a power of two:
210        // - <https://doc.rust-lang.org/beta/std/alloc/struct.Layout.html#method.align>
211        unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(std::mem::align_of::<T>())) }
212    }
213
214    /// Compute the operation `lhs % self`.
215    ///
216    /// # Note
217    ///
218    /// The argument order of this function is reversed from the typical `align_offset`
219    /// method in the standard library.
220    pub const fn arg_mod(self, lhs: usize) -> usize {
221        lhs & (self.raw() - 1)
222    }
223
224    /// Compute the amount `x` that would have to be added to `lhs` so the quantity
225    /// `lhs + x` is a multiple of `self`.
226    ///
227    /// # Note
228    ///
229    /// The argument order of this function is reversed from the typical `align_offset`
230    /// method in the standard library.
231    pub const fn arg_align_offset(self, lhs: usize) -> usize {
232        let m = self.arg_mod(lhs);
233        if m == 0 { 0 } else { self.raw() - m }
234    }
235
236    /// Calculate the smallest value greater than or equal to `lhs` that is a multiple of
237    /// `self`. Return `None` if the operation would result in an overflow.
238    pub const fn arg_checked_next_multiple_of(self, lhs: usize) -> Option<usize> {
239        let offset = self.arg_align_offset(lhs);
240        lhs.checked_add(offset)
241    }
242}
243
244impl From<PowerOfTwo> for usize {
245    #[inline(always)]
246    fn from(value: PowerOfTwo) -> Self {
247        value.raw()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    fn format_not_positive_error<T>(value: T) -> String
256    where
257        T: Debug + Default,
258    {
259        format!(
260            "value {:?} is not greater than {:?} (its default value)",
261            value,
262            T::default(),
263        )
264    }
265
266    #[test]
267    fn test_positive_f32() {
268        let x = Positive::<f32>::new(1.0);
269        assert!(x.is_ok());
270        let x = x.unwrap();
271        assert_eq!(x.into_inner(), 1.0);
272
273        // Using 0 should return an error.
274        let x = Positive::<f32>::new(0.0);
275        assert!(x.is_err());
276        assert_eq!(
277            x.unwrap_err().to_string(),
278            format_not_positive_error::<f32>(0.0)
279        );
280
281        // Using -1 should return an error.
282        let x = Positive::<f32>::new(-1.0);
283        assert!(x.is_err());
284        assert_eq!(
285            x.unwrap_err().to_string(),
286            format_not_positive_error::<f32>(-1.0)
287        );
288
289        // SAFETY: 1.0 is greater than zero.
290        let x = unsafe { Positive::<f32>::new_unchecked(1.0) };
291        assert_eq!(x.into_inner(), 1.0);
292    }
293
294    #[test]
295    fn test_positive_i64() {
296        let x = Positive::<i64>::new(1);
297        assert!(x.is_ok());
298        let x = x.unwrap();
299        assert_eq!(x.into_inner(), 1);
300
301        // Using 0 should return an error.
302        let x = Positive::<i64>::new(0);
303        assert!(x.is_err());
304        assert_eq!(
305            x.unwrap_err().to_string(),
306            format_not_positive_error::<i64>(0)
307        );
308
309        // Using -1 should return an error.
310        let x = Positive::<i64>::new(-1);
311        assert!(x.is_err());
312        assert_eq!(
313            x.unwrap_err().to_string(),
314            format_not_positive_error::<i64>(-1)
315        );
316
317        // SAFETY: 1 is greater than zero.
318        let x = unsafe { Positive::<i64>::new_unchecked(1) };
319        assert_eq!(x.into_inner(), 1);
320    }
321
322    #[test]
323    fn test_power_of_two() {
324        assert!(PowerOfTwo::new(0).is_err());
325        assert_eq!(PowerOfTwo::next(0).unwrap(), PowerOfTwo::new(1).unwrap());
326        for i in 0..63 {
327            let base = 2usize.pow(i);
328            let p = PowerOfTwo::new(base).unwrap();
329            assert_eq!(p.into_inner().get(), base);
330            assert_eq!(p.raw(), base);
331            assert_eq!(PowerOfTwo::new(base).unwrap().raw(), base);
332            assert_eq!(<_ as Into<usize>>::into(p), base);
333
334            if i != 1 {
335                assert!(PowerOfTwo::new(base - 1).is_err(), "failed for i = {}", i);
336                assert_eq!(PowerOfTwo::next(base - 1).unwrap().raw(), base);
337            }
338
339            if i != 0 {
340                assert!(PowerOfTwo::new(base + 1).is_err(), "failed for i = {}", i);
341            }
342
343            assert_eq!(p.arg_mod(0), 0);
344            assert_eq!(p.arg_mod(p.raw()), 0);
345
346            assert_eq!(p.arg_align_offset(0), 0);
347            assert_eq!(p.arg_align_offset(base), 0);
348
349            assert_eq!(p.arg_checked_next_multiple_of(0), Some(0));
350            assert_eq!(p.arg_checked_next_multiple_of(base), Some(base));
351
352            assert_eq!(p.arg_checked_next_multiple_of(1), Some(base));
353
354            if i > 1 {
355                assert_eq!(p.arg_mod(base + 1), 1);
356                assert_eq!(p.arg_mod(2 * base - 1), base - 1);
357
358                assert_eq!(p.arg_align_offset(base + 1), base - 1);
359                assert_eq!(p.arg_align_offset(2 * base - 1), 1);
360
361                assert_eq!(p.arg_checked_next_multiple_of(base + 1), Some(2 * base));
362                assert_eq!(p.arg_checked_next_multiple_of(2 * base - 1), Some(2 * base));
363            }
364        }
365
366        assert!(PowerOfTwo::next(2usize.pow(63) + 1).is_none());
367        assert!(PowerOfTwo::next(usize::MAX).is_none());
368    }
369}