use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hash, Hasher};
use siphasher::sip::SipHasher13;
pub const MAX_SHARDS: usize = 1 << 16;
#[must_use]
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct ShardSelector {
shards: usize,
k0: u64,
k1: u64,
}
impl ShardSelector {
pub fn new(shards: usize, seed: u64) -> Self {
let (k0, k1) = derive_keys(seed);
Self {
shards: clamp_shards(shards),
k0,
k1,
}
}
pub fn randomized(shards: usize) -> Self {
let (k0, k1) = random_key_pair();
Self {
shards: clamp_shards(shards),
k0,
k1,
}
}
pub fn shard_count(&self) -> usize {
self.shards
}
pub fn shard_for_key<K: Hash + ?Sized>(&self, key: &K) -> usize {
let mut hasher = SipHasher13::new_with_keys(self.k0, self.k1);
key.hash(&mut hasher);
let h = hasher.finish();
((u128::from(h) * self.shards as u128) >> 64) as usize
}
}
impl Default for ShardSelector {
fn default() -> Self {
Self::randomized(1)
}
}
impl core::fmt::Debug for ShardSelector {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ShardSelector")
.field("shards", &self.shards)
.field("keys", &"<redacted>")
.finish()
}
}
#[inline]
fn clamp_shards(shards: usize) -> usize {
shards.clamp(1, MAX_SHARDS)
}
#[inline]
fn derive_keys(seed: u64) -> (u64, u64) {
let k0 = splitmix64(seed ^ 0x9E37_79B9_7F4A_7C15);
let k1 = splitmix64(seed ^ 0xBF58_476D_1CE4_E5B9);
(k0, k1)
}
#[inline]
fn splitmix64(mut x: u64) -> u64 {
x = (x ^ (x >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
x = (x ^ (x >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
x ^ (x >> 31)
}
fn random_key_pair() -> (u64, u64) {
let rs = RandomState::new();
let k0 = rs.hash_one(0x0000_0000_0000_0000_u64);
let k1 = rs.hash_one(0xFFFF_FFFF_FFFF_FFFF_u64);
(k0, k1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shard_selector_is_deterministic() {
let selector = &ShardSelector::new(8, 123);
let a = selector.shard_for_key(&"key");
let b = selector.shard_for_key(&"key");
assert_eq!(a, b);
assert!(a < selector.shard_count());
}
#[test]
fn shards_clamped_to_max_shards() {
let selector = ShardSelector::new(usize::MAX, 0);
assert_eq!(selector.shard_count(), MAX_SHARDS);
}
#[test]
fn randomized_selectors_differ_across_construction() {
let mut keys = std::collections::HashSet::new();
for _ in 0..32 {
let s = ShardSelector::randomized(8);
keys.insert((s.k0, s.k1));
}
assert!(
keys.len() > 1,
"randomized selectors produced identical keys across 32 constructions"
);
}
#[test]
fn debug_impl_does_not_leak_key_material() {
let selector = ShardSelector::new(4, 0xDEAD_BEEF_CAFE_F00D);
let rendered = format!("{selector:?}");
assert!(rendered.contains("<redacted>"));
assert!(!rendered.contains("DEAD"));
assert!(!rendered.contains("dead"));
assert!(!rendered.contains("3735928559")); assert!(!rendered.contains(&selector.k0.to_string()));
assert!(!rendered.contains(&selector.k1.to_string()));
}
#[test]
fn derive_keys_is_non_trivial_for_zero_seed() {
let (k0, k1) = derive_keys(0);
assert_ne!(k0, 0, "seed=0 must not produce k0=0");
assert_ne!(k1, 0, "seed=0 must not produce k1=0");
assert_ne!(k0, k1, "derived keys must differ from each other");
}
#[test]
fn fastrange_reduction_covers_full_range() {
let selector = ShardSelector::new(16, 0x1234_5678_9ABC_DEF0);
let mut hits = [false; 16];
for i in 0..1024u32 {
hits[selector.shard_for_key(&i)] = true;
}
assert!(
hits.iter().all(|&h| h),
"fastrange did not cover all shards: {hits:?}"
);
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[cfg_attr(miri, ignore)]
#[test]
fn prop_deterministic_mapping(
shard_count in 1usize..64,
seed in any::<u64>(),
key in any::<u32>()
) {
let selector = ShardSelector::new(shard_count, seed);
let shard1 = selector.shard_for_key(&key);
let shard2 = selector.shard_for_key(&key);
let shard3 = selector.shard_for_key(&key);
prop_assert_eq!(shard1, shard2);
prop_assert_eq!(shard2, shard3);
}
#[cfg_attr(miri, ignore)]
#[test]
fn prop_deterministic_batch(
shard_count in 1usize..64,
seed in any::<u64>(),
keys in prop::collection::vec(any::<u32>(), 0..50)
) {
let selector = ShardSelector::new(shard_count, seed);
let shards1: Vec<_> = keys.iter().map(|k| selector.shard_for_key(k)).collect();
let shards2: Vec<_> = keys.iter().map(|k| selector.shard_for_key(k)).collect();
prop_assert_eq!(shards1, shards2);
}
}
proptest! {
#[cfg_attr(miri, ignore)]
#[test]
fn prop_shard_in_range(
shard_count in 1usize..128,
seed in any::<u64>(),
key in any::<u64>()
) {
let selector = ShardSelector::new(shard_count, seed);
let shard = selector.shard_for_key(&key);
prop_assert!(shard < shard_count);
prop_assert!(shard < selector.shard_count());
}
#[cfg_attr(miri, ignore)]
#[test]
fn prop_all_keys_valid_range(
shard_count in 1usize..64,
seed in any::<u64>(),
keys in prop::collection::vec(any::<u32>(), 0..100)
) {
let selector = ShardSelector::new(shard_count, seed);
for key in keys {
let shard = selector.shard_for_key(&key);
prop_assert!(shard < shard_count);
}
}
}
proptest! {
#[cfg_attr(miri, ignore)]
#[test]
fn prop_shard_count_matches(
shard_count in 1usize..128,
seed in any::<u64>()
) {
let selector = ShardSelector::new(shard_count, seed);
prop_assert_eq!(selector.shard_count(), shard_count);
}
#[cfg_attr(miri, ignore)]
#[test]
fn prop_zero_shards_clamped(seed in any::<u64>()) {
let selector = ShardSelector::new(0, seed);
prop_assert_eq!(selector.shard_count(), 1);
for i in 0..10u32 {
let shard = selector.shard_for_key(&i);
prop_assert_eq!(shard, 0);
}
}
#[cfg_attr(miri, ignore)]
#[test]
fn prop_oversized_shards_clamped(
shard_count in (MAX_SHARDS + 1)..=usize::MAX,
seed in any::<u64>()
) {
let selector = ShardSelector::new(shard_count, seed);
prop_assert_eq!(selector.shard_count(), MAX_SHARDS);
}
}
proptest! {
#[cfg_attr(miri, ignore)]
#[test]
fn prop_single_shard_returns_zero(
seed in any::<u64>(),
keys in prop::collection::vec(any::<u32>(), 0..50)
) {
let selector = ShardSelector::new(1, seed);
for key in keys {
let shard = selector.shard_for_key(&key);
prop_assert_eq!(shard, 0);
}
}
}
proptest! {
#[cfg_attr(miri, ignore)]
#[test]
fn prop_different_seeds_different_selectors(
shard_count in 1usize..64,
seed1 in any::<u64>(),
seed2 in any::<u64>()
) {
prop_assume!(seed1 != seed2);
let selector1 = ShardSelector::new(shard_count, seed1);
let selector2 = ShardSelector::new(shard_count, seed2);
prop_assert_ne!(selector1, selector2);
}
#[cfg_attr(miri, ignore)]
#[test]
fn prop_seed_affects_mapping(
shard_count in 2usize..16,
seed1 in any::<u64>(),
seed2 in any::<u64>(),
keys in prop::collection::vec(any::<u32>(), 10..50)
) {
prop_assume!(seed1 != seed2);
let selector1 = ShardSelector::new(shard_count, seed1);
let selector2 = ShardSelector::new(shard_count, seed2);
for key in &keys {
let _shard1 = selector1.shard_for_key(key);
let _shard2 = selector2.shard_for_key(key);
}
}
}
proptest! {
#[cfg_attr(miri, ignore)]
#[test]
fn prop_keys_use_shards(
shard_count in 2usize..16,
seed in any::<u64>(),
keys in prop::collection::vec(any::<u32>(), 20..100)
) {
let selector = ShardSelector::new(shard_count, seed);
let mut shard_counts = vec![0usize; shard_count];
for key in &keys {
let shard = selector.shard_for_key(key);
shard_counts[shard] += 1;
}
let used_shards = shard_counts.iter().filter(|&&count| count > 0).count();
prop_assert!(used_shards > 0);
let unique_keys: std::collections::HashSet<_> = keys.iter().collect();
if unique_keys.len() >= shard_count * 2 {
prop_assert!(used_shards > 1);
}
}
}
proptest! {
#[cfg_attr(miri, ignore)]
#[test]
fn prop_works_with_u32(
shard_count in 1usize..32,
seed in any::<u64>(),
keys in prop::collection::vec(any::<u32>(), 0..30)
) {
let selector = ShardSelector::new(shard_count, seed);
for key in keys {
let shard = selector.shard_for_key(&key);
prop_assert!(shard < shard_count);
}
}
#[cfg_attr(miri, ignore)]
#[test]
fn prop_works_with_u64(
shard_count in 1usize..32,
seed in any::<u64>(),
keys in prop::collection::vec(any::<u64>(), 0..30)
) {
let selector = ShardSelector::new(shard_count, seed);
for key in keys {
let shard = selector.shard_for_key(&key);
prop_assert!(shard < shard_count);
}
}
#[cfg_attr(miri, ignore)]
#[test]
fn prop_works_with_string(
shard_count in 1usize..32,
seed in any::<u64>(),
keys in prop::collection::vec("[a-z]{1,10}", 0..30)
) {
let selector = ShardSelector::new(shard_count, seed);
for key in keys {
let shard = selector.shard_for_key(&key);
prop_assert!(shard < shard_count);
}
}
}
proptest! {
#[cfg_attr(miri, ignore)]
#[test]
fn prop_default_single_shard(keys in prop::collection::vec(any::<u32>(), 0..30)) {
let selector = ShardSelector::default();
prop_assert_eq!(selector.shard_count(), 1);
for key in keys {
let shard = selector.shard_for_key(&key);
prop_assert_eq!(shard, 0);
}
}
}
proptest! {
#[cfg_attr(miri, ignore)]
#[test]
fn prop_same_config_equal(
shard_count in 1usize..64,
seed in any::<u64>()
) {
let selector1 = ShardSelector::new(shard_count, seed);
let selector2 = ShardSelector::new(shard_count, seed);
prop_assert_eq!(selector1, selector2);
}
#[cfg_attr(miri, ignore)]
#[test]
fn prop_different_shard_count_not_equal(
shard_count1 in 1usize..32,
shard_count2 in 32usize..64,
seed in any::<u64>()
) {
let selector1 = ShardSelector::new(shard_count1, seed);
let selector2 = ShardSelector::new(shard_count2, seed);
prop_assert_ne!(selector1, selector2);
}
}
}