use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ShardingStrategy {
#[default]
EntityHash,
RoundRobin,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ShardAssignment {
pub num_shards: usize,
pub buckets: Vec<Vec<String>>,
}
impl ShardAssignment {
pub fn total(&self) -> usize {
self.buckets.iter().map(|b| b.len()).sum()
}
pub fn shard_of(&self, id: &str) -> Option<usize> {
self.buckets.iter().enumerate().find_map(|(i, b)| {
if b.iter().any(|s| s == id) {
Some(i)
} else {
None
}
})
}
}
#[derive(Debug, Clone)]
pub struct ModelShardManager {
num_shards: usize,
strategy: ShardingStrategy,
rr_index: HashMap<String, usize>,
}
impl ModelShardManager {
pub fn new(num_shards: usize, strategy: ShardingStrategy) -> Self {
Self {
num_shards: num_shards.max(1),
strategy,
rr_index: HashMap::new(),
}
}
pub fn num_shards(&self) -> usize {
self.num_shards
}
pub fn strategy(&self) -> ShardingStrategy {
self.strategy
}
pub fn shard_for(&self, id: &str) -> usize {
match self.strategy {
ShardingStrategy::EntityHash => self.hash_shard(id),
ShardingStrategy::RoundRobin => {
self.rr_index
.get(id)
.copied()
.unwrap_or_else(|| self.hash_shard(id))
}
}
}
pub fn partition<I, S>(&mut self, ids: I) -> ShardAssignment
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let mut buckets: Vec<Vec<String>> = (0..self.num_shards).map(|_| Vec::new()).collect();
match self.strategy {
ShardingStrategy::EntityHash => {
for raw in ids {
let id: String = raw.into();
let shard = self.hash_shard(&id);
buckets[shard].push(id);
}
}
ShardingStrategy::RoundRobin => {
let mut next: usize = 0;
for raw in ids {
let id: String = raw.into();
let shard = *self.rr_index.entry(id.clone()).or_insert_with(|| {
let s = next % self.num_shards;
next += 1;
s
});
buckets[shard].push(id);
}
}
}
ShardAssignment {
num_shards: self.num_shards,
buckets,
}
}
pub fn reshard(&mut self, prior: &ShardAssignment) -> ShardAssignment {
let flat: Vec<String> = prior.buckets.iter().flatten().cloned().collect();
self.partition(flat)
}
fn hash_shard(&self, id: &str) -> usize {
const SEED: u64 = 0x517c_c1b7_2722_0a95;
let mut h = DefaultHasher::new();
SEED.hash(&mut h);
id.hash(&mut h);
(h.finish() as usize) % self.num_shards
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shard_manager_default_strategy_is_entity_hash() {
let mgr = ModelShardManager::new(4, ShardingStrategy::default());
assert_eq!(mgr.strategy(), ShardingStrategy::EntityHash);
assert_eq!(mgr.num_shards(), 4);
}
#[test]
fn shard_for_deterministic_and_bounded() {
let mgr = ModelShardManager::new(8, ShardingStrategy::EntityHash);
let id = "http://example.org/Alice";
let s1 = mgr.shard_for(id);
let s2 = mgr.shard_for(id);
assert_eq!(s1, s2, "shard_for must be deterministic");
assert!(s1 < 8, "shard index must be bounded by num_shards");
}
#[test]
fn shard_for_zero_shards_coerced_to_one() {
let mgr = ModelShardManager::new(0, ShardingStrategy::EntityHash);
assert_eq!(mgr.num_shards(), 1);
assert_eq!(mgr.shard_for("anything"), 0);
}
#[test]
fn partition_roundrobin_buckets_evenly() {
let mut mgr = ModelShardManager::new(4, ShardingStrategy::RoundRobin);
let ids: Vec<String> = (0..16).map(|i| format!("e{i}")).collect();
let a = mgr.partition(ids);
assert_eq!(a.num_shards, 4);
assert_eq!(a.total(), 16);
for b in &a.buckets {
assert_eq!(b.len(), 4);
}
}
#[test]
fn partition_hash_total_equals_input_size() {
let mut mgr = ModelShardManager::new(4, ShardingStrategy::EntityHash);
let ids: Vec<String> = (0..100).map(|i| format!("entity_{i}")).collect();
let a = mgr.partition(ids);
assert_eq!(a.total(), 100);
}
#[test]
fn partition_hash_distributes_across_shards() {
let mut mgr = ModelShardManager::new(4, ShardingStrategy::EntityHash);
let ids: Vec<String> = (0..200).map(|i| format!("entity_{i}")).collect();
let a = mgr.partition(ids);
for (i, b) in a.buckets.iter().enumerate() {
assert!(
!b.is_empty(),
"shard {i} got no entities — distribution failed"
);
}
}
#[test]
fn shard_assignment_shard_of_lookup() {
let mut mgr = ModelShardManager::new(2, ShardingStrategy::RoundRobin);
let a = mgr.partition(vec!["a", "b", "c", "d"]);
assert_eq!(a.shard_of("a"), Some(0));
assert_eq!(a.shard_of("b"), Some(1));
assert_eq!(a.shard_of("missing"), None);
}
#[test]
fn reshard_preserves_total_count_after_resize() {
let mut mgr_small = ModelShardManager::new(2, ShardingStrategy::EntityHash);
let ids: Vec<String> = (0..32).map(|i| format!("e{i}")).collect();
let small = mgr_small.partition(ids.clone());
assert_eq!(small.total(), 32);
let mut mgr_big = ModelShardManager::new(8, ShardingStrategy::EntityHash);
let big = mgr_big.reshard(&small);
assert_eq!(big.num_shards, 8);
assert_eq!(big.total(), 32);
}
#[test]
fn reshard_routes_each_id_to_its_new_shard() {
let ids: Vec<String> = (0..50).map(|i| format!("entity:{i}")).collect();
let mut mgr2 = ModelShardManager::new(2, ShardingStrategy::EntityHash);
let prior = mgr2.partition(ids);
let mut mgr5 = ModelShardManager::new(5, ShardingStrategy::EntityHash);
let after = mgr5.reshard(&prior);
for (i, bucket) in after.buckets.iter().enumerate() {
for id in bucket {
assert_eq!(mgr5.shard_for(id), i);
}
}
}
#[test]
fn partition_stable_across_managers_with_same_shard_count() {
let ids: Vec<String> = (0..30).map(|i| format!("e_{i}")).collect();
let mut a = ModelShardManager::new(4, ShardingStrategy::EntityHash);
let mut b = ModelShardManager::new(4, ShardingStrategy::EntityHash);
let pa = a.partition(ids.clone());
let pb = b.partition(ids);
assert_eq!(
pa, pb,
"two managers with same config must produce same shards"
);
}
#[test]
fn partition_unstable_when_shard_count_changes() {
let ids: Vec<String> = (0..50).map(|i| format!("k_{i}")).collect();
let mut mgr2 = ModelShardManager::new(2, ShardingStrategy::EntityHash);
let mut mgr4 = ModelShardManager::new(4, ShardingStrategy::EntityHash);
let p2 = mgr2.partition(ids.clone());
let p4 = mgr4.partition(ids);
assert_ne!(p2.num_shards, p4.num_shards);
}
#[test]
fn shard_assignment_serialization() {
let mut mgr = ModelShardManager::new(3, ShardingStrategy::EntityHash);
let a = mgr.partition(vec!["x", "y", "z"]);
let json = serde_json::to_string(&a).expect("serialize");
let a2: ShardAssignment = serde_json::from_str(&json).expect("deserialize");
assert_eq!(a, a2);
}
}