Skip to main content

diskann_wide/
bitmask.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use super::{
7    arch,
8    constant::{Const, SupportedLaneCount},
9    splitjoin::{LoHi, SplitJoin},
10};
11
12/// A lane-wise mask represented as a bit-mask.
13///
14/// The representation for this type is the smallest unsigned integer capable of holding
15/// `N` bits.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub struct BitMask<const N: usize, A: arch::Sealed = arch::Current>(
18    pub <Const<N> as SupportedLaneCount>::BitMaskType,
19    A,
20)
21where
22    Const<N>: SupportedLaneCount;
23
24impl<const N: usize, A> BitMask<N, A>
25where
26    Const<N>: SupportedLaneCount,
27    A: arch::Sealed,
28{
29    pub fn as_scalar(self) -> BitMask<N, arch::emulated::Scalar> {
30        BitMask::<N, arch::emulated::Scalar>(self.0, arch::emulated::Scalar)
31    }
32
33    pub fn as_current(self) -> BitMask<N, arch::Current> {
34        BitMask::<N, arch::Current>(self.0, arch::current())
35    }
36
37    pub fn as_arch<B>(self, arch: B) -> BitMask<N, B>
38    where
39        B: arch::Sealed,
40    {
41        BitMask(self.0, arch)
42    }
43
44    pub(crate) fn get_arch(self) -> A {
45        self.1
46    }
47}
48
49/// Perform a potentially lossy conversion from a raw integer.
50///
51/// The associated constant `NARROWING` can be queried to check if the conversion is allowed
52/// to narrow from the provided integer.
53///
54/// Narrowing conversions will only retain the lower bits.
55pub trait FromInt<I, A: arch::Sealed> {
56    /// Will the conversion only sample from the lower-order bits of the provided integer.
57    const NARROWING: bool;
58    /// Turn an integer into an instance of `Self`.
59    fn from_int(arch: A, value: I) -> Self;
60}
61
62impl<A: arch::Sealed> FromInt<u8, A> for BitMask<1, A> {
63    const NARROWING: bool = true;
64    fn from_int(arch: A, value: u8) -> Self {
65        Self(value & 0x1, arch)
66    }
67}
68
69impl<A: arch::Sealed> FromInt<u8, A> for BitMask<2, A> {
70    const NARROWING: bool = true;
71    fn from_int(arch: A, value: u8) -> Self {
72        Self(value & 0x3, arch)
73    }
74}
75
76impl<A: arch::Sealed> FromInt<u8, A> for BitMask<4, A> {
77    const NARROWING: bool = true;
78    fn from_int(arch: A, value: u8) -> Self {
79        Self(value & 0xF, arch)
80    }
81}
82
83impl<A: arch::Sealed> FromInt<u8, A> for BitMask<8, A> {
84    const NARROWING: bool = false;
85    fn from_int(arch: A, value: u8) -> Self {
86        Self(value, arch)
87    }
88}
89
90impl<A: arch::Sealed> FromInt<u16, A> for BitMask<16, A> {
91    const NARROWING: bool = false;
92    fn from_int(arch: A, value: u16) -> Self {
93        Self(value, arch)
94    }
95}
96
97impl<A: arch::Sealed> FromInt<u32, A> for BitMask<32, A> {
98    const NARROWING: bool = false;
99    fn from_int(arch: A, value: u32) -> Self {
100        Self(value, arch)
101    }
102}
103
104impl<A: arch::Sealed> FromInt<u64, A> for BitMask<64, A> {
105    const NARROWING: bool = false;
106    fn from_int(arch: A, value: u64) -> Self {
107        Self(value, arch)
108    }
109}
110
111macro_rules! splitjoin {
112    ($from:literal, $to:literal, $mask:literal, $full:ty, $half:ty) => {
113        impl<A: arch::Sealed> SplitJoin for BitMask<$from, A> {
114            type Halved = BitMask<$to, A>;
115            fn split(self) -> LoHi<Self::Halved> {
116                let arch = self.1;
117                LoHi {
118                    lo: Self::Halved::from_int(arch, (self.0 & $mask) as $half),
119                    hi: Self::Halved::from_int(arch, ((self.0 >> $to) & $mask) as $half),
120                }
121            }
122
123            fn join(lohi: LoHi<Self::Halved>) -> Self {
124                let arch = lohi.lo.1;
125                let lo: $full = lohi.lo.0.into();
126                let hi: $full = lohi.hi.0.into();
127                Self(hi << $to | lo, arch)
128            }
129        }
130    };
131}
132
133splitjoin!(2, 1, 0x1, u8, u8);
134splitjoin!(4, 2, 0x3, u8, u8);
135splitjoin!(8, 4, 0xf, u8, u8);
136splitjoin!(16, 8, 0xff, u16, u8);
137splitjoin!(32, 16, 0xffff, u32, u16);
138splitjoin!(64, 32, 0xffff_ffff, u64, u32);