diskann_quantization/
num.rs1use std::{fmt::Debug, num::NonZeroUsize};
9
10use thiserror::Error;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14#[repr(transparent)]
15pub struct Positive<T>(T)
16where
17 T: PartialOrd + Default + Debug;
18
19#[derive(Debug, Clone, Copy, Error)]
20#[error("value {:?} is not greater than {:?} (its default value)", .0, T::default())]
21pub struct NotPositiveError<T: Debug + Default>(T);
22
23impl<T> Positive<T>
24where
25 T: PartialOrd + Default + Debug,
26{
27 pub fn new(value: T) -> Result<Self, NotPositiveError<T>> {
29 if value > T::default() {
30 Ok(Self(value))
31 } else {
32 Err(NotPositiveError(value))
33 }
34 }
35
36 pub const unsafe fn new_unchecked(value: T) -> Self {
42 Self(value)
43 }
44
45 pub fn into_inner(self) -> T {
47 self.0
48 }
49}
50
51pub(crate) const POSITIVE_ONE_F32: Positive<f32> = unsafe { Positive::new_unchecked(1.0) };
53
54#[derive(Debug, Clone, Copy, PartialEq)]
55#[repr(transparent)]
56pub struct PowerOfTwo(NonZeroUsize);
57
58#[derive(Debug, Clone, Copy, Error)]
59#[error("value {0} must be a power of two")]
60#[non_exhaustive]
61pub struct NotPowerOfTwo(usize);
62
63impl PowerOfTwo {
64 pub const fn new(value: usize) -> Result<Self, NotPowerOfTwo> {
66 let v = match NonZeroUsize::new(value) {
67 Some(value) => value,
68 None => return Err(NotPowerOfTwo(value)),
69 };
70 if v.is_power_of_two() {
71 Ok(unsafe { Self::new_unchecked(v) })
73 } else {
74 Err(NotPowerOfTwo(value))
75 }
76 }
77
78 pub const fn next(value: usize) -> Option<Self> {
81 match value.checked_next_power_of_two() {
83 Some(v) => Some(unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(v)) }),
88 None => None,
89 }
90 }
91
92 pub const unsafe fn new_unchecked(value: NonZeroUsize) -> Self {
98 Self(value)
99 }
100
101 pub const fn into_inner(self) -> NonZeroUsize {
103 self.0
104 }
105
106 pub const fn raw(self) -> usize {
108 self.0.get()
109 }
110
111 pub const fn from_align(layout: &std::alloc::Layout) -> Self {
113 unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(layout.align())) }
116 }
117
118 pub const fn alignment_of<T>() -> Self {
120 unsafe { Self::new_unchecked(NonZeroUsize::new_unchecked(std::mem::align_of::<T>())) }
123 }
124
125 pub const fn arg_mod(self, lhs: usize) -> usize {
132 lhs & (self.raw() - 1)
133 }
134
135 pub const fn arg_align_offset(self, lhs: usize) -> usize {
143 let m = self.arg_mod(lhs);
144 if m == 0 {
145 0
146 } else {
147 self.raw() - m
148 }
149 }
150
151 pub const fn arg_checked_next_multiple_of(self, lhs: usize) -> Option<usize> {
154 let offset = self.arg_align_offset(lhs);
155 lhs.checked_add(offset)
156 }
157}
158
159impl From<PowerOfTwo> for usize {
160 #[inline(always)]
161 fn from(value: PowerOfTwo) -> Self {
162 value.raw()
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 fn format_not_positive_error<T>(value: T) -> String
171 where
172 T: Debug + Default,
173 {
174 format!(
175 "value {:?} is not greater than {:?} (its default value)",
176 value,
177 T::default(),
178 )
179 }
180
181 #[test]
182 fn test_positive_f32() {
183 let x = Positive::<f32>::new(1.0);
184 assert!(x.is_ok());
185 let x = x.unwrap();
186 assert_eq!(x.into_inner(), 1.0);
187
188 let x = Positive::<f32>::new(0.0);
190 assert!(x.is_err());
191 assert_eq!(
192 x.unwrap_err().to_string(),
193 format_not_positive_error::<f32>(0.0)
194 );
195
196 let x = Positive::<f32>::new(-1.0);
198 assert!(x.is_err());
199 assert_eq!(
200 x.unwrap_err().to_string(),
201 format_not_positive_error::<f32>(-1.0)
202 );
203
204 let x = unsafe { Positive::<f32>::new_unchecked(1.0) };
206 assert_eq!(x.into_inner(), 1.0);
207 }
208
209 #[test]
210 fn test_positive_i64() {
211 let x = Positive::<i64>::new(1);
212 assert!(x.is_ok());
213 let x = x.unwrap();
214 assert_eq!(x.into_inner(), 1);
215
216 let x = Positive::<i64>::new(0);
218 assert!(x.is_err());
219 assert_eq!(
220 x.unwrap_err().to_string(),
221 format_not_positive_error::<i64>(0)
222 );
223
224 let x = Positive::<i64>::new(-1);
226 assert!(x.is_err());
227 assert_eq!(
228 x.unwrap_err().to_string(),
229 format_not_positive_error::<i64>(-1)
230 );
231
232 let x = unsafe { Positive::<i64>::new_unchecked(1) };
234 assert_eq!(x.into_inner(), 1);
235 }
236
237 #[test]
238 fn test_power_of_two() {
239 assert!(PowerOfTwo::new(0).is_err());
240 assert_eq!(PowerOfTwo::next(0).unwrap(), PowerOfTwo::new(1).unwrap());
241 for i in 0..63 {
242 let base = 2usize.pow(i);
243 let p = PowerOfTwo::new(base).unwrap();
244 assert_eq!(p.into_inner().get(), base);
245 assert_eq!(p.raw(), base);
246 assert_eq!(PowerOfTwo::new(base).unwrap().raw(), base);
247 assert_eq!(<_ as Into<usize>>::into(p), base);
248
249 if i != 1 {
250 assert!(PowerOfTwo::new(base - 1).is_err(), "failed for i = {}", i);
251 assert_eq!(PowerOfTwo::next(base - 1).unwrap().raw(), base);
252 }
253
254 if i != 0 {
255 assert!(PowerOfTwo::new(base + 1).is_err(), "failed for i = {}", i);
256 }
257
258 assert_eq!(p.arg_mod(0), 0);
259 assert_eq!(p.arg_mod(p.raw()), 0);
260
261 assert_eq!(p.arg_align_offset(0), 0);
262 assert_eq!(p.arg_align_offset(base), 0);
263
264 assert_eq!(p.arg_checked_next_multiple_of(0), Some(0));
265 assert_eq!(p.arg_checked_next_multiple_of(base), Some(base));
266
267 assert_eq!(p.arg_checked_next_multiple_of(1), Some(base));
268
269 if i > 1 {
270 assert_eq!(p.arg_mod(base + 1), 1);
271 assert_eq!(p.arg_mod(2 * base - 1), base - 1);
272
273 assert_eq!(p.arg_align_offset(base + 1), base - 1);
274 assert_eq!(p.arg_align_offset(2 * base - 1), 1);
275
276 assert_eq!(p.arg_checked_next_multiple_of(base + 1), Some(2 * base));
277 assert_eq!(p.arg_checked_next_multiple_of(2 * base - 1), Some(2 * base));
278 }
279 }
280
281 assert!(PowerOfTwo::next(2usize.pow(63) + 1).is_none());
282 assert!(PowerOfTwo::next(usize::MAX).is_none());
283 }
284}