1use crate::error::{AmateRSError, ErrorContext, Result};
7use crate::types::{CipherBlob, Key};
8use parking_lot::RwLock;
9use std::collections::BTreeMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
14pub struct MemtableConfig {
15 pub max_size_bytes: usize,
17 pub enable_wal: bool,
19}
20
21impl Default for MemtableConfig {
22 fn default() -> Self {
23 Self {
24 max_size_bytes: 64 * 1024 * 1024, enable_wal: true,
26 }
27 }
28}
29
30#[derive(Debug, Clone)]
32enum MemtableEntry {
33 Value(CipherBlob),
35 Tombstone,
37}
38
39pub struct Memtable {
44 data: Arc<RwLock<BTreeMap<Key, MemtableEntry>>>,
46 size_bytes: Arc<RwLock<usize>>,
48 config: MemtableConfig,
50 sequence: Arc<RwLock<u64>>,
52}
53
54impl Memtable {
55 pub fn new() -> Self {
57 Self::with_config(MemtableConfig::default())
58 }
59
60 pub fn with_config(config: MemtableConfig) -> Self {
62 Self {
63 data: Arc::new(RwLock::new(BTreeMap::new())),
64 size_bytes: Arc::new(RwLock::new(0)),
65 config,
66 sequence: Arc::new(RwLock::new(0)),
67 }
68 }
69
70 pub fn put(&self, key: Key, value: CipherBlob) -> Result<()> {
72 let entry_size = Self::estimate_entry_size(&key, &value);
73
74 let mut data = self.data.write();
75 let mut size = self.size_bytes.write();
76
77 if let Some(old_entry) = data.get(&key) {
79 let old_size = match old_entry {
80 MemtableEntry::Value(v) => Self::estimate_entry_size(&key, v),
81 MemtableEntry::Tombstone => key.len() + 1,
82 };
83 *size = size.saturating_sub(old_size);
84 }
85
86 data.insert(key, MemtableEntry::Value(value));
87 *size += entry_size;
88
89 let mut seq = self.sequence.write();
91 *seq += 1;
92
93 Ok(())
94 }
95
96 pub fn get(&self, key: &Key) -> Result<Option<CipherBlob>> {
98 let data = self.data.read();
99
100 match data.get(key) {
101 Some(MemtableEntry::Value(v)) => Ok(Some(v.clone())),
102 Some(MemtableEntry::Tombstone) => Ok(None),
103 None => Ok(None),
104 }
105 }
106
107 pub fn delete(&self, key: Key) -> Result<()> {
109 let mut data = self.data.write();
110 let mut size = self.size_bytes.write();
111
112 if let Some(old_entry) = data.get(&key) {
114 let old_size = match old_entry {
115 MemtableEntry::Value(v) => Self::estimate_entry_size(&key, v),
116 MemtableEntry::Tombstone => key.len() + 1,
117 };
118 *size = size.saturating_sub(old_size);
119 }
120
121 let tombstone_size = key.len() + 1;
122 data.insert(key, MemtableEntry::Tombstone);
123 *size += tombstone_size;
124
125 let mut seq = self.sequence.write();
127 *seq += 1;
128
129 Ok(())
130 }
131
132 pub fn should_flush(&self) -> bool {
134 let size = *self.size_bytes.read();
135 size >= self.config.max_size_bytes
136 }
137
138 pub fn size_bytes(&self) -> usize {
140 *self.size_bytes.read()
141 }
142
143 pub fn len(&self) -> usize {
145 self.data.read().len()
146 }
147
148 pub fn is_empty(&self) -> bool {
150 self.data.read().is_empty()
151 }
152
153 pub fn sequence(&self) -> u64 {
155 *self.sequence.read()
156 }
157
158 pub fn entries(&self) -> Vec<(Key, Option<CipherBlob>)> {
162 let data = self.data.read();
163 data.iter()
164 .map(|(k, v)| {
165 let value = match v {
166 MemtableEntry::Value(blob) => Some(blob.clone()),
167 MemtableEntry::Tombstone => None,
168 };
169 (k.clone(), value)
170 })
171 .collect()
172 }
173
174 pub fn range(&self, start: &Key, end: &Key) -> Vec<(Key, CipherBlob)> {
176 let data = self.data.read();
177 data.range(start..end)
178 .filter_map(|(k, v)| match v {
179 MemtableEntry::Value(blob) => Some((k.clone(), blob.clone())),
180 MemtableEntry::Tombstone => None,
181 })
182 .collect()
183 }
184
185 #[cfg(test)]
187 pub fn clear(&self) {
188 let mut data = self.data.write();
189 let mut size = self.size_bytes.write();
190 data.clear();
191 *size = 0;
192 }
193
194 fn estimate_entry_size(key: &Key, value: &CipherBlob) -> usize {
196 key.len() + value.len() + 64
198 }
199}
200
201impl Default for Memtable {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_memtable_basic_operations() -> Result<()> {
213 let memtable = Memtable::new();
214
215 let key = Key::from_str("test_key");
216 let value = CipherBlob::new(vec![1, 2, 3, 4, 5]);
217
218 memtable.put(key.clone(), value.clone())?;
220 assert_eq!(memtable.len(), 1);
221
222 let retrieved = memtable.get(&key)?;
224 assert_eq!(retrieved, Some(value.clone()));
225
226 memtable.delete(key.clone())?;
228 let retrieved = memtable.get(&key)?;
229 assert_eq!(retrieved, None);
230
231 Ok(())
232 }
233
234 #[test]
235 fn test_memtable_size_tracking() -> Result<()> {
236 let memtable = Memtable::new();
237
238 assert_eq!(memtable.size_bytes(), 0);
239
240 let key = Key::from_str("key");
241 let value = CipherBlob::new(vec![0u8; 1000]);
242
243 memtable.put(key, value)?;
244
245 assert!(memtable.size_bytes() > 1000);
247
248 Ok(())
249 }
250
251 #[test]
252 fn test_memtable_ordering() -> Result<()> {
253 let memtable = Memtable::new();
254
255 memtable.put(Key::from_str("key3"), CipherBlob::new(vec![3]))?;
257 memtable.put(Key::from_str("key1"), CipherBlob::new(vec![1]))?;
258 memtable.put(Key::from_str("key2"), CipherBlob::new(vec![2]))?;
259
260 let entries = memtable.entries();
262 assert_eq!(entries.len(), 3);
263
264 assert_eq!(entries[0].0, Key::from_str("key1"));
265 assert_eq!(entries[1].0, Key::from_str("key2"));
266 assert_eq!(entries[2].0, Key::from_str("key3"));
267
268 Ok(())
269 }
270
271 #[test]
272 fn test_memtable_range() -> Result<()> {
273 let memtable = Memtable::new();
274
275 for i in 0..10 {
276 let key = Key::from_str(&format!("key_{:02}", i));
277 let value = CipherBlob::new(vec![i as u8]);
278 memtable.put(key, value)?;
279 }
280
281 let start = Key::from_str("key_03");
282 let end = Key::from_str("key_07");
283 let range = memtable.range(&start, &end);
284
285 assert_eq!(range.len(), 4); Ok(())
288 }
289
290 #[test]
291 fn test_memtable_flush_threshold() -> Result<()> {
292 let config = MemtableConfig {
293 max_size_bytes: 1000,
294 enable_wal: false,
295 };
296 let memtable = Memtable::with_config(config);
297
298 assert!(!memtable.should_flush());
299
300 for i in 0..100 {
302 let key = Key::from_str(&format!("key_{}", i));
303 let value = CipherBlob::new(vec![0u8; 100]);
304 memtable.put(key, value)?;
305
306 if memtable.should_flush() {
307 break;
308 }
309 }
310
311 assert!(memtable.should_flush());
312
313 Ok(())
314 }
315
316 #[test]
317 fn test_memtable_update() -> Result<()> {
318 let memtable = Memtable::new();
319
320 let key = Key::from_str("key");
321 let value1 = CipherBlob::new(vec![1, 2, 3]);
322 let value2 = CipherBlob::new(vec![4, 5, 6, 7, 8]);
323
324 memtable.put(key.clone(), value1)?;
325 let size1 = memtable.size_bytes();
326
327 memtable.put(key.clone(), value2.clone())?;
328 let size2 = memtable.size_bytes();
329
330 assert_ne!(size1, size2);
332
333 let retrieved = memtable.get(&key)?;
334 assert_eq!(retrieved, Some(value2));
335
336 Ok(())
337 }
338
339 #[test]
340 fn test_memtable_sequence() -> Result<()> {
341 let memtable = Memtable::new();
342
343 assert_eq!(memtable.sequence(), 0);
344
345 memtable.put(Key::from_str("key1"), CipherBlob::new(vec![1]))?;
346 assert_eq!(memtable.sequence(), 1);
347
348 memtable.put(Key::from_str("key2"), CipherBlob::new(vec![2]))?;
349 assert_eq!(memtable.sequence(), 2);
350
351 memtable.delete(Key::from_str("key1"))?;
352 assert_eq!(memtable.sequence(), 3);
353
354 Ok(())
355 }
356}