Skip to main content

oxirs_stream/state/
distributed.rs

1//! # Distributed State Management
2//!
3//! Distributed state coordination for stream processors across partitions.
4//! Uses consistent hashing (FNV-1a) for key-to-partition routing.
5
6use std::collections::HashMap;
7
8// ─── PartitionStateValue ──────────────────────────────────────────────────────
9
10/// Value types storable in a distributed state partition.
11#[derive(Debug, Clone, PartialEq)]
12pub enum PartitionStateValue {
13    Integer(i64),
14    Float(f64),
15    Bytes(Vec<u8>),
16    StringVal(String),
17    Counter(u64),
18    Gauge { value: f64, timestamp: i64 },
19}
20
21// ─── StatePartition ───────────────────────────────────────────────────────────
22
23/// A single partition of distributed state.
24#[derive(Debug, Clone)]
25pub struct StatePartition {
26    pub partition_id: u32,
27    pub state: HashMap<String, PartitionStateValue>,
28    pub version: u64,
29    pub last_checkpointed: i64,
30}
31
32impl StatePartition {
33    pub fn new(partition_id: u32) -> Self {
34        Self {
35            partition_id,
36            state: HashMap::new(),
37            version: 0,
38            last_checkpointed: 0,
39        }
40    }
41
42    fn bump_version(&mut self) -> u64 {
43        self.version += 1;
44        self.version
45    }
46}
47
48// ─── StateCoordinator ─────────────────────────────────────────────────────────
49
50/// Coordinates state replication metadata across peer nodes.
51#[derive(Debug, Clone)]
52pub struct StateCoordinator {
53    pub node_id: String,
54    pub peers: Vec<String>,
55}
56
57impl StateCoordinator {
58    pub fn new(node_id: impl Into<String>) -> Self {
59        Self {
60            node_id: node_id.into(),
61            peers: Vec::new(),
62        }
63    }
64
65    pub fn add_peer(&mut self, peer: impl Into<String>) {
66        self.peers.push(peer.into());
67    }
68}
69
70// ─── DistributedStateStore ────────────────────────────────────────────────────
71
72/// Distributed state store with consistent FNV-1a hashing for key-to-partition mapping.
73///
74/// Each key is routed to exactly one partition. Replication is handled by
75/// `replicate_to` which returns a snapshot of a partition for a peer node.
76pub struct DistributedStateStore {
77    pub(crate) partitions: Vec<StatePartition>,
78    replication_factor: usize,
79    coordinator: StateCoordinator,
80}
81
82impl DistributedStateStore {
83    /// Create a new store with `partition_count` partitions.
84    pub fn new(partition_count: u32, replication_factor: usize) -> Self {
85        let partitions = (0..partition_count).map(StatePartition::new).collect();
86        Self {
87            partitions,
88            replication_factor,
89            coordinator: StateCoordinator::new("local"),
90        }
91    }
92
93    /// FNV-1a 64-bit hash for consistent partition routing.
94    fn fnv_hash(key: &str) -> u64 {
95        const FNV_OFFSET: u64 = 14_695_981_039_346_656_037;
96        const FNV_PRIME: u64 = 1_099_511_628_211;
97        let mut hash = FNV_OFFSET;
98        for byte in key.as_bytes() {
99            hash ^= *byte as u64;
100            hash = hash.wrapping_mul(FNV_PRIME);
101        }
102        hash
103    }
104
105    /// Determine which partition a key belongs to (consistent hashing).
106    pub fn partition_for(&self, key: &str) -> u32 {
107        let count = self.partitions.len() as u64;
108        if count == 0 {
109            return 0;
110        }
111        (Self::fnv_hash(key) % count) as u32
112    }
113
114    /// Get a value by key, or `None` if absent.
115    pub fn get(&self, key: &str) -> Option<&PartitionStateValue> {
116        let pid = self.partition_for(key) as usize;
117        self.partitions.get(pid)?.state.get(key)
118    }
119
120    /// Set a key-value pair. Returns the new partition version number.
121    pub fn set(&mut self, key: &str, value: PartitionStateValue) -> u64 {
122        let pid = self.partition_for(key) as usize;
123        let partition = &mut self.partitions[pid];
124        partition.state.insert(key.to_string(), value);
125        partition.bump_version()
126    }
127
128    /// Delete a key. Returns `true` if the key previously existed.
129    pub fn delete(&mut self, key: &str) -> bool {
130        let pid = self.partition_for(key) as usize;
131        match self.partitions.get_mut(pid) {
132            Some(partition) => partition.state.remove(key).is_some(),
133            None => false,
134        }
135    }
136
137    /// Return all key-value pairs from `partition_id` for replication to `peer`.
138    pub fn replicate_to(
139        &self,
140        _peer: &str,
141        partition_id: u32,
142    ) -> Vec<(String, PartitionStateValue)> {
143        self.partitions
144            .iter()
145            .find(|p| p.partition_id == partition_id)
146            .map(|p| {
147                p.state
148                    .iter()
149                    .map(|(k, v)| (k.clone(), v.clone()))
150                    .collect()
151            })
152            .unwrap_or_default()
153    }
154
155    /// Snapshot a partition (clone it) and record `last_checkpointed` timestamp.
156    pub fn checkpoint_partition(&mut self, partition_id: u32) -> StatePartition {
157        let now_ms = std::time::SystemTime::now()
158            .duration_since(std::time::UNIX_EPOCH)
159            .map(|d| d.as_millis() as i64)
160            .unwrap_or(0);
161        let partition = self
162            .partitions
163            .iter_mut()
164            .find(|p| p.partition_id == partition_id)
165            .expect("partition_id out of range");
166        partition.last_checkpointed = now_ms;
167        partition.clone()
168    }
169
170    /// Restore a partition from a previously-created checkpoint snapshot.
171    pub fn restore_partition(&mut self, partition: StatePartition) {
172        if let Some(p) = self
173            .partitions
174            .iter_mut()
175            .find(|p| p.partition_id == partition.partition_id)
176        {
177            *p = partition;
178        }
179    }
180
181    /// Total number of partitions in this store.
182    pub fn partition_count(&self) -> u32 {
183        self.partitions.len() as u32
184    }
185
186    /// Total number of keys across all partitions.
187    pub fn total_keys(&self) -> usize {
188        self.partitions.iter().map(|p| p.state.len()).sum()
189    }
190
191    /// The configured replication factor.
192    pub fn replication_factor(&self) -> usize {
193        self.replication_factor
194    }
195
196    /// Read-only reference to the coordinator.
197    pub fn coordinator(&self) -> &StateCoordinator {
198        &self.coordinator
199    }
200
201    /// Mutable reference to the coordinator.
202    pub fn coordinator_mut(&mut self) -> &mut StateCoordinator {
203        &mut self.coordinator
204    }
205}
206
207// ─── StateAggregator ──────────────────────────────────────────────────────────
208
209/// High-level aggregator built on top of `DistributedStateStore`.
210///
211/// Provides common streaming aggregation patterns: increment counters,
212/// running float sums, gauges, and windowed event counts.
213pub struct StateAggregator {
214    store: DistributedStateStore,
215}
216
217impl StateAggregator {
218    /// Create an aggregator backed by a store with `partition_count` partitions.
219    pub fn new(partition_count: u32) -> Self {
220        Self {
221            store: DistributedStateStore::new(partition_count, 1),
222        }
223    }
224
225    /// Increment an integer counter by `by`. Returns the updated value.
226    pub fn increment(&mut self, key: &str, by: i64) -> i64 {
227        let current = match self.store.get(key) {
228            Some(PartitionStateValue::Integer(v)) => *v,
229            Some(PartitionStateValue::Counter(v)) => *v as i64,
230            _ => 0,
231        };
232        let next = current + by;
233        self.store.set(key, PartitionStateValue::Integer(next));
234        next
235    }
236
237    /// Add `value` to a running float sum. Returns the updated sum.
238    pub fn accumulate(&mut self, key: &str, value: f64) -> f64 {
239        let current = match self.store.get(key) {
240            Some(PartitionStateValue::Float(v)) => *v,
241            _ => 0.0,
242        };
243        let next = current + value;
244        self.store.set(key, PartitionStateValue::Float(next));
245        next
246    }
247
248    /// Update a gauge value (timestamped float).
249    pub fn update_gauge(&mut self, key: &str, value: f64) {
250        let timestamp = std::time::SystemTime::now()
251            .duration_since(std::time::UNIX_EPOCH)
252            .map(|d| d.as_millis() as i64)
253            .unwrap_or(0);
254        self.store
255            .set(key, PartitionStateValue::Gauge { value, timestamp });
256    }
257
258    /// Count events within a named window. Uses a composite key `window_key:event_key`.
259    /// Returns the updated count.
260    pub fn window_count(&mut self, window_key: &str, event_key: &str) -> u64 {
261        let key = format!("{window_key}:{event_key}");
262        let current = match self.store.get(&key) {
263            Some(PartitionStateValue::Counter(v)) => *v,
264            _ => 0,
265        };
266        let next = current + 1;
267        self.store.set(&key, PartitionStateValue::Counter(next));
268        next
269    }
270
271    /// Merge all state from `other` store into this aggregator's store.
272    pub fn merge_from(&mut self, other: &DistributedStateStore) {
273        for partition in &other.partitions {
274            for (key, value) in &partition.state {
275                self.store.set(key, value.clone());
276            }
277        }
278    }
279
280    /// Read-only access to the underlying store.
281    pub fn store(&self) -> &DistributedStateStore {
282        &self.store
283    }
284}
285
286// ─── Tests ────────────────────────────────────────────────────────────────────
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    // ── DistributedStateStore ──────────────────────────────────────────────────
293
294    #[test]
295    fn test_new_store_empty() {
296        let store = DistributedStateStore::new(4, 1);
297        assert_eq!(store.partition_count(), 4);
298        assert_eq!(store.total_keys(), 0);
299        assert_eq!(store.replication_factor(), 1);
300    }
301
302    #[test]
303    fn test_set_and_get_string() {
304        let mut store = DistributedStateStore::new(4, 1);
305        store.set("hello", PartitionStateValue::StringVal("world".to_string()));
306        match store.get("hello") {
307            Some(PartitionStateValue::StringVal(s)) => assert_eq!(s, "world"),
308            other => panic!("unexpected: {other:?}"),
309        }
310    }
311
312    #[test]
313    fn test_set_returns_version_increases() {
314        let mut store = DistributedStateStore::new(4, 1);
315        let v1 = store.set("k", PartitionStateValue::Integer(1));
316        let v2 = store.set("k", PartitionStateValue::Integer(2));
317        assert!(v2 > v1, "version must increase on each write");
318    }
319
320    #[test]
321    fn test_delete_existing_key() {
322        let mut store = DistributedStateStore::new(4, 1);
323        store.set("k", PartitionStateValue::Counter(10));
324        assert!(
325            store.delete("k"),
326            "delete should return true for existing key"
327        );
328        assert!(store.get("k").is_none());
329    }
330
331    #[test]
332    fn test_delete_missing_key() {
333        let mut store = DistributedStateStore::new(4, 1);
334        assert!(!store.delete("nonexistent"));
335    }
336
337    #[test]
338    fn test_partition_for_deterministic() {
339        let store = DistributedStateStore::new(8, 1);
340        let p1 = store.partition_for("my_key");
341        let p2 = store.partition_for("my_key");
342        assert_eq!(p1, p2, "same key must always map to same partition");
343        assert!(p1 < 8);
344    }
345
346    #[test]
347    fn test_partition_for_distributes_across_partitions() {
348        let store = DistributedStateStore::new(8, 1);
349        let keys = [
350            "alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta",
351        ];
352        let partitions: std::collections::HashSet<u32> =
353            keys.iter().map(|k| store.partition_for(k)).collect();
354        assert!(
355            partitions.len() >= 2,
356            "8 keys over 8 partitions must use at least 2 different partitions"
357        );
358    }
359
360    #[test]
361    fn test_total_keys_after_operations() {
362        let mut store = DistributedStateStore::new(4, 1);
363        store.set("a", PartitionStateValue::Integer(1));
364        store.set("b", PartitionStateValue::Integer(2));
365        store.set("c", PartitionStateValue::Integer(3));
366        assert_eq!(store.total_keys(), 3);
367        store.delete("b");
368        assert_eq!(store.total_keys(), 2);
369    }
370
371    #[test]
372    fn test_replicate_to_returns_partition_contents() {
373        let mut store = DistributedStateStore::new(4, 2);
374        store.set("key1", PartitionStateValue::Integer(42));
375        let pid = store.partition_for("key1");
376        let replica = store.replicate_to("peer-node", pid);
377        assert!(!replica.is_empty());
378        assert!(replica.iter().any(|(k, _)| k == "key1"));
379    }
380
381    #[test]
382    fn test_replicate_to_nonexistent_partition() {
383        let store = DistributedStateStore::new(4, 1);
384        let replica = store.replicate_to("peer", 99);
385        assert!(replica.is_empty());
386    }
387
388    #[test]
389    fn test_checkpoint_and_restore() {
390        let mut store = DistributedStateStore::new(4, 1);
391        let expected_val = 42.5_f64;
392        store.set("x", PartitionStateValue::Float(expected_val));
393        let pid = store.partition_for("x");
394
395        let checkpoint = store.checkpoint_partition(pid);
396        assert!(
397            checkpoint.last_checkpointed > 0,
398            "last_checkpointed must be set"
399        );
400
401        // Corrupt state
402        store.set("x", PartitionStateValue::Float(0.0));
403
404        // Restore
405        store.restore_partition(checkpoint);
406        match store.get("x") {
407            Some(PartitionStateValue::Float(v)) => {
408                assert!((v - expected_val).abs() < 1e-9);
409            }
410            other => panic!("unexpected after restore: {other:?}"),
411        }
412    }
413
414    #[test]
415    fn test_coordinator_default_node_id() {
416        let store = DistributedStateStore::new(2, 1);
417        assert_eq!(store.coordinator().node_id, "local");
418        assert!(store.coordinator().peers.is_empty());
419    }
420
421    #[test]
422    fn test_coordinator_add_peers() {
423        let mut store = DistributedStateStore::new(2, 1);
424        store.coordinator_mut().add_peer("node-2");
425        store.coordinator_mut().add_peer("node-3");
426        assert_eq!(store.coordinator().peers.len(), 2);
427    }
428
429    #[test]
430    fn test_all_value_variants() {
431        let mut store = DistributedStateStore::new(8, 1);
432        store.set("int_k", PartitionStateValue::Integer(-10));
433        store.set("float_k", PartitionStateValue::Float(2.5));
434        store.set("bytes_k", PartitionStateValue::Bytes(vec![1, 2, 3]));
435        store.set("str_k", PartitionStateValue::StringVal("hi".to_string()));
436        store.set("ctr_k", PartitionStateValue::Counter(99));
437        store.set(
438            "gauge_k",
439            PartitionStateValue::Gauge {
440                value: 1.0,
441                timestamp: 1000,
442            },
443        );
444        assert_eq!(store.total_keys(), 6);
445    }
446
447    #[test]
448    fn test_single_partition_all_keys_same_partition() {
449        let store = DistributedStateStore::new(1, 1);
450        assert_eq!(store.partition_for("anything"), 0);
451        assert_eq!(store.partition_for("other_key"), 0);
452    }
453
454    #[test]
455    fn test_overwrite_value() {
456        let mut store = DistributedStateStore::new(4, 1);
457        store.set("key", PartitionStateValue::Integer(1));
458        store.set("key", PartitionStateValue::Integer(2));
459        match store.get("key") {
460            Some(PartitionStateValue::Integer(v)) => assert_eq!(*v, 2),
461            other => panic!("unexpected: {other:?}"),
462        }
463    }
464
465    #[test]
466    fn test_state_partition_new() {
467        let p = StatePartition::new(5);
468        assert_eq!(p.partition_id, 5);
469        assert_eq!(p.version, 0);
470        assert!(p.state.is_empty());
471        assert_eq!(p.last_checkpointed, 0);
472    }
473
474    // ── StateAggregator ────────────────────────────────────────────────────────
475
476    #[test]
477    fn test_aggregator_increment_positive() {
478        let mut agg = StateAggregator::new(4);
479        assert_eq!(agg.increment("counter", 5), 5);
480        assert_eq!(agg.increment("counter", 3), 8);
481    }
482
483    #[test]
484    fn test_aggregator_increment_negative() {
485        let mut agg = StateAggregator::new(4);
486        agg.increment("counter", 10);
487        assert_eq!(agg.increment("counter", -2), 8);
488    }
489
490    #[test]
491    fn test_aggregator_accumulate_floats() {
492        let mut agg = StateAggregator::new(4);
493        let v1 = agg.accumulate("sum", 1.5);
494        let v2 = agg.accumulate("sum", 2.5);
495        assert!((v1 - 1.5).abs() < 1e-9);
496        assert!((v2 - 4.0).abs() < 1e-9);
497    }
498
499    #[test]
500    fn test_aggregator_update_gauge() {
501        let mut agg = StateAggregator::new(4);
502        agg.update_gauge("temperature", 98.6);
503        match agg.store().get("temperature") {
504            Some(PartitionStateValue::Gauge { value, .. }) => {
505                assert!((value - 98.6).abs() < 1e-9);
506            }
507            other => panic!("unexpected: {other:?}"),
508        }
509    }
510
511    #[test]
512    fn test_aggregator_window_count_isolated() {
513        let mut agg = StateAggregator::new(4);
514        assert_eq!(agg.window_count("win-1", "click"), 1);
515        assert_eq!(agg.window_count("win-1", "click"), 2);
516        assert_eq!(agg.window_count("win-1", "view"), 1);
517        assert_eq!(agg.window_count("win-2", "click"), 1);
518    }
519
520    #[test]
521    fn test_aggregator_merge_from() {
522        let mut store2 = DistributedStateStore::new(4, 1);
523        store2.set("shared_key", PartitionStateValue::Integer(100));
524
525        let mut agg = StateAggregator::new(4);
526        agg.merge_from(&store2);
527
528        match agg.store().get("shared_key") {
529            Some(PartitionStateValue::Integer(v)) => assert_eq!(*v, 100),
530            other => panic!("unexpected: {other:?}"),
531        }
532    }
533
534    #[test]
535    fn test_aggregator_store_accessor() {
536        let agg = StateAggregator::new(4);
537        assert_eq!(agg.store().partition_count(), 4);
538        assert_eq!(agg.store().total_keys(), 0);
539    }
540}