1use bytes::Bytes;
64use rkyv::{
65 api::high::{HighDeserializer, HighSerializer, HighValidator},
66 bytecheck::CheckBytes,
67 rancor::Error as RkyvError,
68 ser::allocator::ArenaHandle,
69 util::AlignedVec,
70 Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize,
71};
72use std::collections::BTreeMap;
73use std::ops::Bound;
74use std::ops::Range;
75
76fn prefix_successor(prefix: &[u8]) -> Option<Vec<u8>> {
81 if prefix.is_empty() {
82 return None;
83 }
84 let mut successor = prefix.to_vec();
85 while let Some(last) = successor.last_mut() {
87 if *last < 0xFF {
88 *last += 1;
89 return Some(successor);
90 }
91 successor.pop();
92 }
93 None
95}
96
97pub trait StateStore: Send {
119 fn get(&self, key: &[u8]) -> Option<Bytes>;
127
128 fn put(&mut self, key: &[u8], value: &[u8]) -> Result<(), StateError>;
137
138 fn delete(&mut self, key: &[u8]) -> Result<(), StateError>;
146
147 fn prefix_scan<'a>(&'a self, prefix: &'a [u8])
157 -> Box<dyn Iterator<Item = (Bytes, Bytes)> + 'a>;
158
159 fn range_scan<'a>(
169 &'a self,
170 range: Range<&'a [u8]>,
171 ) -> Box<dyn Iterator<Item = (Bytes, Bytes)> + 'a>;
172
173 fn contains(&self, key: &[u8]) -> bool {
177 self.get(key).is_some()
178 }
179
180 fn size_bytes(&self) -> usize;
185
186 fn len(&self) -> usize;
188
189 fn is_empty(&self) -> bool {
191 self.len() == 0
192 }
193
194 fn snapshot(&self) -> StateSnapshot;
205
206 fn restore(&mut self, snapshot: StateSnapshot);
211
212 fn clear(&mut self);
214
215 fn flush(&mut self) -> Result<(), StateError> {
224 Ok(()) }
226
227 fn get_or_insert(&mut self, key: &[u8], default: &[u8]) -> Result<Bytes, StateError> {
235 if let Some(value) = self.get(key) {
236 Ok(value)
237 } else {
238 self.put(key, default)?;
239 Ok(Bytes::copy_from_slice(default))
240 }
241 }
242}
243
244pub trait StateStoreExt: StateStore {
268 fn get_typed<T>(&self, key: &[u8]) -> Result<Option<T>, StateError>
277 where
278 T: Archive,
279 T::Archived: for<'a> CheckBytes<HighValidator<'a, RkyvError>>
280 + RkyvDeserialize<T, HighDeserializer<RkyvError>>,
281 {
282 match self.get(key) {
283 Some(bytes) => {
284 let archived = rkyv::access::<T::Archived, RkyvError>(&bytes)
285 .map_err(|e| StateError::Serialization(e.to_string()))?;
286 let value = rkyv::deserialize::<T, RkyvError>(archived)
287 .map_err(|e| StateError::Serialization(e.to_string()))?;
288 Ok(Some(value))
289 }
290 None => Ok(None),
291 }
292 }
293
294 fn put_typed<T>(&mut self, key: &[u8], value: &T) -> Result<(), StateError>
302 where
303 T: for<'a> RkyvSerialize<HighSerializer<AlignedVec, ArenaHandle<'a>, RkyvError>>,
304 {
305 let bytes = rkyv::to_bytes::<RkyvError>(value)
306 .map_err(|e| StateError::Serialization(e.to_string()))?;
307 self.put(key, &bytes)
308 }
309
310 fn update<F>(&mut self, key: &[u8], f: F) -> Result<(), StateError>
319 where
320 F: FnOnce(Option<Bytes>) -> Option<Vec<u8>>,
321 {
322 let current = self.get(key);
323 match f(current) {
324 Some(new_value) => self.put(key, &new_value),
325 None => self.delete(key),
326 }
327 }
328}
329
330impl<T: StateStore + ?Sized> StateStoreExt for T {}
332
333#[derive(Debug, Clone, Archive, RkyvSerialize, RkyvDeserialize)]
340pub struct StateSnapshot {
341 data: Vec<(Vec<u8>, Vec<u8>)>,
343 timestamp_ns: u64,
345 version: u32,
347}
348
349impl StateSnapshot {
350 #[must_use]
352 #[allow(clippy::cast_possible_truncation)]
353 pub fn new(data: Vec<(Vec<u8>, Vec<u8>)>) -> Self {
354 Self {
355 data,
356 timestamp_ns: std::time::SystemTime::now()
358 .duration_since(std::time::UNIX_EPOCH)
359 .map(|d| d.as_nanos() as u64)
360 .unwrap_or(0),
361 version: 1,
362 }
363 }
364
365 #[must_use]
367 pub fn data(&self) -> &[(Vec<u8>, Vec<u8>)] {
368 &self.data
369 }
370
371 #[must_use]
373 pub fn timestamp_ns(&self) -> u64 {
374 self.timestamp_ns
375 }
376
377 #[must_use]
379 pub fn len(&self) -> usize {
380 self.data.len()
381 }
382
383 #[must_use]
385 pub fn is_empty(&self) -> bool {
386 self.data.is_empty()
387 }
388
389 #[must_use]
391 pub fn size_bytes(&self) -> usize {
392 self.data.iter().map(|(k, v)| k.len() + v.len()).sum()
393 }
394
395 pub fn to_bytes(&self) -> Result<AlignedVec, StateError> {
403 rkyv::to_bytes::<RkyvError>(self).map_err(|e| StateError::Serialization(e.to_string()))
404 }
405
406 pub fn from_bytes(bytes: &[u8]) -> Result<Self, StateError> {
414 let archived = rkyv::access::<<Self as Archive>::Archived, RkyvError>(bytes)
415 .map_err(|e| StateError::Serialization(e.to_string()))?;
416 rkyv::deserialize::<Self, RkyvError>(archived)
417 .map_err(|e| StateError::Serialization(e.to_string()))
418 }
419}
420
421pub struct InMemoryStore {
440 data: BTreeMap<Vec<u8>, Bytes>,
442 size_bytes: usize,
444}
445
446impl InMemoryStore {
447 #[must_use]
449 pub fn new() -> Self {
450 Self {
451 data: BTreeMap::new(),
452 size_bytes: 0,
453 }
454 }
455
456 #[must_use]
461 pub fn with_capacity(_capacity: usize) -> Self {
462 Self::new()
463 }
464
465 #[must_use]
470 pub fn capacity(&self) -> usize {
471 self.data.len()
472 }
473
474 pub fn shrink_to_fit(&mut self) {
479 }
481}
482
483impl Default for InMemoryStore {
484 fn default() -> Self {
485 Self::new()
486 }
487}
488
489impl StateStore for InMemoryStore {
490 #[inline]
491 fn get(&self, key: &[u8]) -> Option<Bytes> {
492 self.data.get(key).cloned()
493 }
494
495 #[inline]
496 fn put(&mut self, key: &[u8], value: &[u8]) -> Result<(), StateError> {
497 let value_bytes = Bytes::copy_from_slice(value);
498
499 match self.data.entry(key.to_vec()) {
501 std::collections::btree_map::Entry::Occupied(mut entry) => {
502 self.size_bytes -= entry.get().len();
503 self.size_bytes += value.len();
504 *entry.get_mut() = value_bytes;
505 }
506 std::collections::btree_map::Entry::Vacant(entry) => {
507 self.size_bytes += key.len() + value.len();
508 entry.insert(value_bytes);
509 }
510 }
511 Ok(())
512 }
513
514 fn delete(&mut self, key: &[u8]) -> Result<(), StateError> {
515 if let Some(old_value) = self.data.remove(key) {
516 self.size_bytes -= key.len() + old_value.len();
517 }
518 Ok(())
519 }
520
521 fn prefix_scan<'a>(
522 &'a self,
523 prefix: &'a [u8],
524 ) -> Box<dyn Iterator<Item = (Bytes, Bytes)> + 'a> {
525 if prefix.is_empty() {
526 return Box::new(
528 self.data
529 .iter()
530 .map(|(k, v)| (Bytes::copy_from_slice(k), v.clone())),
531 );
532 }
533 if let Some(end) = prefix_successor(prefix) {
534 Box::new(
535 self.data
536 .range::<[u8], _>((Bound::Included(prefix), Bound::Excluded(end.as_slice())))
537 .map(|(k, v)| (Bytes::copy_from_slice(k), v.clone())),
538 )
539 } else {
540 Box::new(
542 self.data
543 .range::<[u8], _>((Bound::Included(prefix), Bound::Unbounded))
544 .map(|(k, v)| (Bytes::copy_from_slice(k), v.clone())),
545 )
546 }
547 }
548
549 fn range_scan<'a>(
550 &'a self,
551 range: Range<&'a [u8]>,
552 ) -> Box<dyn Iterator<Item = (Bytes, Bytes)> + 'a> {
553 Box::new(
554 self.data
555 .range::<[u8], _>((Bound::Included(range.start), Bound::Excluded(range.end)))
556 .map(|(k, v)| (Bytes::copy_from_slice(k), v.clone())),
557 )
558 }
559
560 #[inline]
561 fn contains(&self, key: &[u8]) -> bool {
562 self.data.contains_key(key.as_ref())
563 }
564
565 fn size_bytes(&self) -> usize {
566 self.size_bytes
567 }
568
569 fn len(&self) -> usize {
570 self.data.len()
571 }
572
573 fn snapshot(&self) -> StateSnapshot {
574 let data: Vec<(Vec<u8>, Vec<u8>)> = self
575 .data
576 .iter()
577 .map(|(k, v)| (k.clone(), v.to_vec()))
578 .collect();
579 StateSnapshot::new(data)
580 }
581
582 fn restore(&mut self, snapshot: StateSnapshot) {
583 self.data.clear();
584 self.size_bytes = 0;
585
586 for (key, value) in snapshot.data {
587 self.size_bytes += key.len() + value.len();
588 self.data.insert(key, Bytes::from(value));
589 }
590 }
591
592 fn clear(&mut self) {
593 self.data.clear();
594 self.size_bytes = 0;
595 }
596}
597
598#[derive(Debug, thiserror::Error)]
600pub enum StateError {
601 #[error("I/O error: {0}")]
603 Io(#[from] std::io::Error),
604
605 #[error("Serialization error: {0}")]
607 Serialization(String),
608
609 #[error("Deserialization error: {0}")]
611 Deserialization(String),
612
613 #[error("Corruption error: {0}")]
615 Corruption(String),
616
617 #[error("Operation not supported: {0}")]
619 NotSupported(String),
620
621 #[error("Key not found")]
623 KeyNotFound,
624
625 #[error("Store capacity exceeded: {0}")]
627 CapacityExceeded(String),
628}
629
630mod mmap;
631
632pub use self::StateError as Error;
634pub use mmap::MmapStateStore;
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
641 fn test_in_memory_store_basic() {
642 let mut store = InMemoryStore::new();
643
644 store.put(b"key1", b"value1").unwrap();
646 assert_eq!(store.get(b"key1").unwrap(), Bytes::from("value1"));
647 assert_eq!(store.len(), 1);
648
649 store.put(b"key1", b"value2").unwrap();
651 assert_eq!(store.get(b"key1").unwrap(), Bytes::from("value2"));
652 assert_eq!(store.len(), 1);
653
654 store.delete(b"key1").unwrap();
656 assert!(store.get(b"key1").is_none());
657 assert_eq!(store.len(), 0);
658
659 store.delete(b"nonexistent").unwrap();
661 }
662
663 #[test]
664 fn test_contains() {
665 let mut store = InMemoryStore::new();
666 assert!(!store.contains(b"key1"));
667
668 store.put(b"key1", b"value1").unwrap();
669 assert!(store.contains(b"key1"));
670
671 store.delete(b"key1").unwrap();
672 assert!(!store.contains(b"key1"));
673 }
674
675 #[test]
676 fn test_prefix_scan() {
677 let mut store = InMemoryStore::new();
678 store.put(b"prefix:1", b"value1").unwrap();
679 store.put(b"prefix:2", b"value2").unwrap();
680 store.put(b"prefix:10", b"value10").unwrap();
681 store.put(b"other:1", b"value3").unwrap();
682
683 let results: Vec<_> = store.prefix_scan(b"prefix:").collect();
684 assert_eq!(results.len(), 3);
685
686 for (key, _) in &results {
688 assert!(key.starts_with(b"prefix:"));
689 }
690
691 let all: Vec<_> = store.prefix_scan(b"").collect();
693 assert_eq!(all.len(), 4);
694 }
695
696 #[test]
697 fn test_range_scan() {
698 let mut store = InMemoryStore::new();
699 store.put(b"a", b"1").unwrap();
700 store.put(b"b", b"2").unwrap();
701 store.put(b"c", b"3").unwrap();
702 store.put(b"d", b"4").unwrap();
703
704 let results: Vec<_> = store.range_scan(b"b".as_slice()..b"d".as_slice()).collect();
705 assert_eq!(results.len(), 2);
706
707 let keys: Vec<_> = results.iter().map(|(k, _)| k.as_ref()).collect();
708 assert!(keys.contains(&b"b".as_slice()));
709 assert!(keys.contains(&b"c".as_slice()));
710 }
711
712 #[test]
713 fn test_snapshot_and_restore() {
714 let mut store = InMemoryStore::new();
715 store.put(b"key1", b"value1").unwrap();
716 store.put(b"key2", b"value2").unwrap();
717
718 let snapshot = store.snapshot();
720 assert_eq!(snapshot.len(), 2);
721
722 store.put(b"key1", b"modified").unwrap();
724 store.put(b"key3", b"value3").unwrap();
725 store.delete(b"key2").unwrap();
726
727 assert_eq!(store.len(), 2);
728 assert_eq!(store.get(b"key1").unwrap(), Bytes::from("modified"));
729
730 store.restore(snapshot);
732
733 assert_eq!(store.len(), 2);
734 assert_eq!(store.get(b"key1").unwrap(), Bytes::from("value1"));
735 assert_eq!(store.get(b"key2").unwrap(), Bytes::from("value2"));
736 assert!(store.get(b"key3").is_none());
737 }
738
739 #[test]
740 fn test_snapshot_serialization() {
741 let mut store = InMemoryStore::new();
742 store.put(b"key1", b"value1").unwrap();
743 store.put(b"key2", b"value2").unwrap();
744
745 let snapshot = store.snapshot();
746
747 let bytes = snapshot.to_bytes().unwrap();
749 let restored = StateSnapshot::from_bytes(&bytes).unwrap();
750
751 assert_eq!(restored.len(), snapshot.len());
752 assert_eq!(restored.data(), snapshot.data());
753 }
754
755 #[test]
756 fn test_typed_access() {
757 let mut store = InMemoryStore::new();
758
759 store.put_typed(b"count", &42u64).unwrap();
761 let count: u64 = store.get_typed(b"count").unwrap().unwrap();
762 assert_eq!(count, 42);
763
764 store.put_typed(b"name", &String::from("alice")).unwrap();
766 let name: String = store.get_typed(b"name").unwrap().unwrap();
767 assert_eq!(name, "alice");
768
769 let nums = vec![1i64, 2, 3, 4, 5];
771 store.put_typed(b"nums", &nums).unwrap();
772 let restored: Vec<i64> = store.get_typed(b"nums").unwrap().unwrap();
773 assert_eq!(restored, nums);
774
775 let missing: Option<u64> = store.get_typed(b"missing").unwrap();
777 assert!(missing.is_none());
778 }
779
780 #[test]
781 fn test_get_or_insert() {
782 let mut store = InMemoryStore::new();
783
784 let value = store.get_or_insert(b"key1", b"default").unwrap();
786 assert_eq!(value, Bytes::from("default"));
787 assert_eq!(store.len(), 1);
788
789 store.put(b"key1", b"modified").unwrap();
791 let value = store.get_or_insert(b"key1", b"default").unwrap();
792 assert_eq!(value, Bytes::from("modified"));
793 }
794
795 #[test]
796 fn test_update() {
797 let mut store = InMemoryStore::new();
798 store.put(b"counter", b"\x00\x00\x00\x00").unwrap();
799
800 store
802 .update(b"counter", |current| {
803 let val = current.map_or(0u32, |b| {
804 u32::from_le_bytes(b.as_ref().try_into().unwrap_or([0; 4]))
805 });
806 Some((val + 1).to_le_bytes().to_vec())
807 })
808 .unwrap();
809
810 let bytes = store.get(b"counter").unwrap();
811 let val = u32::from_le_bytes(bytes.as_ref().try_into().unwrap());
812 assert_eq!(val, 1);
813
814 store.update(b"counter", |_| None).unwrap();
816 assert!(store.get(b"counter").is_none());
817 }
818
819 #[test]
820 fn test_size_tracking() {
821 let mut store = InMemoryStore::new();
822 assert_eq!(store.size_bytes(), 0);
823
824 store.put(b"key1", b"value1").unwrap();
825 assert_eq!(store.size_bytes(), 4 + 6); store.put(b"key2", b"value2").unwrap();
828 assert_eq!(store.size_bytes(), (4 + 6) * 2);
829
830 store.put(b"key1", b"v1").unwrap();
832 assert_eq!(store.size_bytes(), 4 + 2 + 4 + 6); store.delete(b"key1").unwrap();
835 assert_eq!(store.size_bytes(), 4 + 6);
836
837 store.clear();
838 assert_eq!(store.size_bytes(), 0);
839 }
840
841 #[test]
842 fn test_with_capacity() {
843 let store = InMemoryStore::with_capacity(1000);
844 assert_eq!(store.capacity(), 0);
846 assert!(store.is_empty());
847 }
848
849 #[test]
850 fn test_clear() {
851 let mut store = InMemoryStore::new();
852 store.put(b"key1", b"value1").unwrap();
853 store.put(b"key2", b"value2").unwrap();
854
855 assert_eq!(store.len(), 2);
856 assert!(store.size_bytes() > 0);
857
858 store.clear();
859
860 assert_eq!(store.len(), 0);
861 assert_eq!(store.size_bytes(), 0);
862 assert!(store.get(b"key1").is_none());
863 }
864
865 #[test]
866 fn test_prefix_successor() {
867 assert_eq!(prefix_successor(b"abc"), Some(b"abd".to_vec()));
869
870 assert_eq!(prefix_successor(b""), None);
872
873 assert_eq!(prefix_successor(&[0xFF, 0xFF, 0xFF]), None);
875
876 assert_eq!(prefix_successor(&[0x01, 0xFF]), Some(vec![0x02]));
878 assert_eq!(
879 prefix_successor(&[0x01, 0x02, 0xFF]),
880 Some(vec![0x01, 0x03])
881 );
882
883 assert_eq!(prefix_successor(&[0x00]), Some(vec![0x01]));
885 assert_eq!(prefix_successor(&[0xFE]), Some(vec![0xFF]));
886 assert_eq!(prefix_successor(&[0xFF]), None);
887 }
888
889 #[test]
890 fn test_prefix_scan_binary_keys() {
891 let mut store = InMemoryStore::new();
892
893 let prefix_a = [0x00, 0x01]; let prefix_b = [0x00, 0x02]; store.put(&[0x00, 0x01, 0xAA], b"val1").unwrap();
898 store.put(&[0x00, 0x01, 0xBB], b"val2").unwrap();
899 store.put(&[0x00, 0x02, 0xCC], b"val3").unwrap();
900 store.put(&[0x00, 0x02, 0xDD], b"val4").unwrap();
901 store.put(&[0x01, 0x01, 0xEE], b"val5").unwrap();
902
903 let results_a: Vec<_> = store.prefix_scan(&prefix_a).collect();
905 assert_eq!(results_a.len(), 2);
906 for (key, _) in &results_a {
907 assert!(key.starts_with(&prefix_a));
908 }
909
910 let results_b: Vec<_> = store.prefix_scan(&prefix_b).collect();
912 assert_eq!(results_b.len(), 2);
913 for (key, _) in &results_b {
914 assert!(key.starts_with(&prefix_b));
915 }
916
917 let results_ff: Vec<_> = store.prefix_scan(&[0xFF, 0xFF]).collect();
919 assert_eq!(results_ff.len(), 0);
920 }
921
922 #[test]
923 fn test_prefix_scan_returns_sorted() {
924 let mut store = InMemoryStore::new();
925 store.put(b"prefix:c", b"3").unwrap();
926 store.put(b"prefix:a", b"1").unwrap();
927 store.put(b"prefix:b", b"2").unwrap();
928
929 let results: Vec<_> = store.prefix_scan(b"prefix:").collect();
930 let keys: Vec<_> = results.iter().map(|(k, _)| k.as_ref().to_vec()).collect();
931 assert_eq!(
932 keys,
933 vec![
934 b"prefix:a".to_vec(),
935 b"prefix:b".to_vec(),
936 b"prefix:c".to_vec()
937 ]
938 );
939 }
940}