1use std::{convert::TryFrom, marker::PhantomData, ops::RangeBounds};
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use kvdb::{DBTransaction, KeyValueDB};
5use kvdb_memorydb::InMemory as MemoryDatabase;
6#[cfg(feature = "native")]
7use kvdb_persy::PersyDatabase as NativeDatabase;
8#[cfg(feature = "web")]
9use kvdb_web::Database as WebDatabase;
10
11pub struct SparseArray<D: KeyValueDB, T: BorshSerialize + BorshDeserialize> {
13 pub db: D,
14 _phantom: PhantomData<T>,
15}
16
17#[cfg(feature = "web")]
18pub type WebSparseArray<T> = SparseArray<WebDatabase, T>;
19
20#[cfg(feature = "native")]
21pub type NativeSparseArray<T> = SparseArray<NativeDatabase, T>;
22
23#[cfg(feature = "web")]
24impl<T> SparseArray<WebDatabase, T>
25where
26 T: BorshSerialize + BorshDeserialize,
27{
28 pub async fn new_web(name: &str) -> SparseArray<WebDatabase, T> {
29 let db = WebDatabase::open(name.to_owned(), 1).await.unwrap();
30
31 SparseArray {
32 db,
33 _phantom: Default::default(),
34 }
35 }
36}
37
38#[cfg(feature = "native")]
39impl<T> SparseArray<NativeDatabase, T>
40where
41 T: BorshSerialize + BorshDeserialize,
42{
43 pub fn new_native(path: &str) -> std::io::Result<SparseArray<NativeDatabase, T>> {
44 let db = NativeDatabase::open(path, 1, &[])?;
45
46 Ok(SparseArray {
47 db,
48 _phantom: Default::default(),
49 })
50 }
51}
52
53impl<T> SparseArray<MemoryDatabase, T>
54where
55 T: BorshSerialize + BorshDeserialize,
56{
57 pub fn new_test() -> SparseArray<MemoryDatabase, T> {
58 let db = kvdb_memorydb::create(1);
59
60 SparseArray {
61 db,
62 _phantom: Default::default(),
63 }
64 }
65}
66
67impl<D: KeyValueDB, T> SparseArray<D, T>
68where
69 D: KeyValueDB,
70 T: BorshSerialize + BorshDeserialize + 'static,
71{
72 pub fn new(db: D) -> SparseArray<D, T> {
73 SparseArray {
74 db,
75 _phantom: Default::default(),
76 }
77 }
78
79 pub fn get(&self, index: u64) -> Option<T> {
80 let key = index.to_be_bytes();
81
82 self.db
83 .get(0, &key)
84 .unwrap()
85 .map(|data| T::try_from_slice(data.as_slice()).unwrap())
86 }
87
88 pub fn iter(&self) -> SparseArrayIter<T> {
89 SparseArrayIter {
90 inner: Box::new(self.db.iter(0).map(|res| res.unwrap())),
91 _phantom: Default::default(),
92 }
93 }
94
95 pub fn iter_slice<R>(&self, range: R) -> impl Iterator<Item = (u64, T)> + '_
96 where
97 R: RangeBounds<u64> + 'static,
98 {
99 self.iter().filter(move |(index, _)| range.contains(index))
100 }
101
102 pub fn set(&self, index: u64, data: &T) {
103 let mut batch = self.db.transaction();
104 self.set_batched(index, data, &mut batch);
105 self.db.write(batch).unwrap();
106 }
107
108 pub fn remove(&self, index: u64) {
109 let mut batch = self.db.transaction();
110 let key = index.to_be_bytes();
111 batch.delete(0, &key);
112 self.db.write(batch).unwrap();
113 }
114
115 pub fn remove_all_after(&self, index: u64) {
116 let mut batch = self.db.transaction();
117
118 for (index, _) in self.iter_slice(index..) {
119 let key = index.to_be_bytes();
120 batch.delete(0, &key);
121 }
122
123 self.db.write(batch).unwrap();
124 }
125
126 pub fn count(&self) -> usize {
128 self.db.iter(0).count()
129 }
130
131 pub fn set_multiple<'a, I>(&self, items: I)
132 where
133 I: IntoIterator<Item = &'a (u64, T)>,
134 {
135 let mut batch = self.db.transaction();
136
137 for (index, item) in items {
138 self.set_batched(*index, item, &mut batch);
139 }
140
141 self.db.write(batch).unwrap();
142 }
143
144 fn set_batched(&self, index: u64, data: &T, batch: &mut DBTransaction) {
145 let key = index.to_be_bytes();
146 let data = data.try_to_vec().unwrap();
147
148 batch.put(0, &key, &data);
149 }
150}
151
152pub struct SparseArrayIter<'a, T: BorshDeserialize> {
153 inner: Box<dyn Iterator<Item = (smallvec::SmallVec<[u8; 32]>, Vec<u8>)> + 'a>,
154 _phantom: PhantomData<T>,
155}
156
157impl<'a, T: BorshDeserialize> Iterator for SparseArrayIter<'a, T> {
158 type Item = (u64, T);
159
160 fn next(&mut self) -> Option<Self::Item> {
161 self.inner.next().map(|(key, value)| {
162 let key = TryFrom::try_from(key.as_ref()).unwrap();
163 let index = u64::from_be_bytes(key);
164 let data = T::try_from_slice(&value).unwrap();
165
166 (index, data)
167 })
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_sparse_array_iter_slice() {
177 let a = SparseArray::new_test();
178 a.set(1, &1u32);
179 a.set(3, &2);
180 a.set(412345, &3);
181
182 assert_eq!(a.db.iter(0).count(), 3, "inner");
183 assert_eq!(a.iter().count(), 3, "iter");
184
185 assert_eq!(a.iter_slice(0..=412345).count(), 3, "all");
186 assert_eq!(a.iter_slice(1..=412345).count(), 3, "from 1");
187 assert_eq!(a.iter_slice(2..=412345).count(), 2, "from 2");
188 assert_eq!(a.iter_slice(2..=412344).count(), 1, "from 2 except last");
189 }
190}