oxirs_embed/distributed_training/
shard_manager.rs1use serde::{Deserialize, Serialize};
27use std::collections::hash_map::DefaultHasher;
28use std::collections::HashMap;
29use std::hash::{Hash, Hasher};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
33pub enum ShardingStrategy {
34 #[default]
39 EntityHash,
40 RoundRobin,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
53pub struct ShardAssignment {
54 pub num_shards: usize,
56 pub buckets: Vec<Vec<String>>,
58}
59
60impl ShardAssignment {
61 pub fn total(&self) -> usize {
63 self.buckets.iter().map(|b| b.len()).sum()
64 }
65
66 pub fn shard_of(&self, id: &str) -> Option<usize> {
68 self.buckets.iter().enumerate().find_map(|(i, b)| {
69 if b.iter().any(|s| s == id) {
70 Some(i)
71 } else {
72 None
73 }
74 })
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct ModelShardManager {
81 num_shards: usize,
82 strategy: ShardingStrategy,
83 rr_index: HashMap<String, usize>,
86}
87
88impl ModelShardManager {
89 pub fn new(num_shards: usize, strategy: ShardingStrategy) -> Self {
96 Self {
97 num_shards: num_shards.max(1),
98 strategy,
99 rr_index: HashMap::new(),
100 }
101 }
102
103 pub fn num_shards(&self) -> usize {
105 self.num_shards
106 }
107
108 pub fn strategy(&self) -> ShardingStrategy {
110 self.strategy
111 }
112
113 pub fn shard_for(&self, id: &str) -> usize {
119 match self.strategy {
120 ShardingStrategy::EntityHash => self.hash_shard(id),
121 ShardingStrategy::RoundRobin => {
122 self.rr_index
126 .get(id)
127 .copied()
128 .unwrap_or_else(|| self.hash_shard(id))
129 }
130 }
131 }
132
133 pub fn partition<I, S>(&mut self, ids: I) -> ShardAssignment
139 where
140 I: IntoIterator<Item = S>,
141 S: Into<String>,
142 {
143 let mut buckets: Vec<Vec<String>> = (0..self.num_shards).map(|_| Vec::new()).collect();
144
145 match self.strategy {
146 ShardingStrategy::EntityHash => {
147 for raw in ids {
148 let id: String = raw.into();
149 let shard = self.hash_shard(&id);
150 buckets[shard].push(id);
151 }
152 }
153 ShardingStrategy::RoundRobin => {
154 let mut next: usize = 0;
155 for raw in ids {
156 let id: String = raw.into();
157 let shard = *self.rr_index.entry(id.clone()).or_insert_with(|| {
158 let s = next % self.num_shards;
159 next += 1;
160 s
161 });
162 buckets[shard].push(id);
163 }
164 }
165 }
166
167 ShardAssignment {
168 num_shards: self.num_shards,
169 buckets,
170 }
171 }
172
173 pub fn reshard(&mut self, prior: &ShardAssignment) -> ShardAssignment {
178 let flat: Vec<String> = prior.buckets.iter().flatten().cloned().collect();
180 self.partition(flat)
181 }
182
183 fn hash_shard(&self, id: &str) -> usize {
184 const SEED: u64 = 0x517c_c1b7_2722_0a95;
188 let mut h = DefaultHasher::new();
189 SEED.hash(&mut h);
190 id.hash(&mut h);
191 (h.finish() as usize) % self.num_shards
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn shard_manager_default_strategy_is_entity_hash() {
201 let mgr = ModelShardManager::new(4, ShardingStrategy::default());
202 assert_eq!(mgr.strategy(), ShardingStrategy::EntityHash);
203 assert_eq!(mgr.num_shards(), 4);
204 }
205
206 #[test]
207 fn shard_for_deterministic_and_bounded() {
208 let mgr = ModelShardManager::new(8, ShardingStrategy::EntityHash);
209 let id = "http://example.org/Alice";
210 let s1 = mgr.shard_for(id);
211 let s2 = mgr.shard_for(id);
212 assert_eq!(s1, s2, "shard_for must be deterministic");
213 assert!(s1 < 8, "shard index must be bounded by num_shards");
214 }
215
216 #[test]
217 fn shard_for_zero_shards_coerced_to_one() {
218 let mgr = ModelShardManager::new(0, ShardingStrategy::EntityHash);
219 assert_eq!(mgr.num_shards(), 1);
220 assert_eq!(mgr.shard_for("anything"), 0);
221 }
222
223 #[test]
224 fn partition_roundrobin_buckets_evenly() {
225 let mut mgr = ModelShardManager::new(4, ShardingStrategy::RoundRobin);
226 let ids: Vec<String> = (0..16).map(|i| format!("e{i}")).collect();
227 let a = mgr.partition(ids);
228 assert_eq!(a.num_shards, 4);
229 assert_eq!(a.total(), 16);
230 for b in &a.buckets {
231 assert_eq!(b.len(), 4);
232 }
233 }
234
235 #[test]
236 fn partition_hash_total_equals_input_size() {
237 let mut mgr = ModelShardManager::new(4, ShardingStrategy::EntityHash);
238 let ids: Vec<String> = (0..100).map(|i| format!("entity_{i}")).collect();
239 let a = mgr.partition(ids);
240 assert_eq!(a.total(), 100);
241 }
242
243 #[test]
244 fn partition_hash_distributes_across_shards() {
245 let mut mgr = ModelShardManager::new(4, ShardingStrategy::EntityHash);
247 let ids: Vec<String> = (0..200).map(|i| format!("entity_{i}")).collect();
248 let a = mgr.partition(ids);
249 for (i, b) in a.buckets.iter().enumerate() {
250 assert!(
251 !b.is_empty(),
252 "shard {i} got no entities — distribution failed"
253 );
254 }
255 }
256
257 #[test]
258 fn shard_assignment_shard_of_lookup() {
259 let mut mgr = ModelShardManager::new(2, ShardingStrategy::RoundRobin);
260 let a = mgr.partition(vec!["a", "b", "c", "d"]);
261 assert_eq!(a.shard_of("a"), Some(0));
262 assert_eq!(a.shard_of("b"), Some(1));
263 assert_eq!(a.shard_of("missing"), None);
264 }
265
266 #[test]
267 fn reshard_preserves_total_count_after_resize() {
268 let mut mgr_small = ModelShardManager::new(2, ShardingStrategy::EntityHash);
269 let ids: Vec<String> = (0..32).map(|i| format!("e{i}")).collect();
270 let small = mgr_small.partition(ids.clone());
271 assert_eq!(small.total(), 32);
272
273 let mut mgr_big = ModelShardManager::new(8, ShardingStrategy::EntityHash);
274 let big = mgr_big.reshard(&small);
275 assert_eq!(big.num_shards, 8);
276 assert_eq!(big.total(), 32);
277 }
278
279 #[test]
280 fn reshard_routes_each_id_to_its_new_shard() {
281 let ids: Vec<String> = (0..50).map(|i| format!("entity:{i}")).collect();
282 let mut mgr2 = ModelShardManager::new(2, ShardingStrategy::EntityHash);
283 let prior = mgr2.partition(ids);
284
285 let mut mgr5 = ModelShardManager::new(5, ShardingStrategy::EntityHash);
286 let after = mgr5.reshard(&prior);
287
288 for (i, bucket) in after.buckets.iter().enumerate() {
290 for id in bucket {
291 assert_eq!(mgr5.shard_for(id), i);
292 }
293 }
294 }
295
296 #[test]
297 fn partition_stable_across_managers_with_same_shard_count() {
298 let ids: Vec<String> = (0..30).map(|i| format!("e_{i}")).collect();
299 let mut a = ModelShardManager::new(4, ShardingStrategy::EntityHash);
300 let mut b = ModelShardManager::new(4, ShardingStrategy::EntityHash);
301 let pa = a.partition(ids.clone());
302 let pb = b.partition(ids);
303 assert_eq!(
304 pa, pb,
305 "two managers with same config must produce same shards"
306 );
307 }
308
309 #[test]
310 fn partition_unstable_when_shard_count_changes() {
311 let ids: Vec<String> = (0..50).map(|i| format!("k_{i}")).collect();
314 let mut mgr2 = ModelShardManager::new(2, ShardingStrategy::EntityHash);
315 let mut mgr4 = ModelShardManager::new(4, ShardingStrategy::EntityHash);
316 let p2 = mgr2.partition(ids.clone());
317 let p4 = mgr4.partition(ids);
318 assert_ne!(p2.num_shards, p4.num_shards);
319 }
320
321 #[test]
322 fn shard_assignment_serialization() {
323 let mut mgr = ModelShardManager::new(3, ShardingStrategy::EntityHash);
324 let a = mgr.partition(vec!["x", "y", "z"]);
325 let json = serde_json::to_string(&a).expect("serialize");
326 let a2: ShardAssignment = serde_json::from_str(&json).expect("deserialize");
327 assert_eq!(a, a2);
328 }
329}