1use crate::error::StreamError;
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use std::time::Instant;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct StatePartitionKey {
16 pub operator_id: String,
17 pub partition_id: u32,
18 pub subtask_index: u32,
19}
20
21impl StatePartitionKey {
22 pub fn new(operator_id: impl Into<String>, partition_id: u32, subtask_index: u32) -> Self {
24 Self {
25 operator_id: operator_id.into(),
26 partition_id,
27 subtask_index,
28 }
29 }
30
31 pub fn to_prefix(&self) -> Vec<u8> {
33 format!(
34 "{}:{}:{}:",
35 self.operator_id, self.partition_id, self.subtask_index
36 )
37 .into_bytes()
38 }
39}
40
41pub trait StateBackend: Send + Sync {
48 fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, StreamError>;
50
51 fn put(&self, key: &[u8], value: &[u8]) -> Result<(), StreamError>;
53
54 fn delete(&self, key: &[u8]) -> Result<bool, StreamError>;
56
57 #[allow(clippy::type_complexity)]
59 fn range_scan(&self, prefix: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>, StreamError>;
60
61 fn checkpoint(&self, checkpoint_id: u64) -> Result<Vec<u8>, StreamError>;
64
65 fn restore(&self, snapshot: &[u8]) -> Result<(), StreamError>;
67
68 fn size_bytes(&self) -> usize;
70}
71
72fn encode_snapshot(checkpoint_id: u64, data: &HashMap<Vec<u8>, Vec<u8>>) -> Vec<u8> {
81 let entries_size: usize = data.iter().map(|(k, v)| 8 + k.len() + v.len()).sum();
82 let mut out = Vec::with_capacity(16 + entries_size);
83 out.extend_from_slice(&checkpoint_id.to_le_bytes());
84 out.extend_from_slice(&(data.len() as u64).to_le_bytes());
85 for (k, v) in data {
86 out.extend_from_slice(&(k.len() as u32).to_le_bytes());
87 out.extend_from_slice(k);
88 out.extend_from_slice(&(v.len() as u32).to_le_bytes());
89 out.extend_from_slice(v);
90 }
91 out
92}
93
94#[inline]
97fn read_u64(buf: &[u8], offset: usize, field: &str) -> Result<u64, StreamError> {
98 buf.get(offset..offset + 8)
99 .ok_or_else(|| StreamError::Deserialization(format!("snapshot truncated reading {field}")))?
100 .try_into()
101 .map(u64::from_le_bytes)
102 .map_err(|_| StreamError::Deserialization(format!("bad bytes for {field}")))
103}
104
105#[inline]
107fn read_u32(buf: &[u8], offset: usize, field: &str) -> Result<u32, StreamError> {
108 buf.get(offset..offset + 4)
109 .ok_or_else(|| StreamError::Deserialization(format!("snapshot truncated reading {field}")))?
110 .try_into()
111 .map(u32::from_le_bytes)
112 .map_err(|_| StreamError::Deserialization(format!("bad bytes for {field}")))
113}
114
115#[allow(clippy::type_complexity)]
116fn decode_snapshot(snapshot: &[u8]) -> Result<(u64, HashMap<Vec<u8>, Vec<u8>>), StreamError> {
117 if snapshot.len() < 16 {
118 return Err(StreamError::Deserialization(
119 "snapshot too short to contain header".into(),
120 ));
121 }
122
123 let checkpoint_id = read_u64(snapshot, 0, "checkpoint_id")?;
124 let entry_count = read_u64(snapshot, 8, "entry_count")? as usize;
125
126 let mut pos = 16usize;
127 let mut data = HashMap::with_capacity(entry_count);
128
129 for i in 0..entry_count {
130 let key_len = read_u32(snapshot, pos, &format!("key_len[{i}]"))? as usize;
131 pos += 4;
132
133 let key = snapshot
134 .get(pos..pos + key_len)
135 .ok_or_else(|| {
136 StreamError::Deserialization(format!("snapshot truncated at key data[{i}]"))
137 })?
138 .to_vec();
139 pos += key_len;
140
141 let val_len = read_u32(snapshot, pos, &format!("val_len[{i}]"))? as usize;
142 pos += 4;
143
144 let val = snapshot
145 .get(pos..pos + val_len)
146 .ok_or_else(|| {
147 StreamError::Deserialization(format!("snapshot truncated at val data[{i}]"))
148 })?
149 .to_vec();
150 pos += val_len;
151
152 data.insert(key, val);
153 }
154
155 Ok((checkpoint_id, data))
156}
157
158pub struct InMemoryStateBackend {
162 data: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
163 version: Arc<RwLock<u64>>,
165}
166
167impl InMemoryStateBackend {
168 pub fn new() -> Self {
170 Self {
171 data: Arc::new(RwLock::new(HashMap::new())),
172 version: Arc::new(RwLock::new(0)),
173 }
174 }
175
176 pub fn version(&self) -> Result<u64, StreamError> {
178 self.version
179 .read()
180 .map(|g| *g)
181 .map_err(|e| StreamError::Other(format!("version lock poisoned: {e}")))
182 }
183
184 fn bump_version(&self) -> Result<(), StreamError> {
185 let mut ver = self
186 .version
187 .write()
188 .map_err(|e| StreamError::Other(format!("version write-lock poisoned: {e}")))?;
189 *ver += 1;
190 Ok(())
191 }
192}
193
194impl Default for InMemoryStateBackend {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200impl StateBackend for InMemoryStateBackend {
201 fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>, StreamError> {
202 let data = self
203 .data
204 .read()
205 .map_err(|e| StreamError::Other(format!("data read-lock poisoned: {e}")))?;
206 Ok(data.get(key).cloned())
207 }
208
209 fn put(&self, key: &[u8], value: &[u8]) -> Result<(), StreamError> {
210 {
211 let mut data = self
212 .data
213 .write()
214 .map_err(|e| StreamError::Other(format!("data write-lock poisoned: {e}")))?;
215 data.insert(key.to_vec(), value.to_vec());
216 }
217 self.bump_version()
218 }
219
220 fn delete(&self, key: &[u8]) -> Result<bool, StreamError> {
221 let existed = {
222 let mut data = self
223 .data
224 .write()
225 .map_err(|e| StreamError::Other(format!("data write-lock poisoned: {e}")))?;
226 data.remove(key).is_some()
227 };
228 if existed {
229 self.bump_version()?;
230 }
231 Ok(existed)
232 }
233
234 fn range_scan(&self, prefix: &[u8]) -> Result<Vec<(Vec<u8>, Vec<u8>)>, StreamError> {
235 let data = self
236 .data
237 .read()
238 .map_err(|e| StreamError::Other(format!("data read-lock poisoned: {e}")))?;
239 let results = data
240 .iter()
241 .filter(|(k, _)| k.starts_with(prefix))
242 .map(|(k, v)| (k.clone(), v.clone()))
243 .collect();
244 Ok(results)
245 }
246
247 fn checkpoint(&self, checkpoint_id: u64) -> Result<Vec<u8>, StreamError> {
248 let data = self
249 .data
250 .read()
251 .map_err(|e| StreamError::Other(format!("data read-lock poisoned: {e}")))?;
252 Ok(encode_snapshot(checkpoint_id, &data))
253 }
254
255 fn restore(&self, snapshot: &[u8]) -> Result<(), StreamError> {
256 let (_checkpoint_id, restored) = decode_snapshot(snapshot)?;
257 {
258 let mut data = self
259 .data
260 .write()
261 .map_err(|e| StreamError::Other(format!("data write-lock poisoned: {e}")))?;
262 *data = restored;
263 }
264 self.bump_version()
265 }
266
267 fn size_bytes(&self) -> usize {
268 match self.data.read() {
270 Ok(data) => data.iter().map(|(k, v)| k.len() + v.len()).sum(),
271 Err(_) => 0,
272 }
273 }
274}
275
276pub struct KeyedStateStore<K, V> {
283 partition: StatePartitionKey,
284 backend: Arc<dyn StateBackend>,
285 key_serializer: fn(&K) -> Vec<u8>,
286 value_serializer: fn(&V) -> Vec<u8>,
287 value_deserializer: fn(&[u8]) -> Result<V, StreamError>,
288 _phantom: std::marker::PhantomData<(K, V)>,
289}
290
291impl<K: std::fmt::Debug, V: std::fmt::Debug + Clone> KeyedStateStore<K, V> {
292 pub fn new(
294 partition: StatePartitionKey,
295 backend: Arc<dyn StateBackend>,
296 key_ser: fn(&K) -> Vec<u8>,
297 val_ser: fn(&V) -> Vec<u8>,
298 val_de: fn(&[u8]) -> Result<V, StreamError>,
299 ) -> Self {
300 Self {
301 partition,
302 backend,
303 key_serializer: key_ser,
304 value_serializer: val_ser,
305 value_deserializer: val_de,
306 _phantom: std::marker::PhantomData,
307 }
308 }
309
310 fn storage_key(&self, key: &K) -> Vec<u8> {
312 let mut prefix = self.partition.to_prefix();
313 prefix.extend_from_slice(&(self.key_serializer)(key));
314 prefix
315 }
316
317 pub fn get(&self, key: &K) -> Result<Option<V>, StreamError> {
319 match self.backend.get(&self.storage_key(key))? {
320 None => Ok(None),
321 Some(bytes) => (self.value_deserializer)(&bytes).map(Some),
322 }
323 }
324
325 pub fn put(&self, key: &K, value: V) -> Result<(), StreamError> {
327 let bytes = (self.value_serializer)(&value);
328 self.backend.put(&self.storage_key(key), &bytes)
329 }
330
331 pub fn delete(&self, key: &K) -> Result<bool, StreamError> {
333 self.backend.delete(&self.storage_key(key))
334 }
335
336 pub fn update_or_default(
339 &self,
340 key: &K,
341 updater: impl FnOnce(Option<V>) -> V,
342 ) -> Result<V, StreamError> {
343 let current = self.get(key)?;
344 let new_value = updater(current);
345 self.put(key, new_value.clone())?;
346 Ok(new_value)
347 }
348}
349
350pub struct AggregatingState<In, Out> {
356 partition: StatePartitionKey,
357 backend: Arc<dyn StateBackend>,
358 aggregate_key: Vec<u8>,
360 combine_fn: fn(Out, In) -> Out,
362 default: Out,
364 serializer: fn(&Out) -> Vec<u8>,
365 deserializer: fn(&[u8]) -> Result<Out, StreamError>,
366 _phantom: std::marker::PhantomData<In>,
367}
368
369impl<In, Out: Clone> AggregatingState<In, Out> {
370 #[allow(clippy::too_many_arguments)]
372 pub fn new(
373 partition: StatePartitionKey,
374 backend: Arc<dyn StateBackend>,
375 aggregate_key: Vec<u8>,
376 combine_fn: fn(Out, In) -> Out,
377 default: Out,
378 serializer: fn(&Out) -> Vec<u8>,
379 deserializer: fn(&[u8]) -> Result<Out, StreamError>,
380 ) -> Self {
381 Self {
382 partition,
383 backend,
384 aggregate_key,
385 combine_fn,
386 default,
387 serializer,
388 deserializer,
389 _phantom: std::marker::PhantomData,
390 }
391 }
392
393 fn storage_key(&self) -> Vec<u8> {
394 let mut prefix = self.partition.to_prefix();
395 prefix.extend_from_slice(&self.aggregate_key);
396 prefix
397 }
398
399 fn read_accumulator(&self) -> Result<Out, StreamError> {
400 match self.backend.get(&self.storage_key())? {
401 None => Ok(self.default.clone()),
402 Some(bytes) => (self.deserializer)(&bytes),
403 }
404 }
405
406 pub fn add(&self, value: In) -> Result<(), StreamError> {
408 let current = self.read_accumulator()?;
409 let new_acc = (self.combine_fn)(current, value);
410 self.backend
411 .put(&self.storage_key(), &(self.serializer)(&new_acc))
412 }
413
414 pub fn get(&self) -> Result<Out, StreamError> {
416 self.read_accumulator()
417 }
418
419 pub fn clear(&self) -> Result<(), StreamError> {
421 self.backend.delete(&self.storage_key()).map(|_| ())
422 }
423}
424
425#[derive(Debug, Clone)]
429pub struct StateBackendStats {
430 pub size_bytes: usize,
431 pub collected_at: Instant,
432}
433
434impl StateBackendStats {
435 pub fn collect(backend: &dyn StateBackend) -> Self {
437 Self {
438 size_bytes: backend.size_bytes(),
439 collected_at: Instant::now(),
440 }
441 }
442}
443
444#[cfg(test)]
447mod tests {
448 use super::*;
449
450 fn str_key_ser(k: &String) -> Vec<u8> {
452 k.as_bytes().to_vec()
453 }
454
455 fn i64_ser(v: &i64) -> Vec<u8> {
456 v.to_le_bytes().to_vec()
457 }
458
459 fn i64_de(b: &[u8]) -> Result<i64, StreamError> {
460 if b.len() < 8 {
461 return Err(StreamError::Deserialization("i64 needs 8 bytes".into()));
462 }
463 let arr: [u8; 8] = b[..8]
464 .try_into()
465 .map_err(|_| StreamError::Deserialization("i64 slice error".into()))?;
466 Ok(i64::from_le_bytes(arr))
467 }
468
469 fn u64_ser(v: &u64) -> Vec<u8> {
470 v.to_le_bytes().to_vec()
471 }
472
473 fn u64_de(b: &[u8]) -> Result<u64, StreamError> {
474 if b.len() < 8 {
475 return Err(StreamError::Deserialization("u64 needs 8 bytes".into()));
476 }
477 let arr: [u8; 8] = b[..8]
478 .try_into()
479 .map_err(|_| StreamError::Deserialization("u64 slice error".into()))?;
480 Ok(u64::from_le_bytes(arr))
481 }
482
483 fn partition() -> StatePartitionKey {
484 StatePartitionKey::new("op1", 0, 0)
485 }
486
487 #[test]
488 fn test_backend_put_get_delete() {
489 let backend = InMemoryStateBackend::new();
490
491 backend.put(b"hello", b"world").unwrap();
492 let val = backend.get(b"hello").unwrap();
493 assert_eq!(val.as_deref(), Some(b"world".as_ref()));
494
495 let existed = backend.delete(b"hello").unwrap();
496 assert!(existed);
497
498 assert!(backend.get(b"hello").unwrap().is_none());
499
500 let not_found = backend.delete(b"missing").unwrap();
501 assert!(!not_found);
502 }
503
504 #[test]
505 fn test_backend_range_scan() {
506 let backend = InMemoryStateBackend::new();
507
508 backend.put(b"ns:a", b"1").unwrap();
509 backend.put(b"ns:b", b"2").unwrap();
510 backend.put(b"other:c", b"3").unwrap();
511
512 let results = backend.range_scan(b"ns:").unwrap();
513 assert_eq!(results.len(), 2);
514
515 let all = backend.range_scan(b"").unwrap();
516 assert_eq!(all.len(), 3);
517 }
518
519 #[test]
520 fn test_backend_checkpoint_restore() {
521 let backend = InMemoryStateBackend::new();
522
523 backend.put(b"k1", b"v1").unwrap();
524 backend.put(b"k2", b"v2").unwrap();
525
526 let snapshot = backend.checkpoint(42).unwrap();
527 assert!(!snapshot.is_empty());
528
529 backend.delete(b"k1").unwrap();
531 backend.put(b"k2", b"changed").unwrap();
532 backend.put(b"k3", b"new").unwrap();
533
534 backend.restore(&snapshot).unwrap();
536
537 assert_eq!(backend.get(b"k1").unwrap().as_deref(), Some(b"v1".as_ref()));
538 assert_eq!(backend.get(b"k2").unwrap().as_deref(), Some(b"v2".as_ref()));
539 assert!(backend.get(b"k3").unwrap().is_none());
540 }
541
542 #[test]
543 fn test_backend_size_bytes() {
544 let backend = InMemoryStateBackend::new();
545 assert_eq!(backend.size_bytes(), 0);
546
547 backend.put(b"abc", b"def").unwrap();
548 assert_eq!(backend.size_bytes(), 6);
549 }
550
551 #[test]
552 fn test_keyed_state_store_basic() {
553 let backend = Arc::new(InMemoryStateBackend::new());
554 let store: KeyedStateStore<String, i64> =
555 KeyedStateStore::new(partition(), backend, str_key_ser, i64_ser, i64_de);
556
557 let key = "counter".to_string();
558
559 assert!(store.get(&key).unwrap().is_none());
560
561 store.put(&key, 10).unwrap();
562 assert_eq!(store.get(&key).unwrap(), Some(10));
563
564 let new_val = store
565 .update_or_default(&key, |cur| cur.unwrap_or(0) + 5)
566 .unwrap();
567 assert_eq!(new_val, 15);
568 assert_eq!(store.get(&key).unwrap(), Some(15));
569
570 assert!(store.delete(&key).unwrap());
571 assert!(store.get(&key).unwrap().is_none());
572 }
573
574 #[test]
575 fn test_aggregating_state_sum() {
576 let backend = Arc::new(InMemoryStateBackend::new());
577
578 fn combine(acc: u64, x: u64) -> u64 {
579 acc + x
580 }
581
582 let agg: AggregatingState<u64, u64> = AggregatingState::new(
583 partition(),
584 backend,
585 b"total".to_vec(),
586 combine,
587 0u64,
588 u64_ser,
589 u64_de,
590 );
591
592 assert_eq!(agg.get().unwrap(), 0);
593
594 agg.add(10).unwrap();
595 agg.add(20).unwrap();
596 agg.add(5).unwrap();
597
598 assert_eq!(agg.get().unwrap(), 35);
599
600 agg.clear().unwrap();
601 assert_eq!(agg.get().unwrap(), 0);
602 }
603
604 #[test]
605 fn test_partition_namespacing_isolation() {
606 let backend = Arc::new(InMemoryStateBackend::new());
607
608 let p1 = StatePartitionKey::new("op", 0, 0);
609 let p2 = StatePartitionKey::new("op", 0, 1);
610
611 let store1: KeyedStateStore<String, i64> =
612 KeyedStateStore::new(p1, backend.clone(), str_key_ser, i64_ser, i64_de);
613 let store2: KeyedStateStore<String, i64> =
614 KeyedStateStore::new(p2, backend, str_key_ser, i64_ser, i64_de);
615
616 let key = "x".to_string();
617 store1.put(&key, 1).unwrap();
618 store2.put(&key, 2).unwrap();
619
620 assert_eq!(store1.get(&key).unwrap(), Some(1));
621 assert_eq!(store2.get(&key).unwrap(), Some(2));
622 }
623
624 #[test]
625 fn test_snapshot_round_trip_empty() {
626 let backend = InMemoryStateBackend::new();
627 let snapshot = backend.checkpoint(0).unwrap();
628
629 let new_backend = InMemoryStateBackend::new();
630 new_backend.restore(&snapshot).unwrap();
631 assert_eq!(new_backend.size_bytes(), 0);
632 }
633
634 #[test]
635 fn test_decode_snapshot_too_short() {
636 let result = decode_snapshot(b"short");
637 assert!(result.is_err());
638 }
639
640 #[test]
641 fn test_version_bumps_on_write() {
642 let backend = InMemoryStateBackend::new();
643 let v0 = backend.version().unwrap();
644 backend.put(b"k", b"v").unwrap();
645 let v1 = backend.version().unwrap();
646 assert!(v1 > v0);
647 backend.delete(b"k").unwrap();
648 let v2 = backend.version().unwrap();
649 assert!(v2 > v1);
650 }
651
652 #[test]
653 fn test_state_backend_stats() {
654 let backend = InMemoryStateBackend::new();
655 backend.put(b"key", b"value").unwrap();
656 let stats = StateBackendStats::collect(&backend);
657 assert_eq!(stats.size_bytes, 8); }
659}