Skip to main content

phago_distributed/coordinator/
shard_registry.rs

1//! Shard registry for tracking active shards.
2//!
3//! This module maintains a registry of all active shards in the distributed
4//! cluster, including their status, heartbeat information, and metrics.
5
6use crate::types::{ShardId, ShardInfo, ShardStatus};
7use std::collections::HashMap;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10/// Extended shard info with status tracking.
11#[derive(Debug, Clone)]
12pub struct RegisteredShard {
13    /// Basic shard information.
14    pub info: ShardInfo,
15    /// Current status of the shard.
16    pub status: ShardStatus,
17    /// Memory usage in bytes.
18    pub memory_bytes: u64,
19}
20
21impl RegisteredShard {
22    /// Create a new registered shard from info.
23    pub fn new(info: ShardInfo) -> Self {
24        Self {
25            info,
26            status: ShardStatus::Online,
27            memory_bytes: 0,
28        }
29    }
30}
31
32/// Registry of active shards in the cluster.
33///
34/// The registry tracks all shards, their current status, and health metrics.
35/// It is used by the coordinator to manage the cluster topology and route
36/// requests to healthy shards.
37pub struct ShardRegistry {
38    /// Map of shard IDs to their information.
39    shards: HashMap<ShardId, RegisteredShard>,
40    /// Counter for assigning new shard IDs.
41    next_id: u32,
42    /// Timeout for considering a shard dead (milliseconds).
43    heartbeat_timeout_ms: u64,
44}
45
46impl ShardRegistry {
47    /// Create a new empty shard registry.
48    pub fn new() -> Self {
49        Self {
50            shards: HashMap::new(),
51            next_id: 0,
52            heartbeat_timeout_ms: 30_000, // 30 seconds default
53        }
54    }
55
56    /// Create a registry with custom heartbeat timeout.
57    pub fn with_heartbeat_timeout(timeout_ms: u64) -> Self {
58        Self {
59            shards: HashMap::new(),
60            next_id: 0,
61            heartbeat_timeout_ms: timeout_ms,
62        }
63    }
64
65    /// Register a new shard and return its assigned ID.
66    ///
67    /// The shard will be assigned a unique ID and added to the registry.
68    /// Its initial status will be set to Online.
69    pub fn register(&mut self, info: ShardInfo) -> ShardId {
70        let id = ShardId::new(self.next_id);
71        self.next_id += 1;
72
73        let mut registered = RegisteredShard::new(info);
74        registered.info.id = id;
75        registered.info.last_heartbeat = Self::current_timestamp();
76        registered.status = ShardStatus::Online;
77
78        self.shards.insert(id, registered);
79        id
80    }
81
82    /// Register a shard with a specific ID.
83    ///
84    /// This is useful when restoring state or in deterministic testing.
85    /// The next_id counter will be updated if necessary.
86    pub fn register_with_id(&mut self, info: ShardInfo, id: ShardId) -> ShardId {
87        let mut registered = RegisteredShard::new(info);
88        registered.info.id = id;
89        registered.info.last_heartbeat = Self::current_timestamp();
90        registered.status = ShardStatus::Online;
91
92        self.shards.insert(id, registered);
93
94        // Update next_id to avoid conflicts
95        if id.0 >= self.next_id {
96            self.next_id = id.0 + 1;
97        }
98
99        id
100    }
101
102    /// Get shard info by ID.
103    pub fn get(&self, id: &ShardId) -> Option<&ShardInfo> {
104        self.shards.get(id).map(|r| &r.info)
105    }
106
107    /// Get registered shard by ID (includes status).
108    pub fn get_registered(&self, id: &ShardId) -> Option<&RegisteredShard> {
109        self.shards.get(id)
110    }
111
112    /// Get mutable registered shard by ID.
113    pub fn get_registered_mut(&mut self, id: &ShardId) -> Option<&mut RegisteredShard> {
114        self.shards.get_mut(id)
115    }
116
117    /// Remove a shard from the registry.
118    pub fn remove(&mut self, id: &ShardId) -> Option<ShardInfo> {
119        self.shards.remove(id).map(|r| r.info)
120    }
121
122    /// Get all shard infos.
123    pub fn all(&self) -> Vec<ShardInfo> {
124        self.shards.values().map(|r| r.info.clone()).collect()
125    }
126
127    /// Get all shard IDs.
128    pub fn all_ids(&self) -> Vec<ShardId> {
129        self.shards.keys().copied().collect()
130    }
131
132    /// Get the number of registered shards.
133    pub fn count(&self) -> usize {
134        self.shards.len()
135    }
136
137    /// Check if a shard exists.
138    pub fn contains(&self, id: &ShardId) -> bool {
139        self.shards.contains_key(id)
140    }
141
142    /// Update heartbeat timestamp for a shard.
143    pub fn heartbeat(&mut self, id: &ShardId) {
144        if let Some(registered) = self.shards.get_mut(id) {
145            registered.info.last_heartbeat = Self::current_timestamp();
146            // Restore online status if it was marked as recovering
147            if registered.status == ShardStatus::Recovering {
148                registered.status = ShardStatus::Online;
149            }
150        }
151    }
152
153    /// Update heartbeat with explicit timestamp (for testing or remote sync).
154    pub fn heartbeat_with_timestamp(&mut self, id: &ShardId, timestamp: u64) {
155        if let Some(registered) = self.shards.get_mut(id) {
156            registered.info.last_heartbeat = timestamp;
157        }
158    }
159
160    /// Update shard status.
161    pub fn set_status(&mut self, id: &ShardId, status: ShardStatus) {
162        if let Some(registered) = self.shards.get_mut(id) {
163            registered.status = status;
164        }
165    }
166
167    /// Get the status of a shard.
168    pub fn get_status(&self, id: &ShardId) -> Option<ShardStatus> {
169        self.shards.get(id).map(|r| r.status)
170    }
171
172    /// Update shard metrics.
173    pub fn update_metrics(&mut self, id: &ShardId, document_count: usize, memory_bytes: u64) {
174        if let Some(registered) = self.shards.get_mut(id) {
175            registered.info.document_count = document_count;
176            registered.memory_bytes = memory_bytes;
177        }
178    }
179
180    /// Get all online shards.
181    pub fn online_shards(&self) -> Vec<ShardInfo> {
182        self.shards
183            .values()
184            .filter(|r| r.status == ShardStatus::Online)
185            .map(|r| r.info.clone())
186            .collect()
187    }
188
189    /// Get all shards with a specific status.
190    pub fn shards_with_status(&self, status: ShardStatus) -> Vec<ShardInfo> {
191        self.shards
192            .values()
193            .filter(|r| r.status == status)
194            .map(|r| r.info.clone())
195            .collect()
196    }
197
198    /// Check for and mark dead shards based on heartbeat timeout.
199    ///
200    /// Returns the IDs of shards that were marked as offline.
201    pub fn check_dead_shards(&mut self) -> Vec<ShardId> {
202        let now = Self::current_timestamp();
203        let timeout = self.heartbeat_timeout_ms;
204        let mut dead_shards = Vec::new();
205
206        for (id, registered) in self.shards.iter_mut() {
207            if registered.status == ShardStatus::Online
208                && now - registered.info.last_heartbeat > timeout
209            {
210                registered.status = ShardStatus::Offline;
211                dead_shards.push(*id);
212            }
213        }
214
215        dead_shards
216    }
217
218    /// Get total document count across all shards.
219    pub fn total_documents(&self) -> u64 {
220        self.shards
221            .values()
222            .map(|r| r.info.document_count as u64)
223            .sum()
224    }
225
226    /// Get total memory usage across all shards.
227    pub fn total_memory(&self) -> u64 {
228        self.shards.values().map(|r| r.memory_bytes).sum()
229    }
230
231    /// Get the shard with the least documents (for load balancing).
232    pub fn least_loaded_shard(&self) -> Option<ShardId> {
233        self.shards
234            .values()
235            .filter(|r| r.status == ShardStatus::Online)
236            .min_by_key(|r| r.info.document_count)
237            .map(|r| r.info.id)
238    }
239
240    /// Get current Unix timestamp in milliseconds.
241    fn current_timestamp() -> u64 {
242        SystemTime::now()
243            .duration_since(UNIX_EPOCH)
244            .unwrap_or_default()
245            .as_millis() as u64
246    }
247}
248
249impl Default for ShardRegistry {
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    fn test_shard_info() -> ShardInfo {
260        ShardInfo::new(ShardId::new(0), "127.0.0.1:8080".to_string())
261    }
262
263    #[test]
264    fn test_registry_creation() {
265        let registry = ShardRegistry::new();
266        assert_eq!(registry.count(), 0);
267    }
268
269    #[test]
270    fn test_register_shard() {
271        let mut registry = ShardRegistry::new();
272        let info = test_shard_info();
273
274        let id = registry.register(info);
275        assert_eq!(id, ShardId::new(0));
276        assert_eq!(registry.count(), 1);
277
278        let info2 = test_shard_info();
279        let id2 = registry.register(info2);
280        assert_eq!(id2, ShardId::new(1));
281        assert_eq!(registry.count(), 2);
282    }
283
284    #[test]
285    fn test_get_shard() {
286        let mut registry = ShardRegistry::new();
287        let info = test_shard_info();
288        let id = registry.register(info);
289
290        let retrieved = registry.get(&id).unwrap();
291        assert_eq!(retrieved.id, id);
292        assert_eq!(registry.get_status(&id), Some(ShardStatus::Online));
293    }
294
295    #[test]
296    fn test_remove_shard() {
297        let mut registry = ShardRegistry::new();
298        let info = test_shard_info();
299        let id = registry.register(info);
300
301        assert!(registry.contains(&id));
302        let removed = registry.remove(&id);
303        assert!(removed.is_some());
304        assert!(!registry.contains(&id));
305    }
306
307    #[test]
308    fn test_set_status() {
309        let mut registry = ShardRegistry::new();
310        let info = test_shard_info();
311        let id = registry.register(info);
312
313        assert_eq!(registry.get_status(&id), Some(ShardStatus::Online));
314
315        registry.set_status(&id, ShardStatus::Draining);
316        assert_eq!(registry.get_status(&id), Some(ShardStatus::Draining));
317    }
318
319    #[test]
320    fn test_update_metrics() {
321        let mut registry = ShardRegistry::new();
322        let info = test_shard_info();
323        let id = registry.register(info);
324
325        registry.update_metrics(&id, 100, 1024 * 1024);
326
327        let shard = registry.get(&id).unwrap();
328        assert_eq!(shard.document_count, 100);
329        let registered = registry.get_registered(&id).unwrap();
330        assert_eq!(registered.memory_bytes, 1024 * 1024);
331    }
332
333    #[test]
334    fn test_online_shards() {
335        let mut registry = ShardRegistry::new();
336
337        let id1 = registry.register(test_shard_info());
338        let id2 = registry.register(test_shard_info());
339        let _id3 = registry.register(test_shard_info());
340
341        registry.set_status(&id2, ShardStatus::Offline);
342
343        let online = registry.online_shards();
344        assert_eq!(online.len(), 2);
345        assert!(online.iter().all(|s| s.id != id2));
346    }
347
348    #[test]
349    fn test_total_documents() {
350        let mut registry = ShardRegistry::new();
351
352        let id1 = registry.register(test_shard_info());
353        let id2 = registry.register(test_shard_info());
354
355        registry.update_metrics(&id1, 100, 1000);
356        registry.update_metrics(&id2, 200, 2000);
357
358        assert_eq!(registry.total_documents(), 300);
359        assert_eq!(registry.total_memory(), 3000);
360    }
361
362    #[test]
363    fn test_least_loaded_shard() {
364        let mut registry = ShardRegistry::new();
365
366        let id1 = registry.register(test_shard_info());
367        let id2 = registry.register(test_shard_info());
368        let id3 = registry.register(test_shard_info());
369
370        registry.update_metrics(&id1, 100, 1000);
371        registry.update_metrics(&id2, 50, 500);
372        registry.update_metrics(&id3, 200, 2000);
373
374        assert_eq!(registry.least_loaded_shard(), Some(id2));
375    }
376
377    #[test]
378    fn test_check_dead_shards() {
379        let mut registry = ShardRegistry::with_heartbeat_timeout(100);
380        let info = test_shard_info();
381        let id = registry.register(info);
382
383        // Set heartbeat to a very old timestamp
384        registry.heartbeat_with_timestamp(&id, 0);
385
386        let dead = registry.check_dead_shards();
387        assert_eq!(dead.len(), 1);
388        assert_eq!(dead[0], id);
389        assert_eq!(registry.get_status(&id), Some(ShardStatus::Offline));
390    }
391
392    #[test]
393    fn test_register_with_specific_id() {
394        let mut registry = ShardRegistry::new();
395        let info = test_shard_info();
396
397        let id = registry.register_with_id(info, ShardId::new(42));
398        assert_eq!(id, ShardId::new(42));
399        assert!(registry.contains(&id));
400
401        // Next auto-assigned ID should be 43
402        let info2 = test_shard_info();
403        let id2 = registry.register(info2);
404        assert_eq!(id2, ShardId::new(43));
405    }
406}