Skip to main content

oxirs_embed/distributed_training/
shard_manager.rs

1//! Model shard manager: partitions an embedding table by entity-ID hash.
2//!
3//! [`ModelShardManager`] hashes a stable string ID (typically an entity URI or
4//! relation IRI) to one of `num_shards` parameter-server shards.  The mapping is
5//! deterministic and stateless: callers can reliably route reads and gradient
6//! pushes for a given entity to the same shard across processes, without any
7//! coordination, provided every party uses the same `num_shards`.
8//!
9//! The hashing is intentionally based on `std::collections::hash_map::DefaultHasher`
10//! seeded with a fixed key — this gives bit-identical results between Rust
11//! processes built from the same compiler version, which is good enough for
12//! the in-process prototype we ship here.  For a real production system you
13//! would want a cryptographic hash (e.g. SHA-256) or a known fingerprint (e.g.
14//! xxHash) so that the partitioning is stable across language ecosystems too.
15//!
16//! ```
17//! use oxirs_embed::distributed_training::{ModelShardManager, ShardingStrategy};
18//!
19//! let mgr = ModelShardManager::new(4, ShardingStrategy::EntityHash);
20//! let s_alice = mgr.shard_for("http://example.org/Alice");
21//! let s_alice2 = mgr.shard_for("http://example.org/Alice");
22//! assert_eq!(s_alice, s_alice2);                  // deterministic
23//! assert!(s_alice < 4);                           // bounded by num_shards
24//! ```
25
26use serde::{Deserialize, Serialize};
27use std::collections::hash_map::DefaultHasher;
28use std::collections::HashMap;
29use std::hash::{Hash, Hasher};
30
31/// How a [`ModelShardManager`] decides which shard owns a given entity.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
33pub enum ShardingStrategy {
34    /// Hash the entity ID with a fixed seed and take `hash mod num_shards`.
35    ///
36    /// The default — gives uniformly distributed shards under the assumption
37    /// that the underlying hash is well-mixed.
38    #[default]
39    EntityHash,
40    /// Round-robin assignment in **lexicographic** order of insertion.
41    ///
42    /// Used by tests that need a known mapping; not recommended in production
43    /// because shard load depends on insertion order.
44    RoundRobin,
45}
46
47/// Result of running [`ModelShardManager::partition`] over a known set of IDs.
48///
49/// A [`ShardAssignment`] is a vector of `num_shards` buckets, where each bucket
50/// holds the IDs that map to that shard.  Bucket ordering follows the input
51/// order so the assignment is stable for a given input + `num_shards` pair.
52#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
53pub struct ShardAssignment {
54    /// Number of shards (length of `buckets`).
55    pub num_shards: usize,
56    /// `buckets[i]` is the list of IDs that map to shard `i`.
57    pub buckets: Vec<Vec<String>>,
58}
59
60impl ShardAssignment {
61    /// Total number of IDs across all shards.
62    pub fn total(&self) -> usize {
63        self.buckets.iter().map(|b| b.len()).sum()
64    }
65
66    /// Look up which shard an ID was assigned to, or `None` if it isn't present.
67    pub fn shard_of(&self, id: &str) -> Option<usize> {
68        self.buckets.iter().enumerate().find_map(|(i, b)| {
69            if b.iter().any(|s| s == id) {
70                Some(i)
71            } else {
72                None
73            }
74        })
75    }
76}
77
78/// Partitions embedding tables (entity *and* relation) by stable hash of the ID.
79#[derive(Debug, Clone)]
80pub struct ModelShardManager {
81    num_shards: usize,
82    strategy: ShardingStrategy,
83    /// Stable round-robin index, used only when [`ShardingStrategy::RoundRobin`]
84    /// is selected.  Filled lazily via [`Self::partition`].
85    rr_index: HashMap<String, usize>,
86}
87
88impl ModelShardManager {
89    /// Create a new shard manager with `num_shards` shards.
90    ///
91    /// # Panics
92    ///
93    /// Does not panic.  If `num_shards == 0` it is silently coerced to `1`
94    /// (single-shard / no sharding) so the manager is always usable.
95    pub fn new(num_shards: usize, strategy: ShardingStrategy) -> Self {
96        Self {
97            num_shards: num_shards.max(1),
98            strategy,
99            rr_index: HashMap::new(),
100        }
101    }
102
103    /// Number of shards this manager partitions across.
104    pub fn num_shards(&self) -> usize {
105        self.num_shards
106    }
107
108    /// Hashing strategy.
109    pub fn strategy(&self) -> ShardingStrategy {
110        self.strategy
111    }
112
113    /// Compute the shard index for `id`.
114    ///
115    /// Always returns a value in `0..self.num_shards()`.  Calling this method
116    /// repeatedly with the same `id` and same `num_shards` always yields the
117    /// same answer — it is the basis of the parameter-server routing scheme.
118    pub fn shard_for(&self, id: &str) -> usize {
119        match self.strategy {
120            ShardingStrategy::EntityHash => self.hash_shard(id),
121            ShardingStrategy::RoundRobin => {
122                // Read-only round-robin: fall back to hash if we haven't seen
123                // this ID yet via `partition()`.  Avoids surprising callers
124                // that mix `partition` / `shard_for`.
125                self.rr_index
126                    .get(id)
127                    .copied()
128                    .unwrap_or_else(|| self.hash_shard(id))
129            }
130        }
131    }
132
133    /// Partition a list of IDs into shards in input order.
134    ///
135    /// In `EntityHash` mode this is equivalent to `shard_for` on each ID.  In
136    /// `RoundRobin` mode it also populates the manager's internal round-robin
137    /// table so that subsequent `shard_for` calls give consistent answers.
138    pub fn partition<I, S>(&mut self, ids: I) -> ShardAssignment
139    where
140        I: IntoIterator<Item = S>,
141        S: Into<String>,
142    {
143        let mut buckets: Vec<Vec<String>> = (0..self.num_shards).map(|_| Vec::new()).collect();
144
145        match self.strategy {
146            ShardingStrategy::EntityHash => {
147                for raw in ids {
148                    let id: String = raw.into();
149                    let shard = self.hash_shard(&id);
150                    buckets[shard].push(id);
151                }
152            }
153            ShardingStrategy::RoundRobin => {
154                let mut next: usize = 0;
155                for raw in ids {
156                    let id: String = raw.into();
157                    let shard = *self.rr_index.entry(id.clone()).or_insert_with(|| {
158                        let s = next % self.num_shards;
159                        next += 1;
160                        s
161                    });
162                    buckets[shard].push(id);
163                }
164            }
165        }
166
167        ShardAssignment {
168            num_shards: self.num_shards,
169            buckets,
170        }
171    }
172
173    /// Re-shard an existing assignment after `num_shards` changes.
174    ///
175    /// Returns a fresh [`ShardAssignment`] computed by routing every existing
176    /// ID through the new manager.  Used by elastic scaling tests.
177    pub fn reshard(&mut self, prior: &ShardAssignment) -> ShardAssignment {
178        // Flatten in deterministic order, then partition again.
179        let flat: Vec<String> = prior.buckets.iter().flatten().cloned().collect();
180        self.partition(flat)
181    }
182
183    fn hash_shard(&self, id: &str) -> usize {
184        // Mix in a constant seed so two managers with the same num_shards
185        // hash identical IDs to identical buckets, even across processes
186        // (provided they are built with the same Rust toolchain version).
187        const SEED: u64 = 0x517c_c1b7_2722_0a95;
188        let mut h = DefaultHasher::new();
189        SEED.hash(&mut h);
190        id.hash(&mut h);
191        (h.finish() as usize) % self.num_shards
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn shard_manager_default_strategy_is_entity_hash() {
201        let mgr = ModelShardManager::new(4, ShardingStrategy::default());
202        assert_eq!(mgr.strategy(), ShardingStrategy::EntityHash);
203        assert_eq!(mgr.num_shards(), 4);
204    }
205
206    #[test]
207    fn shard_for_deterministic_and_bounded() {
208        let mgr = ModelShardManager::new(8, ShardingStrategy::EntityHash);
209        let id = "http://example.org/Alice";
210        let s1 = mgr.shard_for(id);
211        let s2 = mgr.shard_for(id);
212        assert_eq!(s1, s2, "shard_for must be deterministic");
213        assert!(s1 < 8, "shard index must be bounded by num_shards");
214    }
215
216    #[test]
217    fn shard_for_zero_shards_coerced_to_one() {
218        let mgr = ModelShardManager::new(0, ShardingStrategy::EntityHash);
219        assert_eq!(mgr.num_shards(), 1);
220        assert_eq!(mgr.shard_for("anything"), 0);
221    }
222
223    #[test]
224    fn partition_roundrobin_buckets_evenly() {
225        let mut mgr = ModelShardManager::new(4, ShardingStrategy::RoundRobin);
226        let ids: Vec<String> = (0..16).map(|i| format!("e{i}")).collect();
227        let a = mgr.partition(ids);
228        assert_eq!(a.num_shards, 4);
229        assert_eq!(a.total(), 16);
230        for b in &a.buckets {
231            assert_eq!(b.len(), 4);
232        }
233    }
234
235    #[test]
236    fn partition_hash_total_equals_input_size() {
237        let mut mgr = ModelShardManager::new(4, ShardingStrategy::EntityHash);
238        let ids: Vec<String> = (0..100).map(|i| format!("entity_{i}")).collect();
239        let a = mgr.partition(ids);
240        assert_eq!(a.total(), 100);
241    }
242
243    #[test]
244    fn partition_hash_distributes_across_shards() {
245        // With 100 IDs across 4 shards, every shard should see at least one.
246        let mut mgr = ModelShardManager::new(4, ShardingStrategy::EntityHash);
247        let ids: Vec<String> = (0..200).map(|i| format!("entity_{i}")).collect();
248        let a = mgr.partition(ids);
249        for (i, b) in a.buckets.iter().enumerate() {
250            assert!(
251                !b.is_empty(),
252                "shard {i} got no entities — distribution failed"
253            );
254        }
255    }
256
257    #[test]
258    fn shard_assignment_shard_of_lookup() {
259        let mut mgr = ModelShardManager::new(2, ShardingStrategy::RoundRobin);
260        let a = mgr.partition(vec!["a", "b", "c", "d"]);
261        assert_eq!(a.shard_of("a"), Some(0));
262        assert_eq!(a.shard_of("b"), Some(1));
263        assert_eq!(a.shard_of("missing"), None);
264    }
265
266    #[test]
267    fn reshard_preserves_total_count_after_resize() {
268        let mut mgr_small = ModelShardManager::new(2, ShardingStrategy::EntityHash);
269        let ids: Vec<String> = (0..32).map(|i| format!("e{i}")).collect();
270        let small = mgr_small.partition(ids.clone());
271        assert_eq!(small.total(), 32);
272
273        let mut mgr_big = ModelShardManager::new(8, ShardingStrategy::EntityHash);
274        let big = mgr_big.reshard(&small);
275        assert_eq!(big.num_shards, 8);
276        assert_eq!(big.total(), 32);
277    }
278
279    #[test]
280    fn reshard_routes_each_id_to_its_new_shard() {
281        let ids: Vec<String> = (0..50).map(|i| format!("entity:{i}")).collect();
282        let mut mgr2 = ModelShardManager::new(2, ShardingStrategy::EntityHash);
283        let prior = mgr2.partition(ids);
284
285        let mut mgr5 = ModelShardManager::new(5, ShardingStrategy::EntityHash);
286        let after = mgr5.reshard(&prior);
287
288        // Every ID must end up exactly where mgr5.shard_for(id) says.
289        for (i, bucket) in after.buckets.iter().enumerate() {
290            for id in bucket {
291                assert_eq!(mgr5.shard_for(id), i);
292            }
293        }
294    }
295
296    #[test]
297    fn partition_stable_across_managers_with_same_shard_count() {
298        let ids: Vec<String> = (0..30).map(|i| format!("e_{i}")).collect();
299        let mut a = ModelShardManager::new(4, ShardingStrategy::EntityHash);
300        let mut b = ModelShardManager::new(4, ShardingStrategy::EntityHash);
301        let pa = a.partition(ids.clone());
302        let pb = b.partition(ids);
303        assert_eq!(
304            pa, pb,
305            "two managers with same config must produce same shards"
306        );
307    }
308
309    #[test]
310    fn partition_unstable_when_shard_count_changes() {
311        // Sanity: changing num_shards should change the partition (otherwise
312        // hash_shard is broken or num_shards is being ignored).
313        let ids: Vec<String> = (0..50).map(|i| format!("k_{i}")).collect();
314        let mut mgr2 = ModelShardManager::new(2, ShardingStrategy::EntityHash);
315        let mut mgr4 = ModelShardManager::new(4, ShardingStrategy::EntityHash);
316        let p2 = mgr2.partition(ids.clone());
317        let p4 = mgr4.partition(ids);
318        assert_ne!(p2.num_shards, p4.num_shards);
319    }
320
321    #[test]
322    fn shard_assignment_serialization() {
323        let mut mgr = ModelShardManager::new(3, ShardingStrategy::EntityHash);
324        let a = mgr.partition(vec!["x", "y", "z"]);
325        let json = serde_json::to_string(&a).expect("serialize");
326        let a2: ShardAssignment = serde_json::from_str(&json).expect("deserialize");
327        assert_eq!(a, a2);
328    }
329}