cita_database/
memorydb.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use crate::database::{DataCategory, Database, Result};
5use crate::error::DatabaseError;
6use rocksdb::DBIterator;
7
8// For tests
9pub struct MemoryDB {
10    storage: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
11}
12
13impl MemoryDB {
14    pub fn open() -> Self {
15        MemoryDB {
16            storage: Arc::new(RwLock::new(HashMap::new())),
17        }
18    }
19}
20
21impl Default for MemoryDB {
22    fn default() -> Self {
23        MemoryDB {
24            storage: Arc::new(RwLock::new(HashMap::new())),
25        }
26    }
27}
28
29impl Database for MemoryDB {
30    fn get(&self, category: Option<DataCategory>, key: &[u8]) -> Result<Option<Vec<u8>>> {
31        let storage = Arc::clone(&self.storage);
32        let key = gen_key(&category, key.to_vec());
33
34        let storage = storage.read().map_err(|_| map_rwlock_err())?;
35        let v = storage.get(&key).map(|v| v.to_vec());
36        Ok(v)
37    }
38
39    fn get_batch(
40        &self,
41        category: Option<DataCategory>,
42        keys: &[Vec<u8>],
43    ) -> Result<Vec<Option<Vec<u8>>>> {
44        let storage = Arc::clone(&self.storage);
45        let keys = gen_keys(&category, keys.to_vec());
46
47        let storage = storage.read().map_err(|_| map_rwlock_err())?;
48        let values = keys
49            .into_iter()
50            .map(|key| storage.get(&key.to_vec()).map(|v| v.to_vec()))
51            .collect();
52
53        Ok(values)
54    }
55
56    fn insert(&self, category: Option<DataCategory>, key: Vec<u8>, value: Vec<u8>) -> Result<()> {
57        let storage = Arc::clone(&self.storage);
58        let key = gen_key(&category, key);
59        let value = value.to_vec();
60
61        let mut storage = storage.write().map_err(|_| map_rwlock_err())?;
62        storage.insert(key, value);
63        Ok(())
64    }
65
66    fn insert_batch(
67        &self,
68        category: Option<DataCategory>,
69        keys: Vec<Vec<u8>>,
70        values: Vec<Vec<u8>>,
71    ) -> Result<()> {
72        let storage = Arc::clone(&self.storage);
73        let keys = gen_keys(&category, keys);
74        let values = values.to_vec();
75
76        if keys.len() != values.len() {
77            return Err(DatabaseError::InvalidData);
78        }
79
80        let mut storage = storage.write().map_err(|_| map_rwlock_err())?;
81        for i in 0..keys.len() {
82            let key = keys[i].to_vec();
83            let value = values[i].to_vec();
84
85            storage.insert(key, value);
86        }
87
88        Ok(())
89    }
90
91    fn contains(&self, category: Option<DataCategory>, key: &[u8]) -> Result<bool> {
92        let storage = Arc::clone(&self.storage);
93        let key = gen_key(&category, key.to_vec());
94
95        let storage = storage.read().map_err(|_| map_rwlock_err())?;
96        Ok(storage.contains_key(&key))
97    }
98
99    fn remove(&self, category: Option<DataCategory>, key: &[u8]) -> Result<()> {
100        let storage = Arc::clone(&self.storage);
101        let key = gen_key(&category, key.to_vec());
102
103        let mut storage = storage.write().map_err(|_| map_rwlock_err())?;
104        storage.remove(&key);
105        Ok(())
106    }
107
108    fn remove_batch(&self, category: Option<DataCategory>, keys: &[Vec<u8>]) -> Result<()> {
109        let storage = Arc::clone(&self.storage);
110        let keys = gen_keys(&category, keys.to_vec());
111
112        let mut storage = storage.write().map_err(|_| map_rwlock_err())?;
113        for key in keys {
114            storage.remove(&key);
115        }
116        Ok(())
117    }
118
119    fn restore(&mut self, _new_db: &str) -> Result<()> {
120        unimplemented!()
121    }
122
123    fn iterator(&self, _category: Option<DataCategory>) -> Option<DBIterator> {
124        unimplemented!()
125    }
126
127    fn close(&mut self) {
128        unimplemented!();
129    }
130
131    fn flush(&self) -> Result<()> {
132        unimplemented!();
133    }
134}
135
136fn gen_key(category: &Option<DataCategory>, key: Vec<u8>) -> Vec<u8> {
137    match category {
138        Some(category) => match category {
139            DataCategory::State => [b"state-".to_vec(), key].concat(),
140            DataCategory::Headers => [b"headers-".to_vec(), key].concat(),
141            DataCategory::Bodies => [b"bodies-".to_vec(), key].concat(),
142            DataCategory::Extra => [b"extra-".to_vec(), key].concat(),
143            DataCategory::Trace => [b"trace-".to_vec(), key].concat(),
144            DataCategory::AccountBloom => [b"account-bloom-".to_vec(), key].concat(),
145            DataCategory::Other => [b"other-".to_vec(), key].concat(),
146        },
147        None => key,
148    }
149}
150
151fn gen_keys(category: &Option<DataCategory>, keys: Vec<Vec<u8>>) -> Vec<Vec<u8>> {
152    keys.into_iter().map(|key| gen_key(category, key)).collect()
153}
154
155fn map_rwlock_err() -> DatabaseError {
156    DatabaseError::Internal("rwlock error".to_string())
157}
158
159#[cfg(test)]
160mod tests {
161    use super::MemoryDB;
162    use crate::database::{DataCategory, Database};
163    use crate::error::DatabaseError;
164    use crate::test::{batch_op, insert_get_contains_remove};
165
166    #[test]
167    fn test_insert_get_contains_remove_with_category() {
168        let db = MemoryDB::open();
169
170        insert_get_contains_remove(&db, Some(DataCategory::State));
171    }
172
173    #[test]
174    fn test_insert_get_contains_remove() {
175        let db = MemoryDB::open();
176
177        insert_get_contains_remove(&db, None);
178    }
179
180    #[test]
181    fn test_batch_op_with_category() {
182        let db = MemoryDB::open();
183
184        batch_op(&db, Some(DataCategory::State));
185    }
186
187    #[test]
188    fn test_batch_op() {
189        let db = MemoryDB::open();
190
191        batch_op(&db, None);
192        batch_op(&db, Some(DataCategory::State));
193    }
194
195    #[test]
196    fn test_insert_batch_error() {
197        let db = MemoryDB::open();
198
199        let data = b"test".to_vec();
200
201        match db.insert_batch(None, vec![data], vec![]) {
202            Err(DatabaseError::InvalidData) => (), // pass
203            _ => panic!("should return error DatabaseError::InvalidData"),
204        }
205    }
206}