libzkbob_rs/
sparse_array.rs

1use std::{convert::TryFrom, marker::PhantomData, ops::RangeInclusive};
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use kvdb::{DBTransaction, KeyValueDB};
5use kvdb_memorydb::InMemory as MemoryDatabase;
6#[cfg(feature = "native")]
7use kvdb_rocksdb::{Database as NativeDatabase, DatabaseConfig};
8#[cfg(feature = "web")]
9use kvdb_web::Database as WebDatabase;
10
11/// A persistent sparse array built on top of kvdb
12pub struct SparseArray<D: KeyValueDB, T: BorshSerialize + BorshDeserialize> {
13    pub db: D,
14    _phantom: PhantomData<T>,
15}
16
17#[cfg(feature = "web")]
18pub type WebSparseArray<T> = SparseArray<WebDatabase, T>;
19
20#[cfg(feature = "native")]
21pub type NativeSparseArray<T> = SparseArray<NativeDatabase, T>;
22
23#[cfg(feature = "web")]
24impl<T> SparseArray<WebDatabase, T>
25where
26    T: BorshSerialize + BorshDeserialize,
27{
28    pub async fn new_web(name: &str) -> SparseArray<WebDatabase, T> {
29        let db = WebDatabase::open(name.to_owned(), 1).await.unwrap();
30
31        SparseArray {
32            db,
33            _phantom: Default::default(),
34        }
35    }
36}
37
38#[cfg(feature = "native")]
39impl<T> SparseArray<NativeDatabase, T>
40where
41    T: BorshSerialize + BorshDeserialize,
42{
43    pub fn new_native(
44        config: &DatabaseConfig,
45        path: &str,
46    ) -> std::io::Result<SparseArray<NativeDatabase, T>> {
47        let db = NativeDatabase::open(config, path)?;
48
49        Ok(SparseArray {
50            db,
51            _phantom: Default::default(),
52        })
53    }
54}
55
56impl<T> SparseArray<MemoryDatabase, T>
57where
58    T: BorshSerialize + BorshDeserialize,
59{
60    pub fn new_test() -> SparseArray<MemoryDatabase, T> {
61        let db = kvdb_memorydb::create(1);
62
63        SparseArray {
64            db,
65            _phantom: Default::default(),
66        }
67    }
68}
69
70impl<D: KeyValueDB, T> SparseArray<D, T>
71where
72    D: KeyValueDB,
73    T: BorshSerialize + BorshDeserialize + 'static,
74{
75    pub fn new(db: D) -> SparseArray<D, T> {
76        SparseArray {
77            db,
78            _phantom: Default::default(),
79        }
80    }
81
82    pub fn get(&self, index: u64) -> Option<T> {
83        let key = index.to_be_bytes();
84
85        self.db
86            .get(0, &key)
87            .unwrap()
88            .map(|data| T::try_from_slice(data.as_slice()).unwrap())
89    }
90
91    pub fn iter(&self) -> SparseArrayIter<T> {
92        SparseArrayIter {
93            inner: self.db.iter(0),
94            _phantom: Default::default(),
95        }
96    }
97
98    pub fn iter_slice(&self, range: RangeInclusive<u64>) -> impl Iterator<Item = (u64, T)> + '_ {
99        self.iter().filter(move |(index, _)| range.contains(index))
100    }
101
102    pub fn set(&self, index: u64, data: &T) {
103        let mut batch = self.db.transaction();
104        self.set_batched(index, data, &mut batch);
105        self.db.write(batch).unwrap();
106    }
107
108    pub fn remove(&self, index: u64) {
109        let mut batch = self.db.transaction();
110        let key = index.to_be_bytes();
111        batch.delete(0, &key);
112        self.db.write(batch).unwrap();
113    }
114
115    pub fn remove_from(&self, from_index: u64) {
116        let mut batch = self.db.transaction();
117        for (index, _) in self.iter() {
118            if index >= from_index {
119                let key = index.to_be_bytes();
120                batch.delete(0, &key);
121            }
122        }
123        self.db.write(batch).unwrap();
124    }
125
126    pub fn remove_all(&self) {
127        let mut batch = self.db.transaction();
128        //batch.delete_prefix(0, &[][..]);
129        self.db
130            .iter(0)
131            .for_each(|(key, _)| {
132                batch.delete(0_u32, &key);
133            });
134        self.db.write(batch).unwrap();
135    }
136
137    // FIXME: Crazy inefficient, replace or improve kvdb
138    pub fn count(&self) -> usize {
139        self.db.iter(0).count()
140    }
141
142    pub fn set_multiple<'a, I>(&self, items: I)
143    where
144        I: IntoIterator<Item = &'a (u64, T)>,
145    {
146        let mut batch = self.db.transaction();
147
148        for (index, item) in items {
149            self.set_batched(*index, item, &mut batch);
150        }
151
152        self.db.write(batch).unwrap();
153    }
154
155    fn set_batched(&self, index: u64, data: &T, batch: &mut DBTransaction) {
156        let key = index.to_be_bytes();
157        let data = data.try_to_vec().unwrap();
158
159        batch.put(0, &key, &data);
160    }
161}
162
163pub struct SparseArrayIter<'a, T: BorshDeserialize> {
164    inner: Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a>,
165    _phantom: PhantomData<T>,
166}
167
168impl<'a, T: BorshDeserialize> Iterator for SparseArrayIter<'a, T> {
169    type Item = (u64, T);
170
171    fn next(&mut self) -> Option<Self::Item> {
172        self.inner.next().map(|(key, value)| {
173            let key = TryFrom::try_from(key.as_ref()).unwrap();
174            let index = u64::from_be_bytes(key);
175            let data = T::try_from_slice(&value).unwrap();
176
177            (index, data)
178        })
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_sparse_array_iter_slice() {
188        let a = SparseArray::new_test();
189        a.set(1, &1u32);
190        a.set(3, &2);
191        a.set(412345, &3);
192
193        assert_eq!(a.db.iter(0).count(), 3, "inner");
194        assert_eq!(a.iter().count(), 3, "iter");
195
196        assert_eq!(a.iter_slice(0..=412345).count(), 3, "all");
197        assert_eq!(a.iter_slice(1..=412345).count(), 3, "from 1");
198        assert_eq!(a.iter_slice(2..=412345).count(), 2, "from 2");
199        assert_eq!(a.iter_slice(2..=412344).count(), 1, "from 2 except last");
200    }
201
202    #[test]
203    fn test_sparse_array_remove() {
204        let a = SparseArray::new_test();
205        a.set(1, &1u32);
206        a.set(3, &2);
207        a.set(10, &3);
208        a.set(20, &4);
209        a.set(25, &5);
210        a.set(100, &6);
211        
212        a.remove_from(10);
213        assert_eq!(a.iter_slice(0..=100).count(), 2);
214
215        a.remove_all();
216        assert_eq!(a.iter_slice(0..=100).count(), 0);
217    }
218}