loro_kv_store/
mem_store.rs

1use crate::block::BlockIter;
2use crate::compress::CompressionType;
3use crate::sstable::{SsTable, SsTableBuilder, SsTableIter};
4use crate::{KvIterator, MergeIterator};
5use bytes::Bytes;
6
7use std::ops::Bound;
8use std::{cmp::Ordering, collections::BTreeMap};
9
10#[derive(Debug, Clone)]
11pub struct MemKvStore {
12    mem_table: BTreeMap<Bytes, Bytes>,
13    // From the oldest to the newest
14    ss_table: Vec<SsTable>,
15    block_size: usize,
16    compression_type: CompressionType,
17    /// It's only true when using it to fuzz.
18    /// Otherwise, importing and exporting GC snapshot relies on this field being false to work.
19    should_encode_none: bool,
20}
21
22pub struct MemKvConfig {
23    block_size: usize,
24    compression_type: CompressionType,
25    should_encode_none: bool,
26}
27
28impl Default for MemKvConfig {
29    fn default() -> Self {
30        Self {
31            block_size: MemKvStore::DEFAULT_BLOCK_SIZE,
32            compression_type: CompressionType::LZ4,
33            should_encode_none: false,
34        }
35    }
36}
37
38impl MemKvConfig {
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    pub fn block_size(mut self, block_size: usize) -> Self {
44        self.block_size = block_size;
45        self
46    }
47
48    pub fn compression_type(mut self, compression_type: CompressionType) -> Self {
49        self.compression_type = compression_type;
50        self
51    }
52
53    pub fn should_encode_none(mut self, should_encode_none: bool) -> Self {
54        self.should_encode_none = should_encode_none;
55        self
56    }
57
58    pub fn build(self) -> MemKvStore {
59        MemKvStore::new(self)
60    }
61}
62
63impl MemKvStore {
64    pub const DEFAULT_BLOCK_SIZE: usize = 4 * 1024;
65    pub fn new(config: MemKvConfig) -> Self {
66        Self {
67            mem_table: BTreeMap::new(),
68            ss_table: Vec::new(),
69            block_size: config.block_size,
70            compression_type: config.compression_type,
71            should_encode_none: config.should_encode_none,
72        }
73    }
74
75    pub fn get(&self, key: &[u8]) -> Option<Bytes> {
76        if let Some(v) = self.mem_table.get(key) {
77            if v.is_empty() {
78                return None;
79            }
80            return Some(v.clone());
81        }
82
83        for table in self.ss_table.iter().rev() {
84            if table.first_key > key || table.last_key < key {
85                continue;
86            }
87            // table.
88            let idx = table.find_block_idx(key);
89            let block = table.read_block_cached(idx);
90            let block_iter = BlockIter::new_seek_to_key(block, key);
91            if let Some(k) = block_iter.peek_next_curr_key() {
92                let v = block_iter.peek_next_curr_value().unwrap();
93                if k == key {
94                    return if v.is_empty() { None } else { Some(v) };
95                }
96            }
97        }
98        None
99    }
100
101    pub fn set(&mut self, key: &[u8], value: Bytes) {
102        self.mem_table.insert(Bytes::copy_from_slice(key), value);
103    }
104
105    pub fn compare_and_swap(&mut self, key: &[u8], old: Option<Bytes>, new: Bytes) -> bool {
106        match self.get(key) {
107            Some(v) => {
108                if old == Some(v) {
109                    self.set(key, new);
110                    true
111                } else {
112                    false
113                }
114            }
115            None => {
116                if old.is_none() {
117                    self.set(key, new);
118                    true
119                } else {
120                    false
121                }
122            }
123        }
124    }
125
126    pub fn remove(&mut self, key: &[u8]) {
127        self.set(key, Bytes::new());
128    }
129
130    /// Check if the key exists in the mem table or the sstable
131    ///
132    /// If the value is empty, it means the key is deleted
133    pub fn contains_key(&self, key: &[u8]) -> bool {
134        if self.mem_table.contains_key(key) {
135            return !self.mem_table.get(key).unwrap().is_empty();
136        }
137
138        for table in self.ss_table.iter().rev() {
139            if table.contains_key(key) {
140                if let Some(v) = table.get(key) {
141                    return !v.is_empty();
142                }
143            }
144        }
145        false
146    }
147
148    pub fn scan(
149        &self,
150        start: std::ops::Bound<&[u8]>,
151        end: std::ops::Bound<&[u8]>,
152    ) -> Box<dyn DoubleEndedIterator<Item = (Bytes, Bytes)> + '_> {
153        if self.ss_table.is_empty() {
154            return Box::new(
155                self.mem_table
156                    .range::<[u8], _>((start, end))
157                    .filter(|(_, v)| !v.is_empty())
158                    .map(|(k, v)| (k.clone(), v.clone())),
159            );
160        }
161
162        Box::new(MemStoreIterator::new(
163            self.mem_table
164                .range::<[u8], _>((start, end))
165                .map(|(k, v)| (k.clone(), v.clone())),
166            MergeIterator::new(
167                self.ss_table
168                    .iter()
169                    .rev()
170                    .map(|table| SsTableIter::new_scan(table, start, end))
171                    .collect(),
172            ),
173            true,
174        ))
175    }
176
177    /// The number of valid keys in the mem table and sstable, it's expensive to call
178    pub fn len(&self) -> usize {
179        // TODO: PERF
180        self.scan(Bound::Unbounded, Bound::Unbounded).count()
181    }
182
183    pub fn is_empty(&self) -> bool {
184        self.len() == 0
185    }
186
187    pub fn size(&self) -> usize {
188        self.mem_table
189            .iter()
190            .fold(0, |acc, (k, v)| acc + k.len() + v.len())
191            + self
192                .ss_table
193                .iter()
194                .map(|table| table.data_size())
195                .sum::<usize>()
196    }
197
198    pub fn export_all(&mut self) -> Bytes {
199        if self.mem_table.is_empty() && self.ss_table.len() == 1 {
200            return self.ss_table[0].export_all();
201        }
202
203        if self.ss_table.len() == 1 {
204            return self.export_with_encoded_block();
205        }
206
207        let mut builder = SsTableBuilder::new(
208            self.block_size,
209            self.compression_type,
210            self.should_encode_none,
211        );
212        // we could use scan() here, we should keep the empty value
213        let iter = MemStoreIterator::new(
214            self.mem_table
215                .range::<[u8], _>((Bound::Unbounded, Bound::Unbounded))
216                .map(|(k, v)| (k.clone(), v.clone())),
217            MergeIterator::new(
218                self.ss_table
219                    .iter()
220                    .rev()
221                    .map(|table| SsTableIter::new_scan(table, Bound::Unbounded, Bound::Unbounded))
222                    .collect(),
223            ),
224            false,
225        );
226
227        for (k, v) in iter {
228            builder.add(k, v);
229        }
230
231        if builder.is_empty() {
232            return Bytes::new();
233        }
234        self.mem_table.clear();
235        let ss = builder.build();
236        let ans = ss.export_all();
237        let _ = std::mem::replace(&mut self.ss_table, vec![ss]);
238        ans
239    }
240
241    /// We can import several times, the latter will override the former.
242    pub fn import_all(&mut self, bytes: Bytes) -> Result<(), String> {
243        if bytes.is_empty() {
244            return Ok(());
245        }
246        let ss_table = SsTable::import_all(bytes).map_err(|e| e.to_string())?;
247        self.ss_table.push(ss_table);
248        Ok(())
249    }
250
251    #[tracing::instrument(level = "debug", skip(self))]
252    fn export_with_encoded_block(&mut self) -> Bytes {
253        ensure_cov::notify_cov("kv-store::mem_store::export_with_encoded_block");
254        let mut mem_iter = self.mem_table.iter().peekable();
255        let mut sstable_iter = self.ss_table[0].iter();
256        let mut builder = SsTableBuilder::new(
257            self.block_size,
258            self.compression_type,
259            self.should_encode_none,
260        );
261        'outer: while let Some(next_mem_pair) = mem_iter.peek() {
262            let block = loop {
263                let Some(block) = sstable_iter.peek_next_block() else {
264                    builder.add(next_mem_pair.0.clone(), next_mem_pair.1.clone());
265                    mem_iter.next();
266                    continue 'outer;
267                };
268                if block.last_key() < next_mem_pair.0 {
269                    builder.add_new_block(block.clone());
270                    sstable_iter.next_block();
271                    continue;
272                }
273                break block;
274            };
275
276            if block.first_key() > next_mem_pair.0 {
277                builder.add(next_mem_pair.0.clone(), next_mem_pair.1.clone());
278                mem_iter.next();
279                continue;
280            }
281
282            // There are overlap between next_mem_pair and block
283            let mut iter = BlockIter::new(block.clone());
284            let mut next_mem_pair = mem_iter.peek();
285            while let Some(k) = iter.peek_next_key() {
286                loop {
287                    match next_mem_pair {
288                        Some(next_mem_pair_inner) => {
289                            if k > next_mem_pair_inner.0 {
290                                builder.add(
291                                    next_mem_pair_inner.0.clone(),
292                                    next_mem_pair_inner.1.clone(),
293                                );
294                                mem_iter.next();
295                                next_mem_pair = mem_iter.peek();
296                                continue;
297                            }
298                            if k == next_mem_pair_inner.0 {
299                                builder.add(k, next_mem_pair_inner.1.clone());
300                                mem_iter.next();
301                                next_mem_pair = mem_iter.peek();
302                                iter.next();
303                                break;
304                            }
305                            // k < next_mem_pair_inner.0
306                            builder.add(k, iter.peek_next_value().unwrap());
307                            iter.next();
308                            break;
309                        }
310                        None => {
311                            builder.add(k, iter.peek_next_value().unwrap());
312                            iter.next();
313                            break;
314                        }
315                    }
316                }
317            }
318
319            sstable_iter.next_block();
320        }
321
322        while let Some(block) = sstable_iter.peek_next_block() {
323            builder.add_new_block(block.clone());
324            sstable_iter.next_block();
325        }
326
327        if builder.is_empty() {
328            return Bytes::new();
329        }
330
331        drop(mem_iter);
332        self.mem_table.clear();
333        let ss = builder.build();
334        let ans = ss.export_all();
335        let _ = std::mem::replace(&mut self.ss_table, vec![ss]);
336        ans
337    }
338
339    #[allow(unused)]
340    fn check_encode_data_correctness(&self, bytes: &Bytes) {
341        let this_data: BTreeMap<Bytes, Bytes> =
342            self.scan(Bound::Unbounded, Bound::Unbounded).collect();
343        let mut other_kv = MemKvStore::new(Default::default());
344        other_kv.import_all(bytes.clone()).unwrap();
345        let other_data: BTreeMap<Bytes, Bytes> =
346            other_kv.scan(Bound::Unbounded, Bound::Unbounded).collect();
347        assert_eq!(this_data, other_data);
348    }
349}
350
351#[derive(Debug)]
352pub struct MemStoreIterator<T, S> {
353    mem: T,
354    sst: S,
355    current_mem: Option<(Bytes, Bytes)>,
356    current_sstable: Option<(Bytes, Bytes)>,
357    back_mem: Option<(Bytes, Bytes)>,
358    back_sstable: Option<(Bytes, Bytes)>,
359    filter_empty: bool,
360}
361
362impl<T, S> MemStoreIterator<T, S>
363where
364    T: DoubleEndedIterator<Item = (Bytes, Bytes)>,
365    S: DoubleEndedIterator<Item = (Bytes, Bytes)>,
366{
367    fn new(mut mem: T, sst: S, filter_empty: bool) -> Self {
368        let current_mem = mem.next();
369        let back_mem = mem.next_back();
370        Self {
371            mem,
372            sst,
373            current_mem,
374            back_mem,
375            current_sstable: None,
376            back_sstable: None,
377            filter_empty,
378        }
379    }
380}
381
382impl<T, S> Iterator for MemStoreIterator<T, S>
383where
384    T: DoubleEndedIterator<Item = (Bytes, Bytes)>,
385    S: DoubleEndedIterator<Item = (Bytes, Bytes)>,
386{
387    type Item = (Bytes, Bytes);
388    fn next(&mut self) -> Option<Self::Item> {
389        loop {
390            if self.current_sstable.is_none() {
391                if let Some((k, v)) = self.sst.next() {
392                    self.current_sstable = Some((k, v));
393                }
394            }
395
396            if self.current_mem.is_none() && self.back_mem.is_some() {
397                std::mem::swap(&mut self.back_mem, &mut self.current_mem);
398            }
399            let ans = match (&self.current_mem, &self.current_sstable) {
400                (Some((mem_key, _)), Some((iter_key, _))) => match mem_key.cmp(iter_key) {
401                    Ordering::Less => self.current_mem.take().inspect(|_kv| {
402                        self.current_mem = self.mem.next();
403                    }),
404                    Ordering::Equal => {
405                        self.current_sstable.take();
406                        self.current_mem.take().inspect(|_kv| {
407                            self.current_mem = self.mem.next();
408                        })
409                    }
410                    Ordering::Greater => self.current_sstable.take(),
411                },
412                (Some(_), None) => self.current_mem.take().inspect(|_kv| {
413                    self.current_mem = self.mem.next();
414                }),
415                (None, Some(_)) => self.current_sstable.take(),
416                (None, None) => None,
417            };
418
419            if self.filter_empty {
420                if let Some((_k, v)) = &ans {
421                    if v.is_empty() {
422                        continue;
423                    }
424                }
425            }
426
427            return ans;
428        }
429    }
430}
431
432impl<T, S> DoubleEndedIterator for MemStoreIterator<T, S>
433where
434    T: DoubleEndedIterator<Item = (Bytes, Bytes)>,
435    S: DoubleEndedIterator<Item = (Bytes, Bytes)>,
436{
437    fn next_back(&mut self) -> Option<Self::Item> {
438        if self.back_sstable.is_none() {
439            if let Some((k, v)) = self.sst.next_back() {
440                self.back_sstable = Some((k, v));
441            }
442        }
443
444        if self.back_mem.is_none() && self.current_mem.is_some() {
445            std::mem::swap(&mut self.back_mem, &mut self.current_mem);
446        }
447
448        let ans = match (&self.back_mem, &self.back_sstable) {
449            (Some((mem_key, _)), Some((iter_key, _))) => match mem_key.cmp(iter_key) {
450                Ordering::Greater => self.back_mem.take().inspect(|_kv| {
451                    self.back_mem = self.mem.next_back();
452                }),
453                Ordering::Equal => {
454                    self.back_sstable.take();
455                    self.back_mem.take().inspect(|_kv| {
456                        self.back_mem = self.mem.next_back();
457                    })
458                }
459                Ordering::Less => self.back_sstable.take(),
460            },
461            (Some(_), None) => self.back_mem.take().inspect(|_kv| {
462                self.back_mem = self.mem.next_back();
463            }),
464            (None, Some(_)) => self.back_sstable.take(),
465            (None, None) => None,
466        };
467        if self.filter_empty {
468            if let Some((_k, v)) = &ans {
469                if v.is_empty() {
470                    return self.next_back();
471                }
472            }
473        }
474        ans
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use std::vec;
481
482    use crate::{mem_store::MemKvConfig, MemKvStore};
483    use bytes::Bytes;
484    #[test]
485    fn test_mem_kv_store() {
486        let key = &[0];
487        let value = Bytes::from_static(&[0]);
488
489        let key2 = &[0, 1];
490        let value2 = Bytes::from_static(&[0, 1]);
491        let mut store = new_store();
492        store.set(key, value.clone());
493        assert_eq!(store.get(key), Some(value));
494        store.remove(key);
495        assert!(store.is_empty());
496        assert_eq!(store.get(key), None);
497        store.compare_and_swap(key, None, value2.clone());
498        assert_eq!(store.get(key), Some(value2.clone()));
499        assert!(store.contains_key(key));
500        assert!(!store.contains_key(key2));
501
502        store.set(key2, value2.clone());
503        assert_eq!(store.get(key2), Some(value2.clone()));
504        assert_eq!(store.len(), 2);
505        assert_eq!(store.size(), 7);
506        let bytes = store.export_all();
507        let mut new_store = new_store();
508        assert_eq!(new_store.len(), 0);
509        assert_eq!(new_store.size(), 0);
510        new_store.import_all(bytes).unwrap();
511
512        let iter1 = store
513            .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
514            .collect::<Vec<_>>();
515        let iter2 = new_store
516            .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
517            .collect::<Vec<_>>();
518        assert_eq!(iter1, iter2);
519
520        let iter1 = store
521            .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
522            .rev()
523            .collect::<Vec<_>>();
524        let iter2 = new_store
525            .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
526            .rev()
527            .collect::<Vec<_>>();
528        assert_eq!(iter1, iter2);
529    }
530
531    #[test]
532    fn test_large_block() {
533        let mut store = new_store();
534        let key = &[0];
535        let value = Bytes::from_static(&[0]);
536
537        let key2 = &[0, 1];
538        let key3 = &[0, 1, 2];
539        let large_value = Bytes::from_iter([0; 1024 * 8]);
540        let large_value2 = Bytes::from_iter([0; 1024 * 8]);
541        store.set(key, value.clone());
542        store.set(key2, large_value.clone());
543        let v2 = store.get(&[]);
544        assert_eq!(v2, None);
545        assert_eq!(store.get(key), Some(value.clone()));
546        assert_eq!(store.get(key2), Some(large_value.clone()));
547        store.export_all();
548        store.set(key3, large_value2.clone());
549        assert_eq!(store.get(key3), Some(large_value2.clone()));
550        assert_eq!(store.len(), 3);
551
552        let iter = store
553            .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
554            .collect::<Vec<_>>();
555        assert_eq!(
556            iter,
557            vec![
558                (Bytes::from_static(key), value.clone()),
559                (Bytes::from_static(key2), large_value.clone()),
560                (Bytes::from_static(key3), large_value2.clone())
561            ]
562        );
563
564        let iter2 = store
565            .scan(
566                std::ops::Bound::Included(key),
567                std::ops::Bound::Included(key3),
568            )
569            .collect::<Vec<_>>();
570        assert_eq!(iter, iter2);
571
572        let iter3 = store
573            .scan(
574                std::ops::Bound::Excluded(key),
575                std::ops::Bound::Excluded(key3),
576            )
577            .collect::<Vec<_>>();
578        assert_eq!(iter3.len(), 1);
579        assert_eq!(iter3[0], (Bytes::from_static(key2), large_value.clone()));
580
581        let v = store.get(key2).unwrap();
582        assert_eq!(v, large_value);
583
584        let v2 = store.get(&[]);
585        assert_eq!(v2, None);
586
587        store.compare_and_swap(key, Some(value.clone()), large_value.clone());
588        assert!(store.contains_key(key));
589    }
590
591    #[test]
592    fn same_key() {
593        let mut store = new_store();
594        let key = &[0];
595        let value = Bytes::from_static(&[0]);
596        store.set(key, value.clone());
597        store.export_all();
598        store.set(key, Bytes::new());
599        assert_eq!(store.get(key), None);
600        let iter = store
601            .scan(std::ops::Bound::Unbounded, std::ops::Bound::Unbounded)
602            .collect::<Vec<_>>();
603        assert_eq!(iter.len(), 0);
604        store.set(key, value.clone());
605        assert_eq!(store.get(key), Some(value));
606    }
607
608    #[test]
609    fn import_several_times() {
610        dev_utils::setup_test_log();
611        let a = Bytes::from_static(b"a");
612        let b = Bytes::from_static(b"b");
613        let c = Bytes::from_static(b"c");
614        let d = Bytes::from_static(b"d");
615        let e = Bytes::from_static(b"e");
616        let mut store = new_store();
617        store.set(&a, a.clone());
618        store.export_all();
619        store.set(&c, c.clone());
620        let encode1 = store.export_all();
621        let mut store2 = new_store();
622        store2.set(&b, b.clone());
623        store2.export_all();
624        store2.set(&c, Bytes::new());
625        let encode2 = store2.export_all();
626        let mut store3 = new_store();
627        store3.set(&d, d.clone());
628        store3.set(&a, Bytes::new());
629        tracing::info_span!("export da").in_scope(|| {
630            store3.export_all();
631        });
632        store3.set(&e, e.clone());
633        store3.set(&c, c.clone());
634        let encode3 = tracing::info_span!("export ec").in_scope(|| store3.export_all());
635
636        let mut store = new_store();
637        store.import_all(encode1).unwrap();
638        store.import_all(encode2).unwrap();
639        store.import_all(encode3).unwrap();
640        assert_eq!(store.get(&a), None);
641        assert_eq!(store.get(&b), Some(b.clone()));
642        assert_eq!(store.get(&c), Some(c.clone()));
643        assert_eq!(store.get(&d), Some(d.clone()));
644        assert_eq!(store.get(&e), Some(e.clone()));
645    }
646
647    fn new_store() -> MemKvStore {
648        MemKvStore::new(MemKvConfig::default().should_encode_none(true))
649    }
650}