Skip to main content

phago_distributed/coordinator/
mod.rs

1//! Coordinator for distributed colony orchestration.
2//!
3//! The coordinator is the central component of the distributed colony system,
4//! responsible for:
5//! - Managing the cluster topology via the shard registry
6//! - Routing documents to shards using consistent hashing
7//! - Synchronizing ticks across shards using barriers
8//! - Aggregating global statistics like document frequencies
9
10mod shard_registry;
11mod tick_barrier;
12
13pub use shard_registry::{RegisteredShard, ShardRegistry};
14pub use tick_barrier::TickBarrier;
15
16use crate::hashing::ConsistentHashRing;
17use crate::types::*;
18use phago_core::types::{DocumentId, Tick};
19use std::collections::HashMap;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Arc;
22use tokio::sync::RwLock;
23
24/// The distributed coordinator.
25///
26/// The coordinator manages the distributed colony by:
27/// - Maintaining a registry of all active shards
28/// - Routing documents to shards using consistent hashing
29/// - Synchronizing tick phases across all shards
30/// - Aggregating global statistics for TF-IDF computation
31///
32/// # Thread Safety
33///
34/// The coordinator is designed for concurrent access. It uses interior
35/// mutability with `RwLock` for the registry and hash ring, and atomics
36/// for the tick counter.
37///
38/// # Example
39///
40/// ```rust,ignore
41/// use phago_distributed::Coordinator;
42/// use phago_distributed::types::{ShardInfo, NodeAddress};
43///
44/// let coordinator = Coordinator::new(3);
45///
46/// // Register a shard
47/// let info = ShardInfo::new(NodeAddress::new("127.0.0.1", 8080));
48/// let shard_id = coordinator.register_shard(info).await?;
49///
50/// // Route a document
51/// let doc_id = DocumentId::new();
52/// let target_shard = coordinator.route_document(&doc_id).await;
53/// ```
54pub struct Coordinator {
55    /// Registry of all shards.
56    shards: Arc<RwLock<ShardRegistry>>,
57    /// Current global tick.
58    current_tick: Arc<AtomicU64>,
59    /// Tick barrier for synchronization.
60    barrier: Arc<TickBarrier>,
61    /// Consistent hash ring for document routing.
62    hash_ring: Arc<RwLock<ConsistentHashRing>>,
63    /// Configuration for the distributed system.
64    config: DistributedConfig,
65}
66
67impl Coordinator {
68    /// Create a new coordinator with the specified number of shards.
69    ///
70    /// The coordinator will initialize the hash ring with the given
71    /// number of shards, but actual shards must be registered before
72    /// they can receive documents.
73    pub fn new(num_shards: u32) -> Self {
74        Self {
75            shards: Arc::new(RwLock::new(ShardRegistry::new())),
76            current_tick: Arc::new(AtomicU64::new(0)),
77            barrier: Arc::new(TickBarrier::new(num_shards as usize)),
78            hash_ring: Arc::new(RwLock::new(ConsistentHashRing::new(num_shards))),
79            config: DistributedConfig {
80                num_shards,
81                ..Default::default()
82            },
83        }
84    }
85
86    /// Create a coordinator with custom configuration.
87    pub fn with_config(config: DistributedConfig) -> Self {
88        let num_shards = config.num_shards;
89        Self {
90            shards: Arc::new(RwLock::new(ShardRegistry::new())),
91            current_tick: Arc::new(AtomicU64::new(0)),
92            barrier: Arc::new(TickBarrier::new(num_shards as usize)),
93            hash_ring: Arc::new(RwLock::new(ConsistentHashRing::with_virtual_nodes(
94                num_shards,
95                config.virtual_nodes_per_shard,
96            ))),
97            config,
98        }
99    }
100
101    /// Register a shard with the coordinator.
102    ///
103    /// The shard will be assigned a unique ID, added to the registry,
104    /// and included in the hash ring for document routing.
105    ///
106    /// # Arguments
107    ///
108    /// * `info` - Information about the shard to register
109    ///
110    /// # Returns
111    ///
112    /// The assigned shard ID.
113    pub async fn register_shard(&self, info: ShardInfo) -> DistributedResult<ShardId> {
114        let mut registry = self.shards.write().await;
115        let shard_id = registry.register(info);
116
117        // Add to hash ring
118        let mut ring = self.hash_ring.write().await;
119        ring.add_shard(shard_id);
120
121        // Update barrier for new shard count
122        self.barrier.set_shard_count(registry.count()).await;
123
124        Ok(shard_id)
125    }
126
127    /// Deregister a shard from the coordinator.
128    ///
129    /// The shard will be removed from the registry and hash ring.
130    /// Documents previously assigned to this shard will be redistributed.
131    pub async fn deregister_shard(&self, shard_id: ShardId) -> DistributedResult<()> {
132        let mut registry = self.shards.write().await;
133
134        if registry.remove(&shard_id).is_none() {
135            return Err(DistributedError::ShardNotFound(shard_id));
136        }
137
138        // Remove from hash ring
139        let mut ring = self.hash_ring.write().await;
140        ring.remove_shard(shard_id);
141
142        // Update barrier
143        self.barrier.set_shard_count(registry.count()).await;
144
145        Ok(())
146    }
147
148    /// Route a document to the appropriate shard.
149    ///
150    /// Uses consistent hashing to determine which shard should store
151    /// the document. The same document will always route to the same
152    /// shard (unless the cluster topology changes).
153    pub async fn route_document(&self, doc_id: &DocumentId) -> ShardId {
154        let ring = self.hash_ring.read().await;
155        ring.get_shard(doc_id)
156    }
157
158    /// Get replica shards for a document.
159    ///
160    /// Returns the primary shard plus additional replica shards based
161    /// on the configured replication factor.
162    pub async fn get_replica_shards(&self, doc_id: &DocumentId) -> Vec<ShardId> {
163        let ring = self.hash_ring.read().await;
164        ring.get_replica_shards(doc_id, self.config.replication_factor as usize)
165    }
166
167    /// Signal that a shard has completed a phase.
168    ///
169    /// This is called by each shard when it finishes a phase of the tick.
170    /// The coordinator tracks progress and releases the barrier when all
171    /// shards have completed.
172    pub async fn phase_complete(
173        &self,
174        shard_id: ShardId,
175        phase: TickPhase,
176        tick: Tick,
177    ) -> DistributedResult<()> {
178        self.barrier.complete(shard_id, phase, tick).await
179    }
180
181    /// Wait for all shards to complete a phase.
182    ///
183    /// Blocks until all registered shards have signaled completion
184    /// of the specified phase.
185    pub async fn wait_for_phase(&self, phase: TickPhase, tick: Tick) -> DistributedResult<()> {
186        self.barrier.wait_all(phase, tick).await
187    }
188
189    /// Advance to the next tick.
190    ///
191    /// This should be called after all phases of the current tick
192    /// are complete. Returns the new tick number.
193    pub async fn advance_tick(&self) -> Tick {
194        let new_tick = self.current_tick.fetch_add(1, Ordering::SeqCst) + 1;
195        self.barrier.reset_for_tick(new_tick).await;
196        new_tick
197    }
198
199    /// Get the current tick number.
200    pub fn current_tick(&self) -> Tick {
201        self.current_tick.load(Ordering::SeqCst)
202    }
203
204    /// Aggregate global document frequencies from all shards.
205    ///
206    /// This is used for computing global TF-IDF scores. Each shard
207    /// provides its local document frequencies, and the coordinator
208    /// sums them to produce global counts.
209    ///
210    /// # Arguments
211    ///
212    /// * `local_dfs` - Vector of term->count maps from each shard
213    ///
214    /// # Returns
215    ///
216    /// A map of term->global_count across all shards.
217    pub fn aggregate_global_df(
218        &self,
219        local_dfs: Vec<HashMap<String, u64>>,
220    ) -> HashMap<String, u64> {
221        let mut global_df = HashMap::new();
222        for local in local_dfs {
223            for (term, count) in local {
224                *global_df.entry(term).or_insert(0) += count;
225            }
226        }
227        global_df
228    }
229
230    /// Get all registered shards.
231    pub async fn all_shards(&self) -> Vec<ShardInfo> {
232        let registry = self.shards.read().await;
233        registry.all()
234    }
235
236    /// Get all online shards.
237    pub async fn online_shards(&self) -> Vec<ShardInfo> {
238        let registry = self.shards.read().await;
239        registry.online_shards()
240    }
241
242    /// Get a specific shard's information.
243    pub async fn get_shard(&self, shard_id: ShardId) -> Option<ShardInfo> {
244        let registry = self.shards.read().await;
245        registry.get(&shard_id).cloned()
246    }
247
248    /// Update heartbeat for a shard.
249    ///
250    /// Called periodically by shards to indicate they are still alive.
251    pub async fn shard_heartbeat(&self, shard_id: ShardId) {
252        let mut registry = self.shards.write().await;
253        registry.heartbeat(&shard_id);
254    }
255
256    /// Check for dead shards and mark them offline.
257    ///
258    /// Returns the IDs of shards that were marked offline.
259    pub async fn check_shard_health(&self) -> Vec<ShardId> {
260        let mut registry = self.shards.write().await;
261        registry.check_dead_shards()
262    }
263
264    /// Update shard metrics.
265    pub async fn update_shard_metrics(
266        &self,
267        shard_id: ShardId,
268        document_count: usize,
269        memory_bytes: u64,
270    ) {
271        let mut registry = self.shards.write().await;
272        registry.update_metrics(&shard_id, document_count, memory_bytes);
273    }
274
275    /// Get the total number of documents across all shards.
276    pub async fn total_documents(&self) -> u64 {
277        let registry = self.shards.read().await;
278        registry.total_documents()
279    }
280
281    /// Get cluster statistics.
282    pub async fn cluster_stats(&self) -> ClusterStats {
283        let registry = self.shards.read().await;
284        ClusterStats {
285            total_shards: registry.count() as u32,
286            online_shards: registry.online_shards().len() as u32,
287            total_documents: registry.total_documents(),
288            total_memory_bytes: registry.total_memory(),
289            current_tick: self.current_tick(),
290        }
291    }
292
293    /// Get the configuration.
294    pub fn config(&self) -> &DistributedConfig {
295        &self.config
296    }
297
298    /// Get the shard count from the hash ring.
299    pub async fn shard_count(&self) -> u32 {
300        let ring = self.hash_ring.read().await;
301        ring.shard_count()
302    }
303}
304
305/// Statistics about the distributed cluster.
306#[derive(Debug, Clone)]
307pub struct ClusterStats {
308    /// Total number of shards in the cluster.
309    pub total_shards: u32,
310    /// Number of currently online shards.
311    pub online_shards: u32,
312    /// Total documents across all shards.
313    pub total_documents: u64,
314    /// Total memory usage across all shards.
315    pub total_memory_bytes: u64,
316    /// Current simulation tick.
317    pub current_tick: Tick,
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    fn test_shard_info() -> ShardInfo {
325        ShardInfo::new(ShardId::new(0), "127.0.0.1:8080".to_string())
326    }
327
328    #[tokio::test]
329    async fn test_coordinator_creation() {
330        let coord = Coordinator::new(3);
331        assert_eq!(coord.current_tick(), 0);
332        assert_eq!(coord.shard_count().await, 3);
333    }
334
335    #[tokio::test]
336    async fn test_register_shard() {
337        let coord = Coordinator::new(3);
338        let info = test_shard_info();
339
340        let shard_id = coord.register_shard(info).await.unwrap();
341        assert_eq!(shard_id, ShardId::new(0));
342
343        let shards = coord.all_shards().await;
344        assert_eq!(shards.len(), 1);
345    }
346
347    #[tokio::test]
348    async fn test_document_routing() {
349        let coord = Coordinator::new(3);
350
351        // Register some shards
352        for _ in 0..3 {
353            coord.register_shard(test_shard_info()).await.unwrap();
354        }
355
356        let doc_id = DocumentId::from_seed(42);
357
358        // Routing should be consistent
359        let shard1 = coord.route_document(&doc_id).await;
360        let shard2 = coord.route_document(&doc_id).await;
361        assert_eq!(shard1, shard2);
362    }
363
364    #[tokio::test]
365    async fn test_advance_tick() {
366        let coord = Coordinator::new(1);
367
368        assert_eq!(coord.current_tick(), 0);
369
370        let tick1 = coord.advance_tick().await;
371        assert_eq!(tick1, 1);
372        assert_eq!(coord.current_tick(), 1);
373
374        let tick2 = coord.advance_tick().await;
375        assert_eq!(tick2, 2);
376        assert_eq!(coord.current_tick(), 2);
377    }
378
379    #[tokio::test]
380    async fn test_aggregate_global_df() {
381        let coord = Coordinator::new(2);
382
383        let local1 = HashMap::from([("hello".to_string(), 5), ("world".to_string(), 3)]);
384        let local2 = HashMap::from([("hello".to_string(), 2), ("rust".to_string(), 7)]);
385
386        let global = coord.aggregate_global_df(vec![local1, local2]);
387
388        assert_eq!(global.get("hello"), Some(&7));
389        assert_eq!(global.get("world"), Some(&3));
390        assert_eq!(global.get("rust"), Some(&7));
391    }
392
393    #[tokio::test]
394    async fn test_deregister_shard() {
395        let coord = Coordinator::new(3);
396
397        let id1 = coord.register_shard(test_shard_info()).await.unwrap();
398        let id2 = coord.register_shard(test_shard_info()).await.unwrap();
399
400        assert_eq!(coord.all_shards().await.len(), 2);
401
402        coord.deregister_shard(id1).await.unwrap();
403        assert_eq!(coord.all_shards().await.len(), 1);
404
405        // Deregistering again should error
406        let result = coord.deregister_shard(id1).await;
407        assert!(result.is_err());
408    }
409
410    #[tokio::test]
411    async fn test_cluster_stats() {
412        let coord = Coordinator::new(3);
413
414        coord.register_shard(test_shard_info()).await.unwrap();
415        let id2 = coord.register_shard(test_shard_info()).await.unwrap();
416
417        coord.update_shard_metrics(id2, 100, 1024).await;
418
419        let stats = coord.cluster_stats().await;
420        assert_eq!(stats.total_shards, 2);
421        assert_eq!(stats.online_shards, 2);
422        assert_eq!(stats.total_documents, 100);
423        assert_eq!(stats.total_memory_bytes, 1024);
424    }
425
426    #[tokio::test]
427    async fn test_replica_shards() {
428        let config = DistributedConfig {
429            num_shards: 5,
430            replication_factor: 2,
431            ..Default::default()
432        };
433        let coord = Coordinator::with_config(config);
434
435        let doc_id = DocumentId::from_seed(42);
436        let replicas = coord.get_replica_shards(&doc_id).await;
437
438        // Should get primary + 2 replicas = 3 shards
439        assert_eq!(replicas.len(), 3);
440
441        // All should be unique
442        let unique: std::collections::HashSet<_> = replicas.iter().collect();
443        assert_eq!(unique.len(), 3);
444    }
445}