1use std::fs::remove_dir_all;
3use std::marker::PhantomData;
4use rocksdb::{DBWithThreadMode, MultiThreaded, DBIteratorWithThreadMode, BoundColumnFamily, Options, DBPinnableSlice};
5
6pub trait ConstantSize {
7 const LEN: usize;
8}
9
10impl ConstantSize for u32 {
11 const LEN: usize = 4;
12}
13
14pub trait Serialize {
15 fn serialize(&self) -> Vec<u8>;
16}
17
18impl Serialize for String {
19 fn serialize(&self) -> Vec<u8> {
20 self.as_bytes().to_vec()
21 }
22}
23
24impl Serialize for u32 {
25 fn serialize(&self) -> Vec<u8> {
26 self.to_le_bytes().to_vec()
27 }
28}
29
30impl<S> Serialize for Vec<S>
31 where S: Serialize,
32{
33 fn serialize(&self) -> Vec<u8> {
34 self.iter().map(|item| item.serialize()).collect::<Vec<Vec<u8>>>().concat()
35 }
36}
37
38pub trait Deserialize {
39 fn deserialize(buf: &[u8]) -> Self;
40}
41
42impl Deserialize for String {
43 fn deserialize(buf: &[u8]) -> Self {
44 Self::from_utf8(buf.to_vec()).unwrap()
45 }
46}
47
48impl Deserialize for u32 {
49 fn deserialize(buf: &[u8]) -> Self {
50 assert_eq!(buf.len(), 4);
51 let buf: [u8; 4] = [buf[0], buf[1], buf[2], buf[3]];
52 u32::from_le_bytes(buf)
53 }
54}
55
56impl<D> Deserialize for Vec<D>
57 where D: Deserialize + ConstantSize,
58{
59 fn deserialize(buf: &[u8]) -> Self {
60 let mut offset = 0usize;
61 let mut ret = Vec::new();
62 while offset < buf.len() {
63 ret.push(D::deserialize(&buf[offset..offset+D::LEN]));
64 offset += D::LEN;
65 }
66 ret
67 }
68}
69
70#[derive(Debug, Clone, Default)]
71pub struct Empty {}
72
73impl ConstantSize for Empty {
74 const LEN: usize = 0;
75}
76
77impl Serialize for Empty {
78 fn serialize(&self) -> Vec<u8> {
79 Vec::new()
80 }
81}
82
83impl Deserialize for Empty {
84 fn deserialize(_buf: &[u8]) -> Self {
85 Empty {}
86 }
87}
88
89type Rocks = DBWithThreadMode<MultiThreaded>;
90
91pub struct RocksDBIterator<'a, K, V>
92 where K: Serialize + Deserialize, V: Serialize + Deserialize,
93{
94 base: DBIteratorWithThreadMode<'a, Rocks>,
95 _k: PhantomData<fn() -> K>,
96 _v: PhantomData<fn() -> V>,
97}
98
99impl<'a, K, V> RocksDBIterator<'a, K, V>
100 where K: Serialize + Deserialize, V: Serialize + Deserialize,
101{
102 pub fn new(base: DBIteratorWithThreadMode<'a, Rocks>) -> Self {
103 Self {
104 base,
105 _k: PhantomData,
106 _v: PhantomData,
107 }
108 }
109}
110
111impl<'a, K, V> Iterator for RocksDBIterator<'a, K, V>
112 where K: Serialize + Deserialize, V: Serialize + Deserialize,
113{
114 type Item = (K, V);
115 fn next(&mut self) -> Option<Self::Item> {
116 self.base.next().map(|(key, value)| (K::deserialize(&key), V::deserialize(&value)))
117 }
118}
119
120pub struct RocksDBPrefixIterator<'a, K, V>
121 where K: Serialize + Deserialize,
122 V: Serialize + Deserialize,
123{
124 base: DBIteratorWithThreadMode<'a, Rocks>,
125 prefix: Vec<u8>,
126 _k: PhantomData<fn() -> K>,
127 _v: PhantomData<fn() -> V>,
128}
129
130impl<'a, K, V> RocksDBPrefixIterator<'a, K, V>
131 where K: Serialize + Deserialize,
132 V: Serialize + Deserialize,
133{
134 pub fn new(base: DBIteratorWithThreadMode<'a, Rocks>, prefix: Vec<u8>) -> Self
135 {
136 Self {
137 base,
138 prefix,
139 _k: PhantomData,
140 _v: PhantomData,
141 }
142 }
143}
144
145impl<'a, K, V> Iterator for RocksDBPrefixIterator<'a, K, V>
146 where K: Serialize + Deserialize,
147 V: Serialize + Deserialize,
148{
149 type Item = (K, V);
150 fn next(&mut self) -> Option<Self::Item> {
151 match self.base.next() {
152 Some((key, value)) => {
153 if self.prefix != key[0..self.prefix.len()] {
154 None
155 } else {
156 Some((K::deserialize(&key), V::deserialize(&value)))
157 }
158 },
159 None => None,
160 }
161 }
162}
163
164pub struct RocksDBColumnFamily<'a, K, V>
165 where K: Serialize + Deserialize + 'static,
166 V: Serialize + Deserialize + 'static,
167{
168 base: &'a RocksDB<Empty, Empty>,
169 name: String,
170 cf: BoundColumnFamily<'a>,
171 _k: PhantomData<fn() -> K>,
172 _v: PhantomData<fn() -> V>,
173}
174
175impl<'a, K, V> RocksDBColumnFamily<'a, K, V>
176 where K: Serialize + Deserialize + 'static,
177 V: Serialize + Deserialize + 'static,
178{
179 pub fn new(base: &'a RocksDB<Empty, Empty>, name: &str) -> Self {
180 let cf = match base.db.cf_handle(name) {
181 Some(cf) => cf,
182 None => {
183 let mut opts = Options::default();
184 opts.set_max_open_files(100);
185 opts.create_if_missing(true);
186 base.db.create_cf(name, &opts).unwrap();
187 base.db.cf_handle(name).unwrap()
188 },
189 };
190 Self {
191 base,
192 name: name.to_string(),
193 cf,
194 _k: PhantomData,
195 _v: PhantomData,
196 }
197 }
198 pub fn name(&self) -> &str {
199 self.name.as_str()
200 }
201 pub fn get(&self, key: &K) -> Option<V> {
202 self.base.db.get_pinned_cf(self.cf, key.serialize()).unwrap().map(|value| V::deserialize(&value))
203 }
204 pub fn put(&self, key: &K, value: &V) {
205 self.base.db.put_cf(self.cf, key.serialize(), value.serialize()).unwrap();
206 }
207 pub fn delete(&self, key: &K) {
208 self.base.db.delete_cf(self.cf, key.serialize()).unwrap();
209 }
210 pub fn iter(&self) -> RocksDBIterator<'_, K, V> {
211 RocksDBIterator::new(self.base.db.iterator_cf(self.cf, rocksdb::IteratorMode::Start))
212 }
213 pub fn prefix_iter(&self, prefix: Vec<u8>) -> RocksDBPrefixIterator<'_, K, V> {
214 RocksDBPrefixIterator::new(self.base.db.prefix_iterator_cf(self.cf, prefix.clone()), prefix)
215 }
216}
217
218#[derive(Debug)]
219pub struct RocksDB<K, V>
220 where K: Serialize + Deserialize + 'static,
221 V: Serialize + Deserialize + 'static,
222{
223 temporary: bool,
224 path: String,
225 db: Rocks,
226 _k: PhantomData<fn() -> K>,
227 _v: PhantomData<fn() -> V>,
228}
229
230impl<K, V> RocksDB<K, V>
231 where K: Serialize + Deserialize + 'static,
232 V: Serialize + Deserialize + 'static,
233{
234 pub fn new(path: &str, temporary: bool) -> Self {
235 if temporary && std::path::Path::new(path).exists() {
236 remove_dir_all(path).unwrap();
237 }
238 let mut opts = Options::default();
239 opts.set_max_open_files(100);
240 opts.create_if_missing(true);
241 let db = Rocks::open(&opts, path).expect("Failed to open the database.");
242 Self {
243 temporary,
244 path: path.to_string(),
245 db,
246 _k: PhantomData,
247 _v: PhantomData,
248 }
249 }
250 pub fn get(&self, key: &K) -> Option<V> {
251 self.db.get_pinned(key.serialize()).unwrap().map(|value| V::deserialize(&value))
252 }
253 pub fn get_raw(&self, key: &K) -> Option<DBPinnableSlice<'_>> {
254 self.db.get_pinned(key.serialize()).unwrap()
255 }
256 pub fn multi_get<I: IntoIterator<Item = K>>(&self, keys: I) -> Vec<Option<V>> {
257 let keys: Vec<Vec<u8>> = keys.into_iter().map(|key| key.serialize()).collect();
258 self.db.multi_get(keys).unwrap().iter().map(|value| {
259 if value.is_empty() {
260 None
261 } else {
262 Some(V::deserialize(value))
263 }
264 }).collect()
265 }
266 pub fn put(&self, key: &K, value: &V) {
267 self.db.put(key.serialize(), value.serialize()).unwrap();
268 }
269 pub fn delete(&self, key: &K) {
270 self.db.delete(key.serialize()).unwrap();
271 }
272 pub fn iter(&self) -> RocksDBIterator<'_, K, V> {
273 RocksDBIterator::new(self.db.iterator(rocksdb::IteratorMode::Start))
274 }
275 pub fn prefix_iter(&self, prefix: Vec<u8>) -> RocksDBPrefixIterator<'_, K, V> {
276 RocksDBPrefixIterator::new(self.db.prefix_iterator(prefix.clone()), prefix)
277 }
278 pub fn purge(&self) {
279 remove_dir_all(&self.path).unwrap();
280 }
281}
282
283impl<K, V> Drop for RocksDB<K, V>
284 where K: Serialize + Deserialize + 'static,
285 V: Serialize + Deserialize + 'static,
286{
287 fn drop(&mut self) {
288 if self.temporary {
289 self.purge();
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 #[test]
298 fn rocks_db() {
299 let db = RocksDB::<String, Vec<u32>>::new("/tmp/chainseeker/test_rocks_db", true);
300 let key1 = "bar".to_string();
301 let value1 = vec![3939, 4649];
302 let key2 = "foo".to_string();
303 let value2 = vec![1234, 5678];
304 db.put(&key1, &value1);
305 db.put(&key2, &value2);
306 assert_eq!(db.get(&key1), Some(value1.clone()));
307 assert_eq!(db.get(&key2), Some(value2.clone()));
308 assert_eq!(
309 db.iter().collect::<Vec<(String, Vec<u32>)>>(),
310 vec![(key1.clone(), value1.clone()), (key2.clone(), value2.clone())]);
311 assert_eq!(
312 db.prefix_iter(key1.as_bytes().to_vec()).collect::<Vec<(String, Vec<u32>)>>(),
313 vec![(key1.clone(), value1)]);
314 db.delete(&key1);
315 assert_eq!(db.get(&key1), None);
316 assert_eq!(db.multi_get(vec![key1, key2]), vec![None, Some(value2)]);
317 }
318 #[test]
319 fn rocks_db_cf() {
320 let db = RocksDB::<Empty, Empty>::new("/tmp/chainseeker/test_rocks_db_cf", true);
321 let db_cf1 = RocksDBColumnFamily::<u32, u32>::new(&db, "cf1");
322 let db_cf2 = RocksDBColumnFamily::<u32, u32>::new(&db, "cf2");
323 db_cf1.put(&114514, &12345);
324 db_cf2.put(&114514, &67890);
325 assert_eq!(db_cf1.get(&114514), Some(12345));
326 assert_eq!(db_cf2.get(&114514), Some(67890));
327 }
328}