phago_distributed/coordinator/
mod.rs1mod shard_registry;
11mod tick_barrier;
12
13pub use shard_registry::{RegisteredShard, ShardRegistry};
14pub use tick_barrier::TickBarrier;
15
16use crate::hashing::ConsistentHashRing;
17use crate::types::*;
18use phago_core::types::{DocumentId, Tick};
19use std::collections::HashMap;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Arc;
22use tokio::sync::RwLock;
23
24pub struct Coordinator {
55 shards: Arc<RwLock<ShardRegistry>>,
57 current_tick: Arc<AtomicU64>,
59 barrier: Arc<TickBarrier>,
61 hash_ring: Arc<RwLock<ConsistentHashRing>>,
63 config: DistributedConfig,
65}
66
67impl Coordinator {
68 pub fn new(num_shards: u32) -> Self {
74 Self {
75 shards: Arc::new(RwLock::new(ShardRegistry::new())),
76 current_tick: Arc::new(AtomicU64::new(0)),
77 barrier: Arc::new(TickBarrier::new(num_shards as usize)),
78 hash_ring: Arc::new(RwLock::new(ConsistentHashRing::new(num_shards))),
79 config: DistributedConfig {
80 num_shards,
81 ..Default::default()
82 },
83 }
84 }
85
86 pub fn with_config(config: DistributedConfig) -> Self {
88 let num_shards = config.num_shards;
89 Self {
90 shards: Arc::new(RwLock::new(ShardRegistry::new())),
91 current_tick: Arc::new(AtomicU64::new(0)),
92 barrier: Arc::new(TickBarrier::new(num_shards as usize)),
93 hash_ring: Arc::new(RwLock::new(ConsistentHashRing::with_virtual_nodes(
94 num_shards,
95 config.virtual_nodes_per_shard,
96 ))),
97 config,
98 }
99 }
100
101 pub async fn register_shard(&self, info: ShardInfo) -> DistributedResult<ShardId> {
114 let mut registry = self.shards.write().await;
115 let shard_id = registry.register(info);
116
117 let mut ring = self.hash_ring.write().await;
119 ring.add_shard(shard_id);
120
121 self.barrier.set_shard_count(registry.count()).await;
123
124 Ok(shard_id)
125 }
126
127 pub async fn deregister_shard(&self, shard_id: ShardId) -> DistributedResult<()> {
132 let mut registry = self.shards.write().await;
133
134 if registry.remove(&shard_id).is_none() {
135 return Err(DistributedError::ShardNotFound(shard_id));
136 }
137
138 let mut ring = self.hash_ring.write().await;
140 ring.remove_shard(shard_id);
141
142 self.barrier.set_shard_count(registry.count()).await;
144
145 Ok(())
146 }
147
148 pub async fn route_document(&self, doc_id: &DocumentId) -> ShardId {
154 let ring = self.hash_ring.read().await;
155 ring.get_shard(doc_id)
156 }
157
158 pub async fn get_replica_shards(&self, doc_id: &DocumentId) -> Vec<ShardId> {
163 let ring = self.hash_ring.read().await;
164 ring.get_replica_shards(doc_id, self.config.replication_factor as usize)
165 }
166
167 pub async fn phase_complete(
173 &self,
174 shard_id: ShardId,
175 phase: TickPhase,
176 tick: Tick,
177 ) -> DistributedResult<()> {
178 self.barrier.complete(shard_id, phase, tick).await
179 }
180
181 pub async fn wait_for_phase(&self, phase: TickPhase, tick: Tick) -> DistributedResult<()> {
186 self.barrier.wait_all(phase, tick).await
187 }
188
189 pub async fn advance_tick(&self) -> Tick {
194 let new_tick = self.current_tick.fetch_add(1, Ordering::SeqCst) + 1;
195 self.barrier.reset_for_tick(new_tick).await;
196 new_tick
197 }
198
199 pub fn current_tick(&self) -> Tick {
201 self.current_tick.load(Ordering::SeqCst)
202 }
203
204 pub fn aggregate_global_df(
218 &self,
219 local_dfs: Vec<HashMap<String, u64>>,
220 ) -> HashMap<String, u64> {
221 let mut global_df = HashMap::new();
222 for local in local_dfs {
223 for (term, count) in local {
224 *global_df.entry(term).or_insert(0) += count;
225 }
226 }
227 global_df
228 }
229
230 pub async fn all_shards(&self) -> Vec<ShardInfo> {
232 let registry = self.shards.read().await;
233 registry.all()
234 }
235
236 pub async fn online_shards(&self) -> Vec<ShardInfo> {
238 let registry = self.shards.read().await;
239 registry.online_shards()
240 }
241
242 pub async fn get_shard(&self, shard_id: ShardId) -> Option<ShardInfo> {
244 let registry = self.shards.read().await;
245 registry.get(&shard_id).cloned()
246 }
247
248 pub async fn shard_heartbeat(&self, shard_id: ShardId) {
252 let mut registry = self.shards.write().await;
253 registry.heartbeat(&shard_id);
254 }
255
256 pub async fn check_shard_health(&self) -> Vec<ShardId> {
260 let mut registry = self.shards.write().await;
261 registry.check_dead_shards()
262 }
263
264 pub async fn update_shard_metrics(
266 &self,
267 shard_id: ShardId,
268 document_count: usize,
269 memory_bytes: u64,
270 ) {
271 let mut registry = self.shards.write().await;
272 registry.update_metrics(&shard_id, document_count, memory_bytes);
273 }
274
275 pub async fn total_documents(&self) -> u64 {
277 let registry = self.shards.read().await;
278 registry.total_documents()
279 }
280
281 pub async fn cluster_stats(&self) -> ClusterStats {
283 let registry = self.shards.read().await;
284 ClusterStats {
285 total_shards: registry.count() as u32,
286 online_shards: registry.online_shards().len() as u32,
287 total_documents: registry.total_documents(),
288 total_memory_bytes: registry.total_memory(),
289 current_tick: self.current_tick(),
290 }
291 }
292
293 pub fn config(&self) -> &DistributedConfig {
295 &self.config
296 }
297
298 pub async fn shard_count(&self) -> u32 {
300 let ring = self.hash_ring.read().await;
301 ring.shard_count()
302 }
303}
304
305#[derive(Debug, Clone)]
307pub struct ClusterStats {
308 pub total_shards: u32,
310 pub online_shards: u32,
312 pub total_documents: u64,
314 pub total_memory_bytes: u64,
316 pub current_tick: Tick,
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 fn test_shard_info() -> ShardInfo {
325 ShardInfo::new(ShardId::new(0), "127.0.0.1:8080".to_string())
326 }
327
328 #[tokio::test]
329 async fn test_coordinator_creation() {
330 let coord = Coordinator::new(3);
331 assert_eq!(coord.current_tick(), 0);
332 assert_eq!(coord.shard_count().await, 3);
333 }
334
335 #[tokio::test]
336 async fn test_register_shard() {
337 let coord = Coordinator::new(3);
338 let info = test_shard_info();
339
340 let shard_id = coord.register_shard(info).await.unwrap();
341 assert_eq!(shard_id, ShardId::new(0));
342
343 let shards = coord.all_shards().await;
344 assert_eq!(shards.len(), 1);
345 }
346
347 #[tokio::test]
348 async fn test_document_routing() {
349 let coord = Coordinator::new(3);
350
351 for _ in 0..3 {
353 coord.register_shard(test_shard_info()).await.unwrap();
354 }
355
356 let doc_id = DocumentId::from_seed(42);
357
358 let shard1 = coord.route_document(&doc_id).await;
360 let shard2 = coord.route_document(&doc_id).await;
361 assert_eq!(shard1, shard2);
362 }
363
364 #[tokio::test]
365 async fn test_advance_tick() {
366 let coord = Coordinator::new(1);
367
368 assert_eq!(coord.current_tick(), 0);
369
370 let tick1 = coord.advance_tick().await;
371 assert_eq!(tick1, 1);
372 assert_eq!(coord.current_tick(), 1);
373
374 let tick2 = coord.advance_tick().await;
375 assert_eq!(tick2, 2);
376 assert_eq!(coord.current_tick(), 2);
377 }
378
379 #[tokio::test]
380 async fn test_aggregate_global_df() {
381 let coord = Coordinator::new(2);
382
383 let local1 = HashMap::from([("hello".to_string(), 5), ("world".to_string(), 3)]);
384 let local2 = HashMap::from([("hello".to_string(), 2), ("rust".to_string(), 7)]);
385
386 let global = coord.aggregate_global_df(vec![local1, local2]);
387
388 assert_eq!(global.get("hello"), Some(&7));
389 assert_eq!(global.get("world"), Some(&3));
390 assert_eq!(global.get("rust"), Some(&7));
391 }
392
393 #[tokio::test]
394 async fn test_deregister_shard() {
395 let coord = Coordinator::new(3);
396
397 let id1 = coord.register_shard(test_shard_info()).await.unwrap();
398 let id2 = coord.register_shard(test_shard_info()).await.unwrap();
399
400 assert_eq!(coord.all_shards().await.len(), 2);
401
402 coord.deregister_shard(id1).await.unwrap();
403 assert_eq!(coord.all_shards().await.len(), 1);
404
405 let result = coord.deregister_shard(id1).await;
407 assert!(result.is_err());
408 }
409
410 #[tokio::test]
411 async fn test_cluster_stats() {
412 let coord = Coordinator::new(3);
413
414 coord.register_shard(test_shard_info()).await.unwrap();
415 let id2 = coord.register_shard(test_shard_info()).await.unwrap();
416
417 coord.update_shard_metrics(id2, 100, 1024).await;
418
419 let stats = coord.cluster_stats().await;
420 assert_eq!(stats.total_shards, 2);
421 assert_eq!(stats.online_shards, 2);
422 assert_eq!(stats.total_documents, 100);
423 assert_eq!(stats.total_memory_bytes, 1024);
424 }
425
426 #[tokio::test]
427 async fn test_replica_shards() {
428 let config = DistributedConfig {
429 num_shards: 5,
430 replication_factor: 2,
431 ..Default::default()
432 };
433 let coord = Coordinator::with_config(config);
434
435 let doc_id = DocumentId::from_seed(42);
436 let replicas = coord.get_replica_shards(&doc_id).await;
437
438 assert_eq!(replicas.len(), 3);
440
441 let unique: std::collections::HashSet<_> = replicas.iter().collect();
443 assert_eq!(unique.len(), 3);
444 }
445}