power_consistent_hash/
lib.rs

1use std::arch::asm;
2
3use pcg32::Pcg32;
4use thiserror::Error;
5use tracing::trace;
6
7mod pcg32;
8
9/// State of power consistent hash algorithm
10pub struct PowerConsistentHasher {
11    // number of buckets
12    n: u32,
13    // m - 1
14    m_minus_one: u32,
15    // m/2 - 1
16    m_half_minus_one: u32,
17}
18
19#[derive(Debug, Error)]
20pub enum PowerConsistentHasherError {
21    #[error("at least 2 buckets are required for consistent hashing")]
22    NotEnoughBuckets,
23}
24
25impl PowerConsistentHasher {
26    pub fn try_new(num_of_buckets: u32) -> Result<Self, PowerConsistentHasherError> {
27        if num_of_buckets < 2 {
28            return Err(PowerConsistentHasherError::NotEnoughBuckets);
29        }
30
31        // closest larger power of 2
32        let mut m = num_of_buckets;
33        // https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
34        m -= 1;
35        m |= m >> 1;
36        m |= m >> 2;
37        m |= m >> 4;
38        m |= m >> 8;
39        m |= m >> 16;
40        let m_minus_one = m;
41        m += 1;
42        let m_half_minus_one = (m >> 1) - 1;
43
44        trace!(
45            n = num_of_buckets,
46            upper_power_of_two = m,
47            "PowerConsistentHasher is initialized"
48        );
49
50        Ok(Self {
51            n: num_of_buckets,
52            m_minus_one,
53            m_half_minus_one,
54        })
55    }
56
57    #[cfg(feature = "seahash")]
58    /// Hash contiguous bytes buffer
59    pub fn hash_bytes(&self, buf: &[u8]) -> u32 {
60        let key = seahash::hash(buf);
61        self.hash_u64(key)
62    }
63
64    /// Hash u64 consistently.
65    ///
66    /// Used when keys are sufficiently distributed over u64 range
67    pub fn hash_u64(&self, key: u64) -> u32 {
68        let (r1, maybe_rng) = consistent_hash_power_of_two(key, self.m_minus_one);
69
70        if r1 < self.n {
71            trace!(r1 = r1, "Choice in [0; m) range");
72            return r1;
73        }
74        let mut rng = maybe_rng.expect("when r1 is not 0 rng has to be initialized");
75        rng.step();
76        let r2 = g(self.n, self.m_half_minus_one, rng);
77        trace!(
78            r2 = r2,
79            m_half_minus_one = self.m_half_minus_one,
80            "Calculated r2"
81        );
82
83        if r2 > self.m_half_minus_one {
84            trace!(r2 = r2, "Choice in [m/2; n) range");
85            return r2;
86        }
87        let (r, _) = consistent_hash_power_of_two(key, self.m_half_minus_one);
88        trace!(
89            r = r,
90            m_half_minus_one = self.m_half_minus_one,
91            "Choice in [0, m/2) range"
92        );
93        r
94    }
95}
96
97// f function, maps key to uniform integer range [0; m - 1]
98//
99// m has to be power of 2
100// Key should countain reasonbly random bits. key width >= log2(m).
101fn consistent_hash_power_of_two(key: u64, m_minus_one: u32) -> (u32, Option<Pcg32>) {
102    trace!(
103        key = key,
104        m_minus_one = m_minus_one,
105        "consistent_hash_power_of_two"
106    );
107
108    let log2bits = (key & m_minus_one as u64) as u32;
109    trace!(log2bits = log2bits, "log2bits");
110    if log2bits == 0 {
111        return (0, None);
112    }
113    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
114    // SAFETY: log2bits is not zero, bsr instruction will output definite result
115    let msb_set = unsafe { msb_bit_index(log2bits) };
116    #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
117    const MAX_BIT_INDEX: u32 = std::mem::size_of::<u64>() as u32 * 2 - 1;
118    #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
119    let msb_set = MAX_BIT_INDEX - log2bits.leading_zeros();
120
121    let mut rng = Pcg32::new(key, msb_set as u64);
122    // 2 ** msb_set, h >= 1
123    let h = 1_u32 << msb_set;
124    trace!(h = h, "Power of 2");
125    // rand integer in [h; 2h - 1] range
126    let r = h.wrapping_add(rng.next_u32() & h.wrapping_sub(1));
127    (r, Some(rng))
128}
129
130unsafe fn msb_bit_index(n: u32) -> u32 {
131    let bsr: u32;
132    asm!("bsr {:e}, {:e}", lateout(reg) bsr, in(reg) n, options(pure, nomem, nostack));
133    bsr
134}
135
136// returns integer in range [s, n) with a weighted probability
137//
138// U > (x+1)/(j+1), U is random in (0, 1)
139//
140// Transform to use integer PRNG:
141//
142// (j + 1) * U * u32::MAX > (x + 1) * u32::MAX
143// (j + 1) * (1 + rand_32) > (x + 1) * u32::MAX
144//
145// min_j + 1 = (x + 1) * u32::MAX.div_euclid(1 + rand32) + 1
146// min_j = ((x + 1) * u32::MAX.div_euclid(1 + rand32)
147// min_j = scaled_x.div_euclid(1 + rand32)
148//      where scaled_x = (x + 1) * u32::MAX
149//
150// r = min_j
151//
152// if min_j >= n
153// then
154// n * (1 + rand_32) <= scaled_x
155//
156// and x - result
157//
158// if doesn't hold then compute min_j, x = min_j and continue
159
160fn g(n: u32, s: u32, mut rng: Pcg32) -> u32 {
161    let mut x = s; // x < n
162    let n = n as u64;
163
164    loop {
165        let scaled_x: u64 = (x as u64 + 1) * u32::MAX as u64;
166        let rnd_plus_one = rng.next_u32() as u64 + 1;
167
168        // if x >= n then scaled_x >= (n + 1) * u32::MAX
169        // n * rnd_plus_one <= n * (1 + u32::MAX)
170        //                  <= n * u32::MAX + n <= (n + 1) * u32::MAX <= scaled_x
171        //                                          It always holds for 32 bit n
172        if n * rnd_plus_one <= scaled_x {
173            break;
174        }
175        // x < n
176
177        // n * rnd_plus_one > scaled_x
178        // n > scaled_x / rnd_plus_one
179        // n > scaled_x.div_euclid(rnd_plus_one)
180        // n > min_j = r = new_x
181        // thus x is set to r
182        rng.step();
183        // new x is not less then the previous x and less then n
184        x = scaled_x.div_euclid(rnd_plus_one) as u32;
185        debug_assert!((x as u64) < n);
186    }
187    x
188}
189
190#[cfg(test)]
191mod tests;