Skip to main content

oxirs_stream/state/
distributed_state.rs

1//! # Distributed State Store
2//!
3//! Consistent state store for stateful stream operators across partitions.
4//! Supports: key-value state, list state, map state, aggregating state.
5
6use crate::error::StreamError;
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::Instant;
10
11// ─── Partition Key ────────────────────────────────────────────────────────────
12
13/// Unique identifier for a state partition.
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct StatePartitionKey {
16    pub operator_id: String,
17    pub partition_id: u32,
18    pub subtask_index: u32,
19}
20
21impl StatePartitionKey {
22    /// Create a new partition key.
23    pub fn new(operator_id: impl Into<String>, partition_id: u32, subtask_index: u32) -> Self {
24        Self {
25            operator_id: operator_id.into(),
26            partition_id,
27            subtask_index,
28        }
29    }
30
31    /// Serialize partition key to a byte prefix for namespacing state keys.
32    pub fn to_prefix(&self) -> Vec<u8> {
33        format!(
34            "{}:{}:{}:",
35            self.operator_id, self.partition_id, self.subtask_index
36        )
37        .into_bytes()
38    }
39}
40
41// ─── StateBackend trait ───────────────────────────────────────────────────────
42
43/// Pluggable storage backend for state operators.
44///
45/// Implementations must be `Send + Sync` so they can be shared across async
46/// tasks and threads.
47pub trait StateBackend: Send + Sync {
48    /// Return the value for the given key, or `None` if absent.
49    fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, StreamError>;
50
51    /// Insert or overwrite a key-value pair.
52    fn put(&self, key: &[u8], value: &[u8]) -> Result<(), StreamError>;
53
54    /// Remove a key. Returns `true` if the key previously existed.
55    fn delete(&self, key: &[u8]) -> Result<bool, StreamError>;
56
57    /// Return all entries whose key starts with `prefix`.
58    #[allow(clippy::type_complexity)]
59    fn range_scan(&self, prefix: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>, StreamError>;
60
61    /// Serialize the current state into an opaque byte snapshot tagged with
62    /// `checkpoint_id`.
63    fn checkpoint(&self, checkpoint_id: u64) -> Result<Vec<u8>, StreamError>;
64
65    /// Replace current state with the content of a previously-created snapshot.
66    fn restore(&self, snapshot: &[u8]) -> Result<(), StreamError>;
67
68    /// Approximate heap/disk footprint in bytes.
69    fn size_bytes(&self) -> usize;
70}
71
72// ─── Snapshot encoding ────────────────────────────────────────────────────────
73//
74// Binary format (all integers little-endian):
75//
76//   [u64 checkpoint_id]
77//   [u64 entry_count]
78//   { [u32 key_len] [key_bytes…] [u32 val_len] [val_bytes…] } × entry_count
79
80fn encode_snapshot(checkpoint_id: u64, data: &HashMap<Vec<u8>, Vec<u8>>) -> Vec<u8> {
81    let entries_size: usize = data.iter().map(|(k, v)| 8 + k.len() + v.len()).sum();
82    let mut out = Vec::with_capacity(16 + entries_size);
83    out.extend_from_slice(&checkpoint_id.to_le_bytes());
84    out.extend_from_slice(&(data.len() as u64).to_le_bytes());
85    for (k, v) in data {
86        out.extend_from_slice(&(k.len() as u32).to_le_bytes());
87        out.extend_from_slice(k);
88        out.extend_from_slice(&(v.len() as u32).to_le_bytes());
89        out.extend_from_slice(v);
90    }
91    out
92}
93
94/// Read a `u64` from `buf[offset..offset+8]`, returning a descriptive error on
95/// failure.
96#[inline]
97fn read_u64(buf: &[u8], offset: usize, field: &str) -> Result<u64, StreamError> {
98    buf.get(offset..offset + 8)
99        .ok_or_else(|| StreamError::Deserialization(format!("snapshot truncated reading {field}")))?
100        .try_into()
101        .map(u64::from_le_bytes)
102        .map_err(|_| StreamError::Deserialization(format!("bad bytes for {field}")))
103}
104
105/// Read a `u32` from `buf[offset..offset+4]`.
106#[inline]
107fn read_u32(buf: &[u8], offset: usize, field: &str) -> Result<u32, StreamError> {
108    buf.get(offset..offset + 4)
109        .ok_or_else(|| StreamError::Deserialization(format!("snapshot truncated reading {field}")))?
110        .try_into()
111        .map(u32::from_le_bytes)
112        .map_err(|_| StreamError::Deserialization(format!("bad bytes for {field}")))
113}
114
115#[allow(clippy::type_complexity)]
116fn decode_snapshot(snapshot: &[u8]) -> Result<(u64, HashMap<Vec<u8>, Vec<u8>>), StreamError> {
117    if snapshot.len() < 16 {
118        return Err(StreamError::Deserialization(
119            "snapshot too short to contain header".into(),
120        ));
121    }
122
123    let checkpoint_id = read_u64(snapshot, 0, "checkpoint_id")?;
124    let entry_count = read_u64(snapshot, 8, "entry_count")? as usize;
125
126    let mut pos = 16usize;
127    let mut data = HashMap::with_capacity(entry_count);
128
129    for i in 0..entry_count {
130        let key_len = read_u32(snapshot, pos, &format!("key_len[{i}]"))? as usize;
131        pos += 4;
132
133        let key = snapshot
134            .get(pos..pos + key_len)
135            .ok_or_else(|| {
136                StreamError::Deserialization(format!("snapshot truncated at key data[{i}]"))
137            })?
138            .to_vec();
139        pos += key_len;
140
141        let val_len = read_u32(snapshot, pos, &format!("val_len[{i}]"))? as usize;
142        pos += 4;
143
144        let val = snapshot
145            .get(pos..pos + val_len)
146            .ok_or_else(|| {
147                StreamError::Deserialization(format!("snapshot truncated at val data[{i}]"))
148            })?
149            .to_vec();
150        pos += val_len;
151
152        data.insert(key, val);
153    }
154
155    Ok((checkpoint_id, data))
156}
157
158// ─── In-memory backend ────────────────────────────────────────────────────────
159
160/// In-memory `StateBackend` — fast but not durable across process restarts.
161pub struct InMemoryStateBackend {
162    data: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
163    /// Monotonically increasing logical version, incremented on every write.
164    version: Arc<RwLock<u64>>,
165}
166
167impl InMemoryStateBackend {
168    /// Create an empty in-memory backend.
169    pub fn new() -> Self {
170        Self {
171            data: Arc::new(RwLock::new(HashMap::new())),
172            version: Arc::new(RwLock::new(0)),
173        }
174    }
175
176    /// Current logical version.
177    pub fn version(&self) -> Result<u64, StreamError> {
178        self.version
179            .read()
180            .map(|g| *g)
181            .map_err(|e| StreamError::Other(format!("version lock poisoned: {e}")))
182    }
183
184    fn bump_version(&self) -> Result<(), StreamError> {
185        let mut ver = self
186            .version
187            .write()
188            .map_err(|e| StreamError::Other(format!("version write-lock poisoned: {e}")))?;
189        *ver += 1;
190        Ok(())
191    }
192}
193
194impl Default for InMemoryStateBackend {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200impl StateBackend for InMemoryStateBackend {
201    fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, StreamError> {
202        let data = self
203            .data
204            .read()
205            .map_err(|e| StreamError::Other(format!("data read-lock poisoned: {e}")))?;
206        Ok(data.get(key).cloned())
207    }
208
209    fn put(&self, key: &[u8], value: &[u8]) -> Result<(), StreamError> {
210        {
211            let mut data = self
212                .data
213                .write()
214                .map_err(|e| StreamError::Other(format!("data write-lock poisoned: {e}")))?;
215            data.insert(key.to_vec(), value.to_vec());
216        }
217        self.bump_version()
218    }
219
220    fn delete(&self, key: &[u8]) -> Result<bool, StreamError> {
221        let existed = {
222            let mut data = self
223                .data
224                .write()
225                .map_err(|e| StreamError::Other(format!("data write-lock poisoned: {e}")))?;
226            data.remove(key).is_some()
227        };
228        if existed {
229            self.bump_version()?;
230        }
231        Ok(existed)
232    }
233
234    fn range_scan(&self, prefix: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>, StreamError> {
235        let data = self
236            .data
237            .read()
238            .map_err(|e| StreamError::Other(format!("data read-lock poisoned: {e}")))?;
239        let results = data
240            .iter()
241            .filter(|(k, _)| k.starts_with(prefix))
242            .map(|(k, v)| (k.clone(), v.clone()))
243            .collect();
244        Ok(results)
245    }
246
247    fn checkpoint(&self, checkpoint_id: u64) -> Result<Vec<u8>, StreamError> {
248        let data = self
249            .data
250            .read()
251            .map_err(|e| StreamError::Other(format!("data read-lock poisoned: {e}")))?;
252        Ok(encode_snapshot(checkpoint_id, &data))
253    }
254
255    fn restore(&self, snapshot: &[u8]) -> Result<(), StreamError> {
256        let (_checkpoint_id, restored) = decode_snapshot(snapshot)?;
257        {
258            let mut data = self
259                .data
260                .write()
261                .map_err(|e| StreamError::Other(format!("data write-lock poisoned: {e}")))?;
262            *data = restored;
263        }
264        self.bump_version()
265    }
266
267    fn size_bytes(&self) -> usize {
268        // If the lock is poisoned we return 0 (best-effort metric).
269        match self.data.read() {
270            Ok(data) => data.iter().map(|(k, v)| k.len() + v.len()).sum(),
271            Err(_) => 0,
272        }
273    }
274}
275
276// ─── Keyed state store ────────────────────────────────────────────────────────
277
278/// Typed, partitioned key-value state handle for stateful stream operators.
279///
280/// Serialization is provided by caller-supplied function pointers to keep this
281/// crate free of hard-coded codec dependencies.
282pub struct KeyedStateStore<K, V> {
283    partition: StatePartitionKey,
284    backend: Arc<dyn StateBackend>,
285    key_serializer: fn(&K) -> Vec<u8>,
286    value_serializer: fn(&V) -> Vec<u8>,
287    value_deserializer: fn(&[u8]) -> Result<V, StreamError>,
288    _phantom: std::marker::PhantomData<(K, V)>,
289}
290
291impl<K: std::fmt::Debug, V: std::fmt::Debug + Clone> KeyedStateStore<K, V> {
292    /// Create a new keyed store backed by `backend`.
293    pub fn new(
294        partition: StatePartitionKey,
295        backend: Arc<dyn StateBackend>,
296        key_ser: fn(&K) -> Vec<u8>,
297        val_ser: fn(&V) -> Vec<u8>,
298        val_de: fn(&[u8]) -> Result<V, StreamError>,
299    ) -> Self {
300        Self {
301            partition,
302            backend,
303            key_serializer: key_ser,
304            value_serializer: val_ser,
305            value_deserializer: val_de,
306            _phantom: std::marker::PhantomData,
307        }
308    }
309
310    /// Build the fully-namespaced storage key for `key`.
311    fn storage_key(&self, key: &K) -> Vec<u8> {
312        let mut prefix = self.partition.to_prefix();
313        prefix.extend_from_slice(&(self.key_serializer)(key));
314        prefix
315    }
316
317    /// Return the value for `key`, or `None` if absent.
318    pub fn get(&self, key: &K) -> Result<Option<V>, StreamError> {
319        match self.backend.get(&self.storage_key(key))? {
320            None => Ok(None),
321            Some(bytes) => (self.value_deserializer)(&bytes).map(Some),
322        }
323    }
324
325    /// Store `value` under `key`.
326    pub fn put(&self, key: &K, value: V) -> Result<(), StreamError> {
327        let bytes = (self.value_serializer)(&value);
328        self.backend.put(&self.storage_key(key), &bytes)
329    }
330
331    /// Remove `key`. Returns `true` if it existed.
332    pub fn delete(&self, key: &K) -> Result<bool, StreamError> {
333        self.backend.delete(&self.storage_key(key))
334    }
335
336    /// Atomic read-modify-write.  `updater` receives the current value (or
337    /// `None`) and returns the new value to store.
338    pub fn update_or_default(
339        &self,
340        key: &K,
341        updater: impl FnOnce(Option<V>) -> V,
342    ) -> Result<V, StreamError> {
343        let current = self.get(key)?;
344        let new_value = updater(current);
345        self.put(key, new_value.clone())?;
346        Ok(new_value)
347    }
348}
349
350// ─── Aggregating state ────────────────────────────────────────────────────────
351
352/// Aggregating state that folds incoming values into a running accumulator.
353///
354/// Typical use-cases: running sum, count, min/max, HyperLogLog cardinality.
355pub struct AggregatingState<In, Out> {
356    partition: StatePartitionKey,
357    backend: Arc<dyn StateBackend>,
358    /// Fixed key used to store the accumulator within the backend.
359    aggregate_key: Vec<u8>,
360    /// `combine_fn(accumulator, new_input) -> new_accumulator`
361    combine_fn: fn(Out, In) -> Out,
362    /// Value returned when no accumulator has been written yet.
363    default: Out,
364    serializer: fn(&Out) -> Vec<u8>,
365    deserializer: fn(&[u8]) -> Result<Out, StreamError>,
366    _phantom: std::marker::PhantomData<In>,
367}
368
369impl<In, Out: Clone> AggregatingState<In, Out> {
370    /// Create a new aggregating state descriptor.
371    #[allow(clippy::too_many_arguments)]
372    pub fn new(
373        partition: StatePartitionKey,
374        backend: Arc<dyn StateBackend>,
375        aggregate_key: Vec<u8>,
376        combine_fn: fn(Out, In) -> Out,
377        default: Out,
378        serializer: fn(&Out) -> Vec<u8>,
379        deserializer: fn(&[u8]) -> Result<Out, StreamError>,
380    ) -> Self {
381        Self {
382            partition,
383            backend,
384            aggregate_key,
385            combine_fn,
386            default,
387            serializer,
388            deserializer,
389            _phantom: std::marker::PhantomData,
390        }
391    }
392
393    fn storage_key(&self) -> Vec<u8> {
394        let mut prefix = self.partition.to_prefix();
395        prefix.extend_from_slice(&self.aggregate_key);
396        prefix
397    }
398
399    fn read_accumulator(&self) -> Result<Out, StreamError> {
400        match self.backend.get(&self.storage_key())? {
401            None => Ok(self.default.clone()),
402            Some(bytes) => (self.deserializer)(&bytes),
403        }
404    }
405
406    /// Fold `value` into the accumulator.
407    pub fn add(&self, value: In) -> Result<(), StreamError> {
408        let current = self.read_accumulator()?;
409        let new_acc = (self.combine_fn)(current, value);
410        self.backend
411            .put(&self.storage_key(), &(self.serializer)(&new_acc))
412    }
413
414    /// Return the current accumulator value.
415    pub fn get(&self) -> Result<Out, StreamError> {
416        self.read_accumulator()
417    }
418
419    /// Reset the accumulator (delete from backend so default is returned).
420    pub fn clear(&self) -> Result<(), StreamError> {
421        self.backend.delete(&self.storage_key()).map(|_| ())
422    }
423}
424
425// ─── Stats ────────────────────────────────────────────────────────────────────
426
427/// Point-in-time metrics for a state backend.
428#[derive(Debug, Clone)]
429pub struct StateBackendStats {
430    pub size_bytes: usize,
431    pub collected_at: Instant,
432}
433
434impl StateBackendStats {
435    /// Collect current stats from a backend.
436    pub fn collect(backend: &dyn StateBackend) -> Self {
437        Self {
438            size_bytes: backend.size_bytes(),
439            collected_at: Instant::now(),
440        }
441    }
442}
443
444// ─── Tests ────────────────────────────────────────────────────────────────────
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    // Helper serializers / deserializers used only in tests.
451    fn str_key_ser(k: &String) -> Vec<u8> {
452        k.as_bytes().to_vec()
453    }
454
455    fn i64_ser(v: &i64) -> Vec<u8> {
456        v.to_le_bytes().to_vec()
457    }
458
459    fn i64_de(b: &[u8]) -> Result<i64, StreamError> {
460        if b.len() < 8 {
461            return Err(StreamError::Deserialization("i64 needs 8 bytes".into()));
462        }
463        let arr: [u8; 8] = b[..8]
464            .try_into()
465            .map_err(|_| StreamError::Deserialization("i64 slice error".into()))?;
466        Ok(i64::from_le_bytes(arr))
467    }
468
469    fn u64_ser(v: &u64) -> Vec<u8> {
470        v.to_le_bytes().to_vec()
471    }
472
473    fn u64_de(b: &[u8]) -> Result<u64, StreamError> {
474        if b.len() < 8 {
475            return Err(StreamError::Deserialization("u64 needs 8 bytes".into()));
476        }
477        let arr: [u8; 8] = b[..8]
478            .try_into()
479            .map_err(|_| StreamError::Deserialization("u64 slice error".into()))?;
480        Ok(u64::from_le_bytes(arr))
481    }
482
483    fn partition() -> StatePartitionKey {
484        StatePartitionKey::new("op1", 0, 0)
485    }
486
487    #[test]
488    fn test_backend_put_get_delete() {
489        let backend = InMemoryStateBackend::new();
490
491        backend.put(b"hello", b"world").unwrap();
492        let val = backend.get(b"hello").unwrap();
493        assert_eq!(val.as_deref(), Some(b"world".as_ref()));
494
495        let existed = backend.delete(b"hello").unwrap();
496        assert!(existed);
497
498        assert!(backend.get(b"hello").unwrap().is_none());
499
500        let not_found = backend.delete(b"missing").unwrap();
501        assert!(!not_found);
502    }
503
504    #[test]
505    fn test_backend_range_scan() {
506        let backend = InMemoryStateBackend::new();
507
508        backend.put(b"ns:a", b"1").unwrap();
509        backend.put(b"ns:b", b"2").unwrap();
510        backend.put(b"other:c", b"3").unwrap();
511
512        let results = backend.range_scan(b"ns:").unwrap();
513        assert_eq!(results.len(), 2);
514
515        let all = backend.range_scan(b"").unwrap();
516        assert_eq!(all.len(), 3);
517    }
518
519    #[test]
520    fn test_backend_checkpoint_restore() {
521        let backend = InMemoryStateBackend::new();
522
523        backend.put(b"k1", b"v1").unwrap();
524        backend.put(b"k2", b"v2").unwrap();
525
526        let snapshot = backend.checkpoint(42).unwrap();
527        assert!(!snapshot.is_empty());
528
529        // Corrupt the live state.
530        backend.delete(b"k1").unwrap();
531        backend.put(b"k2", b"changed").unwrap();
532        backend.put(b"k3", b"new").unwrap();
533
534        // Restore to snapshot.
535        backend.restore(&snapshot).unwrap();
536
537        assert_eq!(backend.get(b"k1").unwrap().as_deref(), Some(b"v1".as_ref()));
538        assert_eq!(backend.get(b"k2").unwrap().as_deref(), Some(b"v2".as_ref()));
539        assert!(backend.get(b"k3").unwrap().is_none());
540    }
541
542    #[test]
543    fn test_backend_size_bytes() {
544        let backend = InMemoryStateBackend::new();
545        assert_eq!(backend.size_bytes(), 0);
546
547        backend.put(b"abc", b"def").unwrap();
548        assert_eq!(backend.size_bytes(), 6);
549    }
550
551    #[test]
552    fn test_keyed_state_store_basic() {
553        let backend = Arc::new(InMemoryStateBackend::new());
554        let store: KeyedStateStore<String, i64> =
555            KeyedStateStore::new(partition(), backend, str_key_ser, i64_ser, i64_de);
556
557        let key = "counter".to_string();
558
559        assert!(store.get(&key).unwrap().is_none());
560
561        store.put(&key, 10).unwrap();
562        assert_eq!(store.get(&key).unwrap(), Some(10));
563
564        let new_val = store
565            .update_or_default(&key, |cur| cur.unwrap_or(0) + 5)
566            .unwrap();
567        assert_eq!(new_val, 15);
568        assert_eq!(store.get(&key).unwrap(), Some(15));
569
570        assert!(store.delete(&key).unwrap());
571        assert!(store.get(&key).unwrap().is_none());
572    }
573
574    #[test]
575    fn test_aggregating_state_sum() {
576        let backend = Arc::new(InMemoryStateBackend::new());
577
578        fn combine(acc: u64, x: u64) -> u64 {
579            acc + x
580        }
581
582        let agg: AggregatingState<u64, u64> = AggregatingState::new(
583            partition(),
584            backend,
585            b"total".to_vec(),
586            combine,
587            0u64,
588            u64_ser,
589            u64_de,
590        );
591
592        assert_eq!(agg.get().unwrap(), 0);
593
594        agg.add(10).unwrap();
595        agg.add(20).unwrap();
596        agg.add(5).unwrap();
597
598        assert_eq!(agg.get().unwrap(), 35);
599
600        agg.clear().unwrap();
601        assert_eq!(agg.get().unwrap(), 0);
602    }
603
604    #[test]
605    fn test_partition_namespacing_isolation() {
606        let backend = Arc::new(InMemoryStateBackend::new());
607
608        let p1 = StatePartitionKey::new("op", 0, 0);
609        let p2 = StatePartitionKey::new("op", 0, 1);
610
611        let store1: KeyedStateStore<String, i64> =
612            KeyedStateStore::new(p1, backend.clone(), str_key_ser, i64_ser, i64_de);
613        let store2: KeyedStateStore<String, i64> =
614            KeyedStateStore::new(p2, backend, str_key_ser, i64_ser, i64_de);
615
616        let key = "x".to_string();
617        store1.put(&key, 1).unwrap();
618        store2.put(&key, 2).unwrap();
619
620        assert_eq!(store1.get(&key).unwrap(), Some(1));
621        assert_eq!(store2.get(&key).unwrap(), Some(2));
622    }
623
624    #[test]
625    fn test_snapshot_round_trip_empty() {
626        let backend = InMemoryStateBackend::new();
627        let snapshot = backend.checkpoint(0).unwrap();
628
629        let new_backend = InMemoryStateBackend::new();
630        new_backend.restore(&snapshot).unwrap();
631        assert_eq!(new_backend.size_bytes(), 0);
632    }
633
634    #[test]
635    fn test_decode_snapshot_too_short() {
636        let result = decode_snapshot(b"short");
637        assert!(result.is_err());
638    }
639
640    #[test]
641    fn test_version_bumps_on_write() {
642        let backend = InMemoryStateBackend::new();
643        let v0 = backend.version().unwrap();
644        backend.put(b"k", b"v").unwrap();
645        let v1 = backend.version().unwrap();
646        assert!(v1 > v0);
647        backend.delete(b"k").unwrap();
648        let v2 = backend.version().unwrap();
649        assert!(v2 > v1);
650    }
651
652    #[test]
653    fn test_state_backend_stats() {
654        let backend = InMemoryStateBackend::new();
655        backend.put(b"key", b"value").unwrap();
656        let stats = StateBackendStats::collect(&backend);
657        assert_eq!(stats.size_bytes, 8); // 3 + 5
658    }
659}