diskann_quantization/
num.rs1use std::{fmt::Debug, num::NonZeroUsize};
9
10use thiserror::Error;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14#[repr(transparent)]
15pub struct Positive<T>(T)
16where
17 T: PartialOrd + Default + Debug;
18
19#[derive(Debug, Clone, Copy, Error)]
20#[error("value {:?} is not greater than {:?} (its default value)", .0, T::default())]
21pub struct NotPositiveError<T: Debug + Default>(T);
22
23impl<T> Positive<T>
24where
25 T: PartialOrd + Default + Debug,
26{
27 pub fn new(value: T) -> Result<Self, NotPositiveError<T>> {
29 if value > T::default() {
30 Ok(Self(value))
31 } else {
32 Err(NotPositiveError(value))
33 }
34 }
35
36 pub const unsafe fn new_unchecked(value: T) -> Self {
42 Self(value)
43 }
44
45 pub fn into_inner(self) -> T {
47 self.0
48 }
49}
50
51pub(crate) const POSITIVE_ONE_F32: Positive<f32> = unsafe { Positive::new_unchecked(1.0) };
53
54#[derive(Debug, Clone, Copy, PartialEq)]
55#[repr(transparent)]
56pub struct PowerOfTwo(NonZeroUsize);
57
58#[derive(Debug, Clone, Copy, Error)]
59#[error("value {0} must be a power of two")]
60#[non_exhaustive]
61pub struct NotPowerOfTwo(usize);
62
63impl PowerOfTwo {
64 pub const fn new(value: usize) -> Result<Self, NotPowerOfTwo> {
66 let v = match NonZeroUsize::new(value) {
67 Some(value) => value,
68 None => return Err(NotPowerOfTwo(value)),
69 };
70 if v.is_power_of_two() {
71 Ok(unsafe { Self::new_unchecked(v) })
73 } else {
74 Err(NotPowerOfTwo(value))
75 }
76 }
77
78 pub const fn next(value: usize) -> Option<Self> {
81 match value.checked_next_power_of_two() {
83 Some(v) => Some(unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(v)) }),
88 None => None,
89 }
90 }
91
92 pub const unsafe fn new_unchecked(value: NonZeroUsize) -> Self {
98 Self(value)
99 }
100
101 pub const fn into_inner(self) -> NonZeroUsize {
103 self.0
104 }
105
106 pub const fn raw(self) -> usize {
108 self.0.get()
109 }
110
111 pub const fn from_align(layout: &std::alloc::Layout) -> Self {
113 unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(layout.align())) }
116 }
117
118 pub const fn alignment_of<T>() -> Self {
120 unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(std::mem::align_of::<T>())) }
123 }
124
125 pub const fn arg_mod(self, lhs: usize) -> usize {
132 lhs & (self.raw() - 1)
133 }
134
135 pub const fn arg_align_offset(self, lhs: usize) -> usize {
143 let m = self.arg_mod(lhs);
144 if m == 0 { 0 } else { self.raw() - m }
145 }
146
147 pub const fn arg_checked_next_multiple_of(self, lhs: usize) -> Option<usize> {
150 let offset = self.arg_align_offset(lhs);
151 lhs.checked_add(offset)
152 }
153}
154
155impl From<PowerOfTwo> for usize {
156 #[inline(always)]
157 fn from(value: PowerOfTwo) -> Self {
158 value.raw()
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 fn format_not_positive_error<T>(value: T) -> String
167 where
168 T: Debug + Default,
169 {
170 format!(
171 "value {:?} is not greater than {:?} (its default value)",
172 value,
173 T::default(),
174 )
175 }
176
177 #[test]
178 fn test_positive_f32() {
179 let x = Positive::<f32>::new(1.0);
180 assert!(x.is_ok());
181 let x = x.unwrap();
182 assert_eq!(x.into_inner(), 1.0);
183
184 let x = Positive::<f32>::new(0.0);
186 assert!(x.is_err());
187 assert_eq!(
188 x.unwrap_err().to_string(),
189 format_not_positive_error::<f32>(0.0)
190 );
191
192 let x = Positive::<f32>::new(-1.0);
194 assert!(x.is_err());
195 assert_eq!(
196 x.unwrap_err().to_string(),
197 format_not_positive_error::<f32>(-1.0)
198 );
199
200 let x = unsafe { Positive::<f32>::new_unchecked(1.0) };
202 assert_eq!(x.into_inner(), 1.0);
203 }
204
205 #[test]
206 fn test_positive_i64() {
207 let x = Positive::<i64>::new(1);
208 assert!(x.is_ok());
209 let x = x.unwrap();
210 assert_eq!(x.into_inner(), 1);
211
212 let x = Positive::<i64>::new(0);
214 assert!(x.is_err());
215 assert_eq!(
216 x.unwrap_err().to_string(),
217 format_not_positive_error::<i64>(0)
218 );
219
220 let x = Positive::<i64>::new(-1);
222 assert!(x.is_err());
223 assert_eq!(
224 x.unwrap_err().to_string(),
225 format_not_positive_error::<i64>(-1)
226 );
227
228 let x = unsafe { Positive::<i64>::new_unchecked(1) };
230 assert_eq!(x.into_inner(), 1);
231 }
232
233 #[test]
234 fn test_power_of_two() {
235 assert!(PowerOfTwo::new(0).is_err());
236 assert_eq!(PowerOfTwo::next(0).unwrap(), PowerOfTwo::new(1).unwrap());
237 for i in 0..63 {
238 let base = 2usize.pow(i);
239 let p = PowerOfTwo::new(base).unwrap();
240 assert_eq!(p.into_inner().get(), base);
241 assert_eq!(p.raw(), base);
242 assert_eq!(PowerOfTwo::new(base).unwrap().raw(), base);
243 assert_eq!(<_ as Into<usize>>::into(p), base);
244
245 if i != 1 {
246 assert!(PowerOfTwo::new(base - 1).is_err(), "failed for i = {}", i);
247 assert_eq!(PowerOfTwo::next(base - 1).unwrap().raw(), base);
248 }
249
250 if i != 0 {
251 assert!(PowerOfTwo::new(base + 1).is_err(), "failed for i = {}", i);
252 }
253
254 assert_eq!(p.arg_mod(0), 0);
255 assert_eq!(p.arg_mod(p.raw()), 0);
256
257 assert_eq!(p.arg_align_offset(0), 0);
258 assert_eq!(p.arg_align_offset(base), 0);
259
260 assert_eq!(p.arg_checked_next_multiple_of(0), Some(0));
261 assert_eq!(p.arg_checked_next_multiple_of(base), Some(base));
262
263 assert_eq!(p.arg_checked_next_multiple_of(1), Some(base));
264
265 if i > 1 {
266 assert_eq!(p.arg_mod(base + 1), 1);
267 assert_eq!(p.arg_mod(2 * base - 1), base - 1);
268
269 assert_eq!(p.arg_align_offset(base + 1), base - 1);
270 assert_eq!(p.arg_align_offset(2 * base - 1), 1);
271
272 assert_eq!(p.arg_checked_next_multiple_of(base + 1), Some(2 * base));
273 assert_eq!(p.arg_checked_next_multiple_of(2 * base - 1), Some(2 * base));
274 }
275 }
276
277 assert!(PowerOfTwo::next(2usize.pow(63) + 1).is_none());
278 assert!(PowerOfTwo::next(usize::MAX).is_none());
279 }
280}