1use crate::block::BlockIter;
2use crate::compress::CompressionType;
3use crate::sstable::{SsTable, SsTableBuilder, SsTableIter};
4use crate::{KvIterator, MergeIterator};
5use bytes::Bytes;
6
7use std::ops::Bound;
8use std::{cmp::Ordering, collections::BTreeMap};
9
10#[derive(Debug, Clone)]
11pub struct MemKvStore {
12 mem_table: BTreeMap<Bytes, Bytes>,
13 ss_table: Vec<SsTable>,
15 block_size: usize,
16 compression_type: CompressionType,
17 should_encode_none: bool,
20}
21
22pub struct MemKvConfig {
23 block_size: usize,
24 compression_type: CompressionType,
25 should_encode_none: bool,
26}
27
28impl Default for MemKvConfig {
29 fn default() -> Self {
30 Self {
31 block_size: MemKvStore::DEFAULT_BLOCK_SIZE,
32 compression_type: CompressionType::LZ4,
33 should_encode_none: false,
34 }
35 }
36}
37
38impl MemKvConfig {
39 pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn block_size(mut self, block_size: usize) -> Self {
44 self.block_size = block_size;
45 self
46 }
47
48 pub fn compression_type(mut self, compression_type: CompressionType) -> Self {
49 self.compression_type = compression_type;
50 self
51 }
52
53 pub fn should_encode_none(mut self, should_encode_none: bool) -> Self {
54 self.should_encode_none = should_encode_none;
55 self
56 }
57
58 pub fn build(self) -> MemKvStore {
59 MemKvStore::new(self)
60 }
61}
62
63impl MemKvStore {
64 pub const DEFAULT_BLOCK_SIZE: usize = 4 * 1024;
65 pub fn new(config: MemKvConfig) -> Self {
66 Self {
67 mem_table: BTreeMap::new(),
68 ss_table: Vec::new(),
69 block_size: config.block_size,
70 compression_type: config.compression_type,
71 should_encode_none: config.should_encode_none,
72 }
73 }
74
75 pub fn get(&self, key: &[u8]) -> Option<Bytes> {
76 if let Some(v) = self.mem_table.get(key) {
77 if v.is_empty() {
78 return None;
79 }
80 return Some(v.clone());
81 }
82
83 for table in self.ss_table.iter().rev() {
84 if table.first_key > key || table.last_key < key {
85 continue;
86 }
87 let idx = table.find_block_idx(key);
89 let block = table.read_block_cached(idx);
90 let block_iter = BlockIter::new_seek_to_key(block, key);
91 if let Some(k) = block_iter.peek_next_curr_key() {
92 let v = block_iter.peek_next_curr_value().unwrap();
93 if k == key {
94 return if v.is_empty() { None } else { Some(v) };
95 }
96 }
97 }
98 None
99 }
100
101 pub fn set(&mut self, key: &[u8], value: Bytes) {
102 self.mem_table.insert(Bytes::copy_from_slice(key), value);
103 }
104
105 pub fn compare_and_swap(&mut self, key: &[u8], old: Option<Bytes>, new: Bytes) -> bool {
106 match self.get(key) {
107 Some(v) => {
108 if old == Some(v) {
109 self.set(key, new);
110 true
111 } else {
112 false
113 }
114 }
115 None => {
116 if old.is_none() {
117 self.set(key, new);
118 true
119 } else {
120 false
121 }
122 }
123 }
124 }
125
126 pub fn remove(&mut self, key: &[u8]) {
127 self.set(key, Bytes::new());
128 }
129
130 pub fn contains_key(&self, key: &[u8]) -> bool {
134 if self.mem_table.contains_key(key) {
135 return !self.mem_table.get(key).unwrap().is_empty();
136 }
137
138 for table in self.ss_table.iter().rev() {
139 if table.contains_key(key) {
140 if let Some(v) = table.get(key) {
141 return !v.is_empty();
142 }
143 }
144 }
145 false
146 }
147
148 pub fn scan(
149 &self,
150 start: std::ops::Bound<&[u8]>,
151 end: std::ops::Bound<&[u8]>,
152 ) -> Box<dyn DoubleEndedIterator<Item = (Bytes, Bytes)> + '_> {
153 if self.ss_table.is_empty() {
154 return Box::new(
155 self.mem_table
156 .range::<[u8], _>((start, end))
157 .filter(|(_, v)| !v.is_empty())
158 .map(|(k, v)| (k.clone(), v.clone())),
159 );
160 }
161
162 Box::new(MemStoreIterator::new(
163 self.mem_table
164 .range::<[u8], _>((start, end))
165 .map(|(k, v)| (k.clone(), v.clone())),
166 MergeIterator::new(
167 self.ss_table
168 .iter()
169 .rev()
170 .map(|table| SsTableIter::new_scan(table, start, end))
171 .collect(),
172 ),
173 true,
174 ))
175 }
176
177 pub fn len(&self) -> usize {
179 self.scan(Bound::Unbounded, Bound::Unbounded).count()
181 }
182
183 pub fn is_empty(&self) -> bool {
184 self.len() == 0
185 }
186
187 pub fn size(&self) -> usize {
188 self.mem_table
189 .iter()
190 .fold(0, |acc, (k, v)| acc + k.len() + v.len())
191 + self
192 .ss_table
193 .iter()
194 .map(|table| table.data_size())
195 .sum::<usize>()
196 }
197
198 pub fn export_all(&mut self) -> Bytes {
199 if self.mem_table.is_empty() && self.ss_table.len() == 1 {
200 return self.ss_table[0].export_all();
201 }
202
203 if self.ss_table.len() == 1 {
204 return self.export_with_encoded_block();
205 }
206
207 let mut builder = SsTableBuilder::new(
208 self.block_size,
209 self.compression_type,
210 self.should_encode_none,
211 );
212 let iter = MemStoreIterator::new(
214 self.mem_table
215 .range::<[u8], _>((Bound::Unbounded, Bound::Unbounded))
216 .map(|(k, v)| (k.clone(), v.clone())),
217 MergeIterator::new(
218 self.ss_table
219 .iter()
220 .rev()
221 .map(|table| SsTableIter::new_scan(table, Bound::Unbounded, Bound::Unbounded))
222 .collect(),
223 ),
224 false,
225 );
226
227 for (k, v) in iter {
228 builder.add(k, v);
229 }
230
231 if builder.is_empty() {
232 return Bytes::new();
233 }
234 self.mem_table.clear();
235 let ss = builder.build();
236 let ans = ss.export_all();
237 let _ = std::mem::replace(&mut self.ss_table, vec![ss]);
238 ans
239 }
240
241 pub fn import_all(&mut self, bytes: Bytes) -> Result<(), String> {
243 if bytes.is_empty() {
244 return Ok(());
245 }
246 let ss_table = SsTable::import_all(bytes).map_err(|e| e.to_string())?;
247 self.ss_table.push(ss_table);
248 Ok(())
249 }
250
251 #[tracing::instrument(level = "debug", skip(self))]
252 fn export_with_encoded_block(&mut self) -> Bytes {
253 ensure_cov::notify_cov("kv-store::mem_store::export_with_encoded_block");
254 let mut mem_iter = self.mem_table.iter().peekable();
255 let mut sstable_iter = self.ss_table[0].iter();
256 let mut builder = SsTableBuilder::new(
257 self.block_size,
258 self.compression_type,
259 self.should_encode_none,
260 );
261 'outer: while let Some(next_mem_pair) = mem_iter.peek() {
262 let block = loop {
263 let Some(block) = sstable_iter.peek_next_block() else {
264 builder.add(next_mem_pair.0.clone(), next_mem_pair.1.clone());
265 mem_iter.next();
266 continue 'outer;
267 };
268 if block.last_key() < next_mem_pair.0 {
269 builder.add_new_block(block.clone());
270 sstable_iter.next_block();
271 continue;
272 }
273 break block;
274 };
275
276 if block.first_key() > next_mem_pair.0 {
277 builder.add(next_mem_pair.0.clone(), next_mem_pair.1.clone());
278 mem_iter.next();
279 continue;
280 }
281
282 let mut iter = BlockIter::new(block.clone());
284 let mut next_mem_pair = mem_iter.peek();
285 while let Some(k) = iter.peek_next_key() {
286 loop {
287 match next_mem_pair {
288 Some(next_mem_pair_inner) => {
289 if k > next_mem_pair_inner.0 {
290 builder.add(
291 next_mem_pair_inner.0.clone(),
292 next_mem_pair_inner.1.clone(),
293 );
294 mem_iter.next();
295 next_mem_pair = mem_iter.peek();
296 continue;
297 }
298 if k == next_mem_pair_inner.0 {
299 builder.add(k, next_mem_pair_inner.1.clone());
300 mem_iter.next();
301 next_mem_pair = mem_iter.peek();
302 iter.next();
303 break;
304 }
305 builder.add(k, iter.peek_next_value().unwrap());
307 iter.next();
308 break;
309 }
310 None => {
311 builder.add(k, iter.peek_next_value().unwrap());
312 iter.next();
313 break;
314 }
315 }
316 }
317 }
318
319 sstable_iter.next_block();
320 }
321
322 while let Some(block) = sstable_iter.peek_next_block() {
323 builder.add_new_block(block.clone());
324 sstable_iter.next_block();
325 }
326
327 if builder.is_empty() {
328 return Bytes::new();
329 }
330
331 drop(mem_iter);
332 self.mem_table.clear();
333 let ss = builder.build();
334 let ans = ss.export_all();
335 let _ = std::mem::replace(&mut self.ss_table, vec![ss]);
336 ans
337 }
338
339 #[allow(unused)]
340 fn check_encode_data_correctness(&self, bytes: &Bytes) {
341 let this_data: BTreeMap<Bytes, Bytes> =
342 self.scan(Bound::Unbounded, Bound::Unbounded).collect();
343 let mut other_kv = MemKvStore::new(Default::default());
344 other_kv.import_all(bytes.clone()).unwrap();
345 let other_data: BTreeMap<Bytes, Bytes> =
346 other_kv.scan(Bound::Unbounded, Bound::Unbounded).collect();
347 assert_eq!(this_data, other_data);
348 }
349}
350
351#[derive(Debug)]
352pub struct MemStoreIterator<T, S> {
353 mem: T,
354 sst: S,
355 current_mem: Option<(Bytes, Bytes)>,
356 current_sstable: Option<(Bytes, Bytes)>,
357 back_mem: Option<(Bytes, Bytes)>,
358 back_sstable: Option<(Bytes, Bytes)>,
359 filter_empty: bool,
360}
361
362impl<T, S> MemStoreIterator<T, S>
363where
364 T: DoubleEndedIterator<Item = (Bytes, Bytes)>,
365 S: DoubleEndedIterator<Item = (Bytes, Bytes)>,
366{
367 fn new(mut mem: T, sst: S, filter_empty: bool) -> Self {
368 let current_mem = mem.next();
369 let back_mem = mem.next_back();
370 Self {
371 mem,
372 sst,
373 current_mem,
374 back_mem,
375 current_sstable: None,
376 back_sstable: None,
377 filter_empty,
378 }
379 }
380}
381
382impl<T, S> Iterator for MemStoreIterator<T, S>
383where
384 T: DoubleEndedIterator<Item = (Bytes, Bytes)>,
385 S: DoubleEndedIterator<Item = (Bytes, Bytes)>,
386{
387 type Item = (Bytes, Bytes);
388 fn next(&mut self) -> Option<Self::Item> {
389 loop {
390 if self.current_sstable.is_none() {
391 if let Some((k, v)) = self.sst.next() {
392 self.current_sstable = Some((k, v));
393 }
394 }
395
396 if self.current_mem.is_none() && self.back_mem.is_some() {
397 std::mem::swap(&mut self.back_mem, &mut self.current_mem);
398 }
399 let ans = match (&self.current_mem, &self.current_sstable) {
400 (Some((mem_key, _)), Some((iter_key, _))) => match mem_key.cmp(iter_key) {
401 Ordering::Less => self.current_mem.take().inspect(|_kv| {
402 self.current_mem = self.mem.next();
403 }),
404 Ordering::Equal => {
405 self.current_sstable.take();
406 self.current_mem.take().inspect(|_kv| {
407 self.current_mem = self.mem.next();
408 })
409 }
410 Ordering::Greater => self.current_sstable.take(),
411 },
412 (Some(_), None) => self.current_mem.take().inspect(|_kv| {
413 self.current_mem = self.mem.next();
414 }),
415 (None, Some(_)) => self.current_sstable.take(),
416 (None, None) => None,
417 };
418
419 if self.filter_empty {
420 if let Some((_k, v)) = &ans {
421 if v.is_empty() {
422 continue;
423 }
424 }
425 }
426
427 return ans;
428 }
429 }
430}
431
432impl<T, S> DoubleEndedIterator for MemStoreIterator<T, S>
433where
434 T: DoubleEndedIterator<Item = (Bytes, Bytes)>,
435 S: DoubleEndedIterator<Item = (Bytes, Bytes)>,
436{
437 fn next_back(&mut self) -> Option<Self::Item> {
438 if self.back_sstable.is_none() {
439 if let Some((k, v)) = self.sst.next_back() {
440 self.back_sstable = Some((k, v));
441 }
442 }
443
444 if self.back_mem.is_none() && self.current_mem.is_some() {
445 std::mem::swap(&mut self.back_mem, &mut self.current_mem);
446 }
447
448 let ans = match (&self.back_mem, &self.back_sstable) {
449 (Some((mem_key, _)), Some((iter_key, _))) => match mem_key.cmp(iter_key) {
450 Ordering::Greater => self.back_mem.take().inspect(|_kv| {
451 self.back_mem = self.mem.next_back();
452 }),
453 Ordering::Equal => {
454 self.back_sstable.take();
455 self.back_mem.take().inspect(|_kv| {
456 self.back_mem = self.mem.next_back();
457 })
458 }
459 Ordering::Less => self.back_sstable.take(),
460 },
461 (Some(_), None) => self.back_mem.take().inspect(|_kv| {
462 self.back_mem = self.mem.next_back();
463 }),
464 (None, Some(_)) => self.back_sstable.take(),
465 (None, None) => None,
466 };
467 if self.filter_empty {
468 if let Some((_k, v)) = &ans {
469 if v.is_empty() {
470 return self.next_back();
471 }
472 }
473 }
474 ans
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use std::vec;
481
482 use crate::{mem_store::MemKvConfig, MemKvStore};
483 use bytes::Bytes;
484 #[test]
485 fn test_mem_kv_store() {
486 let key = &[0];
487 let value = Bytes::from_static(&[0]);
488
489 let key2 = &[0, 1];
490 let value2 = Bytes::from_static(&[0, 1]);
491 let mut store = new_store();
492 store.set(key, value.clone());
493 assert_eq!(store.get(key), Some(value));
494 store.remove(key);
495 assert!(store.is_empty());
496 assert_eq!(store.get(key), None);
497 store.compare_and_swap(key, None, value2.clone());
498 assert_eq!(store.get(key), Some(value2.clone()));
499 assert!(store.contains_key(key));
500 assert!(!store.contains_key(key2));
501
502 store.set(key2, value2.clone());
503 assert_eq!(store.get(key2), Some(value2.clone()));
504 assert_eq!(store.len(), 2);
505 assert_eq!(store.size(), 7);
506 let bytes = store.export_all();
507 let mut new_store = new_store();
508 assert_eq!(new_store.len(), 0);
509 assert_eq!(new_store.size(), 0);
510 new_store.import_all(bytes).unwrap();
511
512 let iter1 = store
513 .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
514 .collect::<Vec<_>>();
515 let iter2 = new_store
516 .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
517 .collect::<Vec<_>>();
518 assert_eq!(iter1, iter2);
519
520 let iter1 = store
521 .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
522 .rev()
523 .collect::<Vec<_>>();
524 let iter2 = new_store
525 .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
526 .rev()
527 .collect::<Vec<_>>();
528 assert_eq!(iter1, iter2);
529 }
530
531 #[test]
532 fn test_large_block() {
533 let mut store = new_store();
534 let key = &[0];
535 let value = Bytes::from_static(&[0]);
536
537 let key2 = &[0, 1];
538 let key3 = &[0, 1, 2];
539 let large_value = Bytes::from_iter([0; 1024 * 8]);
540 let large_value2 = Bytes::from_iter([0; 1024 * 8]);
541 store.set(key, value.clone());
542 store.set(key2, large_value.clone());
543 let v2 = store.get(&[]);
544 assert_eq!(v2, None);
545 assert_eq!(store.get(key), Some(value.clone()));
546 assert_eq!(store.get(key2), Some(large_value.clone()));
547 store.export_all();
548 store.set(key3, large_value2.clone());
549 assert_eq!(store.get(key3), Some(large_value2.clone()));
550 assert_eq!(store.len(), 3);
551
552 let iter = store
553 .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
554 .collect::<Vec<_>>();
555 assert_eq!(
556 iter,
557 vec![
558 (Bytes::from_static(key), value.clone()),
559 (Bytes::from_static(key2), large_value.clone()),
560 (Bytes::from_static(key3), large_value2.clone())
561 ]
562 );
563
564 let iter2 = store
565 .scan(
566 std::ops::Bound::Included(key),
567 std::ops::Bound::Included(key3),
568 )
569 .collect::<Vec<_>>();
570 assert_eq!(iter, iter2);
571
572 let iter3 = store
573 .scan(
574 std::ops::Bound::Excluded(key),
575 std::ops::Bound::Excluded(key3),
576 )
577 .collect::<Vec<_>>();
578 assert_eq!(iter3.len(), 1);
579 assert_eq!(iter3[0], (Bytes::from_static(key2), large_value.clone()));
580
581 let v = store.get(key2).unwrap();
582 assert_eq!(v, large_value);
583
584 let v2 = store.get(&[]);
585 assert_eq!(v2, None);
586
587 store.compare_and_swap(key, Some(value.clone()), large_value.clone());
588 assert!(store.contains_key(key));
589 }
590
591 #[test]
592 fn same_key() {
593 let mut store = new_store();
594 let key = &[0];
595 let value = Bytes::from_static(&[0]);
596 store.set(key, value.clone());
597 store.export_all();
598 store.set(key, Bytes::new());
599 assert_eq!(store.get(key), None);
600 let iter = store
601 .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
602 .collect::<Vec<_>>();
603 assert_eq!(iter.len(), 0);
604 store.set(key, value.clone());
605 assert_eq!(store.get(key), Some(value));
606 }
607
608 #[test]
609 fn import_several_times() {
610 dev_utils::setup_test_log();
611 let a = Bytes::from_static(b"a");
612 let b = Bytes::from_static(b"b");
613 let c = Bytes::from_static(b"c");
614 let d = Bytes::from_static(b"d");
615 let e = Bytes::from_static(b"e");
616 let mut store = new_store();
617 store.set(&a, a.clone());
618 store.export_all();
619 store.set(&c, c.clone());
620 let encode1 = store.export_all();
621 let mut store2 = new_store();
622 store2.set(&b, b.clone());
623 store2.export_all();
624 store2.set(&c, Bytes::new());
625 let encode2 = store2.export_all();
626 let mut store3 = new_store();
627 store3.set(&d, d.clone());
628 store3.set(&a, Bytes::new());
629 tracing::info_span!("export da").in_scope(|| {
630 store3.export_all();
631 });
632 store3.set(&e, e.clone());
633 store3.set(&c, c.clone());
634 let encode3 = tracing::info_span!("export ec").in_scope(|| store3.export_all());
635
636 let mut store = new_store();
637 store.import_all(encode1).unwrap();
638 store.import_all(encode2).unwrap();
639 store.import_all(encode3).unwrap();
640 assert_eq!(store.get(&a), None);
641 assert_eq!(store.get(&b), Some(b.clone()));
642 assert_eq!(store.get(&c), Some(c.clone()));
643 assert_eq!(store.get(&d), Some(d.clone()));
644 assert_eq!(store.get(&e), Some(e.clone()));
645 }
646
647 fn new_store() -> MemKvStore {
648 MemKvStore::new(MemKvConfig::default().should_encode_none(true))
649 }
650}