tempest-kv 0.0.2

Key-Value storage layer for TempestDB
Documentation
#![allow(unused)]

#[macro_use]
extern crate tracing;

mod iterator;
mod manifest;
mod memtable;
mod sst;
mod wal;

pub mod base;
pub mod batch;
pub mod commit;
pub mod config;
pub mod error;
pub mod migration;
pub mod strategy;

#[cfg(test)]
mod tests;

use bytes::Bytes;
pub use error::*;
pub use strategy::*;

use std::{cell::RefCell, marker::PhantomData, ops::Bound, path::PathBuf, rc::Rc};

use futures::{FutureExt, pin_mut, select};
use tempest_io::Io;
use tempest_rt::{
    JoinHandle, spawn,
    sync::{
        mpsc::{self, Receiver},
        oneshot::{self, TryRecvError},
    },
};
use tracing::Instrument;

use crate::{
    base::{InternalKey, KeyKind, KeyTrailer, SeqNum},
    batch::WriteBatch,
    config::StorageConfig,
    iterator::{LogicalDedupIterator, SnapshotIterator, StorageIterator},
    memtable::{Memtable, MemtableIterator},
    wal::Wal,
};

///// Options that control the behavior of a single read operation.
/////
///// Passed to [`Storage::scan`] and [`Storage::get`] to inject per-read
///// strategy without changing the storage's own configuration. Currently
///// carries a [`Filter`] that gates entry visibility, but is designed as
///// an extension point - future fields like scan limits or read direction
///// can be added without breaking existing callsites.
//#[derive(Default)]
//pub struct ReadOpts<S: StorageStrategy> {
//    ///// The filter applied to every entry on this read's iterator stack.
//    ///// Entries where [`Filter::keep`] returns `false` are skipped entirely.
//    //pub filter: S::Filter,
//    _marker: PhantomData<S>,
//}

#[derive(Debug)]
pub struct WriteMessage {
    batch: WriteBatch,
    tx: oneshot::Sender<Result<(), StorageError>>,
}

pub struct GetMessage {
    key: Bytes,
    tx: oneshot::Sender<Result<Option<Bytes>, StorageError>>,
}

pub struct ScanMessage {
    start: Bound<Bytes>,
    end: Bound<Bytes>,
    tx: oneshot::Sender<mpsc::Receiver<Result<(Bytes, Bytes), StorageError>>>,
}

pub struct StorageHandle {
    write_tx: mpsc::BoundedSender<WriteMessage>,
    get_tx: mpsc::BoundedSender<GetMessage>,
    scan_tx: mpsc::BoundedSender<ScanMessage>,
}

impl StorageHandle {
    pub async fn write(&self, batch: WriteBatch) -> Result<(), StorageError> {
        let (tx, rx) = oneshot::channel();
        self.write_tx
            .clone() // PERF: rc clone of sender (cheap and required)
            .send(WriteMessage { batch, tx })
            .await
            .map_err(|_| StorageError::WorkerDied)?;

        rx.recv()
            .await
            .map_err(|_| StorageError::WorkerDied)
            .flatten()
    }

    pub async fn get(&self, key: Bytes) -> Result<Option<Bytes>, StorageError> {
        let (tx, rx) = oneshot::channel();
        self.get_tx
            .clone() // PERF: rc clone of sender (cheap and required)
            .send(GetMessage { key, tx })
            .await
            .map_err(|_| StorageError::WorkerDied)?;

        rx.recv()
            .await
            .map_err(|_| StorageError::WorkerDied)
            .flatten()
    }

    pub async fn scan(
        &self,
        start: Bound<Bytes>,
        end: Bound<Bytes>,
    ) -> Result<Receiver<Result<(Bytes, Bytes), StorageError>>, StorageError> {
        let (tx, rx) = oneshot::channel();
        self.scan_tx
            .clone()
            .send(ScanMessage { start, end, tx })
            .await
            .map_err(|_| StorageError::WorkerDied)?;

        rx.recv().await.map_err(|_| StorageError::WorkerDied)
    }
}

pub struct Storage<I: Io, S: StorageStrategy> {
    active: Rc<RefCell<Memtable<S::Comparer>>>,
    immutables: Vec<Rc<RefCell<Memtable<S::Comparer>>>>,

    //wal: Wal<I>,
    next_seqnum: u64,

    config: StorageConfig,
    _marker: PhantomData<fn() -> (I, S)>,
}

impl<I: Io, S: StorageStrategy> Storage<I, S> {
    fn get_seqnum(&mut self) -> SeqNum {
        let seqnum = self.next_seqnum;
        self.next_seqnum += 1;
        SeqNum::new(seqnum).expect("seqnum overflow")
    }

    fn snapshot(&mut self) -> SeqNum {
        // SAFETY: the least this will be is SeqNum::MIN (START - 1), which is still valid
        unsafe { SeqNum::new_unchecked(self.next_seqnum - 1) }
    }

    async fn handle_write(&mut self, msg: WriteMessage) {
        let mut batch = msg.batch;
        let seqnum = self.get_seqnum();
        batch.commit(seqnum);

        let _guard = debug_span!(
            "storage.write",
            batch.count = batch.count(),
            seqnum = seqnum.get(),
        )
        .entered();

        // TODO: commit to WAL

        let mut memtable = self.active.borrow_mut();
        for (key, trailer, value) in batch.into_iter() {
            memtable.insert(InternalKey::new(key, trailer), value);
        }
        debug!("write committed to memtable");

        let _ = msg.tx.send(Ok(()));
    }

    async fn handle_get(&mut self, msg: GetMessage) {
        let snapshot = self.snapshot();
        let _guard = debug_span!(
            "storage.get",
            key.len = msg.key.len(),
            snapshot = snapshot.get()
        )
        .entered();

        macro_rules! respond {
            ($kind:expr, $value:expr) => {
                // receiver closed; probably does not matter here
                let _ = match $kind {
                    KeyKind::Delete => {
                        debug!(found = true, kind = "Delete");
                        msg.tx.send(Ok(None))
                    }
                    KeyKind::Put => {
                        debug!(found = true, kind = "Put");
                        msg.tx.send(Ok(Some($value)))
                    }
                };
            };
            (None) => {
                debug!(found = false);
                let _ = msg.tx.send(Ok(None));
            };
        }

        // -- check memtables --
        if let Some((kind, value)) = self.active.borrow().get(&msg.key, snapshot) {
            respond!(kind, value);
            return;
        }

        if let Some((kind, value)) = self
            .immutables
            .iter()
            .find_map(|imm| imm.borrow().get(&msg.key, snapshot))
        {
            respond!(kind, value);
            return;
        }

        respond!(None);
    }

    async fn handle_scan(&mut self, msg: ScanMessage) {
        let active = self.active.clone();
        let snapshot = self.snapshot();
        spawn(async move {
            async move {
                let (mut tx, rx) = mpsc::bounded(64);
                if let Err(_) = msg.tx.send(rx) {
                    debug!("scan aborted: caller dropped receiver before channel was sent");
                    return;
                }

                // -- construct source iterator --
                let source = MemtableIterator::new(active);
                // TODO: source from immutable memtables and ssts

                // TODO: apply filter iterator based on filter context in msg to allow for engine MVCC
                let mut iter = LogicalDedupIterator::new(
                    SnapshotIterator::<I, S::Comparer, _>::new(source, snapshot),
                );

                if let Err(err) = match msg.start {
                    Bound::Included(seek) => {
                        iter.seek(InternalKey::new(
                            seek,
                            KeyTrailer::new(snapshot, KeyKind::MAX),
                        ))
                        .await
                    }
                    Bound::Excluded(seek) => Ok(()),
                    Bound::Unbounded => Ok(()),
                } {
                    let _ = tx.send(Err(err)).await;
                    return;
                }

                loop {
                    match iter.next().await {
                        Ok(Some((key, value))) => {
                            if match &msg.end {
                                Bound::Included(end) => key.key() > end,
                                Bound::Excluded(end) => key.key() >= end,
                                Bound::Unbounded => false,
                            } {
                                break;
                            }
                            match key.trailer().kind() {
                                KeyKind::Put => {
                                    if let Err(_) = tx.send(Ok((key.key().clone(), value))).await {
                                        break;
                                    }
                                }
                                KeyKind::Delete => {}
                            }
                        }
                        Ok(None) => break,
                        Err(err) => {
                            let _ = tx.send(Err(err)).await;
                            break;
                        }
                    }
                }
                debug!("scan complete");
                // NB: dropping tx will close the channel i.e. EOF
            }
            .instrument(debug_span!("storage.scan", snapshot = snapshot.get()))
            .await
        });
        debug!("scan dispatched");
    }

    async fn run(
        &mut self,
        mut write_rx: mpsc::Receiver<WriteMessage>,
        mut get_rx: mpsc::Receiver<GetMessage>,
        mut scan_rx: mpsc::Receiver<ScanMessage>,
    ) {
        loop {
            let mut write_fut = write_rx.recv().fuse();
            let mut get_fut = get_rx.recv().fuse();
            let mut scan_fut = scan_rx.recv().fuse();
            pin_mut!(write_fut, get_fut, scan_fut);

            select! {
                res = write_fut => match res {
                    Ok(msg) => self.handle_write(msg).await,
                    Err(_) => break, // channel closed
                },
                res = get_fut => match res {
                    Ok(msg) => self.handle_get(msg).await,
                    Err(_) => break, // channel closed
                },
                res = scan_fut => match res {
                    Ok(msg) => self.handle_scan(msg).await,
                    Err(_) => break, // channel closed
                },
            }

            // TODO: drive WAL? => make WAL a background task?
        }
    }

    pub fn init(dir: PathBuf, config: StorageConfig) -> (StorageHandle, JoinHandle<()>) {
        let (write_tx, write_rx) = mpsc::bounded(256);
        let (get_tx, get_rx) = mpsc::bounded(256);
        let (scan_tx, scan_rx) = mpsc::bounded(256);

        let join_handle = spawn(async move {
            let dir_display = dir.display().to_string();
            async move {
                info!("storage worker starting");

                // -- initialize wal --
                //let wal = match Wal::init(dir.join("wal"), config.wal.clone(), |r| {
                //    todo!("apply wal recovery: {r:?}")
                //})
                //.await
                //{
                //    Ok(wal) => {
                //        info!("WAL initialized");
                //        wal
                //    }
                //    Err(err) => {
                //        error!(?err, "failed to initialize WAL");
                //        return;
                //    }
                //};

                Storage::<I, S> {
                    active: Memtable::new_shared(0),
                    immutables: Vec::new(),

                    //wal,
                    next_seqnum: SeqNum::START.get(),

                    config,
                    _marker: PhantomData,
                }
                .run(write_rx, get_rx, scan_rx)
                .await;

                info!("storage worker exiting");
            }
            .instrument(info_span!("storage.worker", dir = dir_display))
            .await
        });

        let storage_handle = StorageHandle {
            write_tx,
            get_tx,
            scan_tx,
        };

        (storage_handle, join_handle)
    }
}