Skip to main content

oxirs_stream/distributed_state/
manager.rs

1//! # Distributed State Manager
2//!
3//! Coordinates distributed state across stream processors with:
4//! - Periodic state checkpointing (snapshots of operator state)
5//! - Exactly-once semantics via sequence-number deduplication
6//! - State migration when processors join or leave
7
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::sync::Arc;
10use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
11
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info, warn};
15
16use super::{DistributedCheckpointer, StateResult};
17
18// ─── Exactly-Once Deduplication ──────────────────────────────────────────────
19
20/// Configuration for the deduplication log
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DeduplicationConfig {
23    /// Maximum number of sequence entries to track per source
24    pub max_entries_per_source: usize,
25    /// Expire entries older than this duration
26    pub expiry: Duration,
27}
28
29impl Default for DeduplicationConfig {
30    fn default() -> Self {
31        Self {
32            max_entries_per_source: 10_000,
33            expiry: Duration::from_secs(3600),
34        }
35    }
36}
37
38/// A tracked sequence entry for deduplication
39#[derive(Debug, Clone)]
40struct SequenceEntry {
41    sequence_number: u64,
42    received_at: Instant,
43}
44
45/// Deduplication log using sequence numbers for exactly-once semantics.
46///
47/// Each source stream maintains a monotonically increasing sequence number.
48/// The log tracks the highest contiguous sequence received, plus a set of
49/// out-of-order sequences, to detect and reject duplicates.
50pub struct SequenceDeduplicator {
51    config: DeduplicationConfig,
52    /// Per-source: highest contiguous sequence number received
53    high_watermarks: Arc<RwLock<HashMap<String, u64>>>,
54    /// Per-source: out-of-order sequence numbers above the high watermark
55    pending_sequences: Arc<RwLock<HashMap<String, VecDeque<SequenceEntry>>>>,
56    /// Total duplicates rejected
57    duplicates_rejected: Arc<RwLock<u64>>,
58    /// Total unique messages accepted
59    unique_accepted: Arc<RwLock<u64>>,
60}
61
62impl SequenceDeduplicator {
63    /// Creates a new deduplicator with the given configuration.
64    pub fn new(config: DeduplicationConfig) -> Self {
65        Self {
66            config,
67            high_watermarks: Arc::new(RwLock::new(HashMap::new())),
68            pending_sequences: Arc::new(RwLock::new(HashMap::new())),
69            duplicates_rejected: Arc::new(RwLock::new(0)),
70            unique_accepted: Arc::new(RwLock::new(0)),
71        }
72    }
73
74    /// Checks whether a message is a duplicate.
75    ///
76    /// Returns `true` if the message is **new** (not a duplicate) and should
77    /// be processed. Returns `false` if it is a duplicate.
78    pub fn check_and_record(&self, source_id: &str, sequence_number: u64) -> bool {
79        let mut watermarks = self.high_watermarks.write();
80        let current_watermark = watermarks.entry(source_id.to_string()).or_insert(0);
81
82        // If below or equal to the high watermark, it is a duplicate
83        if sequence_number <= *current_watermark && *current_watermark > 0 {
84            // Could be in pending (out-of-order), check
85            let pending = self.pending_sequences.read();
86            if let Some(entries) = pending.get(source_id) {
87                if entries.iter().any(|e| e.sequence_number == sequence_number) {
88                    *self.duplicates_rejected.write() += 1;
89                    return false;
90                }
91            }
92            *self.duplicates_rejected.write() += 1;
93            return false;
94        }
95
96        // Check if it is already in pending
97        {
98            let pending = self.pending_sequences.read();
99            if let Some(entries) = pending.get(source_id) {
100                if entries.iter().any(|e| e.sequence_number == sequence_number) {
101                    *self.duplicates_rejected.write() += 1;
102                    return false;
103                }
104            }
105        }
106
107        // Record the new sequence
108        if sequence_number == *current_watermark + 1 || *current_watermark == 0 {
109            // Contiguous: advance the watermark
110            *current_watermark = sequence_number;
111            // Advance through any pending that are now contiguous
112            drop(watermarks);
113            self.advance_watermark(source_id);
114        } else {
115            // Out-of-order: add to pending
116            drop(watermarks);
117            let mut pending = self.pending_sequences.write();
118            let entries = pending.entry(source_id.to_string()).or_default();
119            entries.push_back(SequenceEntry {
120                sequence_number,
121                received_at: Instant::now(),
122            });
123            // Cap the pending entries
124            while entries.len() > self.config.max_entries_per_source {
125                entries.pop_front();
126            }
127        }
128
129        *self.unique_accepted.write() += 1;
130        true
131    }
132
133    /// Advances the watermark by consuming contiguous pending sequences.
134    fn advance_watermark(&self, source_id: &str) {
135        let mut watermarks = self.high_watermarks.write();
136        let watermark = watermarks.entry(source_id.to_string()).or_insert(0);
137        let mut pending = self.pending_sequences.write();
138        if let Some(entries) = pending.get_mut(source_id) {
139            entries.make_contiguous().sort_by_key(|e| e.sequence_number);
140            while let Some(front) = entries.front() {
141                if front.sequence_number == *watermark + 1 {
142                    *watermark += 1;
143                    entries.pop_front();
144                } else {
145                    break;
146                }
147            }
148        }
149    }
150
151    /// Returns the high watermark for a source.
152    pub fn high_watermark(&self, source_id: &str) -> u64 {
153        self.high_watermarks
154            .read()
155            .get(source_id)
156            .copied()
157            .unwrap_or(0)
158    }
159
160    /// Expires old pending entries.
161    pub fn expire_old_entries(&self) {
162        let now = Instant::now();
163        let mut pending = self.pending_sequences.write();
164        for entries in pending.values_mut() {
165            entries.retain(|e| now.duration_since(e.received_at) < self.config.expiry);
166        }
167    }
168
169    /// Returns deduplication statistics.
170    pub fn stats(&self) -> DeduplicationStats {
171        let pending_count: usize = self
172            .pending_sequences
173            .read()
174            .values()
175            .map(|e| e.len())
176            .sum();
177        DeduplicationStats {
178            duplicates_rejected: *self.duplicates_rejected.read(),
179            unique_accepted: *self.unique_accepted.read(),
180            tracked_sources: self.high_watermarks.read().len(),
181            pending_sequences: pending_count,
182        }
183    }
184}
185
186/// Statistics for the deduplication log
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct DeduplicationStats {
189    /// Total duplicate messages rejected
190    pub duplicates_rejected: u64,
191    /// Total unique messages accepted
192    pub unique_accepted: u64,
193    /// Number of sources being tracked
194    pub tracked_sources: usize,
195    /// Total out-of-order sequences pending
196    pub pending_sequences: usize,
197}
198
199// ─── Operator State Snapshot ─────────────────────────────────────────────────
200
201/// A snapshot of a single operator's state
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct OperatorStateSnapshot {
204    /// Operator identifier
205    pub operator_id: String,
206    /// Serialized state bytes
207    pub state_bytes: Vec<u8>,
208    /// State version (monotonically increasing)
209    pub version: u64,
210    /// Timestamp of snapshot
211    pub created_at: u64,
212    /// Size in bytes
213    pub size_bytes: usize,
214}
215
216/// Configuration for periodic checkpointing
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct CheckpointConfig {
219    /// Interval between checkpoints
220    pub checkpoint_interval: Duration,
221    /// Maximum number of retained checkpoints
222    pub max_retained_checkpoints: usize,
223    /// Whether to verify checkpoint integrity
224    pub verify_integrity: bool,
225}
226
227impl Default for CheckpointConfig {
228    fn default() -> Self {
229        Self {
230            checkpoint_interval: Duration::from_secs(30),
231            max_retained_checkpoints: 5,
232            verify_integrity: true,
233        }
234    }
235}
236
237/// A complete checkpoint containing all operator snapshots
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct StateCheckpoint {
240    /// Checkpoint identifier
241    pub checkpoint_id: String,
242    /// All operator snapshots
243    pub operator_snapshots: HashMap<String, OperatorStateSnapshot>,
244    /// Checkpoint version
245    pub version: u64,
246    /// Merkle root for integrity verification
247    pub merkle_root: String,
248    /// Creation timestamp (microseconds since UNIX epoch)
249    pub created_at: u64,
250    /// Whether checkpoint is complete (all operators contributed)
251    pub is_complete: bool,
252}
253
254// ─── State Migration ─────────────────────────────────────────────────────────
255
256/// Describes a state partition assignment
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct PartitionAssignment {
259    /// Partition identifier
260    pub partition_id: String,
261    /// Currently assigned processor node
262    pub assigned_to: String,
263    /// State size in bytes (approximate)
264    pub state_size_bytes: usize,
265    /// Load score (0.0 to 1.0)
266    pub load_score: f64,
267}
268
269/// A migration plan describing how to rebalance state
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct MigrationPlan {
272    /// Migrations to execute
273    pub migrations: Vec<MigrationStep>,
274    /// Estimated total bytes to transfer
275    pub total_bytes_to_transfer: usize,
276    /// Reason for migration
277    pub reason: MigrationReason,
278}
279
280/// A single step in a migration plan
281#[derive(Debug, Clone, Serialize, Deserialize)]
282pub struct MigrationStep {
283    /// Partition being migrated
284    pub partition_id: String,
285    /// Source processor
286    pub from_node: String,
287    /// Target processor
288    pub to_node: String,
289    /// Estimated state size
290    pub state_size_bytes: usize,
291}
292
293/// Reason a migration was triggered
294#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
295pub enum MigrationReason {
296    /// A new processor joined the cluster
297    NodeJoined { node_id: String },
298    /// A processor left the cluster
299    NodeLeft { node_id: String },
300    /// Load imbalance detected
301    LoadImbalance,
302    /// Manual trigger
303    Manual,
304}
305
306// ─── Distributed State Manager ───────────────────────────────────────────────
307
308/// Statistics for the distributed state manager
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct DistributedStateManagerStats {
311    /// Total checkpoints taken
312    pub checkpoints_taken: u64,
313    /// Total state migrations performed
314    pub migrations_performed: u64,
315    /// Current number of partitions
316    pub partition_count: usize,
317    /// Current number of active processors
318    pub active_processors: usize,
319    /// Total state size across all partitions (bytes)
320    pub total_state_bytes: usize,
321    /// Deduplication statistics
322    pub dedup_stats: DeduplicationStats,
323    /// Average checkpoint duration (milliseconds)
324    pub avg_checkpoint_duration_ms: f64,
325}
326
327/// The main distributed state manager that coordinates state across
328/// stream processors.
329///
330/// Provides:
331/// - Periodic state checkpointing via operator snapshots
332/// - Exactly-once semantics via sequence-number deduplication
333/// - State migration when processors join or leave the cluster
334pub struct DistributedStateManager {
335    /// This node's identifier
336    node_id: String,
337    /// Checkpoint configuration
338    checkpoint_config: CheckpointConfig,
339    /// Sequence deduplicator for exactly-once semantics
340    deduplicator: SequenceDeduplicator,
341    /// Current partition assignments
342    partitions: Arc<RwLock<HashMap<String, PartitionAssignment>>>,
343    /// Active processor nodes
344    active_processors: Arc<RwLock<HashSet<String>>>,
345    /// Stored checkpoints (most recent first)
346    checkpoints: Arc<RwLock<VecDeque<StateCheckpoint>>>,
347    /// Current checkpoint version counter
348    checkpoint_version: Arc<RwLock<u64>>,
349    /// Migration history
350    migration_history: Arc<RwLock<Vec<MigrationPlan>>>,
351    /// Total checkpoints taken
352    checkpoints_taken: Arc<RwLock<u64>>,
353    /// Total migrations performed
354    migrations_performed: Arc<RwLock<u64>>,
355    /// Checkpoint duration accumulator
356    checkpoint_duration_sum_ms: Arc<RwLock<f64>>,
357    /// Last checkpoint time
358    last_checkpoint: Arc<RwLock<Option<Instant>>>,
359}
360
361impl DistributedStateManager {
362    /// Creates a new distributed state manager.
363    pub fn new(
364        node_id: impl Into<String>,
365        checkpoint_config: CheckpointConfig,
366        dedup_config: DeduplicationConfig,
367    ) -> Self {
368        Self {
369            node_id: node_id.into(),
370            checkpoint_config,
371            deduplicator: SequenceDeduplicator::new(dedup_config),
372            partitions: Arc::new(RwLock::new(HashMap::new())),
373            active_processors: Arc::new(RwLock::new(HashSet::new())),
374            checkpoints: Arc::new(RwLock::new(VecDeque::new())),
375            checkpoint_version: Arc::new(RwLock::new(0)),
376            migration_history: Arc::new(RwLock::new(Vec::new())),
377            checkpoints_taken: Arc::new(RwLock::new(0)),
378            migrations_performed: Arc::new(RwLock::new(0)),
379            checkpoint_duration_sum_ms: Arc::new(RwLock::new(0.0)),
380            last_checkpoint: Arc::new(RwLock::new(None)),
381        }
382    }
383
384    /// Returns the node ID of this manager.
385    pub fn node_id(&self) -> &str {
386        &self.node_id
387    }
388
389    /// Registers a processor node as active.
390    pub fn register_processor(&self, node_id: impl Into<String>) {
391        let id = node_id.into();
392        self.active_processors.write().insert(id.clone());
393        info!("Registered processor: {}", id);
394    }
395
396    /// Removes a processor node.
397    pub fn remove_processor(&self, node_id: &str) {
398        self.active_processors.write().remove(node_id);
399        info!("Removed processor: {}", node_id);
400    }
401
402    /// Assigns a partition to a processor.
403    pub fn assign_partition(&self, assignment: PartitionAssignment) {
404        debug!(
405            "Assigning partition {} to {}",
406            assignment.partition_id, assignment.assigned_to
407        );
408        self.partitions
409            .write()
410            .insert(assignment.partition_id.clone(), assignment);
411    }
412
413    /// Checks and records a message for exactly-once processing.
414    ///
415    /// Returns `true` if the message is new and should be processed.
416    pub fn check_exactly_once(&self, source_id: &str, sequence_number: u64) -> bool {
417        self.deduplicator
418            .check_and_record(source_id, sequence_number)
419    }
420
421    /// Returns the high watermark for a source.
422    pub fn high_watermark(&self, source_id: &str) -> u64 {
423        self.deduplicator.high_watermark(source_id)
424    }
425
426    /// Takes a checkpoint of the given operator states.
427    ///
428    /// Returns the checkpoint if successful.
429    pub fn take_checkpoint(
430        &self,
431        operator_states: HashMap<String, Vec<u8>>,
432    ) -> StateResult<StateCheckpoint> {
433        let start = Instant::now();
434
435        let mut version = self.checkpoint_version.write();
436        *version += 1;
437        let current_version = *version;
438        drop(version);
439
440        let now_micros = SystemTime::now()
441            .duration_since(UNIX_EPOCH)
442            .unwrap_or_default()
443            .as_micros() as u64;
444
445        let checkpoint_id = format!("ckpt-{}-{}", self.node_id, current_version);
446
447        // Build operator snapshots
448        let mut operator_snapshots = HashMap::new();
449        for (op_id, state_bytes) in operator_states {
450            let size = state_bytes.len();
451            operator_snapshots.insert(
452                op_id.clone(),
453                OperatorStateSnapshot {
454                    operator_id: op_id,
455                    state_bytes,
456                    version: current_version,
457                    created_at: now_micros,
458                    size_bytes: size,
459                },
460            );
461        }
462
463        // Compute merkle root over all operator states
464        let mut all_bytes = Vec::new();
465        let mut sorted_keys: Vec<&String> = operator_snapshots.keys().collect();
466        sorted_keys.sort();
467        for key in sorted_keys {
468            if let Some(snapshot) = operator_snapshots.get(key) {
469                all_bytes.extend_from_slice(&snapshot.state_bytes);
470            }
471        }
472        let merkle_root = DistributedCheckpointer::compute_merkle_root(&all_bytes);
473
474        let checkpoint = StateCheckpoint {
475            checkpoint_id,
476            operator_snapshots,
477            version: current_version,
478            merkle_root,
479            created_at: now_micros,
480            is_complete: true,
481        };
482
483        // Store the checkpoint
484        let max_retained = self.checkpoint_config.max_retained_checkpoints;
485        let mut checkpoints = self.checkpoints.write();
486        checkpoints.push_front(checkpoint.clone());
487        while checkpoints.len() > max_retained {
488            checkpoints.pop_back();
489        }
490
491        *self.checkpoints_taken.write() += 1;
492        *self.last_checkpoint.write() = Some(Instant::now());
493
494        let elapsed = start.elapsed().as_millis() as f64;
495        *self.checkpoint_duration_sum_ms.write() += elapsed;
496
497        info!(
498            "Checkpoint {} taken (version {}, {} operators, {:.1}ms)",
499            checkpoint.checkpoint_id,
500            current_version,
501            checkpoint.operator_snapshots.len(),
502            elapsed
503        );
504
505        Ok(checkpoint)
506    }
507
508    /// Restores state from the latest checkpoint.
509    ///
510    /// Returns the operator states map if a checkpoint exists.
511    pub fn restore_from_latest(&self) -> Option<HashMap<String, Vec<u8>>> {
512        let checkpoints = self.checkpoints.read();
513        let latest = checkpoints.front()?;
514
515        if self.checkpoint_config.verify_integrity {
516            // Verify merkle root
517            let mut all_bytes = Vec::new();
518            let mut sorted_keys: Vec<&String> = latest.operator_snapshots.keys().collect();
519            sorted_keys.sort();
520            for key in sorted_keys {
521                if let Some(snapshot) = latest.operator_snapshots.get(key) {
522                    all_bytes.extend_from_slice(&snapshot.state_bytes);
523                }
524            }
525            let computed = DistributedCheckpointer::compute_merkle_root(&all_bytes);
526            if computed != latest.merkle_root {
527                warn!("Checkpoint {} failed integrity check", latest.checkpoint_id);
528                return None;
529            }
530        }
531
532        let states: HashMap<String, Vec<u8>> = latest
533            .operator_snapshots
534            .iter()
535            .map(|(k, v)| (k.clone(), v.state_bytes.clone()))
536            .collect();
537        info!(
538            "Restored state from checkpoint {} (version {})",
539            latest.checkpoint_id, latest.version
540        );
541        Some(states)
542    }
543
544    /// Returns all stored checkpoints (most recent first).
545    pub fn checkpoints(&self) -> Vec<StateCheckpoint> {
546        self.checkpoints.read().iter().cloned().collect()
547    }
548
549    /// Returns whether a checkpoint is due based on the configured interval.
550    pub fn is_checkpoint_due(&self) -> bool {
551        let last = self.last_checkpoint.read();
552        match *last {
553            Some(instant) => instant.elapsed() >= self.checkpoint_config.checkpoint_interval,
554            None => true,
555        }
556    }
557
558    /// Plans a migration based on the current partition assignments and processor set.
559    ///
560    /// Returns `None` if no migration is needed.
561    pub fn plan_migration(&self, reason: MigrationReason) -> Option<MigrationPlan> {
562        let partitions = self.partitions.read();
563        let processors = self.active_processors.read();
564
565        if processors.is_empty() || partitions.is_empty() {
566            return None;
567        }
568
569        let processor_list: Vec<String> = processors.iter().cloned().collect();
570
571        // Build current load per processor
572        let mut load_per_processor: HashMap<String, Vec<String>> = HashMap::new();
573        for proc_id in &processor_list {
574            load_per_processor.insert(proc_id.clone(), Vec::new());
575        }
576        for (partition_id, assignment) in partitions.iter() {
577            load_per_processor
578                .entry(assignment.assigned_to.clone())
579                .or_default()
580                .push(partition_id.clone());
581        }
582
583        let total_partitions = partitions.len();
584        let target_per_processor = total_partitions / processor_list.len();
585        let remainder = total_partitions % processor_list.len();
586
587        // Find overloaded and underloaded processors
588        let mut migrations = Vec::new();
589        let mut donors: Vec<(String, Vec<String>)> = Vec::new();
590        let mut receivers: Vec<(String, usize)> = Vec::new();
591
592        for (i, proc_id) in processor_list.iter().enumerate() {
593            let current_count = load_per_processor
594                .get(proc_id)
595                .map(|v| v.len())
596                .unwrap_or(0);
597            let target = target_per_processor + if i < remainder { 1 } else { 0 };
598            if current_count > target {
599                let excess: Vec<String> = load_per_processor
600                    .get(proc_id)
601                    .map(|v| v[target..].to_vec())
602                    .unwrap_or_default();
603                donors.push((proc_id.clone(), excess));
604            } else if current_count < target {
605                receivers.push((proc_id.clone(), target - current_count));
606            }
607        }
608
609        // Match donors with receivers
610        let mut donor_iter = donors
611            .iter()
612            .flat_map(|(from, parts)| parts.iter().map(move |p| (from.clone(), p.clone())));
613        for (to_node, need) in &receivers {
614            for _ in 0..*need {
615                if let Some((from_node, partition_id)) = donor_iter.next() {
616                    let state_size = partitions
617                        .get(&partition_id)
618                        .map(|a| a.state_size_bytes)
619                        .unwrap_or(0);
620                    migrations.push(MigrationStep {
621                        partition_id,
622                        from_node,
623                        to_node: to_node.clone(),
624                        state_size_bytes: state_size,
625                    });
626                }
627            }
628        }
629
630        if migrations.is_empty() {
631            return None;
632        }
633
634        let total_bytes = migrations.iter().map(|m| m.state_size_bytes).sum();
635
636        Some(MigrationPlan {
637            migrations,
638            total_bytes_to_transfer: total_bytes,
639            reason,
640        })
641    }
642
643    /// Executes a migration plan by updating partition assignments.
644    ///
645    /// Returns the number of partitions migrated.
646    pub fn execute_migration(&self, plan: &MigrationPlan) -> usize {
647        let mut partitions = self.partitions.write();
648        let mut migrated = 0;
649
650        for step in &plan.migrations {
651            if let Some(assignment) = partitions.get_mut(&step.partition_id) {
652                assignment.assigned_to = step.to_node.clone();
653                migrated += 1;
654                debug!(
655                    "Migrated partition {} from {} to {}",
656                    step.partition_id, step.from_node, step.to_node
657                );
658            }
659        }
660
661        *self.migrations_performed.write() += 1;
662        self.migration_history.write().push(plan.clone());
663        info!(
664            "Migration complete: {} partitions moved ({} bytes)",
665            migrated, plan.total_bytes_to_transfer
666        );
667        migrated
668    }
669
670    /// Handles a node joining the cluster: registers it and optionally migrates.
671    pub fn handle_node_joined(&self, node_id: &str) -> Option<MigrationPlan> {
672        self.register_processor(node_id);
673        self.plan_migration(MigrationReason::NodeJoined {
674            node_id: node_id.to_string(),
675        })
676    }
677
678    /// Handles a node leaving the cluster: reassigns its partitions.
679    pub fn handle_node_left(&self, node_id: &str) -> Option<MigrationPlan> {
680        self.remove_processor(node_id);
681        // Reassign partitions from the departed node
682        self.plan_migration(MigrationReason::NodeLeft {
683            node_id: node_id.to_string(),
684        })
685    }
686
687    /// Returns current partition assignments.
688    pub fn partition_assignments(&self) -> Vec<PartitionAssignment> {
689        self.partitions.read().values().cloned().collect()
690    }
691
692    /// Returns active processor node IDs.
693    pub fn active_processors(&self) -> Vec<String> {
694        self.active_processors.read().iter().cloned().collect()
695    }
696
697    /// Returns migration history.
698    pub fn migration_history(&self) -> Vec<MigrationPlan> {
699        self.migration_history.read().clone()
700    }
701
702    /// Returns comprehensive statistics.
703    pub fn stats(&self) -> DistributedStateManagerStats {
704        let checkpoints_taken = *self.checkpoints_taken.read();
705        let avg_duration = if checkpoints_taken > 0 {
706            *self.checkpoint_duration_sum_ms.read() / checkpoints_taken as f64
707        } else {
708            0.0
709        };
710
711        let total_state_bytes: usize = self
712            .partitions
713            .read()
714            .values()
715            .map(|p| p.state_size_bytes)
716            .sum();
717
718        DistributedStateManagerStats {
719            checkpoints_taken,
720            migrations_performed: *self.migrations_performed.read(),
721            partition_count: self.partitions.read().len(),
722            active_processors: self.active_processors.read().len(),
723            total_state_bytes,
724            dedup_stats: self.deduplicator.stats(),
725            avg_checkpoint_duration_ms: avg_duration,
726        }
727    }
728}
729
730// ─── Tests ───────────────────────────────────────────────────────────────────
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735
736    fn make_manager() -> DistributedStateManager {
737        DistributedStateManager::new(
738            "node-1",
739            CheckpointConfig::default(),
740            DeduplicationConfig::default(),
741        )
742    }
743
744    // ── Deduplication Tests ──────────────────────────────────────────────────
745
746    #[test]
747    fn test_dedup_first_message_accepted() {
748        let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
749        assert!(dedup.check_and_record("src-1", 1));
750    }
751
752    #[test]
753    fn test_dedup_duplicate_rejected() {
754        let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
755        assert!(dedup.check_and_record("src-1", 1));
756        assert!(!dedup.check_and_record("src-1", 1));
757    }
758
759    #[test]
760    fn test_dedup_sequential_messages() {
761        let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
762        for i in 1..=10 {
763            assert!(dedup.check_and_record("src-1", i));
764        }
765        assert_eq!(dedup.high_watermark("src-1"), 10);
766    }
767
768    #[test]
769    fn test_dedup_out_of_order_accepted() {
770        let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
771        assert!(dedup.check_and_record("src-1", 1));
772        assert!(dedup.check_and_record("src-1", 3)); // out of order
773        assert!(dedup.check_and_record("src-1", 2)); // fills the gap
774        assert_eq!(dedup.high_watermark("src-1"), 3);
775    }
776
777    #[test]
778    fn test_dedup_multiple_sources() {
779        let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
780        assert!(dedup.check_and_record("src-a", 1));
781        assert!(dedup.check_and_record("src-b", 1));
782        assert!(!dedup.check_and_record("src-a", 1));
783        assert!(dedup.check_and_record("src-a", 2));
784    }
785
786    #[test]
787    fn test_dedup_stats() {
788        let dedup = SequenceDeduplicator::new(DeduplicationConfig::default());
789        dedup.check_and_record("src-1", 1);
790        dedup.check_and_record("src-1", 1); // duplicate
791        dedup.check_and_record("src-2", 1);
792
793        let stats = dedup.stats();
794        assert_eq!(stats.unique_accepted, 2);
795        assert_eq!(stats.duplicates_rejected, 1);
796        assert_eq!(stats.tracked_sources, 2);
797    }
798
799    #[test]
800    fn test_dedup_expire_old_entries() {
801        let config = DeduplicationConfig {
802            max_entries_per_source: 100,
803            expiry: Duration::from_millis(1),
804        };
805        let dedup = SequenceDeduplicator::new(config);
806        dedup.check_and_record("src-1", 1);
807        dedup.check_and_record("src-1", 5); // out of order, goes to pending
808        std::thread::sleep(Duration::from_millis(5));
809        dedup.expire_old_entries();
810        let stats = dedup.stats();
811        assert_eq!(stats.pending_sequences, 0);
812    }
813
814    // ── Checkpoint Tests ─────────────────────────────────────────────────────
815
816    #[test]
817    fn test_take_checkpoint() {
818        let mgr = make_manager();
819        let mut states = HashMap::new();
820        states.insert("op-1".to_string(), b"state-1".to_vec());
821        states.insert("op-2".to_string(), b"state-2".to_vec());
822
823        let ckpt = mgr
824            .take_checkpoint(states)
825            .expect("checkpoint should succeed");
826        assert_eq!(ckpt.operator_snapshots.len(), 2);
827        assert!(ckpt.is_complete);
828        assert!(!ckpt.merkle_root.is_empty());
829        assert_eq!(ckpt.version, 1);
830    }
831
832    #[test]
833    fn test_restore_from_latest() {
834        let mgr = make_manager();
835        let mut states = HashMap::new();
836        states.insert("op-1".to_string(), b"data-a".to_vec());
837        mgr.take_checkpoint(states)
838            .expect("checkpoint should succeed");
839
840        let restored = mgr.restore_from_latest().expect("should restore");
841        assert_eq!(restored.get("op-1"), Some(&b"data-a".to_vec()));
842    }
843
844    #[test]
845    fn test_checkpoint_retention() {
846        let config = CheckpointConfig {
847            max_retained_checkpoints: 2,
848            ..Default::default()
849        };
850        let mgr = DistributedStateManager::new("node-1", config, DeduplicationConfig::default());
851
852        for i in 0..5 {
853            let mut states = HashMap::new();
854            states.insert("op".to_string(), format!("state-{}", i).into_bytes());
855            mgr.take_checkpoint(states).expect("should succeed");
856        }
857
858        let checkpoints = mgr.checkpoints();
859        assert_eq!(checkpoints.len(), 2);
860        // Most recent should be first
861        assert_eq!(checkpoints[0].version, 5);
862    }
863
864    #[test]
865    fn test_checkpoint_integrity_verification() {
866        let mgr = make_manager();
867        let mut states = HashMap::new();
868        states.insert("op-1".to_string(), b"my-data".to_vec());
869        mgr.take_checkpoint(states).expect("should succeed");
870
871        // Restore should work with valid integrity
872        let restored = mgr.restore_from_latest();
873        assert!(restored.is_some());
874    }
875
876    #[test]
877    fn test_is_checkpoint_due() {
878        let config = CheckpointConfig {
879            checkpoint_interval: Duration::from_millis(10),
880            ..Default::default()
881        };
882        let mgr = DistributedStateManager::new("node-1", config, DeduplicationConfig::default());
883        assert!(mgr.is_checkpoint_due());
884
885        let mut states = HashMap::new();
886        states.insert("op".to_string(), b"data".to_vec());
887        mgr.take_checkpoint(states).expect("should succeed");
888        assert!(!mgr.is_checkpoint_due());
889
890        std::thread::sleep(Duration::from_millis(15));
891        assert!(mgr.is_checkpoint_due());
892    }
893
894    // ── Migration Tests ──────────────────────────────────────────────────────
895
896    #[test]
897    fn test_migration_plan_on_node_join() {
898        let mgr = make_manager();
899        mgr.register_processor("proc-1");
900        for i in 0..4 {
901            mgr.assign_partition(PartitionAssignment {
902                partition_id: format!("p-{}", i),
903                assigned_to: "proc-1".to_string(),
904                state_size_bytes: 1024,
905                load_score: 0.5,
906            });
907        }
908
909        // A second processor joins
910        let plan = mgr.handle_node_joined("proc-2");
911        assert!(plan.is_some(), "should generate migration plan");
912        let plan = plan.expect("plan exists");
913        assert!(!plan.migrations.is_empty());
914        assert_eq!(
915            plan.reason,
916            MigrationReason::NodeJoined {
917                node_id: "proc-2".to_string()
918            }
919        );
920    }
921
922    #[test]
923    fn test_migration_plan_balanced_no_migration() {
924        let mgr = make_manager();
925        mgr.register_processor("proc-1");
926        mgr.register_processor("proc-2");
927        mgr.assign_partition(PartitionAssignment {
928            partition_id: "p-0".to_string(),
929            assigned_to: "proc-1".to_string(),
930            state_size_bytes: 1024,
931            load_score: 0.5,
932        });
933        mgr.assign_partition(PartitionAssignment {
934            partition_id: "p-1".to_string(),
935            assigned_to: "proc-2".to_string(),
936            state_size_bytes: 1024,
937            load_score: 0.5,
938        });
939
940        let plan = mgr.plan_migration(MigrationReason::Manual);
941        assert!(plan.is_none(), "balanced assignment needs no migration");
942    }
943
944    #[test]
945    fn test_execute_migration() {
946        let mgr = make_manager();
947        mgr.register_processor("proc-1");
948        mgr.register_processor("proc-2");
949        for i in 0..4 {
950            mgr.assign_partition(PartitionAssignment {
951                partition_id: format!("p-{}", i),
952                assigned_to: "proc-1".to_string(),
953                state_size_bytes: 512,
954                load_score: 0.5,
955            });
956        }
957
958        let plan = mgr
959            .plan_migration(MigrationReason::LoadImbalance)
960            .expect("should plan migration");
961        let migrated = mgr.execute_migration(&plan);
962        assert!(migrated > 0);
963
964        // Verify assignments changed
965        let assignments = mgr.partition_assignments();
966        let proc2_count = assignments
967            .iter()
968            .filter(|a| a.assigned_to == "proc-2")
969            .count();
970        assert!(proc2_count > 0, "proc-2 should have partitions now");
971    }
972
973    #[test]
974    fn test_handle_node_left() {
975        let mgr = make_manager();
976        mgr.register_processor("proc-1");
977        mgr.register_processor("proc-2");
978        mgr.assign_partition(PartitionAssignment {
979            partition_id: "p-0".to_string(),
980            assigned_to: "proc-1".to_string(),
981            state_size_bytes: 1024,
982            load_score: 0.3,
983        });
984        mgr.assign_partition(PartitionAssignment {
985            partition_id: "p-1".to_string(),
986            assigned_to: "proc-2".to_string(),
987            state_size_bytes: 1024,
988            load_score: 0.3,
989        });
990
991        // proc-2 leaves
992        let plan = mgr.handle_node_left("proc-2");
993        // With only proc-1 remaining and p-1 still assigned to proc-2
994        // plan should exist to move p-1 to proc-1
995        if let Some(plan) = plan {
996            mgr.execute_migration(&plan);
997        }
998        let procs = mgr.active_processors();
999        assert!(!procs.contains(&"proc-2".to_string()));
1000    }
1001
1002    // ── Manager Integration Tests ────────────────────────────────────────────
1003
1004    #[test]
1005    fn test_manager_exactly_once() {
1006        let mgr = make_manager();
1007        assert!(mgr.check_exactly_once("stream-1", 1));
1008        assert!(mgr.check_exactly_once("stream-1", 2));
1009        assert!(!mgr.check_exactly_once("stream-1", 1)); // duplicate
1010        assert!(mgr.check_exactly_once("stream-1", 3));
1011        assert_eq!(mgr.high_watermark("stream-1"), 3);
1012    }
1013
1014    #[test]
1015    fn test_manager_stats() {
1016        let mgr = make_manager();
1017        mgr.register_processor("proc-1");
1018        mgr.assign_partition(PartitionAssignment {
1019            partition_id: "p-0".to_string(),
1020            assigned_to: "proc-1".to_string(),
1021            state_size_bytes: 2048,
1022            load_score: 0.5,
1023        });
1024        mgr.check_exactly_once("src-1", 1);
1025
1026        let mut states = HashMap::new();
1027        states.insert("op-1".to_string(), b"state".to_vec());
1028        mgr.take_checkpoint(states).expect("should succeed");
1029
1030        let stats = mgr.stats();
1031        assert_eq!(stats.checkpoints_taken, 1);
1032        assert_eq!(stats.partition_count, 1);
1033        assert_eq!(stats.active_processors, 1);
1034        assert_eq!(stats.total_state_bytes, 2048);
1035        assert_eq!(stats.dedup_stats.unique_accepted, 1);
1036    }
1037
1038    #[test]
1039    fn test_manager_multiple_checkpoints_restore_latest() {
1040        let mgr = make_manager();
1041
1042        let mut states1 = HashMap::new();
1043        states1.insert("op".to_string(), b"version-1".to_vec());
1044        mgr.take_checkpoint(states1).expect("should succeed");
1045
1046        let mut states2 = HashMap::new();
1047        states2.insert("op".to_string(), b"version-2".to_vec());
1048        mgr.take_checkpoint(states2).expect("should succeed");
1049
1050        let restored = mgr.restore_from_latest().expect("should restore");
1051        assert_eq!(restored.get("op"), Some(&b"version-2".to_vec()));
1052    }
1053
1054    #[test]
1055    fn test_migration_history() {
1056        let mgr = make_manager();
1057        mgr.register_processor("proc-1");
1058        for i in 0..4 {
1059            mgr.assign_partition(PartitionAssignment {
1060                partition_id: format!("p-{}", i),
1061                assigned_to: "proc-1".to_string(),
1062                state_size_bytes: 256,
1063                load_score: 0.5,
1064            });
1065        }
1066        mgr.register_processor("proc-2");
1067        if let Some(plan) = mgr.plan_migration(MigrationReason::LoadImbalance) {
1068            mgr.execute_migration(&plan);
1069        }
1070
1071        let history = mgr.migration_history();
1072        assert_eq!(history.len(), 1);
1073    }
1074}