1use bytes::Bytes;
64use rkyv::{
65 api::high::{HighDeserializer, HighSerializer, HighValidator},
66 bytecheck::CheckBytes,
67 rancor::Error as RkyvError,
68 ser::allocator::ArenaHandle,
69 util::AlignedVec,
70 Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize,
71};
72use std::collections::BTreeMap;
73use std::ops::Bound;
74use std::ops::Range;
75
76fn prefix_successor(prefix: &[u8]) -> Option<Vec<u8>> {
81 if prefix.is_empty() {
82 return None;
83 }
84 let mut successor = prefix.to_vec();
85 while let Some(last) = successor.last_mut() {
87 if *last < 0xFF {
88 *last += 1;
89 return Some(successor);
90 }
91 successor.pop();
92 }
93 None
95}
96
97pub trait StateStore: Send {
119 fn get(&self, key: &[u8]) -> Option<Bytes>;
127
128 fn put(&mut self, key: &[u8], value: &[u8]) -> Result<(), StateError>;
137
138 fn delete(&mut self, key: &[u8]) -> Result<(), StateError>;
146
147 fn prefix_scan<'a>(&'a self, prefix: &'a [u8])
157 -> Box<dyn Iterator<Item = (Bytes, Bytes)> + 'a>;
158
159 fn range_scan<'a>(
169 &'a self,
170 range: Range<&'a [u8]>,
171 ) -> Box<dyn Iterator<Item = (Bytes, Bytes)> + 'a>;
172
173 fn contains(&self, key: &[u8]) -> bool {
177 self.get(key).is_some()
178 }
179
180 fn size_bytes(&self) -> usize;
185
186 fn len(&self) -> usize;
188
189 fn is_empty(&self) -> bool {
191 self.len() == 0
192 }
193
194 fn snapshot(&self) -> StateSnapshot;
205
206 fn restore(&mut self, snapshot: StateSnapshot);
211
212 fn clear(&mut self);
214
215 fn flush(&mut self) -> Result<(), StateError> {
224 Ok(()) }
226
227 fn get_or_insert(&mut self, key: &[u8], default: &[u8]) -> Result<Bytes, StateError> {
235 if let Some(value) = self.get(key) {
236 Ok(value)
237 } else {
238 self.put(key, default)?;
239 Ok(Bytes::copy_from_slice(default))
240 }
241 }
242}
243
244pub trait StateStoreExt: StateStore {
268 fn get_typed<T>(&self, key: &[u8]) -> Result<Option<T>, StateError>
277 where
278 T: Archive,
279 T::Archived: for<'a> CheckBytes<HighValidator<'a, RkyvError>>
280 + RkyvDeserialize<T, HighDeserializer<RkyvError>>,
281 {
282 match self.get(key) {
283 Some(bytes) => {
284 let archived = rkyv::access::<T::Archived, RkyvError>(&bytes)
285 .map_err(|e| StateError::Serialization(e.to_string()))?;
286 let value = rkyv::deserialize::<T, RkyvError>(archived)
287 .map_err(|e| StateError::Serialization(e.to_string()))?;
288 Ok(Some(value))
289 }
290 None => Ok(None),
291 }
292 }
293
294 fn put_typed<T>(&mut self, key: &[u8], value: &T) -> Result<(), StateError>
302 where
303 T: for<'a> RkyvSerialize<HighSerializer<AlignedVec, ArenaHandle<'a>, RkyvError>>,
304 {
305 let bytes = rkyv::to_bytes::<RkyvError>(value)
306 .map_err(|e| StateError::Serialization(e.to_string()))?;
307 self.put(key, &bytes)
308 }
309
310 fn update<F>(&mut self, key: &[u8], f: F) -> Result<(), StateError>
319 where
320 F: FnOnce(Option<Bytes>) -> Option<Vec<u8>>,
321 {
322 let current = self.get(key);
323 match f(current) {
324 Some(new_value) => self.put(key, &new_value),
325 None => self.delete(key),
326 }
327 }
328}
329
330impl<T: StateStore + ?Sized> StateStoreExt for T {}
332
333#[derive(Debug, Clone, Archive, RkyvSerialize, RkyvDeserialize)]
340pub struct StateSnapshot {
341 data: Vec<(Vec<u8>, Vec<u8>)>,
343 timestamp_ns: u64,
345 version: u32,
347}
348
349impl StateSnapshot {
350 #[must_use]
352 #[allow(clippy::cast_possible_truncation)]
353 pub fn new(data: Vec<(Vec<u8>, Vec<u8>)>) -> Self {
354 Self {
355 data,
356 timestamp_ns: std::time::SystemTime::now()
358 .duration_since(std::time::UNIX_EPOCH)
359 .map(|d| d.as_nanos() as u64)
360 .unwrap_or(0),
361 version: 1,
362 }
363 }
364
365 #[must_use]
367 pub fn data(&self) -> &[(Vec<u8>, Vec<u8>)] {
368 &self.data
369 }
370
371 #[must_use]
373 pub fn timestamp_ns(&self) -> u64 {
374 self.timestamp_ns
375 }
376
377 #[must_use]
379 pub fn len(&self) -> usize {
380 self.data.len()
381 }
382
383 #[must_use]
385 pub fn is_empty(&self) -> bool {
386 self.data.is_empty()
387 }
388
389 #[must_use]
391 pub fn size_bytes(&self) -> usize {
392 self.data.iter().map(|(k, v)| k.len() + v.len()).sum()
393 }
394
395 pub fn to_bytes(&self) -> Result<AlignedVec, StateError> {
403 rkyv::to_bytes::<RkyvError>(self).map_err(|e| StateError::Serialization(e.to_string()))
404 }
405
406 pub fn from_bytes(bytes: &[u8]) -> Result<Self, StateError> {
414 let archived = rkyv::access::<<Self as Archive>::Archived, RkyvError>(bytes)
415 .map_err(|e| StateError::Serialization(e.to_string()))?;
416 rkyv::deserialize::<Self, RkyvError>(archived)
417 .map_err(|e| StateError::Serialization(e.to_string()))
418 }
419}
420
421pub struct InMemoryStore {
440 data: BTreeMap<Vec<u8>, Bytes>,
442 size_bytes: usize,
444}
445
446impl InMemoryStore {
447 #[must_use]
449 pub fn new() -> Self {
450 Self {
451 data: BTreeMap::new(),
452 size_bytes: 0,
453 }
454 }
455
456 #[must_use]
461 pub fn with_capacity(_capacity: usize) -> Self {
462 Self::new()
463 }
464
465 #[must_use]
470 pub fn capacity(&self) -> usize {
471 self.data.len()
472 }
473
474 pub fn shrink_to_fit(&mut self) {
479 }
481}
482
483impl Default for InMemoryStore {
484 fn default() -> Self {
485 Self::new()
486 }
487}
488
489impl StateStore for InMemoryStore {
490 #[inline]
491 fn get(&self, key: &[u8]) -> Option<Bytes> {
492 self.data.get(key).cloned()
493 }
494
495 #[inline]
496 fn put(&mut self, key: &[u8], value: &[u8]) -> Result<(), StateError> {
497 let value_bytes = Bytes::copy_from_slice(value);
498
499 match self.data.entry(key.to_vec()) {
501 std::collections::btree_map::Entry::Occupied(mut entry) => {
502 self.size_bytes -= entry.get().len();
503 self.size_bytes += value.len();
504 *entry.get_mut() = value_bytes;
505 }
506 std::collections::btree_map::Entry::Vacant(entry) => {
507 self.size_bytes += key.len() + value.len();
508 entry.insert(value_bytes);
509 }
510 }
511 Ok(())
512 }
513
514 fn delete(&mut self, key: &[u8]) -> Result<(), StateError> {
515 if let Some(old_value) = self.data.remove(key) {
516 self.size_bytes -= key.len() + old_value.len();
517 }
518 Ok(())
519 }
520
521 fn prefix_scan<'a>(
522 &'a self,
523 prefix: &'a [u8],
524 ) -> Box<dyn Iterator<Item = (Bytes, Bytes)> + 'a> {
525 if prefix.is_empty() {
526 return Box::new(
528 self.data
529 .iter()
530 .map(|(k, v)| (Bytes::copy_from_slice(k), v.clone())),
531 );
532 }
533 if let Some(end) = prefix_successor(prefix) {
534 Box::new(
535 self.data
536 .range::<[u8], _>((Bound::Included(prefix), Bound::Excluded(end.as_slice())))
537 .map(|(k, v)| (Bytes::copy_from_slice(k), v.clone())),
538 )
539 } else {
540 Box::new(
542 self.data
543 .range::<[u8], _>((Bound::Included(prefix), Bound::Unbounded))
544 .map(|(k, v)| (Bytes::copy_from_slice(k), v.clone())),
545 )
546 }
547 }
548
549 fn range_scan<'a>(
550 &'a self,
551 range: Range<&'a [u8]>,
552 ) -> Box<dyn Iterator<Item = (Bytes, Bytes)> + 'a> {
553 Box::new(
554 self.data
555 .range::<[u8], _>((Bound::Included(range.start), Bound::Excluded(range.end)))
556 .map(|(k, v)| (Bytes::copy_from_slice(k), v.clone())),
557 )
558 }
559
560 #[inline]
561 fn contains(&self, key: &[u8]) -> bool {
562 self.data.contains_key(key.as_ref())
563 }
564
565 fn size_bytes(&self) -> usize {
566 self.size_bytes
567 }
568
569 fn len(&self) -> usize {
570 self.data.len()
571 }
572
573 fn snapshot(&self) -> StateSnapshot {
574 let data: Vec<(Vec<u8>, Vec<u8>)> = self
575 .data
576 .iter()
577 .map(|(k, v)| (k.clone(), v.to_vec()))
578 .collect();
579 StateSnapshot::new(data)
580 }
581
582 fn restore(&mut self, snapshot: StateSnapshot) {
583 self.data.clear();
584 self.size_bytes = 0;
585
586 for (key, value) in snapshot.data {
587 self.size_bytes += key.len() + value.len();
588 self.data.insert(key, Bytes::from(value));
589 }
590 }
591
592 fn clear(&mut self) {
593 self.data.clear();
594 self.size_bytes = 0;
595 }
596}
597
598#[derive(Debug, thiserror::Error)]
600pub enum StateError {
601 #[error("I/O error: {0}")]
603 Io(#[from] std::io::Error),
604
605 #[error("Serialization error: {0}")]
607 Serialization(String),
608
609 #[error("Deserialization error: {0}")]
611 Deserialization(String),
612
613 #[error("Corruption error: {0}")]
615 Corruption(String),
616
617 #[error("Operation not supported: {0}")]
619 NotSupported(String),
620
621 #[error("Key not found")]
623 KeyNotFound,
624
625 #[error("Store capacity exceeded: {0}")]
627 CapacityExceeded(String),
628}
629
630mod mmap;
631
632pub mod changelog_aware;
634
635pub use self::StateError as Error;
637pub use changelog_aware::{ChangelogAwareStore, ChangelogSink};
638pub use mmap::MmapStateStore;
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[test]
645 fn test_in_memory_store_basic() {
646 let mut store = InMemoryStore::new();
647
648 store.put(b"key1", b"value1").unwrap();
650 assert_eq!(store.get(b"key1").unwrap(), Bytes::from("value1"));
651 assert_eq!(store.len(), 1);
652
653 store.put(b"key1", b"value2").unwrap();
655 assert_eq!(store.get(b"key1").unwrap(), Bytes::from("value2"));
656 assert_eq!(store.len(), 1);
657
658 store.delete(b"key1").unwrap();
660 assert!(store.get(b"key1").is_none());
661 assert_eq!(store.len(), 0);
662
663 store.delete(b"nonexistent").unwrap();
665 }
666
667 #[test]
668 fn test_contains() {
669 let mut store = InMemoryStore::new();
670 assert!(!store.contains(b"key1"));
671
672 store.put(b"key1", b"value1").unwrap();
673 assert!(store.contains(b"key1"));
674
675 store.delete(b"key1").unwrap();
676 assert!(!store.contains(b"key1"));
677 }
678
679 #[test]
680 fn test_prefix_scan() {
681 let mut store = InMemoryStore::new();
682 store.put(b"prefix:1", b"value1").unwrap();
683 store.put(b"prefix:2", b"value2").unwrap();
684 store.put(b"prefix:10", b"value10").unwrap();
685 store.put(b"other:1", b"value3").unwrap();
686
687 let results: Vec<_> = store.prefix_scan(b"prefix:").collect();
688 assert_eq!(results.len(), 3);
689
690 for (key, _) in &results {
692 assert!(key.starts_with(b"prefix:"));
693 }
694
695 let all: Vec<_> = store.prefix_scan(b"").collect();
697 assert_eq!(all.len(), 4);
698 }
699
700 #[test]
701 fn test_range_scan() {
702 let mut store = InMemoryStore::new();
703 store.put(b"a", b"1").unwrap();
704 store.put(b"b", b"2").unwrap();
705 store.put(b"c", b"3").unwrap();
706 store.put(b"d", b"4").unwrap();
707
708 let results: Vec<_> = store.range_scan(b"b".as_slice()..b"d".as_slice()).collect();
709 assert_eq!(results.len(), 2);
710
711 let keys: Vec<_> = results.iter().map(|(k, _)| k.as_ref()).collect();
712 assert!(keys.contains(&b"b".as_slice()));
713 assert!(keys.contains(&b"c".as_slice()));
714 }
715
716 #[test]
717 fn test_snapshot_and_restore() {
718 let mut store = InMemoryStore::new();
719 store.put(b"key1", b"value1").unwrap();
720 store.put(b"key2", b"value2").unwrap();
721
722 let snapshot = store.snapshot();
724 assert_eq!(snapshot.len(), 2);
725
726 store.put(b"key1", b"modified").unwrap();
728 store.put(b"key3", b"value3").unwrap();
729 store.delete(b"key2").unwrap();
730
731 assert_eq!(store.len(), 2);
732 assert_eq!(store.get(b"key1").unwrap(), Bytes::from("modified"));
733
734 store.restore(snapshot);
736
737 assert_eq!(store.len(), 2);
738 assert_eq!(store.get(b"key1").unwrap(), Bytes::from("value1"));
739 assert_eq!(store.get(b"key2").unwrap(), Bytes::from("value2"));
740 assert!(store.get(b"key3").is_none());
741 }
742
743 #[test]
744 fn test_snapshot_serialization() {
745 let mut store = InMemoryStore::new();
746 store.put(b"key1", b"value1").unwrap();
747 store.put(b"key2", b"value2").unwrap();
748
749 let snapshot = store.snapshot();
750
751 let bytes = snapshot.to_bytes().unwrap();
753 let restored = StateSnapshot::from_bytes(&bytes).unwrap();
754
755 assert_eq!(restored.len(), snapshot.len());
756 assert_eq!(restored.data(), snapshot.data());
757 }
758
759 #[test]
760 fn test_typed_access() {
761 let mut store = InMemoryStore::new();
762
763 store.put_typed(b"count", &42u64).unwrap();
765 let count: u64 = store.get_typed(b"count").unwrap().unwrap();
766 assert_eq!(count, 42);
767
768 store.put_typed(b"name", &String::from("alice")).unwrap();
770 let name: String = store.get_typed(b"name").unwrap().unwrap();
771 assert_eq!(name, "alice");
772
773 let nums = vec![1i64, 2, 3, 4, 5];
775 store.put_typed(b"nums", &nums).unwrap();
776 let restored: Vec<i64> = store.get_typed(b"nums").unwrap().unwrap();
777 assert_eq!(restored, nums);
778
779 let missing: Option<u64> = store.get_typed(b"missing").unwrap();
781 assert!(missing.is_none());
782 }
783
784 #[test]
785 fn test_get_or_insert() {
786 let mut store = InMemoryStore::new();
787
788 let value = store.get_or_insert(b"key1", b"default").unwrap();
790 assert_eq!(value, Bytes::from("default"));
791 assert_eq!(store.len(), 1);
792
793 store.put(b"key1", b"modified").unwrap();
795 let value = store.get_or_insert(b"key1", b"default").unwrap();
796 assert_eq!(value, Bytes::from("modified"));
797 }
798
799 #[test]
800 fn test_update() {
801 let mut store = InMemoryStore::new();
802 store.put(b"counter", b"\x00\x00\x00\x00").unwrap();
803
804 store
806 .update(b"counter", |current| {
807 let val = current.map_or(0u32, |b| {
808 u32::from_le_bytes(b.as_ref().try_into().unwrap_or([0; 4]))
809 });
810 Some((val + 1).to_le_bytes().to_vec())
811 })
812 .unwrap();
813
814 let bytes = store.get(b"counter").unwrap();
815 let val = u32::from_le_bytes(bytes.as_ref().try_into().unwrap());
816 assert_eq!(val, 1);
817
818 store.update(b"counter", |_| None).unwrap();
820 assert!(store.get(b"counter").is_none());
821 }
822
823 #[test]
824 fn test_size_tracking() {
825 let mut store = InMemoryStore::new();
826 assert_eq!(store.size_bytes(), 0);
827
828 store.put(b"key1", b"value1").unwrap();
829 assert_eq!(store.size_bytes(), 4 + 6); store.put(b"key2", b"value2").unwrap();
832 assert_eq!(store.size_bytes(), (4 + 6) * 2);
833
834 store.put(b"key1", b"v1").unwrap();
836 assert_eq!(store.size_bytes(), 4 + 2 + 4 + 6); store.delete(b"key1").unwrap();
839 assert_eq!(store.size_bytes(), 4 + 6);
840
841 store.clear();
842 assert_eq!(store.size_bytes(), 0);
843 }
844
845 #[test]
846 fn test_with_capacity() {
847 let store = InMemoryStore::with_capacity(1000);
848 assert_eq!(store.capacity(), 0);
850 assert!(store.is_empty());
851 }
852
853 #[test]
854 fn test_clear() {
855 let mut store = InMemoryStore::new();
856 store.put(b"key1", b"value1").unwrap();
857 store.put(b"key2", b"value2").unwrap();
858
859 assert_eq!(store.len(), 2);
860 assert!(store.size_bytes() > 0);
861
862 store.clear();
863
864 assert_eq!(store.len(), 0);
865 assert_eq!(store.size_bytes(), 0);
866 assert!(store.get(b"key1").is_none());
867 }
868
869 #[test]
870 fn test_prefix_successor() {
871 assert_eq!(prefix_successor(b"abc"), Some(b"abd".to_vec()));
873
874 assert_eq!(prefix_successor(b""), None);
876
877 assert_eq!(prefix_successor(&[0xFF, 0xFF, 0xFF]), None);
879
880 assert_eq!(prefix_successor(&[0x01, 0xFF]), Some(vec![0x02]));
882 assert_eq!(
883 prefix_successor(&[0x01, 0x02, 0xFF]),
884 Some(vec![0x01, 0x03])
885 );
886
887 assert_eq!(prefix_successor(&[0x00]), Some(vec![0x01]));
889 assert_eq!(prefix_successor(&[0xFE]), Some(vec![0xFF]));
890 assert_eq!(prefix_successor(&[0xFF]), None);
891 }
892
893 #[test]
894 fn test_prefix_scan_binary_keys() {
895 let mut store = InMemoryStore::new();
896
897 let prefix_a = [0x00, 0x01]; let prefix_b = [0x00, 0x02]; store.put(&[0x00, 0x01, 0xAA], b"val1").unwrap();
902 store.put(&[0x00, 0x01, 0xBB], b"val2").unwrap();
903 store.put(&[0x00, 0x02, 0xCC], b"val3").unwrap();
904 store.put(&[0x00, 0x02, 0xDD], b"val4").unwrap();
905 store.put(&[0x01, 0x01, 0xEE], b"val5").unwrap();
906
907 let results_a: Vec<_> = store.prefix_scan(&prefix_a).collect();
909 assert_eq!(results_a.len(), 2);
910 for (key, _) in &results_a {
911 assert!(key.starts_with(&prefix_a));
912 }
913
914 let results_b: Vec<_> = store.prefix_scan(&prefix_b).collect();
916 assert_eq!(results_b.len(), 2);
917 for (key, _) in &results_b {
918 assert!(key.starts_with(&prefix_b));
919 }
920
921 let results_ff: Vec<_> = store.prefix_scan(&[0xFF, 0xFF]).collect();
923 assert_eq!(results_ff.len(), 0);
924 }
925
926 #[test]
927 fn test_prefix_scan_returns_sorted() {
928 let mut store = InMemoryStore::new();
929 store.put(b"prefix:c", b"3").unwrap();
930 store.put(b"prefix:a", b"1").unwrap();
931 store.put(b"prefix:b", b"2").unwrap();
932
933 let results: Vec<_> = store.prefix_scan(b"prefix:").collect();
934 let keys: Vec<_> = results.iter().map(|(k, _)| k.as_ref().to_vec()).collect();
935 assert_eq!(
936 keys,
937 vec![
938 b"prefix:a".to_vec(),
939 b"prefix:b".to_vec(),
940 b"prefix:c".to_vec()
941 ]
942 );
943 }
944}