use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
pub struct Router {
shard_count: u16,
}
impl Router {
pub fn new(shard_count: u16) -> Self {
assert!(shard_count > 0, "shard_count must be greater than zero");
Self { shard_count }
}
#[inline]
pub fn route(&self, key: &[u8]) -> u16 {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
(hash % self.shard_count as u64) as u16
}
pub fn shard_name(&self, shard_id: u16) -> String {
assert!(
shard_id < self.shard_count,
"shard_id {} out of range (max: {})",
shard_id,
self.shard_count - 1
);
format!("shard_{:02}", shard_id)
}
pub fn shard_count(&self) -> u16 {
self.shard_count
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn test_deterministic_routing() {
let router = Router::new(16);
let key = b"user_123";
let shard1 = router.route(key);
let shard2 = router.route(key);
let shard3 = router.route(key);
assert_eq!(shard1, shard2);
assert_eq!(shard2, shard3);
}
#[test]
fn test_routes_within_range() {
let router = Router::new(16);
for i in 0..1000 {
let key = format!("key_{}", i);
let shard = router.route(key.as_bytes());
assert!(shard < 16, "shard {} out of range", shard);
}
}
#[test]
fn test_uniform_distribution() {
let router = Router::new(16);
let mut counts = vec![0usize; 16];
for i in 0..10_000 {
let key = format!("key_{}", i);
let shard = router.route(key.as_bytes());
counts[shard as usize] += 1;
}
for (shard_id, count) in counts.iter().enumerate() {
assert!(
*count > 500 && *count < 750,
"Shard {} has uneven distribution: {} keys",
shard_id,
count
);
}
}
#[test]
fn test_shard_name_format() {
let router = Router::new(100);
assert_eq!(router.shard_name(0), "shard_00");
assert_eq!(router.shard_name(9), "shard_09");
assert_eq!(router.shard_name(10), "shard_10");
assert_eq!(router.shard_name(99), "shard_99");
}
#[test]
fn test_different_keys_may_collide() {
let router = Router::new(4);
let mut assigned_shards = HashSet::new();
for i in 0..100 {
let key = format!("key_{}", i);
let shard = router.route(key.as_bytes());
assigned_shards.insert(shard);
}
assert_eq!(assigned_shards.len(), 4);
}
#[test]
#[should_panic(expected = "shard_count must be greater than zero")]
fn test_panics_on_zero_shards() {
Router::new(0);
}
#[test]
#[should_panic(expected = "out of range")]
fn test_shard_name_panics_on_invalid_id() {
let router = Router::new(16);
router.shard_name(16); }
#[test]
fn test_single_shard_routes_everything_to_zero() {
let router = Router::new(1);
for i in 0..100 {
let key = format!("key_{}", i);
assert_eq!(router.route(key.as_bytes()), 0);
}
}
#[test]
fn test_power_of_two_shards() {
for &shard_count in &[2, 4, 8, 16, 32, 64, 128, 256] {
let router = Router::new(shard_count);
let key = b"test_key";
let shard = router.route(key);
assert!(shard < shard_count);
}
}
#[test]
fn test_non_power_of_two_shards() {
for &shard_count in &[3, 5, 7, 13, 17, 100] {
let router = Router::new(shard_count);
let key = b"test_key";
let shard = router.route(key);
assert!(shard < shard_count);
}
}
}