Skip to main content

legalis_sim/
distributed.rs

1//! Distributed Simulation Framework
2//!
3//! This module provides abstractions and implementations for running simulations
4//! across multiple nodes in a distributed system, with support for:
5//! - Partition-based entity distribution
6//! - Cross-node communication
7//! - Dynamic load balancing
8//! - Fault-tolerant checkpointing
9
10use crate::{SimResult, SimulationError, SimulationMetrics};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13use std::sync::{Arc, Mutex};
14
15/// Node identifier in a distributed system
16pub type NodeId = usize;
17
18/// Message identifier for tracking
19pub type MessageId = u64;
20
21/// Node information in a distributed cluster
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct NodeInfo {
24    /// Node identifier
25    pub id: NodeId,
26    /// Node hostname or address
27    pub address: String,
28    /// Node rank in the cluster (0 = coordinator)
29    pub rank: usize,
30    /// Total number of nodes in the cluster
31    pub total_nodes: usize,
32    /// Current load (0.0 = idle, 1.0 = fully loaded)
33    pub load: f64,
34    /// Number of entities assigned to this node
35    pub entity_count: usize,
36    /// Node status
37    pub status: NodeStatus,
38}
39
40/// Status of a node in the cluster
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42pub enum NodeStatus {
43    /// Node is idle and ready
44    Idle,
45    /// Node is currently processing
46    Active,
47    /// Node is waiting for data
48    Waiting,
49    /// Node has failed
50    Failed,
51    /// Node is recovering from failure
52    Recovering,
53}
54
55impl NodeInfo {
56    /// Create a new node
57    pub fn new(id: NodeId, address: String, rank: usize, total_nodes: usize) -> Self {
58        NodeInfo {
59            id,
60            address,
61            rank,
62            total_nodes,
63            load: 0.0,
64            entity_count: 0,
65            status: NodeStatus::Idle,
66        }
67    }
68
69    /// Check if this node is the coordinator
70    pub fn is_coordinator(&self) -> bool {
71        self.rank == 0
72    }
73
74    /// Update load based on entity count
75    pub fn update_load(&mut self, max_entities_per_node: usize) {
76        if max_entities_per_node > 0 {
77            self.load = (self.entity_count as f64) / (max_entities_per_node as f64);
78            self.load = self.load.min(1.0);
79        }
80    }
81}
82
83/// Entity partition assignment
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct EntityPartition {
86    /// Partition identifier
87    pub id: usize,
88    /// Node assigned to this partition
89    pub node_id: NodeId,
90    /// Entity IDs in this partition
91    pub entity_ids: Vec<String>,
92    /// Partition size
93    pub size: usize,
94}
95
96impl EntityPartition {
97    /// Create a new partition
98    pub fn new(id: usize, node_id: NodeId) -> Self {
99        EntityPartition {
100            id,
101            node_id,
102            entity_ids: Vec::new(),
103            size: 0,
104        }
105    }
106
107    /// Add an entity to this partition
108    pub fn add_entity(&mut self, entity_id: String) {
109        self.entity_ids.push(entity_id);
110        self.size += 1;
111    }
112
113    /// Add multiple entities
114    pub fn add_entities(&mut self, entity_ids: Vec<String>) {
115        self.size += entity_ids.len();
116        self.entity_ids.extend(entity_ids);
117    }
118}
119
120/// Partitioning strategy for distributing entities across nodes
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
122pub enum PartitionStrategy {
123    /// Round-robin distribution
124    RoundRobin,
125    /// Hash-based partitioning
126    Hash,
127    /// Range-based partitioning (by entity ID)
128    Range,
129    /// Load-balanced partitioning
130    LoadBalanced,
131    /// Geographic partitioning (if entities have location)
132    Geographic,
133}
134
135/// Partition manager for distributing entities across nodes
136#[derive(Debug)]
137pub struct PartitionManager {
138    /// Partitioning strategy
139    strategy: PartitionStrategy,
140    /// All partitions
141    partitions: Vec<EntityPartition>,
142    /// Next partition ID
143    next_partition_id: usize,
144}
145
146impl PartitionManager {
147    /// Create a new partition manager
148    pub fn new(strategy: PartitionStrategy) -> Self {
149        PartitionManager {
150            strategy,
151            partitions: Vec::new(),
152            next_partition_id: 0,
153        }
154    }
155
156    /// Create partitions for a set of entity IDs
157    pub fn create_partitions(
158        &mut self,
159        entity_ids: &[String],
160        num_nodes: usize,
161    ) -> SimResult<Vec<EntityPartition>> {
162        if num_nodes == 0 {
163            return Err(SimulationError::InvalidParameter(
164                "Number of nodes must be greater than 0".to_string(),
165            ));
166        }
167
168        let mut partitions = Vec::with_capacity(num_nodes);
169        for node_id in 0..num_nodes {
170            partitions.push(EntityPartition::new(self.next_partition_id, node_id));
171            self.next_partition_id += 1;
172        }
173
174        // Distribute entities according to strategy
175        match self.strategy {
176            PartitionStrategy::RoundRobin => {
177                for (i, entity_id) in entity_ids.iter().enumerate() {
178                    let partition_idx = i % num_nodes;
179                    partitions[partition_idx].add_entity(entity_id.clone());
180                }
181            }
182            PartitionStrategy::Hash => {
183                for entity_id in entity_ids {
184                    let hash = Self::hash_string(entity_id);
185                    let partition_idx = (hash as usize) % num_nodes;
186                    partitions[partition_idx].add_entity(entity_id.clone());
187                }
188            }
189            PartitionStrategy::Range => {
190                let chunk_size = entity_ids.len().div_ceil(num_nodes);
191                for (i, entity_id) in entity_ids.iter().enumerate() {
192                    let partition_idx = i / chunk_size;
193                    let partition_idx = partition_idx.min(num_nodes - 1);
194                    partitions[partition_idx].add_entity(entity_id.clone());
195                }
196            }
197            PartitionStrategy::LoadBalanced | PartitionStrategy::Geographic => {
198                // For now, use round-robin for these strategies
199                // Can be enhanced with actual load balancing later
200                for (i, entity_id) in entity_ids.iter().enumerate() {
201                    let partition_idx = i % num_nodes;
202                    partitions[partition_idx].add_entity(entity_id.clone());
203                }
204            }
205        }
206
207        self.partitions.extend(partitions.clone());
208        Ok(partitions)
209    }
210
211    /// Simple hash function for strings
212    fn hash_string(s: &str) -> u64 {
213        let mut hash = 0u64;
214        for byte in s.bytes() {
215            hash = hash.wrapping_mul(31).wrapping_add(byte as u64);
216        }
217        hash
218    }
219
220    /// Get partition for a specific entity ID
221    pub fn get_partition(&self, entity_id: &str) -> Option<&EntityPartition> {
222        self.partitions
223            .iter()
224            .find(|p| p.entity_ids.contains(&entity_id.to_string()))
225    }
226
227    /// Get all partitions for a specific node
228    pub fn get_node_partitions(&self, node_id: NodeId) -> Vec<&EntityPartition> {
229        self.partitions
230            .iter()
231            .filter(|p| p.node_id == node_id)
232            .collect()
233    }
234
235    /// Get total number of partitions
236    pub fn partition_count(&self) -> usize {
237        self.partitions.len()
238    }
239}
240
241/// Message type for cross-node communication
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub enum ClusterMessageType {
244    /// Barrier synchronization
245    Barrier,
246    /// Entity data transfer
247    EntityData(Vec<String>),
248    /// Simulation results
249    Results(SimulationMetrics),
250    /// Load balancing request
251    LoadBalance,
252    /// Checkpoint trigger
253    Checkpoint,
254    /// Node status update
255    StatusUpdate(NodeStatus),
256    /// Custom user message
257    Custom(String),
258}
259
260/// Message for cross-node communication
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct Message {
263    /// Message ID
264    pub id: MessageId,
265    /// Source node
266    pub source: NodeId,
267    /// Destination node (None = broadcast)
268    pub destination: Option<NodeId>,
269    /// Message type and payload
270    pub message_type: ClusterMessageType,
271    /// Timestamp (seconds since epoch)
272    pub timestamp: u64,
273}
274
275impl Message {
276    /// Create a new message
277    pub fn new(
278        id: MessageId,
279        source: NodeId,
280        destination: Option<NodeId>,
281        message_type: ClusterMessageType,
282    ) -> Self {
283        Message {
284            id,
285            source,
286            destination,
287            message_type,
288            timestamp: std::time::SystemTime::now()
289                .duration_since(std::time::UNIX_EPOCH)
290                .unwrap()
291                .as_secs(),
292        }
293    }
294
295    /// Check if this is a broadcast message
296    pub fn is_broadcast(&self) -> bool {
297        self.destination.is_none()
298    }
299}
300
301/// Message queue for storing and retrieving messages
302#[derive(Debug)]
303pub struct MessageQueue {
304    /// Queue of messages
305    queue: Arc<Mutex<VecDeque<Message>>>,
306    /// Next message ID
307    next_id: Arc<Mutex<MessageId>>,
308}
309
310impl MessageQueue {
311    /// Create a new message queue
312    pub fn new() -> Self {
313        MessageQueue {
314            queue: Arc::new(Mutex::new(VecDeque::new())),
315            next_id: Arc::new(Mutex::new(0)),
316        }
317    }
318
319    /// Send a message
320    pub fn send(
321        &self,
322        source: NodeId,
323        destination: Option<NodeId>,
324        message_type: ClusterMessageType,
325    ) -> SimResult<MessageId> {
326        let mut next_id = self.next_id.lock().unwrap();
327        let id = *next_id;
328        *next_id += 1;
329        drop(next_id);
330
331        let message = Message::new(id, source, destination, message_type);
332        let mut queue = self.queue.lock().unwrap();
333        queue.push_back(message);
334        Ok(id)
335    }
336
337    /// Receive a message (blocking until message available)
338    pub fn receive(&self, node_id: NodeId) -> Option<Message> {
339        let mut queue = self.queue.lock().unwrap();
340        let position = queue
341            .iter()
342            .position(|msg| msg.destination == Some(node_id) || msg.destination.is_none());
343
344        position.and_then(|pos| queue.remove(pos))
345    }
346
347    /// Peek at next message without removing
348    pub fn peek(&self, node_id: NodeId) -> Option<Message> {
349        let queue = self.queue.lock().unwrap();
350        queue
351            .iter()
352            .find(|msg| msg.destination == Some(node_id) || msg.destination.is_none())
353            .cloned()
354    }
355
356    /// Get queue size
357    pub fn size(&self) -> usize {
358        self.queue.lock().unwrap().len()
359    }
360
361    /// Clear the queue
362    pub fn clear(&self) {
363        self.queue.lock().unwrap().clear();
364    }
365}
366
367impl Default for MessageQueue {
368    fn default() -> Self {
369        Self::new()
370    }
371}
372
373/// Load balancing strategy
374#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
375pub enum LoadBalanceStrategy {
376    /// No load balancing
377    None,
378    /// Periodic rebalancing at fixed intervals
379    Periodic,
380    /// Dynamic rebalancing based on load threshold
381    Dynamic,
382    /// Work stealing from busy nodes
383    WorkStealing,
384}
385
386/// Load balancer for distributing work across nodes
387#[derive(Debug)]
388pub struct LoadBalancer {
389    /// Load balancing strategy
390    strategy: LoadBalanceStrategy,
391    /// Load threshold for triggering rebalancing (0.0-1.0)
392    threshold: f64,
393    /// Minimum imbalance to trigger rebalancing
394    min_imbalance: f64,
395}
396
397impl LoadBalancer {
398    /// Create a new load balancer
399    pub fn new(strategy: LoadBalanceStrategy, threshold: f64) -> Self {
400        LoadBalancer {
401            strategy,
402            threshold: threshold.clamp(0.0, 1.0),
403            min_imbalance: 0.2, // 20% imbalance threshold
404        }
405    }
406
407    /// Check if rebalancing is needed
408    pub fn needs_rebalancing(&self, nodes: &[NodeInfo]) -> bool {
409        if nodes.is_empty() || self.strategy == LoadBalanceStrategy::None {
410            return false;
411        }
412
413        let max_load = nodes.iter().map(|n| n.load).fold(0.0, f64::max);
414        let min_load = nodes.iter().map(|n| n.load).fold(1.0, f64::min);
415        let imbalance = max_load - min_load;
416
417        match self.strategy {
418            LoadBalanceStrategy::None => false,
419            LoadBalanceStrategy::Periodic => true, // Caller handles timing
420            LoadBalanceStrategy::Dynamic => max_load > self.threshold,
421            LoadBalanceStrategy::WorkStealing => imbalance > self.min_imbalance,
422        }
423    }
424
425    /// Calculate rebalancing plan
426    pub fn calculate_rebalance(
427        &self,
428        nodes: &[NodeInfo],
429        partitions: &[EntityPartition],
430    ) -> Vec<(usize, NodeId, NodeId)> {
431        // Returns (partition_id, from_node, to_node) tuples
432        let mut moves = Vec::new();
433
434        if nodes.len() < 2 || partitions.is_empty() {
435            return moves;
436        }
437
438        // Find overloaded and underloaded nodes
439        let avg_load = nodes.iter().map(|n| n.load).sum::<f64>() / nodes.len() as f64;
440        let mut overloaded: Vec<_> = nodes.iter().filter(|n| n.load > avg_load).collect();
441        let mut underloaded: Vec<_> = nodes.iter().filter(|n| n.load < avg_load).collect();
442
443        overloaded.sort_by(|a, b| b.load.partial_cmp(&a.load).unwrap());
444        underloaded.sort_by(|a, b| a.load.partial_cmp(&b.load).unwrap());
445
446        // Move partitions from overloaded to underloaded nodes
447        for overloaded_node in &overloaded {
448            let node_partitions: Vec<_> = partitions
449                .iter()
450                .filter(|p| p.node_id == overloaded_node.id)
451                .collect();
452
453            for partition in node_partitions {
454                if let Some(underloaded_node) = underloaded.first()
455                    && overloaded_node.load - underloaded_node.load > self.min_imbalance
456                {
457                    moves.push((partition.id, overloaded_node.id, underloaded_node.id));
458
459                    // Update for next iteration (simplified)
460                    if underloaded.len() > 1 {
461                        underloaded.remove(0);
462                    }
463                }
464            }
465        }
466
467        moves
468    }
469
470    /// Set imbalance threshold
471    pub fn set_imbalance_threshold(&mut self, threshold: f64) {
472        self.min_imbalance = threshold.clamp(0.0, 1.0);
473    }
474}
475
476/// Distributed simulation coordinator
477#[derive(Debug)]
478pub struct ClusterCoordinator {
479    /// Nodes in the cluster
480    nodes: Vec<NodeInfo>,
481    /// Partition manager
482    partition_manager: PartitionManager,
483    /// Message queue
484    message_queue: MessageQueue,
485    /// Load balancer
486    load_balancer: LoadBalancer,
487    /// Coordinator node info
488    coordinator_node: NodeInfo,
489}
490
491impl ClusterCoordinator {
492    /// Create a new cluster coordinator
493    pub fn new(
494        num_nodes: usize,
495        partition_strategy: PartitionStrategy,
496        load_balance_strategy: LoadBalanceStrategy,
497    ) -> Self {
498        let mut nodes = Vec::with_capacity(num_nodes);
499        for i in 0..num_nodes {
500            nodes.push(NodeInfo::new(i, format!("node-{}", i), i, num_nodes));
501        }
502
503        let coordinator_node = nodes[0].clone();
504
505        ClusterCoordinator {
506            nodes,
507            partition_manager: PartitionManager::new(partition_strategy),
508            message_queue: MessageQueue::new(),
509            load_balancer: LoadBalancer::new(load_balance_strategy, 0.8),
510            coordinator_node,
511        }
512    }
513
514    /// Distribute entities across nodes
515    pub fn distribute_entities(&mut self, entity_ids: &[String]) -> SimResult<()> {
516        let partitions = self
517            .partition_manager
518            .create_partitions(entity_ids, self.nodes.len())?;
519
520        // Update node entity counts
521        for partition in &partitions {
522            if let Some(node) = self.nodes.iter_mut().find(|n| n.id == partition.node_id) {
523                node.entity_count += partition.size;
524            }
525        }
526
527        // Update loads
528        let max_entities = entity_ids.len() / self.nodes.len().max(1);
529        for node in &mut self.nodes {
530            node.update_load(max_entities);
531        }
532
533        Ok(())
534    }
535
536    /// Send a message to a specific node or broadcast
537    pub fn send_message(
538        &self,
539        destination: Option<NodeId>,
540        message_type: ClusterMessageType,
541    ) -> SimResult<MessageId> {
542        self.message_queue
543            .send(self.coordinator_node.id, destination, message_type)
544    }
545
546    /// Receive a message for this coordinator
547    pub fn receive_message(&self) -> Option<Message> {
548        self.message_queue.receive(self.coordinator_node.id)
549    }
550
551    /// Perform barrier synchronization
552    pub fn barrier(&self) -> SimResult<()> {
553        // Send barrier message to all nodes
554        self.send_message(None, ClusterMessageType::Barrier)?;
555        Ok(())
556    }
557
558    /// Check if load balancing is needed and rebalance if necessary
559    pub fn rebalance_if_needed(&mut self) -> SimResult<Vec<(usize, NodeId, NodeId)>> {
560        if !self.load_balancer.needs_rebalancing(&self.nodes) {
561            return Ok(Vec::new());
562        }
563
564        let moves = self
565            .load_balancer
566            .calculate_rebalance(&self.nodes, &self.partition_manager.partitions);
567
568        // Send load balance messages
569        if !moves.is_empty() {
570            self.send_message(None, ClusterMessageType::LoadBalance)?;
571        }
572
573        Ok(moves)
574    }
575
576    /// Get node information
577    pub fn get_node(&self, node_id: NodeId) -> Option<&NodeInfo> {
578        self.nodes.iter().find(|n| n.id == node_id)
579    }
580
581    /// Get all nodes
582    pub fn nodes(&self) -> &[NodeInfo] {
583        &self.nodes
584    }
585
586    /// Get partition manager
587    pub fn partition_manager(&self) -> &PartitionManager {
588        &self.partition_manager
589    }
590
591    /// Update node status
592    pub fn update_node_status(&mut self, node_id: NodeId, status: NodeStatus) -> SimResult<()> {
593        if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
594            node.status = status;
595            Ok(())
596        } else {
597            Err(SimulationError::InvalidParameter(format!(
598                "Node {} not found",
599                node_id
600            )))
601        }
602    }
603
604    /// Get number of nodes
605    pub fn num_nodes(&self) -> usize {
606        self.nodes.len()
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613
614    #[test]
615    fn test_node_info_creation() {
616        let node = NodeInfo::new(1, "node-1".to_string(), 1, 4);
617        assert_eq!(node.id, 1);
618        assert_eq!(node.rank, 1);
619        assert_eq!(node.total_nodes, 4);
620        assert!(!node.is_coordinator());
621        assert_eq!(node.status, NodeStatus::Idle);
622    }
623
624    #[test]
625    fn test_coordinator_node() {
626        let node = NodeInfo::new(0, "coordinator".to_string(), 0, 4);
627        assert!(node.is_coordinator());
628    }
629
630    #[test]
631    fn test_node_load_update() {
632        let mut node = NodeInfo::new(1, "node-1".to_string(), 1, 4);
633        node.entity_count = 50;
634        node.update_load(100);
635        assert_eq!(node.load, 0.5);
636    }
637
638    #[test]
639    fn test_entity_partition() {
640        let mut partition = EntityPartition::new(0, 1);
641        partition.add_entity("entity-1".to_string());
642        partition.add_entity("entity-2".to_string());
643        assert_eq!(partition.size, 2);
644        assert_eq!(partition.entity_ids.len(), 2);
645    }
646
647    #[test]
648    fn test_partition_manager_round_robin() {
649        let mut manager = PartitionManager::new(PartitionStrategy::RoundRobin);
650        let entity_ids: Vec<String> = (0..10).map(|i| format!("entity-{}", i)).collect();
651        let partitions = manager.create_partitions(&entity_ids, 3).unwrap();
652
653        assert_eq!(partitions.len(), 3);
654        assert_eq!(
655            partitions[0].size + partitions[1].size + partitions[2].size,
656            10
657        );
658    }
659
660    #[test]
661    fn test_partition_manager_hash() {
662        let mut manager = PartitionManager::new(PartitionStrategy::Hash);
663        let entity_ids: Vec<String> = (0..10).map(|i| format!("entity-{}", i)).collect();
664        let partitions = manager.create_partitions(&entity_ids, 3).unwrap();
665
666        assert_eq!(partitions.len(), 3);
667        let total: usize = partitions.iter().map(|p| p.size).sum();
668        assert_eq!(total, 10);
669    }
670
671    #[test]
672    fn test_partition_manager_range() {
673        let mut manager = PartitionManager::new(PartitionStrategy::Range);
674        let entity_ids: Vec<String> = (0..9).map(|i| format!("entity-{}", i)).collect();
675        let partitions = manager.create_partitions(&entity_ids, 3).unwrap();
676
677        assert_eq!(partitions.len(), 3);
678        assert_eq!(partitions[0].size, 3);
679        assert_eq!(partitions[1].size, 3);
680        assert_eq!(partitions[2].size, 3);
681    }
682
683    #[test]
684    fn test_message_creation() {
685        let msg = Message::new(1, 0, Some(1), ClusterMessageType::Barrier);
686        assert_eq!(msg.id, 1);
687        assert_eq!(msg.source, 0);
688        assert_eq!(msg.destination, Some(1));
689        assert!(!msg.is_broadcast());
690    }
691
692    #[test]
693    fn test_broadcast_message() {
694        let msg = Message::new(1, 0, None, ClusterMessageType::Barrier);
695        assert!(msg.is_broadcast());
696    }
697
698    #[test]
699    fn test_message_queue() {
700        let queue = MessageQueue::new();
701        let id = queue.send(0, Some(1), ClusterMessageType::Barrier).unwrap();
702        assert_eq!(id, 0);
703        assert_eq!(queue.size(), 1);
704
705        let msg = queue.receive(1).unwrap();
706        assert_eq!(msg.id, 0);
707        assert_eq!(queue.size(), 0);
708    }
709
710    #[test]
711    fn test_message_queue_broadcast() {
712        let queue = MessageQueue::new();
713        queue.send(0, None, ClusterMessageType::Barrier).unwrap();
714
715        // Any node can receive broadcast
716        let msg1 = queue.peek(1).unwrap();
717        assert_eq!(msg1.source, 0);
718
719        let msg2 = queue.receive(2).unwrap();
720        assert_eq!(msg2.source, 0);
721    }
722
723    #[test]
724    fn test_load_balancer_no_rebalancing() {
725        let balancer = LoadBalancer::new(LoadBalanceStrategy::None, 0.8);
726        let nodes = vec![
727            NodeInfo::new(0, "node-0".to_string(), 0, 2),
728            NodeInfo::new(1, "node-1".to_string(), 1, 2),
729        ];
730        assert!(!balancer.needs_rebalancing(&nodes));
731    }
732
733    #[test]
734    fn test_load_balancer_dynamic() {
735        let balancer = LoadBalancer::new(LoadBalanceStrategy::Dynamic, 0.5);
736        let mut nodes = vec![
737            NodeInfo::new(0, "node-0".to_string(), 0, 2),
738            NodeInfo::new(1, "node-1".to_string(), 1, 2),
739        ];
740        nodes[0].load = 0.8;
741        nodes[1].load = 0.2;
742
743        assert!(balancer.needs_rebalancing(&nodes));
744    }
745
746    #[test]
747    fn test_cluster_coordinator() {
748        let coordinator = ClusterCoordinator::new(
749            4,
750            PartitionStrategy::RoundRobin,
751            LoadBalanceStrategy::Dynamic,
752        );
753        assert_eq!(coordinator.num_nodes(), 4);
754        assert!(coordinator.get_node(0).is_some());
755    }
756
757    #[test]
758    fn test_cluster_coordinator_distribute_entities() {
759        let mut coordinator =
760            ClusterCoordinator::new(3, PartitionStrategy::RoundRobin, LoadBalanceStrategy::None);
761        let entity_ids: Vec<String> = (0..12).map(|i| format!("entity-{}", i)).collect();
762
763        coordinator.distribute_entities(&entity_ids).unwrap();
764
765        // Each node should have approximately equal entities
766        for node in coordinator.nodes() {
767            assert_eq!(node.entity_count, 4);
768        }
769    }
770
771    #[test]
772    fn test_cluster_coordinator_messaging() {
773        let coordinator =
774            ClusterCoordinator::new(2, PartitionStrategy::RoundRobin, LoadBalanceStrategy::None);
775        let msg_id = coordinator
776            .send_message(Some(1), ClusterMessageType::Barrier)
777            .unwrap();
778        assert_eq!(msg_id, 0);
779    }
780
781    #[test]
782    fn test_cluster_coordinator_barrier() {
783        let coordinator =
784            ClusterCoordinator::new(3, PartitionStrategy::RoundRobin, LoadBalanceStrategy::None);
785        coordinator.barrier().unwrap();
786        assert_eq!(coordinator.message_queue.size(), 1);
787    }
788
789    #[test]
790    fn test_cluster_coordinator_update_node_status() {
791        let mut coordinator =
792            ClusterCoordinator::new(2, PartitionStrategy::RoundRobin, LoadBalanceStrategy::None);
793        coordinator
794            .update_node_status(1, NodeStatus::Active)
795            .unwrap();
796        assert_eq!(coordinator.get_node(1).unwrap().status, NodeStatus::Active);
797    }
798
799    #[test]
800    fn test_partition_manager_get_partition() {
801        let mut manager = PartitionManager::new(PartitionStrategy::RoundRobin);
802        let entity_ids = vec!["entity-1".to_string(), "entity-2".to_string()];
803        manager.create_partitions(&entity_ids, 2).unwrap();
804
805        let partition = manager.get_partition("entity-1");
806        assert!(partition.is_some());
807    }
808
809    #[test]
810    fn test_partition_manager_get_node_partitions() {
811        let mut manager = PartitionManager::new(PartitionStrategy::RoundRobin);
812        let entity_ids: Vec<String> = (0..6).map(|i| format!("entity-{}", i)).collect();
813        manager.create_partitions(&entity_ids, 2).unwrap();
814
815        let partitions = manager.get_node_partitions(0);
816        assert!(!partitions.is_empty());
817    }
818
819    #[test]
820    fn test_load_balancer_calculate_rebalance() {
821        let balancer = LoadBalancer::new(LoadBalanceStrategy::WorkStealing, 0.8);
822        let mut nodes = vec![
823            NodeInfo::new(0, "node-0".to_string(), 0, 2),
824            NodeInfo::new(1, "node-1".to_string(), 1, 2),
825        ];
826        nodes[0].load = 0.9;
827        nodes[1].load = 0.1;
828
829        let mut partitions = vec![
830            EntityPartition::new(0, 0),
831            EntityPartition::new(1, 0),
832            EntityPartition::new(2, 1),
833        ];
834        partitions[0].size = 10;
835        partitions[1].size = 10;
836        partitions[2].size = 2;
837
838        let moves = balancer.calculate_rebalance(&nodes, &partitions);
839        assert!(!moves.is_empty());
840    }
841}