1use std::collections::{HashMap, HashSet, VecDeque};
9use std::sync::Arc;
10use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
11
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info, warn};
15
16use super::{DistributedCheckpointer, StateResult};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DeduplicationConfig {
23 pub max_entries_per_source: usize,
25 pub expiry: Duration,
27}
28
29impl Default for DeduplicationConfig {
30 fn default() -> Self {
31 Self {
32 max_entries_per_source: 10_000,
33 expiry: Duration::from_secs(3600),
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40struct SequenceEntry {
41 sequence_number: u64,
42 received_at: Instant,
43}
44
45pub struct SequenceDeduplicator {
51 config: DeduplicationConfig,
52 high_watermarks: Arc<RwLock<HashMap<String, u64>>>,
54 pending_sequences: Arc<RwLock<HashMap<String, VecDeque<SequenceEntry>>>>,
56 duplicates_rejected: Arc<RwLock<u64>>,
58 unique_accepted: Arc<RwLock<u64>>,
60}
61
62impl SequenceDeduplicator {
63 pub fn new(config: DeduplicationConfig) -> Self {
65 Self {
66 config,
67 high_watermarks: Arc::new(RwLock::new(HashMap::new())),
68 pending_sequences: Arc::new(RwLock::new(HashMap::new())),
69 duplicates_rejected: Arc::new(RwLock::new(0)),
70 unique_accepted: Arc::new(RwLock::new(0)),
71 }
72 }
73
74 pub fn check_and_record(&self, source_id: &str, sequence_number: u64) -> bool {
79 let mut watermarks = self.high_watermarks.write();
80 let current_watermark = watermarks.entry(source_id.to_string()).or_insert(0);
81
82 if sequence_number <= *current_watermark && *current_watermark > 0 {
84 let pending = self.pending_sequences.read();
86 if let Some(entries) = pending.get(source_id) {
87 if entries.iter().any(|e| e.sequence_number == sequence_number) {
88 *self.duplicates_rejected.write() += 1;
89 return false;
90 }
91 }
92 *self.duplicates_rejected.write() += 1;
93 return false;
94 }
95
96 {
98 let pending = self.pending_sequences.read();
99 if let Some(entries) = pending.get(source_id) {
100 if entries.iter().any(|e| e.sequence_number == sequence_number) {
101 *self.duplicates_rejected.write() += 1;
102 return false;
103 }
104 }
105 }
106
107 if sequence_number == *current_watermark + 1 || *current_watermark == 0 {
109 *current_watermark = sequence_number;
111 drop(watermarks);
113 self.advance_watermark(source_id);
114 } else {
115 drop(watermarks);
117 let mut pending = self.pending_sequences.write();
118 let entries = pending.entry(source_id.to_string()).or_default();
119 entries.push_back(SequenceEntry {
120 sequence_number,
121 received_at: Instant::now(),
122 });
123 while entries.len() > self.config.max_entries_per_source {
125 entries.pop_front();
126 }
127 }
128
129 *self.unique_accepted.write() += 1;
130 true
131 }
132
133 fn advance_watermark(&self, source_id: &str) {
135 let mut watermarks = self.high_watermarks.write();
136 let watermark = watermarks.entry(source_id.to_string()).or_insert(0);
137 let mut pending = self.pending_sequences.write();
138 if let Some(entries) = pending.get_mut(source_id) {
139 entries.make_contiguous().sort_by_key(|e| e.sequence_number);
140 while let Some(front) = entries.front() {
141 if front.sequence_number == *watermark + 1 {
142 *watermark += 1;
143 entries.pop_front();
144 } else {
145 break;
146 }
147 }
148 }
149 }
150
151 pub fn high_watermark(&self, source_id: &str) -> u64 {
153 self.high_watermarks
154 .read()
155 .get(source_id)
156 .copied()
157 .unwrap_or(0)
158 }
159
160 pub fn expire_old_entries(&self) {
162 let now = Instant::now();
163 let mut pending = self.pending_sequences.write();
164 for entries in pending.values_mut() {
165 entries.retain(|e| now.duration_since(e.received_at) < self.config.expiry);
166 }
167 }
168
169 pub fn stats(&self) -> DeduplicationStats {
171 let pending_count: usize = self
172 .pending_sequences
173 .read()
174 .values()
175 .map(|e| e.len())
176 .sum();
177 DeduplicationStats {
178 duplicates_rejected: *self.duplicates_rejected.read(),
179 unique_accepted: *self.unique_accepted.read(),
180 tracked_sources: self.high_watermarks.read().len(),
181 pending_sequences: pending_count,
182 }
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct DeduplicationStats {
189 pub duplicates_rejected: u64,
191 pub unique_accepted: u64,
193 pub tracked_sources: usize,
195 pub pending_sequences: usize,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct OperatorStateSnapshot {
204 pub operator_id: String,
206 pub state_bytes: Vec<u8>,
208 pub version: u64,
210 pub created_at: u64,
212 pub size_bytes: usize,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct CheckpointConfig {
219 pub checkpoint_interval: Duration,
221 pub max_retained_checkpoints: usize,
223 pub verify_integrity: bool,
225}
226
227impl Default for CheckpointConfig {
228 fn default() -> Self {
229 Self {
230 checkpoint_interval: Duration::from_secs(30),
231 max_retained_checkpoints: 5,
232 verify_integrity: true,
233 }
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct StateCheckpoint {
240 pub checkpoint_id: String,
242 pub operator_snapshots: HashMap<String, OperatorStateSnapshot>,
244 pub version: u64,
246 pub merkle_root: String,
248 pub created_at: u64,
250 pub is_complete: bool,
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct PartitionAssignment {
259 pub partition_id: String,
261 pub assigned_to: String,
263 pub state_size_bytes: usize,
265 pub load_score: f64,
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct MigrationPlan {
272 pub migrations: Vec<MigrationStep>,
274 pub total_bytes_to_transfer: usize,
276 pub reason: MigrationReason,
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct MigrationStep {
283 pub partition_id: String,
285 pub from_node: String,
287 pub to_node: String,
289 pub state_size_bytes: usize,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
295pub enum MigrationReason {
296 NodeJoined { node_id: String },
298 NodeLeft { node_id: String },
300 LoadImbalance,
302 Manual,
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct DistributedStateManagerStats {
311 pub checkpoints_taken: u64,
313 pub migrations_performed: u64,
315 pub partition_count: usize,
317 pub active_processors: usize,
319 pub total_state_bytes: usize,
321 pub dedup_stats: DeduplicationStats,
323 pub avg_checkpoint_duration_ms: f64,
325}
326
327pub struct DistributedStateManager {
335 node_id: String,
337 checkpoint_config: CheckpointConfig,
339 deduplicator: SequenceDeduplicator,
341 partitions: Arc<RwLock<HashMap<String, PartitionAssignment>>>,
343 active_processors: Arc<RwLock<HashSet<String>>>,
345 checkpoints: Arc<RwLock<VecDeque<StateCheckpoint>>>,
347 checkpoint_version: Arc<RwLock<u64>>,
349 migration_history: Arc<RwLock<Vec<MigrationPlan>>>,
351 checkpoints_taken: Arc<RwLock<u64>>,
353 migrations_performed: Arc<RwLock<u64>>,
355 checkpoint_duration_sum_ms: Arc<RwLock<f64>>,
357 last_checkpoint: Arc<RwLock<Option<Instant>>>,
359}
360
361impl DistributedStateManager {
362 pub fn new(
364 node_id: impl Into<String>,
365 checkpoint_config: CheckpointConfig,
366 dedup_config: DeduplicationConfig,
367 ) -> Self {
368 Self {
369 node_id: node_id.into(),
370 checkpoint_config,
371 deduplicator: SequenceDeduplicator::new(dedup_config),
372 partitions: Arc::new(RwLock::new(HashMap::new())),
373 active_processors: Arc::new(RwLock::new(HashSet::new())),
374 checkpoints: Arc::new(RwLock::new(VecDeque::new())),
375 checkpoint_version: Arc::new(RwLock::new(0)),
376 migration_history: Arc::new(RwLock::new(Vec::new())),
377 checkpoints_taken: Arc::new(RwLock::new(0)),
378 migrations_performed: Arc::new(RwLock::new(0)),
379 checkpoint_duration_sum_ms: Arc::new(RwLock::new(0.0)),
380 last_checkpoint: Arc::new(RwLock::new(None)),
381 }
382 }
383
384 pub fn node_id(&self) -> &str {
386 &self.node_id
387 }
388
389 pub fn register_processor(&self, node_id: impl Into<String>) {
391 let id = node_id.into();
392 self.active_processors.write().insert(id.clone());
393 info!("Registered processor: {}", id);
394 }
395
396 pub fn remove_processor(&self, node_id: &str) {
398 self.active_processors.write().remove(node_id);
399 info!("Removed processor: {}", node_id);
400 }
401
402 pub fn assign_partition(&self, assignment: PartitionAssignment) {
404 debug!(
405 "Assigning partition {} to {}",
406 assignment.partition_id, assignment.assigned_to
407 );
408 self.partitions
409 .write()
410 .insert(assignment.partition_id.clone(), assignment);
411 }
412
413 pub fn check_exactly_once(&self, source_id: &str, sequence_number: u64) -> bool {
417 self.deduplicator
418 .check_and_record(source_id, sequence_number)
419 }
420
421 pub fn high_watermark(&self, source_id: &str) -> u64 {
423 self.deduplicator.high_watermark(source_id)
424 }
425
426 pub fn take_checkpoint(
430 &self,
431 operator_states: HashMap<String, Vec<u8>>,
432 ) -> StateResult<StateCheckpoint> {
433 let start = Instant::now();
434
435 let mut version = self.checkpoint_version.write();
436 *version += 1;
437 let current_version = *version;
438 drop(version);
439
440 let now_micros = SystemTime::now()
441 .duration_since(UNIX_EPOCH)
442 .unwrap_or_default()
443 .as_micros() as u64;
444
445 let checkpoint_id = format!("ckpt-{}-{}", self.node_id, current_version);
446
447 let mut operator_snapshots = HashMap::new();
449 for (op_id, state_bytes) in operator_states {
450 let size = state_bytes.len();
451 operator_snapshots.insert(
452 op_id.clone(),
453 OperatorStateSnapshot {
454 operator_id: op_id,
455 state_bytes,
456 version: current_version,
457 created_at: now_micros,
458 size_bytes: size,
459 },
460 );
461 }
462
463 let mut all_bytes = Vec::new();
465 let mut sorted_keys: Vec<&String> = operator_snapshots.keys().collect();
466 sorted_keys.sort();
467 for key in sorted_keys {
468 if let Some(snapshot) = operator_snapshots.get(key) {
469 all_bytes.extend_from_slice(&snapshot.state_bytes);
470 }
471 }
472 let merkle_root = DistributedCheckpointer::compute_merkle_root(&all_bytes);
473
474 let checkpoint = StateCheckpoint {
475 checkpoint_id,
476 operator_snapshots,
477 version: current_version,
478 merkle_root,
479 created_at: now_micros,
480 is_complete: true,
481 };
482
483 let max_retained = self.checkpoint_config.max_retained_checkpoints;
485 let mut checkpoints = self.checkpoints.write();
486 checkpoints.push_front(checkpoint.clone());
487 while checkpoints.len() > max_retained {
488 checkpoints.pop_back();
489 }
490
491 *self.checkpoints_taken.write() += 1;
492 *self.last_checkpoint.write() = Some(Instant::now());
493
494 let elapsed = start.elapsed().as_millis() as f64;
495 *self.checkpoint_duration_sum_ms.write() += elapsed;
496
497 info!(
498 "Checkpoint {} taken (version {}, {} operators, {:.1}ms)",
499 checkpoint.checkpoint_id,
500 current_version,
501 checkpoint.operator_snapshots.len(),
502 elapsed
503 );
504
505 Ok(checkpoint)
506 }
507
508 pub fn restore_from_latest(&self) -> Option<HashMap<String, Vec<u8>>> {
512 let checkpoints = self.checkpoints.read();
513 let latest = checkpoints.front()?;
514
515 if self.checkpoint_config.verify_integrity {
516 let mut all_bytes = Vec::new();
518 let mut sorted_keys: Vec<&String> = latest.operator_snapshots.keys().collect();
519 sorted_keys.sort();
520 for key in sorted_keys {
521 if let Some(snapshot) = latest.operator_snapshots.get(key) {
522 all_bytes.extend_from_slice(&snapshot.state_bytes);
523 }
524 }
525 let computed = DistributedCheckpointer::compute_merkle_root(&all_bytes);
526 if computed != latest.merkle_root {
527 warn!("Checkpoint {} failed integrity check", latest.checkpoint_id);
528 return None;
529 }
530 }
531
532 let states: HashMap<String, Vec<u8>> = latest
533 .operator_snapshots
534 .iter()
535 .map(|(k, v)| (k.clone(), v.state_bytes.clone()))
536 .collect();
537 info!(
538 "Restored state from checkpoint {} (version {})",
539 latest.checkpoint_id, latest.version
540 );
541 Some(states)
542 }
543
544 pub fn checkpoints(&self) -> Vec<StateCheckpoint> {
546 self.checkpoints.read().iter().cloned().collect()
547 }
548
549 pub fn is_checkpoint_due(&self) -> bool {
551 let last = self.last_checkpoint.read();
552 match *last {
553 Some(instant) => instant.elapsed() >= self.checkpoint_config.checkpoint_interval,
554 None => true,
555 }
556 }
557
558 pub fn plan_migration(&self, reason: MigrationReason) -> Option<MigrationPlan> {
562 let partitions = self.partitions.read();
563 let processors = self.active_processors.read();
564
565 if processors.is_empty() || partitions.is_empty() {
566 return None;
567 }
568
569 let processor_list: Vec<String> = processors.iter().cloned().collect();
570
571 let mut load_per_processor: HashMap<String, Vec<String>> = HashMap::new();
573 for proc_id in &processor_list {
574 load_per_processor.insert(proc_id.clone(), Vec::new());
575 }
576 for (partition_id, assignment) in partitions.iter() {
577 load_per_processor
578 .entry(assignment.assigned_to.clone())
579 .or_default()
580 .push(partition_id.clone());
581 }
582
583 let total_partitions = partitions.len();
584 let target_per_processor = total_partitions / processor_list.len();
585 let remainder = total_partitions % processor_list.len();
586
587 let mut migrations = Vec::new();
589 let mut donors: Vec<(String, Vec<String>)> = Vec::new();
590 let mut receivers: Vec<(String, usize)> = Vec::new();
591
592 for (i, proc_id) in processor_list.iter().enumerate() {
593 let current_count = load_per_processor
594 .get(proc_id)
595 .map(|v| v.len())
596 .unwrap_or(0);
597 let target = target_per_processor + if i < remainder { 1 } else { 0 };
598 if current_count > target {
599 let excess: Vec<String> = load_per_processor
600 .get(proc_id)
601 .map(|v| v[target..].to_vec())
602 .unwrap_or_default();
603 donors.push((proc_id.clone(), excess));
604 } else if current_count < target {
605 receivers.push((proc_id.clone(), target - current_count));
606 }
607 }
608
609 let mut donor_iter = donors
611 .iter()
612 .flat_map(|(from, parts)| parts.iter().map(move |p| (from.clone(), p.clone())));
613 for (to_node, need) in &receivers {
614 for _ in 0..*need {
615 if let Some((from_node, partition_id)) = donor_iter.next() {
616 let state_size = partitions
617 .get(&partition_id)
618 .map(|a| a.state_size_bytes)
619 .unwrap_or(0);
620 migrations.push(MigrationStep {
621 partition_id,
622 from_node,
623 to_node: to_node.clone(),
624 state_size_bytes: state_size,
625 });
626 }
627 }
628 }
629
630 if migrations.is_empty() {
631 return None;
632 }
633
634 let total_bytes = migrations.iter().map(|m| m.state_size_bytes).sum();
635
636 Some(MigrationPlan {
637 migrations,
638 total_bytes_to_transfer: total_bytes,
639 reason,
640 })
641 }
642
643 pub fn execute_migration(&self, plan: &MigrationPlan) -> usize {
647 let mut partitions = self.partitions.write();
648 let mut migrated = 0;
649
650 for step in &plan.migrations {
651 if let Some(assignment) = partitions.get_mut(&step.partition_id) {
652 assignment.assigned_to = step.to_node.clone();
653 migrated += 1;
654 debug!(
655 "Migrated partition {} from {} to {}",
656 step.partition_id, step.from_node, step.to_node
657 );
658 }
659 }
660
661 *self.migrations_performed.write() += 1;
662 self.migration_history.write().push(plan.clone());
663 info!(
664 "Migration complete: {} partitions moved ({} bytes)",
665 migrated, plan.total_bytes_to_transfer
666 );
667 migrated
668 }
669
670 pub fn handle_node_joined(&self, node_id: &str) -> Option<MigrationPlan> {
672 self.register_processor(node_id);
673 self.plan_migration(MigrationReason::NodeJoined {
674 node_id: node_id.to_string(),
675 })
676 }
677
678 pub fn handle_node_left(&self, node_id: &str) -> Option<MigrationPlan> {
680 self.remove_processor(node_id);
681 self.plan_migration(MigrationReason::NodeLeft {
683 node_id: node_id.to_string(),
684 })
685 }
686
687 pub fn partition_assignments(&self) -> Vec<PartitionAssignment> {
689 self.partitions.read().values().cloned().collect()
690 }
691
692 pub fn active_processors(&self) -> Vec<String> {
694 self.active_processors.read().iter().cloned().collect()
695 }
696
697 pub fn migration_history(&self) -> Vec<MigrationPlan> {
699 self.migration_history.read().clone()
700 }
701
702 pub fn stats(&self) -> DistributedStateManagerStats {
704 let checkpoints_taken = *self.checkpoints_taken.read();
705 let avg_duration = if checkpoints_taken > 0 {
706 *self.checkpoint_duration_sum_ms.read() / checkpoints_taken as f64
707 } else {
708 0.0
709 };
710
711 let total_state_bytes: usize = self
712 .partitions
713 .read()
714 .values()
715 .map(|p| p.state_size_bytes)
716 .sum();
717
718 DistributedStateManagerStats {
719 checkpoints_taken,
720 migrations_performed: *self.migrations_performed.read(),
721 partition_count: self.partitions.read().len(),
722 active_processors: self.active_processors.read().len(),
723 total_state_bytes,
724 dedup_stats: self.deduplicator.stats(),
725 avg_checkpoint_duration_ms: avg_duration,
726 }
727 }
728}
729
730#[cfg(test)]
733mod tests {
734 use super::*;
735
736 fn make_manager() -> DistributedStateManager {
737 DistributedStateManager::new(
738 "node-1",
739 CheckpointConfig::default(),
740 DeduplicationConfig::default(),
741 )
742 }
743
744 #[test]
747 fn test_dedup_first_message_accepted() {
748 let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
749 assert!(dedup.check_and_record("src-1", 1));
750 }
751
752 #[test]
753 fn test_dedup_duplicate_rejected() {
754 let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
755 assert!(dedup.check_and_record("src-1", 1));
756 assert!(!dedup.check_and_record("src-1", 1));
757 }
758
759 #[test]
760 fn test_dedup_sequential_messages() {
761 let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
762 for i in 1..=10 {
763 assert!(dedup.check_and_record("src-1", i));
764 }
765 assert_eq!(dedup.high_watermark("src-1"), 10);
766 }
767
768 #[test]
769 fn test_dedup_out_of_order_accepted() {
770 let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
771 assert!(dedup.check_and_record("src-1", 1));
772 assert!(dedup.check_and_record("src-1", 3)); assert!(dedup.check_and_record("src-1", 2)); assert_eq!(dedup.high_watermark("src-1"), 3);
775 }
776
777 #[test]
778 fn test_dedup_multiple_sources() {
779 let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
780 assert!(dedup.check_and_record("src-a", 1));
781 assert!(dedup.check_and_record("src-b", 1));
782 assert!(!dedup.check_and_record("src-a", 1));
783 assert!(dedup.check_and_record("src-a", 2));
784 }
785
786 #[test]
787 fn test_dedup_stats() {
788 let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
789 dedup.check_and_record("src-1", 1);
790 dedup.check_and_record("src-1", 1); dedup.check_and_record("src-2", 1);
792
793 let stats = dedup.stats();
794 assert_eq!(stats.unique_accepted, 2);
795 assert_eq!(stats.duplicates_rejected, 1);
796 assert_eq!(stats.tracked_sources, 2);
797 }
798
799 #[test]
800 fn test_dedup_expire_old_entries() {
801 let config = DeduplicationConfig {
802 max_entries_per_source: 100,
803 expiry: Duration::from_millis(1),
804 };
805 let dedup = SequenceDeduplicator::new(config);
806 dedup.check_and_record("src-1", 1);
807 dedup.check_and_record("src-1", 5); std::thread::sleep(Duration::from_millis(5));
809 dedup.expire_old_entries();
810 let stats = dedup.stats();
811 assert_eq!(stats.pending_sequences, 0);
812 }
813
814 #[test]
817 fn test_take_checkpoint() {
818 let mgr = make_manager();
819 let mut states = HashMap::new();
820 states.insert("op-1".to_string(), b"state-1".to_vec());
821 states.insert("op-2".to_string(), b"state-2".to_vec());
822
823 let ckpt = mgr
824 .take_checkpoint(states)
825 .expect("checkpoint should succeed");
826 assert_eq!(ckpt.operator_snapshots.len(), 2);
827 assert!(ckpt.is_complete);
828 assert!(!ckpt.merkle_root.is_empty());
829 assert_eq!(ckpt.version, 1);
830 }
831
832 #[test]
833 fn test_restore_from_latest() {
834 let mgr = make_manager();
835 let mut states = HashMap::new();
836 states.insert("op-1".to_string(), b"data-a".to_vec());
837 mgr.take_checkpoint(states)
838 .expect("checkpoint should succeed");
839
840 let restored = mgr.restore_from_latest().expect("should restore");
841 assert_eq!(restored.get("op-1"), Some(&b"data-a".to_vec()));
842 }
843
844 #[test]
845 fn test_checkpoint_retention() {
846 let config = CheckpointConfig {
847 max_retained_checkpoints: 2,
848 ..Default::default()
849 };
850 let mgr = DistributedStateManager::new("node-1", config, DeduplicationConfig::default());
851
852 for i in 0..5 {
853 let mut states = HashMap::new();
854 states.insert("op".to_string(), format!("state-{}", i).into_bytes());
855 mgr.take_checkpoint(states).expect("should succeed");
856 }
857
858 let checkpoints = mgr.checkpoints();
859 assert_eq!(checkpoints.len(), 2);
860 assert_eq!(checkpoints[0].version, 5);
862 }
863
864 #[test]
865 fn test_checkpoint_integrity_verification() {
866 let mgr = make_manager();
867 let mut states = HashMap::new();
868 states.insert("op-1".to_string(), b"my-data".to_vec());
869 mgr.take_checkpoint(states).expect("should succeed");
870
871 let restored = mgr.restore_from_latest();
873 assert!(restored.is_some());
874 }
875
876 #[test]
877 fn test_is_checkpoint_due() {
878 let config = CheckpointConfig {
879 checkpoint_interval: Duration::from_millis(10),
880 ..Default::default()
881 };
882 let mgr = DistributedStateManager::new("node-1", config, DeduplicationConfig::default());
883 assert!(mgr.is_checkpoint_due());
884
885 let mut states = HashMap::new();
886 states.insert("op".to_string(), b"data".to_vec());
887 mgr.take_checkpoint(states).expect("should succeed");
888 assert!(!mgr.is_checkpoint_due());
889
890 std::thread::sleep(Duration::from_millis(15));
891 assert!(mgr.is_checkpoint_due());
892 }
893
894 #[test]
897 fn test_migration_plan_on_node_join() {
898 let mgr = make_manager();
899 mgr.register_processor("proc-1");
900 for i in 0..4 {
901 mgr.assign_partition(PartitionAssignment {
902 partition_id: format!("p-{}", i),
903 assigned_to: "proc-1".to_string(),
904 state_size_bytes: 1024,
905 load_score: 0.5,
906 });
907 }
908
909 let plan = mgr.handle_node_joined("proc-2");
911 assert!(plan.is_some(), "should generate migration plan");
912 let plan = plan.expect("plan exists");
913 assert!(!plan.migrations.is_empty());
914 assert_eq!(
915 plan.reason,
916 MigrationReason::NodeJoined {
917 node_id: "proc-2".to_string()
918 }
919 );
920 }
921
922 #[test]
923 fn test_migration_plan_balanced_no_migration() {
924 let mgr = make_manager();
925 mgr.register_processor("proc-1");
926 mgr.register_processor("proc-2");
927 mgr.assign_partition(PartitionAssignment {
928 partition_id: "p-0".to_string(),
929 assigned_to: "proc-1".to_string(),
930 state_size_bytes: 1024,
931 load_score: 0.5,
932 });
933 mgr.assign_partition(PartitionAssignment {
934 partition_id: "p-1".to_string(),
935 assigned_to: "proc-2".to_string(),
936 state_size_bytes: 1024,
937 load_score: 0.5,
938 });
939
940 let plan = mgr.plan_migration(MigrationReason::Manual);
941 assert!(plan.is_none(), "balanced assignment needs no migration");
942 }
943
944 #[test]
945 fn test_execute_migration() {
946 let mgr = make_manager();
947 mgr.register_processor("proc-1");
948 mgr.register_processor("proc-2");
949 for i in 0..4 {
950 mgr.assign_partition(PartitionAssignment {
951 partition_id: format!("p-{}", i),
952 assigned_to: "proc-1".to_string(),
953 state_size_bytes: 512,
954 load_score: 0.5,
955 });
956 }
957
958 let plan = mgr
959 .plan_migration(MigrationReason::LoadImbalance)
960 .expect("should plan migration");
961 let migrated = mgr.execute_migration(&plan);
962 assert!(migrated > 0);
963
964 let assignments = mgr.partition_assignments();
966 let proc2_count = assignments
967 .iter()
968 .filter(|a| a.assigned_to == "proc-2")
969 .count();
970 assert!(proc2_count > 0, "proc-2 should have partitions now");
971 }
972
973 #[test]
974 fn test_handle_node_left() {
975 let mgr = make_manager();
976 mgr.register_processor("proc-1");
977 mgr.register_processor("proc-2");
978 mgr.assign_partition(PartitionAssignment {
979 partition_id: "p-0".to_string(),
980 assigned_to: "proc-1".to_string(),
981 state_size_bytes: 1024,
982 load_score: 0.3,
983 });
984 mgr.assign_partition(PartitionAssignment {
985 partition_id: "p-1".to_string(),
986 assigned_to: "proc-2".to_string(),
987 state_size_bytes: 1024,
988 load_score: 0.3,
989 });
990
991 let plan = mgr.handle_node_left("proc-2");
993 if let Some(plan) = plan {
996 mgr.execute_migration(&plan);
997 }
998 let procs = mgr.active_processors();
999 assert!(!procs.contains(&"proc-2".to_string()));
1000 }
1001
1002 #[test]
1005 fn test_manager_exactly_once() {
1006 let mgr = make_manager();
1007 assert!(mgr.check_exactly_once("stream-1", 1));
1008 assert!(mgr.check_exactly_once("stream-1", 2));
1009 assert!(!mgr.check_exactly_once("stream-1", 1)); assert!(mgr.check_exactly_once("stream-1", 3));
1011 assert_eq!(mgr.high_watermark("stream-1"), 3);
1012 }
1013
1014 #[test]
1015 fn test_manager_stats() {
1016 let mgr = make_manager();
1017 mgr.register_processor("proc-1");
1018 mgr.assign_partition(PartitionAssignment {
1019 partition_id: "p-0".to_string(),
1020 assigned_to: "proc-1".to_string(),
1021 state_size_bytes: 2048,
1022 load_score: 0.5,
1023 });
1024 mgr.check_exactly_once("src-1", 1);
1025
1026 let mut states = HashMap::new();
1027 states.insert("op-1".to_string(), b"state".to_vec());
1028 mgr.take_checkpoint(states).expect("should succeed");
1029
1030 let stats = mgr.stats();
1031 assert_eq!(stats.checkpoints_taken, 1);
1032 assert_eq!(stats.partition_count, 1);
1033 assert_eq!(stats.active_processors, 1);
1034 assert_eq!(stats.total_state_bytes, 2048);
1035 assert_eq!(stats.dedup_stats.unique_accepted, 1);
1036 }
1037
1038 #[test]
1039 fn test_manager_multiple_checkpoints_restore_latest() {
1040 let mgr = make_manager();
1041
1042 let mut states1 = HashMap::new();
1043 states1.insert("op".to_string(), b"version-1".to_vec());
1044 mgr.take_checkpoint(states1).expect("should succeed");
1045
1046 let mut states2 = HashMap::new();
1047 states2.insert("op".to_string(), b"version-2".to_vec());
1048 mgr.take_checkpoint(states2).expect("should succeed");
1049
1050 let restored = mgr.restore_from_latest().expect("should restore");
1051 assert_eq!(restored.get("op"), Some(&b"version-2".to_vec()));
1052 }
1053
1054 #[test]
1055 fn test_migration_history() {
1056 let mgr = make_manager();
1057 mgr.register_processor("proc-1");
1058 for i in 0..4 {
1059 mgr.assign_partition(PartitionAssignment {
1060 partition_id: format!("p-{}", i),
1061 assigned_to: "proc-1".to_string(),
1062 state_size_bytes: 256,
1063 load_score: 0.5,
1064 });
1065 }
1066 mgr.register_processor("proc-2");
1067 if let Some(plan) = mgr.plan_migration(MigrationReason::LoadImbalance) {
1068 mgr.execute_migration(&plan);
1069 }
1070
1071 let history = mgr.migration_history();
1072 assert_eq!(history.len(), 1);
1073 }
1074}