libzeropool_rs/
sparse_array.rs

1use std::{convert::TryFrom, marker::PhantomData, ops::RangeBounds};
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use kvdb::{DBTransaction, KeyValueDB};
5use kvdb_memorydb::InMemory as MemoryDatabase;
6#[cfg(feature = "native")]
7use kvdb_persy::PersyDatabase as NativeDatabase;
8#[cfg(feature = "web")]
9use kvdb_web::Database as WebDatabase;
10
11/// A persistent sparse array built on top of kvdb
12pub struct SparseArray<D: KeyValueDB, T: BorshSerialize + BorshDeserialize> {
13    pub db: D,
14    _phantom: PhantomData<T>,
15}
16
17#[cfg(feature = "web")]
18pub type WebSparseArray<T> = SparseArray<WebDatabase, T>;
19
20#[cfg(feature = "native")]
21pub type NativeSparseArray<T> = SparseArray<NativeDatabase, T>;
22
23#[cfg(feature = "web")]
24impl<T> SparseArray<WebDatabase, T>
25where
26    T: BorshSerialize + BorshDeserialize,
27{
28    pub async fn new_web(name: &str) -> SparseArray<WebDatabase, T> {
29        let db = WebDatabase::open(name.to_owned(), 1).await.unwrap();
30
31        SparseArray {
32            db,
33            _phantom: Default::default(),
34        }
35    }
36}
37
38#[cfg(feature = "native")]
39impl<T> SparseArray<NativeDatabase, T>
40where
41    T: BorshSerialize + BorshDeserialize,
42{
43    pub fn new_native(path: &str) -> std::io::Result<SparseArray<NativeDatabase, T>> {
44        let db = NativeDatabase::open(path, 1, &[])?;
45
46        Ok(SparseArray {
47            db,
48            _phantom: Default::default(),
49        })
50    }
51}
52
53impl<T> SparseArray<MemoryDatabase, T>
54where
55    T: BorshSerialize + BorshDeserialize,
56{
57    pub fn new_test() -> SparseArray<MemoryDatabase, T> {
58        let db = kvdb_memorydb::create(1);
59
60        SparseArray {
61            db,
62            _phantom: Default::default(),
63        }
64    }
65}
66
67impl<D: KeyValueDB, T> SparseArray<D, T>
68where
69    D: KeyValueDB,
70    T: BorshSerialize + BorshDeserialize + 'static,
71{
72    pub fn new(db: D) -> SparseArray<D, T> {
73        SparseArray {
74            db,
75            _phantom: Default::default(),
76        }
77    }
78
79    pub fn get(&self, index: u64) -> Option<T> {
80        let key = index.to_be_bytes();
81
82        self.db
83            .get(0, &key)
84            .unwrap()
85            .map(|data| T::try_from_slice(data.as_slice()).unwrap())
86    }
87
88    pub fn iter(&self) -> SparseArrayIter<T> {
89        SparseArrayIter {
90            inner: Box::new(self.db.iter(0).map(|res| res.unwrap())),
91            _phantom: Default::default(),
92        }
93    }
94
95    pub fn iter_slice<R>(&self, range: R) -> impl Iterator<Item = (u64, T)> + '_
96    where
97        R: RangeBounds<u64> + 'static,
98    {
99        self.iter().filter(move |(index, _)| range.contains(index))
100    }
101
102    pub fn set(&self, index: u64, data: &T) {
103        let mut batch = self.db.transaction();
104        self.set_batched(index, data, &mut batch);
105        self.db.write(batch).unwrap();
106    }
107
108    pub fn remove(&self, index: u64) {
109        let mut batch = self.db.transaction();
110        let key = index.to_be_bytes();
111        batch.delete(0, &key);
112        self.db.write(batch).unwrap();
113    }
114
115    pub fn remove_all_after(&self, index: u64) {
116        let mut batch = self.db.transaction();
117
118        for (index, _) in self.iter_slice(index..) {
119            let key = index.to_be_bytes();
120            batch.delete(0, &key);
121        }
122
123        self.db.write(batch).unwrap();
124    }
125
126    // FIXME: Crazy inefficient, replace or improve kvdb
127    pub fn count(&self) -> usize {
128        self.db.iter(0).count()
129    }
130
131    pub fn set_multiple<'a, I>(&self, items: I)
132    where
133        I: IntoIterator<Item = &'a (u64, T)>,
134    {
135        let mut batch = self.db.transaction();
136
137        for (index, item) in items {
138            self.set_batched(*index, item, &mut batch);
139        }
140
141        self.db.write(batch).unwrap();
142    }
143
144    fn set_batched(&self, index: u64, data: &T, batch: &mut DBTransaction) {
145        let key = index.to_be_bytes();
146        let data = data.try_to_vec().unwrap();
147
148        batch.put(0, &key, &data);
149    }
150}
151
152pub struct SparseArrayIter<'a, T: BorshDeserialize> {
153    inner: Box<dyn Iterator<Item = (smallvec::SmallVec<[u8; 32]>, Vec<u8>)> + 'a>,
154    _phantom: PhantomData<T>,
155}
156
157impl<'a, T: BorshDeserialize> Iterator for SparseArrayIter<'a, T> {
158    type Item = (u64, T);
159
160    fn next(&mut self) -> Option<Self::Item> {
161        self.inner.next().map(|(key, value)| {
162            let key = TryFrom::try_from(key.as_ref()).unwrap();
163            let index = u64::from_be_bytes(key);
164            let data = T::try_from_slice(&value).unwrap();
165
166            (index, data)
167        })
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_sparse_array_iter_slice() {
177        let a = SparseArray::new_test();
178        a.set(1, &1u32);
179        a.set(3, &2);
180        a.set(412345, &3);
181
182        assert_eq!(a.db.iter(0).count(), 3, "inner");
183        assert_eq!(a.iter().count(), 3, "iter");
184
185        assert_eq!(a.iter_slice(0..=412345).count(), 3, "all");
186        assert_eq!(a.iter_slice(1..=412345).count(), 3, "from 1");
187        assert_eq!(a.iter_slice(2..=412345).count(), 2, "from 2");
188        assert_eq!(a.iter_slice(2..=412344).count(), 1, "from 2 except last");
189    }
190}