phago_distributed/coordinator/
shard_registry.rs1use crate::types::{ShardId, ShardInfo, ShardStatus};
7use std::collections::HashMap;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10#[derive(Debug, Clone)]
12pub struct RegisteredShard {
13 pub info: ShardInfo,
15 pub status: ShardStatus,
17 pub memory_bytes: u64,
19}
20
21impl RegisteredShard {
22 pub fn new(info: ShardInfo) -> Self {
24 Self {
25 info,
26 status: ShardStatus::Online,
27 memory_bytes: 0,
28 }
29 }
30}
31
32pub struct ShardRegistry {
38 shards: HashMap<ShardId, RegisteredShard>,
40 next_id: u32,
42 heartbeat_timeout_ms: u64,
44}
45
46impl ShardRegistry {
47 pub fn new() -> Self {
49 Self {
50 shards: HashMap::new(),
51 next_id: 0,
52 heartbeat_timeout_ms: 30_000, }
54 }
55
56 pub fn with_heartbeat_timeout(timeout_ms: u64) -> Self {
58 Self {
59 shards: HashMap::new(),
60 next_id: 0,
61 heartbeat_timeout_ms: timeout_ms,
62 }
63 }
64
65 pub fn register(&mut self, info: ShardInfo) -> ShardId {
70 let id = ShardId::new(self.next_id);
71 self.next_id += 1;
72
73 let mut registered = RegisteredShard::new(info);
74 registered.info.id = id;
75 registered.info.last_heartbeat = Self::current_timestamp();
76 registered.status = ShardStatus::Online;
77
78 self.shards.insert(id, registered);
79 id
80 }
81
82 pub fn register_with_id(&mut self, info: ShardInfo, id: ShardId) -> ShardId {
87 let mut registered = RegisteredShard::new(info);
88 registered.info.id = id;
89 registered.info.last_heartbeat = Self::current_timestamp();
90 registered.status = ShardStatus::Online;
91
92 self.shards.insert(id, registered);
93
94 if id.0 >= self.next_id {
96 self.next_id = id.0 + 1;
97 }
98
99 id
100 }
101
102 pub fn get(&self, id: &ShardId) -> Option<&ShardInfo> {
104 self.shards.get(id).map(|r| &r.info)
105 }
106
107 pub fn get_registered(&self, id: &ShardId) -> Option<&RegisteredShard> {
109 self.shards.get(id)
110 }
111
112 pub fn get_registered_mut(&mut self, id: &ShardId) -> Option<&mut RegisteredShard> {
114 self.shards.get_mut(id)
115 }
116
117 pub fn remove(&mut self, id: &ShardId) -> Option<ShardInfo> {
119 self.shards.remove(id).map(|r| r.info)
120 }
121
122 pub fn all(&self) -> Vec<ShardInfo> {
124 self.shards.values().map(|r| r.info.clone()).collect()
125 }
126
127 pub fn all_ids(&self) -> Vec<ShardId> {
129 self.shards.keys().copied().collect()
130 }
131
132 pub fn count(&self) -> usize {
134 self.shards.len()
135 }
136
137 pub fn contains(&self, id: &ShardId) -> bool {
139 self.shards.contains_key(id)
140 }
141
142 pub fn heartbeat(&mut self, id: &ShardId) {
144 if let Some(registered) = self.shards.get_mut(id) {
145 registered.info.last_heartbeat = Self::current_timestamp();
146 if registered.status == ShardStatus::Recovering {
148 registered.status = ShardStatus::Online;
149 }
150 }
151 }
152
153 pub fn heartbeat_with_timestamp(&mut self, id: &ShardId, timestamp: u64) {
155 if let Some(registered) = self.shards.get_mut(id) {
156 registered.info.last_heartbeat = timestamp;
157 }
158 }
159
160 pub fn set_status(&mut self, id: &ShardId, status: ShardStatus) {
162 if let Some(registered) = self.shards.get_mut(id) {
163 registered.status = status;
164 }
165 }
166
167 pub fn get_status(&self, id: &ShardId) -> Option<ShardStatus> {
169 self.shards.get(id).map(|r| r.status)
170 }
171
172 pub fn update_metrics(&mut self, id: &ShardId, document_count: usize, memory_bytes: u64) {
174 if let Some(registered) = self.shards.get_mut(id) {
175 registered.info.document_count = document_count;
176 registered.memory_bytes = memory_bytes;
177 }
178 }
179
180 pub fn online_shards(&self) -> Vec<ShardInfo> {
182 self.shards
183 .values()
184 .filter(|r| r.status == ShardStatus::Online)
185 .map(|r| r.info.clone())
186 .collect()
187 }
188
189 pub fn shards_with_status(&self, status: ShardStatus) -> Vec<ShardInfo> {
191 self.shards
192 .values()
193 .filter(|r| r.status == status)
194 .map(|r| r.info.clone())
195 .collect()
196 }
197
198 pub fn check_dead_shards(&mut self) -> Vec<ShardId> {
202 let now = Self::current_timestamp();
203 let timeout = self.heartbeat_timeout_ms;
204 let mut dead_shards = Vec::new();
205
206 for (id, registered) in self.shards.iter_mut() {
207 if registered.status == ShardStatus::Online
208 && now - registered.info.last_heartbeat > timeout
209 {
210 registered.status = ShardStatus::Offline;
211 dead_shards.push(*id);
212 }
213 }
214
215 dead_shards
216 }
217
218 pub fn total_documents(&self) -> u64 {
220 self.shards
221 .values()
222 .map(|r| r.info.document_count as u64)
223 .sum()
224 }
225
226 pub fn total_memory(&self) -> u64 {
228 self.shards.values().map(|r| r.memory_bytes).sum()
229 }
230
231 pub fn least_loaded_shard(&self) -> Option<ShardId> {
233 self.shards
234 .values()
235 .filter(|r| r.status == ShardStatus::Online)
236 .min_by_key(|r| r.info.document_count)
237 .map(|r| r.info.id)
238 }
239
240 fn current_timestamp() -> u64 {
242 SystemTime::now()
243 .duration_since(UNIX_EPOCH)
244 .unwrap_or_default()
245 .as_millis() as u64
246 }
247}
248
249impl Default for ShardRegistry {
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 fn test_shard_info() -> ShardInfo {
260 ShardInfo::new(ShardId::new(0), "127.0.0.1:8080".to_string())
261 }
262
263 #[test]
264 fn test_registry_creation() {
265 let registry = ShardRegistry::new();
266 assert_eq!(registry.count(), 0);
267 }
268
269 #[test]
270 fn test_register_shard() {
271 let mut registry = ShardRegistry::new();
272 let info = test_shard_info();
273
274 let id = registry.register(info);
275 assert_eq!(id, ShardId::new(0));
276 assert_eq!(registry.count(), 1);
277
278 let info2 = test_shard_info();
279 let id2 = registry.register(info2);
280 assert_eq!(id2, ShardId::new(1));
281 assert_eq!(registry.count(), 2);
282 }
283
284 #[test]
285 fn test_get_shard() {
286 let mut registry = ShardRegistry::new();
287 let info = test_shard_info();
288 let id = registry.register(info);
289
290 let retrieved = registry.get(&id).unwrap();
291 assert_eq!(retrieved.id, id);
292 assert_eq!(registry.get_status(&id), Some(ShardStatus::Online));
293 }
294
295 #[test]
296 fn test_remove_shard() {
297 let mut registry = ShardRegistry::new();
298 let info = test_shard_info();
299 let id = registry.register(info);
300
301 assert!(registry.contains(&id));
302 let removed = registry.remove(&id);
303 assert!(removed.is_some());
304 assert!(!registry.contains(&id));
305 }
306
307 #[test]
308 fn test_set_status() {
309 let mut registry = ShardRegistry::new();
310 let info = test_shard_info();
311 let id = registry.register(info);
312
313 assert_eq!(registry.get_status(&id), Some(ShardStatus::Online));
314
315 registry.set_status(&id, ShardStatus::Draining);
316 assert_eq!(registry.get_status(&id), Some(ShardStatus::Draining));
317 }
318
319 #[test]
320 fn test_update_metrics() {
321 let mut registry = ShardRegistry::new();
322 let info = test_shard_info();
323 let id = registry.register(info);
324
325 registry.update_metrics(&id, 100, 1024 * 1024);
326
327 let shard = registry.get(&id).unwrap();
328 assert_eq!(shard.document_count, 100);
329 let registered = registry.get_registered(&id).unwrap();
330 assert_eq!(registered.memory_bytes, 1024 * 1024);
331 }
332
333 #[test]
334 fn test_online_shards() {
335 let mut registry = ShardRegistry::new();
336
337 let id1 = registry.register(test_shard_info());
338 let id2 = registry.register(test_shard_info());
339 let _id3 = registry.register(test_shard_info());
340
341 registry.set_status(&id2, ShardStatus::Offline);
342
343 let online = registry.online_shards();
344 assert_eq!(online.len(), 2);
345 assert!(online.iter().all(|s| s.id != id2));
346 }
347
348 #[test]
349 fn test_total_documents() {
350 let mut registry = ShardRegistry::new();
351
352 let id1 = registry.register(test_shard_info());
353 let id2 = registry.register(test_shard_info());
354
355 registry.update_metrics(&id1, 100, 1000);
356 registry.update_metrics(&id2, 200, 2000);
357
358 assert_eq!(registry.total_documents(), 300);
359 assert_eq!(registry.total_memory(), 3000);
360 }
361
362 #[test]
363 fn test_least_loaded_shard() {
364 let mut registry = ShardRegistry::new();
365
366 let id1 = registry.register(test_shard_info());
367 let id2 = registry.register(test_shard_info());
368 let id3 = registry.register(test_shard_info());
369
370 registry.update_metrics(&id1, 100, 1000);
371 registry.update_metrics(&id2, 50, 500);
372 registry.update_metrics(&id3, 200, 2000);
373
374 assert_eq!(registry.least_loaded_shard(), Some(id2));
375 }
376
377 #[test]
378 fn test_check_dead_shards() {
379 let mut registry = ShardRegistry::with_heartbeat_timeout(100);
380 let info = test_shard_info();
381 let id = registry.register(info);
382
383 registry.heartbeat_with_timestamp(&id, 0);
385
386 let dead = registry.check_dead_shards();
387 assert_eq!(dead.len(), 1);
388 assert_eq!(dead[0], id);
389 assert_eq!(registry.get_status(&id), Some(ShardStatus::Offline));
390 }
391
392 #[test]
393 fn test_register_with_specific_id() {
394 let mut registry = ShardRegistry::new();
395 let info = test_shard_info();
396
397 let id = registry.register_with_id(info, ShardId::new(42));
398 assert_eq!(id, ShardId::new(42));
399 assert!(registry.contains(&id));
400
401 let info2 = test_shard_info();
403 let id2 = registry.register(info2);
404 assert_eq!(id2, ShardId::new(43));
405 }
406}