tempest-kv 0.0.2

Key-Value storage layer for TempestDB
Documentation
use std::{cell::RefCell, collections::BTreeMap, ops::Bound, rc::Rc};

use bytes::Bytes;
use tempest_io::Io;

use crate::{
    StorageError,
    base::{Comparer, InternalKey, KeyKind, KeyTrailer, SeqNum},
    iterator::StorageIterator,
};

pub(crate) struct MemtableIterator<C: Comparer> {
    memtable: Rc<RefCell<Memtable<C>>>,
    /// The lower bound for the next BTreeMap query.
    /// `Included(k)` after a seek positions next() to return k.
    /// `Excluded(k)` after a next() call advances past k.
    next_bound: Bound<InternalKey<C, Bytes>>,
    /// The last key returned, used for the forward-only seek check.
    last_key: Option<InternalKey<C, Bytes>>,
}

impl<C: Comparer> MemtableIterator<C> {
    pub(crate) fn new(memtable: Rc<RefCell<Memtable<C>>>) -> Self {
        Self {
            memtable,
            next_bound: Bound::Unbounded,
            last_key: None,
        }
    }
}

impl<I: Io, C: Comparer> StorageIterator<I, C> for MemtableIterator<C> {
    async fn next(&mut self) -> Result<Option<(InternalKey<C, Bytes>, Bytes)>, StorageError> {
        let bound = match &self.next_bound {
            Bound::Unbounded => Bound::Unbounded,
            Bound::Included(k) => Bound::Included(k.clone()),
            Bound::Excluded(k) => Bound::Excluded(k.clone()),
        };
        let entry = {
            let memtable = self.memtable.borrow();
            memtable
                .map
                .range((bound, Bound::Unbounded))
                .next()
                .map(|(k, v)| (k.clone(), v.clone()))
        };
        if let Some((ref k, _)) = entry {
            self.next_bound = Bound::Excluded(k.clone());
            self.last_key = Some(k.clone());
            trace!(
                key.len = k.key().len(),
                seqnum = k.trailer().seqnum().get(),
                kind = ?k.trailer().kind(),
                "memtable_iter: next"
            );
        } else {
            trace!("memtable_iter: exhausted");
        }
        Ok(entry)
    }

    async fn seek(&mut self, key: InternalKey<C, Bytes>) -> Result<(), StorageError> {
        trace!(key.len = key.key().len(), "memtable_iter: seek");
        if let Some(last) = &self.last_key {
            if key.compare_logical(last).is_le() {
                return Ok(());
            }
        }
        self.next_bound = Bound::Included(key);
        Ok(())
    }
}

#[derive(Debug, Default)]
pub(crate) struct Memtable<C: Comparer> {
    // TODO: Replace BTreeMap with a Skiplist
    map: BTreeMap<InternalKey<C>, Bytes>,
    approximate_size: usize,
    min_seqnum: Option<SeqNum>,
    max_seqnum: Option<SeqNum>,
    /// Represents the smallest filenum of the write-ahead logs that back this memtable.
    wal_filenum: u64,
    frozen: bool,
}

impl<C: Comparer> Memtable<C> {
    pub(crate) fn new_shared(wal_filenum: u64) -> Rc<RefCell<Self>> {
        Rc::new(RefCell::new(Self::new(wal_filenum)))
    }

    pub(crate) fn new(wal_filenum: u64) -> Self {
        Self {
            wal_filenum,
            ..Default::default()
        }
    }

    pub(crate) fn insert(&mut self, key: InternalKey<C>, value: Bytes) {
        assert!(!self.frozen, "cannot mutate frozen memtable");
        trace!(
            key_kind = ?key.trailer().kind(), key_len = key.key().len(),
            key=C::format(key.key().as_ref()), ?value, seqnum=?key.trailer().seqnum(),
            "inserting kv pair into memtable",
        );
        self.approximate_size += key.key().len() + value.len() + 16; // 16 for trailer + overhead
        let seqnum = key.trailer().seqnum();
        self.map.insert(key, value);
        self.min_seqnum = Some(match self.min_seqnum {
            Some(s) => s.min(seqnum),
            None => seqnum,
        });
        self.max_seqnum = Some(match self.max_seqnum {
            Some(s) => s.max(seqnum),
            None => seqnum,
        });
    }

    pub(crate) fn freeze(&mut self) {
        assert!(!self.frozen, "cannot mutate frozen memtable");
        self.frozen = true;
    }

    pub(crate) fn get(&self, key: &Bytes, snapshot: SeqNum) -> Option<(KeyKind, Bytes)> {
        let search_trailer = KeyTrailer::new(snapshot, KeyKind::MAX);
        let search_key = InternalKey::new(key.clone(), search_trailer);

        let compare_key = InternalKey::<C, &[u8]>::new(key.as_ref(), search_trailer);
        if let Some((found_key, found_value)) = self.map.range(search_key..).next()
            && found_key.compare_logical(&compare_key).is_eq()
        {
            return Some((found_key.trailer().kind(), found_value.clone()));
        }

        // no value was found
        None
    }

    pub(crate) const fn approximate_size(&self) -> usize {
        self.approximate_size
    }

    /// Returns the length i.e. the number of entries in this memtable
    pub(crate) fn len(&self) -> usize {
        self.map.len()
    }

    /// Returns the smallest key in this memtable.
    pub(crate) fn min_key(&self) -> Option<&InternalKey<C>> {
        self.map.keys().next()
    }

    /// Returns the largest key in this memtable.
    pub(crate) fn max_key(&self) -> Option<&InternalKey<C>> {
        self.map.keys().next_back()
    }

    /// Returns the smallest seqnum in this memtable.
    pub(crate) const fn min_seqnum(&self) -> Option<SeqNum> {
        self.min_seqnum
    }

    /// Returns the largest seqnum in this memtable.
    pub(crate) const fn max_seqnum(&self) -> Option<SeqNum> {
        self.max_seqnum
    }

    pub(crate) const fn wal_filenum(&self) -> u64 {
        self.wal_filenum
    }
}

#[cfg(test)]
mod tests {
    use tempest_io::VirtualIo;
    use tempest_rt::block_on;

    use crate::base::DefaultComparer;

    use super::*;

    type TestIter = MemtableIterator<DefaultComparer>;

    fn make_memtable() -> Rc<RefCell<Memtable<DefaultComparer>>> {
        let m = Rc::new(RefCell::new(Memtable::new(0)));
        let mut b = m.borrow_mut();
        b.insert(InternalKey::test(3), "three".into());
        b.insert(InternalKey::test(5), "five".into());
        b.insert(InternalKey::test(7), "seven".into());
        drop(b);
        m
    }

    async fn next(iter: &mut TestIter) -> Option<(InternalKey<DefaultComparer, Bytes>, Bytes)> {
        <TestIter as StorageIterator<VirtualIo, DefaultComparer>>::next(iter)
            .await
            .unwrap()
    }

    async fn seek(iter: &mut TestIter, key_id: u64) {
        <TestIter as StorageIterator<VirtualIo, DefaultComparer>>::seek(
            iter,
            InternalKey::test(key_id),
        )
        .await
        .unwrap();
    }

    #[test]
    fn test_memtable_get() {
        let mut memtable = Memtable::<DefaultComparer>::new(0);
        memtable.insert(InternalKey::test(3), "three".into());
        memtable.insert(InternalKey::test(5), "five".into());
        memtable.insert(InternalKey::test(7), "seven".into());

        assert_eq!(
            memtable
                .get(InternalKey::<DefaultComparer>::test(3).key(), SeqNum::TEST)
                .unwrap()
                .1,
            "three"
        );
        assert_eq!(
            memtable
                .get(InternalKey::<DefaultComparer>::test(5).key(), SeqNum::TEST)
                .unwrap()
                .1,
            "five"
        );
        assert_eq!(
            memtable
                .get(InternalKey::<DefaultComparer>::test(7).key(), SeqNum::TEST)
                .unwrap()
                .1,
            "seven"
        );
    }

    #[test]
    fn test_memtable_iterator() {
        block_on(VirtualIo::default(), async {
            let mut iter = TestIter::new(make_memtable());

            let (k, v) = next(&mut iter).await.unwrap();
            assert_eq!(k.test_key_as_u64(), 3);
            assert_eq!(v, "three");

            let (k, v) = next(&mut iter).await.unwrap();
            assert_eq!(k.test_key_as_u64(), 5);
            assert_eq!(v, "five");

            let (k, v) = next(&mut iter).await.unwrap();
            assert_eq!(k.test_key_as_u64(), 7);
            assert_eq!(v, "seven");

            assert!(next(&mut iter).await.is_none());
        });
    }

    #[test]
    fn test_memtable_iterator_seek_forward() {
        block_on(VirtualIo::default(), async {
            let mut iter = TestIter::new(make_memtable());

            seek(&mut iter, 5).await;
            let (k, v) = next(&mut iter).await.unwrap();
            assert_eq!(k.test_key_as_u64(), 5);
            assert_eq!(v, "five");

            let (k, _) = next(&mut iter).await.unwrap();
            assert_eq!(k.test_key_as_u64(), 7);
        });
    }

    #[test]
    fn test_memtable_iterator_seek_backward_noop() {
        block_on(VirtualIo::default(), async {
            let mut iter = TestIter::new(make_memtable());

            seek(&mut iter, 5).await;
            next(&mut iter).await; // consume 5

            // backward seek - no-op, next returns 7
            seek(&mut iter, 3).await;
            let (k, _) = next(&mut iter).await.unwrap();
            assert_eq!(k.test_key_as_u64(), 7);
        });
    }

    #[test]
    fn test_memtable_iterator_seek_same_position_noop() {
        block_on(VirtualIo::default(), async {
            let mut iter = TestIter::new(make_memtable());

            seek(&mut iter, 5).await;
            next(&mut iter).await; // consume 5

            // seek to same position - no-op, next returns 7
            seek(&mut iter, 5).await;
            let (k, _) = next(&mut iter).await.unwrap();
            assert_eq!(k.test_key_as_u64(), 7);
        });
    }
}