1use crate::{SimResult, SimulationError, SimulationMetrics};
11use serde::{Deserialize, Serialize};
12use std::collections::VecDeque;
13use std::sync::{Arc, Mutex};
14
15pub type NodeId = usize;
17
18pub type MessageId = u64;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct NodeInfo {
24 pub id: NodeId,
26 pub address: String,
28 pub rank: usize,
30 pub total_nodes: usize,
32 pub load: f64,
34 pub entity_count: usize,
36 pub status: NodeStatus,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42pub enum NodeStatus {
43 Idle,
45 Active,
47 Waiting,
49 Failed,
51 Recovering,
53}
54
55impl NodeInfo {
56 pub fn new(id: NodeId, address: String, rank: usize, total_nodes: usize) -> Self {
58 NodeInfo {
59 id,
60 address,
61 rank,
62 total_nodes,
63 load: 0.0,
64 entity_count: 0,
65 status: NodeStatus::Idle,
66 }
67 }
68
69 pub fn is_coordinator(&self) -> bool {
71 self.rank == 0
72 }
73
74 pub fn update_load(&mut self, max_entities_per_node: usize) {
76 if max_entities_per_node > 0 {
77 self.load = (self.entity_count as f64) / (max_entities_per_node as f64);
78 self.load = self.load.min(1.0);
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct EntityPartition {
86 pub id: usize,
88 pub node_id: NodeId,
90 pub entity_ids: Vec<String>,
92 pub size: usize,
94}
95
96impl EntityPartition {
97 pub fn new(id: usize, node_id: NodeId) -> Self {
99 EntityPartition {
100 id,
101 node_id,
102 entity_ids: Vec::new(),
103 size: 0,
104 }
105 }
106
107 pub fn add_entity(&mut self, entity_id: String) {
109 self.entity_ids.push(entity_id);
110 self.size += 1;
111 }
112
113 pub fn add_entities(&mut self, entity_ids: Vec<String>) {
115 self.size += entity_ids.len();
116 self.entity_ids.extend(entity_ids);
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
122pub enum PartitionStrategy {
123 RoundRobin,
125 Hash,
127 Range,
129 LoadBalanced,
131 Geographic,
133}
134
135#[derive(Debug)]
137pub struct PartitionManager {
138 strategy: PartitionStrategy,
140 partitions: Vec<EntityPartition>,
142 next_partition_id: usize,
144}
145
146impl PartitionManager {
147 pub fn new(strategy: PartitionStrategy) -> Self {
149 PartitionManager {
150 strategy,
151 partitions: Vec::new(),
152 next_partition_id: 0,
153 }
154 }
155
156 pub fn create_partitions(
158 &mut self,
159 entity_ids: &[String],
160 num_nodes: usize,
161 ) -> SimResult<Vec<EntityPartition>> {
162 if num_nodes == 0 {
163 return Err(SimulationError::InvalidParameter(
164 "Number of nodes must be greater than 0".to_string(),
165 ));
166 }
167
168 let mut partitions = Vec::with_capacity(num_nodes);
169 for node_id in 0..num_nodes {
170 partitions.push(EntityPartition::new(self.next_partition_id, node_id));
171 self.next_partition_id += 1;
172 }
173
174 match self.strategy {
176 PartitionStrategy::RoundRobin => {
177 for (i, entity_id) in entity_ids.iter().enumerate() {
178 let partition_idx = i % num_nodes;
179 partitions[partition_idx].add_entity(entity_id.clone());
180 }
181 }
182 PartitionStrategy::Hash => {
183 for entity_id in entity_ids {
184 let hash = Self::hash_string(entity_id);
185 let partition_idx = (hash as usize) % num_nodes;
186 partitions[partition_idx].add_entity(entity_id.clone());
187 }
188 }
189 PartitionStrategy::Range => {
190 let chunk_size = entity_ids.len().div_ceil(num_nodes);
191 for (i, entity_id) in entity_ids.iter().enumerate() {
192 let partition_idx = i / chunk_size;
193 let partition_idx = partition_idx.min(num_nodes - 1);
194 partitions[partition_idx].add_entity(entity_id.clone());
195 }
196 }
197 PartitionStrategy::LoadBalanced | PartitionStrategy::Geographic => {
198 for (i, entity_id) in entity_ids.iter().enumerate() {
201 let partition_idx = i % num_nodes;
202 partitions[partition_idx].add_entity(entity_id.clone());
203 }
204 }
205 }
206
207 self.partitions.extend(partitions.clone());
208 Ok(partitions)
209 }
210
211 fn hash_string(s: &str) -> u64 {
213 let mut hash = 0u64;
214 for byte in s.bytes() {
215 hash = hash.wrapping_mul(31).wrapping_add(byte as u64);
216 }
217 hash
218 }
219
220 pub fn get_partition(&self, entity_id: &str) -> Option<&EntityPartition> {
222 self.partitions
223 .iter()
224 .find(|p| p.entity_ids.contains(&entity_id.to_string()))
225 }
226
227 pub fn get_node_partitions(&self, node_id: NodeId) -> Vec<&EntityPartition> {
229 self.partitions
230 .iter()
231 .filter(|p| p.node_id == node_id)
232 .collect()
233 }
234
235 pub fn partition_count(&self) -> usize {
237 self.partitions.len()
238 }
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
243pub enum ClusterMessageType {
244 Barrier,
246 EntityData(Vec<String>),
248 Results(SimulationMetrics),
250 LoadBalance,
252 Checkpoint,
254 StatusUpdate(NodeStatus),
256 Custom(String),
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct Message {
263 pub id: MessageId,
265 pub source: NodeId,
267 pub destination: Option<NodeId>,
269 pub message_type: ClusterMessageType,
271 pub timestamp: u64,
273}
274
275impl Message {
276 pub fn new(
278 id: MessageId,
279 source: NodeId,
280 destination: Option<NodeId>,
281 message_type: ClusterMessageType,
282 ) -> Self {
283 Message {
284 id,
285 source,
286 destination,
287 message_type,
288 timestamp: std::time::SystemTime::now()
289 .duration_since(std::time::UNIX_EPOCH)
290 .unwrap()
291 .as_secs(),
292 }
293 }
294
295 pub fn is_broadcast(&self) -> bool {
297 self.destination.is_none()
298 }
299}
300
301#[derive(Debug)]
303pub struct MessageQueue {
304 queue: Arc<Mutex<VecDeque<Message>>>,
306 next_id: Arc<Mutex<MessageId>>,
308}
309
310impl MessageQueue {
311 pub fn new() -> Self {
313 MessageQueue {
314 queue: Arc::new(Mutex::new(VecDeque::new())),
315 next_id: Arc::new(Mutex::new(0)),
316 }
317 }
318
319 pub fn send(
321 &self,
322 source: NodeId,
323 destination: Option<NodeId>,
324 message_type: ClusterMessageType,
325 ) -> SimResult<MessageId> {
326 let mut next_id = self.next_id.lock().unwrap();
327 let id = *next_id;
328 *next_id += 1;
329 drop(next_id);
330
331 let message = Message::new(id, source, destination, message_type);
332 let mut queue = self.queue.lock().unwrap();
333 queue.push_back(message);
334 Ok(id)
335 }
336
337 pub fn receive(&self, node_id: NodeId) -> Option<Message> {
339 let mut queue = self.queue.lock().unwrap();
340 let position = queue
341 .iter()
342 .position(|msg| msg.destination == Some(node_id) || msg.destination.is_none());
343
344 position.and_then(|pos| queue.remove(pos))
345 }
346
347 pub fn peek(&self, node_id: NodeId) -> Option<Message> {
349 let queue = self.queue.lock().unwrap();
350 queue
351 .iter()
352 .find(|msg| msg.destination == Some(node_id) || msg.destination.is_none())
353 .cloned()
354 }
355
356 pub fn size(&self) -> usize {
358 self.queue.lock().unwrap().len()
359 }
360
361 pub fn clear(&self) {
363 self.queue.lock().unwrap().clear();
364 }
365}
366
367impl Default for MessageQueue {
368 fn default() -> Self {
369 Self::new()
370 }
371}
372
373#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
375pub enum LoadBalanceStrategy {
376 None,
378 Periodic,
380 Dynamic,
382 WorkStealing,
384}
385
386#[derive(Debug)]
388pub struct LoadBalancer {
389 strategy: LoadBalanceStrategy,
391 threshold: f64,
393 min_imbalance: f64,
395}
396
397impl LoadBalancer {
398 pub fn new(strategy: LoadBalanceStrategy, threshold: f64) -> Self {
400 LoadBalancer {
401 strategy,
402 threshold: threshold.clamp(0.0, 1.0),
403 min_imbalance: 0.2, }
405 }
406
407 pub fn needs_rebalancing(&self, nodes: &[NodeInfo]) -> bool {
409 if nodes.is_empty() || self.strategy == LoadBalanceStrategy::None {
410 return false;
411 }
412
413 let max_load = nodes.iter().map(|n| n.load).fold(0.0, f64::max);
414 let min_load = nodes.iter().map(|n| n.load).fold(1.0, f64::min);
415 let imbalance = max_load - min_load;
416
417 match self.strategy {
418 LoadBalanceStrategy::None => false,
419 LoadBalanceStrategy::Periodic => true, LoadBalanceStrategy::Dynamic => max_load > self.threshold,
421 LoadBalanceStrategy::WorkStealing => imbalance > self.min_imbalance,
422 }
423 }
424
425 pub fn calculate_rebalance(
427 &self,
428 nodes: &[NodeInfo],
429 partitions: &[EntityPartition],
430 ) -> Vec<(usize, NodeId, NodeId)> {
431 let mut moves = Vec::new();
433
434 if nodes.len() < 2 || partitions.is_empty() {
435 return moves;
436 }
437
438 let avg_load = nodes.iter().map(|n| n.load).sum::<f64>() / nodes.len() as f64;
440 let mut overloaded: Vec<_> = nodes.iter().filter(|n| n.load > avg_load).collect();
441 let mut underloaded: Vec<_> = nodes.iter().filter(|n| n.load < avg_load).collect();
442
443 overloaded.sort_by(|a, b| b.load.partial_cmp(&a.load).unwrap());
444 underloaded.sort_by(|a, b| a.load.partial_cmp(&b.load).unwrap());
445
446 for overloaded_node in &overloaded {
448 let node_partitions: Vec<_> = partitions
449 .iter()
450 .filter(|p| p.node_id == overloaded_node.id)
451 .collect();
452
453 for partition in node_partitions {
454 if let Some(underloaded_node) = underloaded.first()
455 && overloaded_node.load - underloaded_node.load > self.min_imbalance
456 {
457 moves.push((partition.id, overloaded_node.id, underloaded_node.id));
458
459 if underloaded.len() > 1 {
461 underloaded.remove(0);
462 }
463 }
464 }
465 }
466
467 moves
468 }
469
470 pub fn set_imbalance_threshold(&mut self, threshold: f64) {
472 self.min_imbalance = threshold.clamp(0.0, 1.0);
473 }
474}
475
476#[derive(Debug)]
478pub struct ClusterCoordinator {
479 nodes: Vec<NodeInfo>,
481 partition_manager: PartitionManager,
483 message_queue: MessageQueue,
485 load_balancer: LoadBalancer,
487 coordinator_node: NodeInfo,
489}
490
491impl ClusterCoordinator {
492 pub fn new(
494 num_nodes: usize,
495 partition_strategy: PartitionStrategy,
496 load_balance_strategy: LoadBalanceStrategy,
497 ) -> Self {
498 let mut nodes = Vec::with_capacity(num_nodes);
499 for i in 0..num_nodes {
500 nodes.push(NodeInfo::new(i, format!("node-{}", i), i, num_nodes));
501 }
502
503 let coordinator_node = nodes[0].clone();
504
505 ClusterCoordinator {
506 nodes,
507 partition_manager: PartitionManager::new(partition_strategy),
508 message_queue: MessageQueue::new(),
509 load_balancer: LoadBalancer::new(load_balance_strategy, 0.8),
510 coordinator_node,
511 }
512 }
513
514 pub fn distribute_entities(&mut self, entity_ids: &[String]) -> SimResult<()> {
516 let partitions = self
517 .partition_manager
518 .create_partitions(entity_ids, self.nodes.len())?;
519
520 for partition in &partitions {
522 if let Some(node) = self.nodes.iter_mut().find(|n| n.id == partition.node_id) {
523 node.entity_count += partition.size;
524 }
525 }
526
527 let max_entities = entity_ids.len() / self.nodes.len().max(1);
529 for node in &mut self.nodes {
530 node.update_load(max_entities);
531 }
532
533 Ok(())
534 }
535
536 pub fn send_message(
538 &self,
539 destination: Option<NodeId>,
540 message_type: ClusterMessageType,
541 ) -> SimResult<MessageId> {
542 self.message_queue
543 .send(self.coordinator_node.id, destination, message_type)
544 }
545
546 pub fn receive_message(&self) -> Option<Message> {
548 self.message_queue.receive(self.coordinator_node.id)
549 }
550
551 pub fn barrier(&self) -> SimResult<()> {
553 self.send_message(None, ClusterMessageType::Barrier)?;
555 Ok(())
556 }
557
558 pub fn rebalance_if_needed(&mut self) -> SimResult<Vec<(usize, NodeId, NodeId)>> {
560 if !self.load_balancer.needs_rebalancing(&self.nodes) {
561 return Ok(Vec::new());
562 }
563
564 let moves = self
565 .load_balancer
566 .calculate_rebalance(&self.nodes, &self.partition_manager.partitions);
567
568 if !moves.is_empty() {
570 self.send_message(None, ClusterMessageType::LoadBalance)?;
571 }
572
573 Ok(moves)
574 }
575
576 pub fn get_node(&self, node_id: NodeId) -> Option<&NodeInfo> {
578 self.nodes.iter().find(|n| n.id == node_id)
579 }
580
581 pub fn nodes(&self) -> &[NodeInfo] {
583 &self.nodes
584 }
585
586 pub fn partition_manager(&self) -> &PartitionManager {
588 &self.partition_manager
589 }
590
591 pub fn update_node_status(&mut self, node_id: NodeId, status: NodeStatus) -> SimResult<()> {
593 if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
594 node.status = status;
595 Ok(())
596 } else {
597 Err(SimulationError::InvalidParameter(format!(
598 "Node {} not found",
599 node_id
600 )))
601 }
602 }
603
604 pub fn num_nodes(&self) -> usize {
606 self.nodes.len()
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613
614 #[test]
615 fn test_node_info_creation() {
616 let node = NodeInfo::new(1, "node-1".to_string(), 1, 4);
617 assert_eq!(node.id, 1);
618 assert_eq!(node.rank, 1);
619 assert_eq!(node.total_nodes, 4);
620 assert!(!node.is_coordinator());
621 assert_eq!(node.status, NodeStatus::Idle);
622 }
623
624 #[test]
625 fn test_coordinator_node() {
626 let node = NodeInfo::new(0, "coordinator".to_string(), 0, 4);
627 assert!(node.is_coordinator());
628 }
629
630 #[test]
631 fn test_node_load_update() {
632 let mut node = NodeInfo::new(1, "node-1".to_string(), 1, 4);
633 node.entity_count = 50;
634 node.update_load(100);
635 assert_eq!(node.load, 0.5);
636 }
637
638 #[test]
639 fn test_entity_partition() {
640 let mut partition = EntityPartition::new(0, 1);
641 partition.add_entity("entity-1".to_string());
642 partition.add_entity("entity-2".to_string());
643 assert_eq!(partition.size, 2);
644 assert_eq!(partition.entity_ids.len(), 2);
645 }
646
647 #[test]
648 fn test_partition_manager_round_robin() {
649 let mut manager = PartitionManager::new(PartitionStrategy::RoundRobin);
650 let entity_ids: Vec<String> = (0..10).map(|i| format!("entity-{}", i)).collect();
651 let partitions = manager.create_partitions(&entity_ids, 3).unwrap();
652
653 assert_eq!(partitions.len(), 3);
654 assert_eq!(
655 partitions[0].size + partitions[1].size + partitions[2].size,
656 10
657 );
658 }
659
660 #[test]
661 fn test_partition_manager_hash() {
662 let mut manager = PartitionManager::new(PartitionStrategy::Hash);
663 let entity_ids: Vec<String> = (0..10).map(|i| format!("entity-{}", i)).collect();
664 let partitions = manager.create_partitions(&entity_ids, 3).unwrap();
665
666 assert_eq!(partitions.len(), 3);
667 let total: usize = partitions.iter().map(|p| p.size).sum();
668 assert_eq!(total, 10);
669 }
670
671 #[test]
672 fn test_partition_manager_range() {
673 let mut manager = PartitionManager::new(PartitionStrategy::Range);
674 let entity_ids: Vec<String> = (0..9).map(|i| format!("entity-{}", i)).collect();
675 let partitions = manager.create_partitions(&entity_ids, 3).unwrap();
676
677 assert_eq!(partitions.len(), 3);
678 assert_eq!(partitions[0].size, 3);
679 assert_eq!(partitions[1].size, 3);
680 assert_eq!(partitions[2].size, 3);
681 }
682
683 #[test]
684 fn test_message_creation() {
685 let msg = Message::new(1, 0, Some(1), ClusterMessageType::Barrier);
686 assert_eq!(msg.id, 1);
687 assert_eq!(msg.source, 0);
688 assert_eq!(msg.destination, Some(1));
689 assert!(!msg.is_broadcast());
690 }
691
692 #[test]
693 fn test_broadcast_message() {
694 let msg = Message::new(1, 0, None, ClusterMessageType::Barrier);
695 assert!(msg.is_broadcast());
696 }
697
698 #[test]
699 fn test_message_queue() {
700 let queue = MessageQueue::new();
701 let id = queue.send(0, Some(1), ClusterMessageType::Barrier).unwrap();
702 assert_eq!(id, 0);
703 assert_eq!(queue.size(), 1);
704
705 let msg = queue.receive(1).unwrap();
706 assert_eq!(msg.id, 0);
707 assert_eq!(queue.size(), 0);
708 }
709
710 #[test]
711 fn test_message_queue_broadcast() {
712 let queue = MessageQueue::new();
713 queue.send(0, None, ClusterMessageType::Barrier).unwrap();
714
715 let msg1 = queue.peek(1).unwrap();
717 assert_eq!(msg1.source, 0);
718
719 let msg2 = queue.receive(2).unwrap();
720 assert_eq!(msg2.source, 0);
721 }
722
723 #[test]
724 fn test_load_balancer_no_rebalancing() {
725 let balancer = LoadBalancer::new(LoadBalanceStrategy::None, 0.8);
726 let nodes = vec![
727 NodeInfo::new(0, "node-0".to_string(), 0, 2),
728 NodeInfo::new(1, "node-1".to_string(), 1, 2),
729 ];
730 assert!(!balancer.needs_rebalancing(&nodes));
731 }
732
733 #[test]
734 fn test_load_balancer_dynamic() {
735 let balancer = LoadBalancer::new(LoadBalanceStrategy::Dynamic, 0.5);
736 let mut nodes = vec![
737 NodeInfo::new(0, "node-0".to_string(), 0, 2),
738 NodeInfo::new(1, "node-1".to_string(), 1, 2),
739 ];
740 nodes[0].load = 0.8;
741 nodes[1].load = 0.2;
742
743 assert!(balancer.needs_rebalancing(&nodes));
744 }
745
746 #[test]
747 fn test_cluster_coordinator() {
748 let coordinator = ClusterCoordinator::new(
749 4,
750 PartitionStrategy::RoundRobin,
751 LoadBalanceStrategy::Dynamic,
752 );
753 assert_eq!(coordinator.num_nodes(), 4);
754 assert!(coordinator.get_node(0).is_some());
755 }
756
757 #[test]
758 fn test_cluster_coordinator_distribute_entities() {
759 let mut coordinator =
760 ClusterCoordinator::new(3, PartitionStrategy::RoundRobin, LoadBalanceStrategy::None);
761 let entity_ids: Vec<String> = (0..12).map(|i| format!("entity-{}", i)).collect();
762
763 coordinator.distribute_entities(&entity_ids).unwrap();
764
765 for node in coordinator.nodes() {
767 assert_eq!(node.entity_count, 4);
768 }
769 }
770
771 #[test]
772 fn test_cluster_coordinator_messaging() {
773 let coordinator =
774 ClusterCoordinator::new(2, PartitionStrategy::RoundRobin, LoadBalanceStrategy::None);
775 let msg_id = coordinator
776 .send_message(Some(1), ClusterMessageType::Barrier)
777 .unwrap();
778 assert_eq!(msg_id, 0);
779 }
780
781 #[test]
782 fn test_cluster_coordinator_barrier() {
783 let coordinator =
784 ClusterCoordinator::new(3, PartitionStrategy::RoundRobin, LoadBalanceStrategy::None);
785 coordinator.barrier().unwrap();
786 assert_eq!(coordinator.message_queue.size(), 1);
787 }
788
789 #[test]
790 fn test_cluster_coordinator_update_node_status() {
791 let mut coordinator =
792 ClusterCoordinator::new(2, PartitionStrategy::RoundRobin, LoadBalanceStrategy::None);
793 coordinator
794 .update_node_status(1, NodeStatus::Active)
795 .unwrap();
796 assert_eq!(coordinator.get_node(1).unwrap().status, NodeStatus::Active);
797 }
798
799 #[test]
800 fn test_partition_manager_get_partition() {
801 let mut manager = PartitionManager::new(PartitionStrategy::RoundRobin);
802 let entity_ids = vec!["entity-1".to_string(), "entity-2".to_string()];
803 manager.create_partitions(&entity_ids, 2).unwrap();
804
805 let partition = manager.get_partition("entity-1");
806 assert!(partition.is_some());
807 }
808
809 #[test]
810 fn test_partition_manager_get_node_partitions() {
811 let mut manager = PartitionManager::new(PartitionStrategy::RoundRobin);
812 let entity_ids: Vec<String> = (0..6).map(|i| format!("entity-{}", i)).collect();
813 manager.create_partitions(&entity_ids, 2).unwrap();
814
815 let partitions = manager.get_node_partitions(0);
816 assert!(!partitions.is_empty());
817 }
818
819 #[test]
820 fn test_load_balancer_calculate_rebalance() {
821 let balancer = LoadBalancer::new(LoadBalanceStrategy::WorkStealing, 0.8);
822 let mut nodes = vec![
823 NodeInfo::new(0, "node-0".to_string(), 0, 2),
824 NodeInfo::new(1, "node-1".to_string(), 1, 2),
825 ];
826 nodes[0].load = 0.9;
827 nodes[1].load = 0.1;
828
829 let mut partitions = vec![
830 EntityPartition::new(0, 0),
831 EntityPartition::new(1, 0),
832 EntityPartition::new(2, 1),
833 ];
834 partitions[0].size = 10;
835 partitions[1].size = 10;
836 partitions[2].size = 2;
837
838 let moves = balancer.calculate_rebalance(&nodes, &partitions);
839 assert!(!moves.is_empty());
840 }
841}