chainseeker_server/
rocks_db.rs

1/// An abstraction struct for key-value store.
2use std::fs::remove_dir_all;
3use std::marker::PhantomData;
4use rocksdb::{DBWithThreadMode, MultiThreaded, DBIteratorWithThreadMode, BoundColumnFamily, Options, DBPinnableSlice};
5
6pub trait ConstantSize {
7    const LEN: usize;
8}
9
10impl ConstantSize for u32 {
11    const LEN: usize = 4;
12}
13
14pub trait Serialize {
15    fn serialize(&self) -> Vec<u8>;
16}
17
18impl Serialize for String {
19    fn serialize(&self) -> Vec<u8> {
20        self.as_bytes().to_vec()
21    }
22}
23
24impl Serialize for u32 {
25    fn serialize(&self) -> Vec<u8> {
26        self.to_le_bytes().to_vec()
27    }
28}
29
30impl<S> Serialize for Vec<S>
31    where S: Serialize,
32{
33    fn serialize(&self) -> Vec<u8> {
34        self.iter().map(|item| item.serialize()).collect::<Vec<Vec<u8>>>().concat()
35    }
36}
37
38pub trait Deserialize {
39    fn deserialize(buf: &[u8]) -> Self;
40}
41
42impl Deserialize for String {
43    fn deserialize(buf: &[u8]) -> Self {
44        Self::from_utf8(buf.to_vec()).unwrap()
45    }
46}
47
48impl Deserialize for u32 {
49    fn deserialize(buf: &[u8]) -> Self {
50        assert_eq!(buf.len(), 4);
51        let buf: [u8; 4] = [buf[0], buf[1], buf[2], buf[3]];
52        u32::from_le_bytes(buf)
53    }
54}
55
56impl<D> Deserialize for Vec<D>
57    where D: Deserialize + ConstantSize,
58{
59    fn deserialize(buf: &[u8]) -> Self {
60        let mut offset = 0usize;
61        let mut ret = Vec::new();
62        while offset < buf.len() {
63            ret.push(D::deserialize(&buf[offset..offset+D::LEN]));
64            offset += D::LEN;
65        }
66        ret
67    }
68}
69
70#[derive(Debug, Clone, Default)]
71pub struct Empty {}
72
73impl ConstantSize for Empty {
74    const LEN: usize = 0;
75}
76
77impl Serialize for Empty {
78    fn serialize(&self) -> Vec<u8> {
79        Vec::new()
80    }
81}
82
83impl Deserialize for Empty {
84    fn deserialize(_buf: &[u8]) -> Self {
85        Empty {}
86    }
87}
88
89type Rocks = DBWithThreadMode<MultiThreaded>;
90
91pub struct RocksDBIterator<'a, K, V>
92    where K: Serialize + Deserialize, V: Serialize + Deserialize,
93{
94    base: DBIteratorWithThreadMode<'a, Rocks>,
95    _k: PhantomData<fn() -> K>,
96    _v: PhantomData<fn() -> V>,
97}
98
99impl<'a, K, V> RocksDBIterator<'a, K, V>
100    where K: Serialize + Deserialize, V: Serialize + Deserialize,
101{
102    pub fn new(base: DBIteratorWithThreadMode<'a, Rocks>) -> Self {
103        Self {
104            base,
105            _k: PhantomData,
106            _v: PhantomData,
107        }
108    }
109}
110
111impl<'a, K, V> Iterator for RocksDBIterator<'a, K, V>
112    where K: Serialize + Deserialize, V: Serialize + Deserialize,
113{
114    type Item = (K, V);
115    fn next(&mut self) -> Option<Self::Item> {
116        self.base.next().map(|(key, value)| (K::deserialize(&key), V::deserialize(&value)))
117    }
118}
119
120pub struct RocksDBPrefixIterator<'a, K, V>
121    where K: Serialize + Deserialize,
122          V: Serialize + Deserialize,
123{
124    base: DBIteratorWithThreadMode<'a, Rocks>,
125    prefix: Vec<u8>,
126    _k: PhantomData<fn() -> K>,
127    _v: PhantomData<fn() -> V>,
128}
129
130impl<'a, K, V> RocksDBPrefixIterator<'a, K, V>
131    where K: Serialize + Deserialize,
132          V: Serialize + Deserialize,
133{
134    pub fn new(base: DBIteratorWithThreadMode<'a, Rocks>, prefix: Vec<u8>) -> Self
135    {
136        Self {
137            base,
138            prefix,
139            _k: PhantomData,
140            _v: PhantomData,
141        }
142    }
143}
144
145impl<'a, K, V> Iterator for RocksDBPrefixIterator<'a, K, V>
146    where K: Serialize + Deserialize,
147          V: Serialize + Deserialize,
148{
149    type Item = (K, V);
150    fn next(&mut self) -> Option<Self::Item> {
151        match self.base.next() {
152            Some((key, value)) => {
153                if self.prefix != key[0..self.prefix.len()] {
154                    None
155                } else {
156                    Some((K::deserialize(&key), V::deserialize(&value)))
157                }
158            },
159            None => None,
160        }
161    }
162}
163
164pub struct RocksDBColumnFamily<'a, K, V>
165    where K: Serialize + Deserialize + 'static,
166          V: Serialize + Deserialize + 'static,
167{
168    base: &'a RocksDB<Empty, Empty>,
169    name: String,
170    cf: BoundColumnFamily<'a>,
171    _k: PhantomData<fn() -> K>,
172    _v: PhantomData<fn() -> V>,
173}
174
175impl<'a, K, V> RocksDBColumnFamily<'a, K, V>
176    where K: Serialize + Deserialize + 'static,
177          V: Serialize + Deserialize + 'static,
178{
179    pub fn new(base: &'a RocksDB<Empty, Empty>, name: &str) -> Self {
180        let cf = match base.db.cf_handle(name) {
181            Some(cf) => cf,
182            None => {
183                let mut opts = Options::default();
184                opts.set_max_open_files(100);
185                opts.create_if_missing(true);
186                base.db.create_cf(name, &opts).unwrap();
187                base.db.cf_handle(name).unwrap()
188            },
189        };
190        Self {
191            base,
192            name: name.to_string(),
193            cf,
194            _k: PhantomData,
195            _v: PhantomData,
196        }
197    }
198    pub fn name(&self) -> &str {
199        self.name.as_str()
200    }
201    pub fn get(&self, key: &K) -> Option<V> {
202        self.base.db.get_pinned_cf(self.cf, key.serialize()).unwrap().map(|value| V::deserialize(&value))
203    }
204    pub fn put(&self, key: &K, value: &V) {
205        self.base.db.put_cf(self.cf, key.serialize(), value.serialize()).unwrap();
206    }
207    pub fn delete(&self, key: &K) {
208        self.base.db.delete_cf(self.cf, key.serialize()).unwrap();
209    }
210    pub fn iter(&self) -> RocksDBIterator<'_, K, V> {
211        RocksDBIterator::new(self.base.db.iterator_cf(self.cf, rocksdb::IteratorMode::Start))
212    }
213    pub fn prefix_iter(&self, prefix: Vec<u8>) -> RocksDBPrefixIterator<'_, K, V> {
214        RocksDBPrefixIterator::new(self.base.db.prefix_iterator_cf(self.cf, prefix.clone()), prefix)
215    }
216}
217
218#[derive(Debug)]
219pub struct RocksDB<K, V>
220    where K: Serialize + Deserialize + 'static,
221          V: Serialize + Deserialize + 'static,
222{
223    temporary: bool,
224    path: String,
225    db: Rocks,
226    _k: PhantomData<fn() -> K>,
227    _v: PhantomData<fn() -> V>,
228}
229
230impl<K, V> RocksDB<K, V>
231    where K: Serialize + Deserialize + 'static,
232          V: Serialize + Deserialize + 'static,
233{
234    pub fn new(path: &str, temporary: bool) -> Self {
235        if temporary && std::path::Path::new(path).exists() {
236            remove_dir_all(path).unwrap();
237        }
238        let mut opts = Options::default();
239        opts.set_max_open_files(100);
240        opts.create_if_missing(true);
241        let db = Rocks::open(&opts, path).expect("Failed to open the database.");
242        Self {
243            temporary,
244            path: path.to_string(),
245            db,
246            _k: PhantomData,
247            _v: PhantomData,
248        }
249    }
250    pub fn get(&self, key: &K) -> Option<V> {
251        self.db.get_pinned(key.serialize()).unwrap().map(|value| V::deserialize(&value))
252    }
253    pub fn get_raw(&self, key: &K) -> Option<DBPinnableSlice<'_>> {
254        self.db.get_pinned(key.serialize()).unwrap()
255    }
256    pub fn multi_get<I: IntoIterator<Item = K>>(&self, keys: I) -> Vec<Option<V>> {
257        let keys: Vec<Vec<u8>> = keys.into_iter().map(|key| key.serialize()).collect();
258        self.db.multi_get(keys).unwrap().iter().map(|value| {
259            if value.is_empty() {
260                None
261            } else {
262                Some(V::deserialize(value))
263            }
264        }).collect()
265    }
266    pub fn put(&self, key: &K, value: &V) {
267        self.db.put(key.serialize(), value.serialize()).unwrap();
268    }
269    pub fn delete(&self, key: &K) {
270        self.db.delete(key.serialize()).unwrap();
271    }
272    pub fn iter(&self) -> RocksDBIterator<'_, K, V> {
273        RocksDBIterator::new(self.db.iterator(rocksdb::IteratorMode::Start))
274    }
275    pub fn prefix_iter(&self, prefix: Vec<u8>) -> RocksDBPrefixIterator<'_, K, V> {
276        RocksDBPrefixIterator::new(self.db.prefix_iterator(prefix.clone()), prefix)
277    }
278    pub fn purge(&self) {
279        remove_dir_all(&self.path).unwrap();
280    }
281}
282
283impl<K, V> Drop for RocksDB<K, V>
284    where K: Serialize + Deserialize + 'static,
285          V: Serialize + Deserialize + 'static,
286{
287    fn drop(&mut self) {
288        if self.temporary {
289            self.purge();
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    #[test]
298    fn rocks_db() {
299        let db = RocksDB::<String, Vec<u32>>::new("/tmp/chainseeker/test_rocks_db", true);
300        let key1 = "bar".to_string();
301        let value1 = vec![3939, 4649];
302        let key2 = "foo".to_string();
303        let value2 = vec![1234, 5678];
304        db.put(&key1, &value1);
305        db.put(&key2, &value2);
306        assert_eq!(db.get(&key1), Some(value1.clone()));
307        assert_eq!(db.get(&key2), Some(value2.clone()));
308        assert_eq!(
309            db.iter().collect::<Vec<(String, Vec<u32>)>>(),
310            vec![(key1.clone(), value1.clone()), (key2.clone(), value2.clone())]);
311        assert_eq!(
312            db.prefix_iter(key1.as_bytes().to_vec()).collect::<Vec<(String, Vec<u32>)>>(),
313            vec![(key1.clone(), value1)]);
314        db.delete(&key1);
315        assert_eq!(db.get(&key1), None);
316        assert_eq!(db.multi_get(vec![key1, key2]), vec![None, Some(value2)]);
317    }
318    #[test]
319    fn rocks_db_cf() {
320        let db = RocksDB::<Empty, Empty>::new("/tmp/chainseeker/test_rocks_db_cf", true);
321        let db_cf1 = RocksDBColumnFamily::<u32, u32>::new(&db, "cf1");
322        let db_cf2 = RocksDBColumnFamily::<u32, u32>::new(&db, "cf2");
323        db_cf1.put(&114514, &12345);
324        db_cf2.put(&114514, &67890);
325        assert_eq!(db_cf1.get(&114514), Some(12345));
326        assert_eq!(db_cf2.get(&114514), Some(67890));
327    }
328}