1use std::{convert::TryFrom, marker::PhantomData, ops::RangeInclusive};
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use kvdb::{DBTransaction, KeyValueDB};
5use kvdb_memorydb::InMemory as MemoryDatabase;
6#[cfg(feature = "native")]
7use kvdb_rocksdb::{Database as NativeDatabase, DatabaseConfig};
8#[cfg(feature = "web")]
9use kvdb_web::Database as WebDatabase;
10
11pub struct SparseArray<D: KeyValueDB, T: BorshSerialize + BorshDeserialize> {
13 pub db: D,
14 _phantom: PhantomData<T>,
15}
16
17#[cfg(feature = "web")]
18pub type WebSparseArray<T> = SparseArray<WebDatabase, T>;
19
20#[cfg(feature = "native")]
21pub type NativeSparseArray<T> = SparseArray<NativeDatabase, T>;
22
23#[cfg(feature = "web")]
24impl<T> SparseArray<WebDatabase, T>
25where
26 T: BorshSerialize + BorshDeserialize,
27{
28 pub async fn new_web(name: &str) -> SparseArray<WebDatabase, T> {
29 let db = WebDatabase::open(name.to_owned(), 1).await.unwrap();
30
31 SparseArray {
32 db,
33 _phantom: Default::default(),
34 }
35 }
36}
37
38#[cfg(feature = "native")]
39impl<T> SparseArray<NativeDatabase, T>
40where
41 T: BorshSerialize + BorshDeserialize,
42{
43 pub fn new_native(
44 config: &DatabaseConfig,
45 path: &str,
46 ) -> std::io::Result<SparseArray<NativeDatabase, T>> {
47 let db = NativeDatabase::open(config, path)?;
48
49 Ok(SparseArray {
50 db,
51 _phantom: Default::default(),
52 })
53 }
54}
55
56impl<T> SparseArray<MemoryDatabase, T>
57where
58 T: BorshSerialize + BorshDeserialize,
59{
60 pub fn new_test() -> SparseArray<MemoryDatabase, T> {
61 let db = kvdb_memorydb::create(1);
62
63 SparseArray {
64 db,
65 _phantom: Default::default(),
66 }
67 }
68}
69
70impl<D: KeyValueDB, T> SparseArray<D, T>
71where
72 D: KeyValueDB,
73 T: BorshSerialize + BorshDeserialize + 'static,
74{
75 pub fn new(db: D) -> SparseArray<D, T> {
76 SparseArray {
77 db,
78 _phantom: Default::default(),
79 }
80 }
81
82 pub fn get(&self, index: u64) -> Option<T> {
83 let key = index.to_be_bytes();
84
85 self.db
86 .get(0, &key)
87 .unwrap()
88 .map(|data| T::try_from_slice(data.as_slice()).unwrap())
89 }
90
91 pub fn iter(&self) -> SparseArrayIter<T> {
92 SparseArrayIter {
93 inner: self.db.iter(0),
94 _phantom: Default::default(),
95 }
96 }
97
98 pub fn iter_slice(&self, range: RangeInclusive<u64>) -> impl Iterator<Item = (u64, T)> + '_ {
99 self.iter().filter(move |(index, _)| range.contains(index))
100 }
101
102 pub fn set(&self, index: u64, data: &T) {
103 let mut batch = self.db.transaction();
104 self.set_batched(index, data, &mut batch);
105 self.db.write(batch).unwrap();
106 }
107
108 pub fn remove(&self, index: u64) {
109 let mut batch = self.db.transaction();
110 let key = index.to_be_bytes();
111 batch.delete(0, &key);
112 self.db.write(batch).unwrap();
113 }
114
115 pub fn remove_from(&self, from_index: u64) {
116 let mut batch = self.db.transaction();
117 for (index, _) in self.iter() {
118 if index >= from_index {
119 let key = index.to_be_bytes();
120 batch.delete(0, &key);
121 }
122 }
123 self.db.write(batch).unwrap();
124 }
125
126 pub fn remove_all(&self) {
127 let mut batch = self.db.transaction();
128 self.db
130 .iter(0)
131 .for_each(|(key, _)| {
132 batch.delete(0_u32, &key);
133 });
134 self.db.write(batch).unwrap();
135 }
136
137 pub fn count(&self) -> usize {
139 self.db.iter(0).count()
140 }
141
142 pub fn set_multiple<'a, I>(&self, items: I)
143 where
144 I: IntoIterator<Item = &'a (u64, T)>,
145 {
146 let mut batch = self.db.transaction();
147
148 for (index, item) in items {
149 self.set_batched(*index, item, &mut batch);
150 }
151
152 self.db.write(batch).unwrap();
153 }
154
155 fn set_batched(&self, index: u64, data: &T, batch: &mut DBTransaction) {
156 let key = index.to_be_bytes();
157 let data = data.try_to_vec().unwrap();
158
159 batch.put(0, &key, &data);
160 }
161}
162
163pub struct SparseArrayIter<'a, T: BorshDeserialize> {
164 inner: Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a>,
165 _phantom: PhantomData<T>,
166}
167
168impl<'a, T: BorshDeserialize> Iterator for SparseArrayIter<'a, T> {
169 type Item = (u64, T);
170
171 fn next(&mut self) -> Option<Self::Item> {
172 self.inner.next().map(|(key, value)| {
173 let key = TryFrom::try_from(key.as_ref()).unwrap();
174 let index = u64::from_be_bytes(key);
175 let data = T::try_from_slice(&value).unwrap();
176
177 (index, data)
178 })
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_sparse_array_iter_slice() {
188 let a = SparseArray::new_test();
189 a.set(1, &1u32);
190 a.set(3, &2);
191 a.set(412345, &3);
192
193 assert_eq!(a.db.iter(0).count(), 3, "inner");
194 assert_eq!(a.iter().count(), 3, "iter");
195
196 assert_eq!(a.iter_slice(0..=412345).count(), 3, "all");
197 assert_eq!(a.iter_slice(1..=412345).count(), 3, "from 1");
198 assert_eq!(a.iter_slice(2..=412345).count(), 2, "from 2");
199 assert_eq!(a.iter_slice(2..=412344).count(), 1, "from 2 except last");
200 }
201
202 #[test]
203 fn test_sparse_array_remove() {
204 let a = SparseArray::new_test();
205 a.set(1, &1u32);
206 a.set(3, &2);
207 a.set(10, &3);
208 a.set(20, &4);
209 a.set(25, &5);
210 a.set(100, &6);
211
212 a.remove_from(10);
213 assert_eq!(a.iter_slice(0..=100).count(), 2);
214
215 a.remove_all();
216 assert_eq!(a.iter_slice(0..=100).count(), 0);
217 }
218}