use std::cmp::Ordering;
use std::collections::BTreeMap;
use std::fs;
use std::ops::{Bound, RangeBounds};
use std::path::{Path, PathBuf};
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::batch::{Batch, Op};
use crate::config::LsmConfig;
use crate::error::{Error, Result};
use crate::memtable::{MemTable, Record};
use crate::scan::Scan;
use crate::sstable::{SsTable, SsTableWriter};
const SSTABLE_FILE: &str = "data.sst";
const SSTABLE_TMP: &str = "data.sst.tmp";
#[derive(Debug)]
struct Inner {
memtable: MemTable,
sstable: Option<SsTable>,
}
#[derive(Debug)]
pub struct Lsm {
dir: PathBuf,
config: LsmConfig,
inner: RwLock<Inner>,
}
impl Lsm {
pub fn open(dir: impl AsRef<Path>) -> Result<Self> {
Self::open_with(dir, LsmConfig::default())
}
pub fn open_with(dir: impl AsRef<Path>, config: LsmConfig) -> Result<Self> {
let dir = dir.as_ref().to_path_buf();
fs::create_dir_all(&dir).map_err(|e| Error::io("create database directory", e))?;
let tmp = dir.join(SSTABLE_TMP);
if tmp.exists() {
fs::remove_file(&tmp).map_err(|e| Error::io("remove stale temporary run", e))?;
}
let run_path = dir.join(SSTABLE_FILE);
let sstable = if run_path.exists() {
Some(SsTable::open(&run_path)?)
} else {
None
};
Ok(Lsm {
dir,
config,
inner: RwLock::new(Inner {
memtable: MemTable::new(),
sstable,
}),
})
}
pub fn put(&self, key: impl AsRef<[u8]>, value: impl AsRef<[u8]>) -> Result<()> {
let mut inner = self.write_guard();
inner
.memtable
.put(key.as_ref().to_vec(), value.as_ref().to_vec());
self.maybe_flush(&mut inner)
}
pub fn delete(&self, key: impl AsRef<[u8]>) -> Result<()> {
let mut inner = self.write_guard();
inner.memtable.delete(key.as_ref().to_vec());
self.maybe_flush(&mut inner)
}
pub fn write(&self, batch: Batch) -> Result<()> {
let mut inner = self.write_guard();
for (key, op) in batch.into_ops() {
match op {
Op::Put(value) => inner.memtable.put(key, value),
Op::Delete => inner.memtable.delete(key),
}
}
self.maybe_flush(&mut inner)
}
pub fn get(&self, key: impl AsRef<[u8]>) -> Result<Option<Vec<u8>>> {
let key = key.as_ref();
let inner = self.read_guard();
match inner.memtable.get(key) {
Some(Record::Value(value)) => Ok(Some(value.clone())),
Some(Record::Tombstone) => Ok(None),
None => match inner.sstable.as_ref() {
Some(table) => table.get(key),
None => Ok(None),
},
}
}
pub fn scan<R>(&self, range: R) -> Result<Scan>
where
R: RangeBounds<Vec<u8>>,
{
let inner = self.read_guard();
let entries = collect_range(&inner, &range)?;
Ok(Scan::new(entries))
}
pub fn flush(&self) -> Result<()> {
let mut inner = self.write_guard();
if inner.memtable.is_empty() {
return Ok(());
}
self.flush_locked(&mut inner)
}
fn maybe_flush(&self, inner: &mut Inner) -> Result<()> {
if inner.memtable.approx_size() >= self.config.memtable_capacity_bytes()
&& !inner.memtable.is_empty()
{
self.flush_locked(inner)?;
}
Ok(())
}
fn flush_locked(&self, inner: &mut Inner) -> Result<()> {
let memtable = inner.memtable.take();
let tmp = self.dir.join(SSTABLE_TMP);
let run_path = self.dir.join(SSTABLE_FILE);
let mut writer = SsTableWriter::create(&tmp)?;
merge_into(&mut writer, &memtable, inner.sstable.as_ref())?;
writer.finish()?;
inner.sstable = None;
fs::rename(&tmp, &run_path).map_err(|e| Error::io("install flushed run", e))?;
inner.sstable = Some(SsTable::open(&run_path)?);
Ok(())
}
fn read_guard(&self) -> RwLockReadGuard<'_, Inner> {
self.inner
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn write_guard(&self) -> RwLockWriteGuard<'_, Inner> {
self.inner
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
}
fn run_key_at(sstable: Option<&SsTable>, si: usize, len: usize) -> Option<&[u8]> {
match sstable {
Some(table) if si < len => Some(table.key_at(si)),
_ => None,
}
}
fn merge_into(
writer: &mut SsTableWriter,
memtable: &BTreeMap<Vec<u8>, Record>,
sstable: Option<&SsTable>,
) -> Result<()> {
let run_len = sstable.map_or(0, SsTable::len);
let mut buffered = memtable.iter().peekable();
let mut si = 0usize;
loop {
let (take_buffered, take_run) = match (buffered.peek(), run_key_at(sstable, si, run_len)) {
(None, None) => break,
(Some(_), None) => (true, false),
(None, Some(_)) => (false, true),
(Some((bk, _)), Some(rk)) => match bk.as_slice().cmp(rk) {
Ordering::Less => (true, false),
Ordering::Greater => (false, true),
Ordering::Equal => (true, true),
},
};
if take_buffered {
if let Some((key, Record::Value(value))) = buffered.next() {
writer.push(key, value)?;
}
}
if take_run {
if let Some(table) = sstable {
if !take_buffered {
let value = table.read_value(si)?;
writer.push(table.key_at(si), &value)?;
}
si += 1;
}
}
}
Ok(())
}
fn collect_range<R>(inner: &Inner, range: &R) -> Result<Vec<(Vec<u8>, Vec<u8>)>>
where
R: RangeBounds<Vec<u8>>,
{
let sstable = inner.sstable.as_ref();
let run_len = sstable.map_or(0, SsTable::len);
let mut buffered = inner.memtable.iter().peekable();
let mut si = 0usize;
let mut out = Vec::new();
loop {
let (take_buffered, take_run) = match (buffered.peek(), run_key_at(sstable, si, run_len)) {
(None, None) => break,
(Some(_), None) => (true, false),
(None, Some(_)) => (false, true),
(Some((bk, _)), Some(rk)) => match bk.as_slice().cmp(rk) {
Ordering::Less => (true, false),
Ordering::Greater => (false, true),
Ordering::Equal => (true, true),
},
};
if take_buffered {
if let Some((key, Record::Value(value))) = buffered.next() {
if in_range(range, key) {
out.push((key.clone(), value.clone()));
}
}
}
if take_run {
if let Some(table) = sstable {
if !take_buffered && in_range(range, table.key_at(si)) {
let value = table.read_value(si)?;
out.push((table.key_at(si).to_vec(), value));
}
si += 1;
}
}
}
Ok(out)
}
fn in_range<R: RangeBounds<Vec<u8>>>(range: &R, key: &[u8]) -> bool {
let after_start = match range.start_bound() {
Bound::Included(s) => key >= s.as_slice(),
Bound::Excluded(s) => key > s.as_slice(),
Bound::Unbounded => true,
};
let before_end = match range.end_bound() {
Bound::Included(e) => key <= e.as_slice(),
Bound::Excluded(e) => key < e.as_slice(),
Bound::Unbounded => true,
};
after_start && before_end
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn db() -> (tempfile::TempDir, Lsm) {
let dir = tempfile::tempdir().unwrap();
let db = Lsm::open(dir.path()).unwrap();
(dir, db)
}
#[test]
fn test_put_get_roundtrip() {
let (_d, db) = db();
db.put(b"k", b"v").unwrap();
assert_eq!(db.get(b"k").unwrap(), Some(b"v".to_vec()));
}
#[test]
fn test_get_absent_is_none() {
let (_d, db) = db();
assert_eq!(db.get(b"absent").unwrap(), None);
}
#[test]
fn test_overwrite_returns_latest() {
let (_d, db) = db();
db.put(b"k", b"old").unwrap();
db.put(b"k", b"new").unwrap();
assert_eq!(db.get(b"k").unwrap(), Some(b"new".to_vec()));
}
#[test]
fn test_delete_masks_value() {
let (_d, db) = db();
db.put(b"k", b"v").unwrap();
db.delete(b"k").unwrap();
assert_eq!(db.get(b"k").unwrap(), None);
}
#[test]
fn test_delete_then_put_revives_key() {
let (_d, db) = db();
db.put(b"k", b"v1").unwrap();
db.delete(b"k").unwrap();
db.put(b"k", b"v2").unwrap();
assert_eq!(db.get(b"k").unwrap(), Some(b"v2".to_vec()));
}
#[test]
fn test_flush_then_read_from_run() {
let (_d, db) = db();
db.put(b"k", b"v").unwrap();
db.flush().unwrap();
assert_eq!(db.get(b"k").unwrap(), Some(b"v".to_vec()));
}
#[test]
fn test_buffer_shadows_run_value() {
let (_d, db) = db();
db.put(b"k", b"on-disk").unwrap();
db.flush().unwrap();
db.put(b"k", b"in-memory").unwrap();
assert_eq!(db.get(b"k").unwrap(), Some(b"in-memory".to_vec()));
}
#[test]
fn test_delete_after_flush_masks_run_value() {
let (_d, db) = db();
db.put(b"k", b"v").unwrap();
db.flush().unwrap();
db.delete(b"k").unwrap();
assert_eq!(db.get(b"k").unwrap(), None);
db.flush().unwrap();
assert_eq!(db.get(b"k").unwrap(), None);
}
#[test]
fn test_reopen_reads_flushed_keys() {
let dir = tempfile::tempdir().unwrap();
{
let db = Lsm::open(dir.path()).unwrap();
db.put(b"a", b"1").unwrap();
db.put(b"b", b"2").unwrap();
db.flush().unwrap();
}
let db = Lsm::open(dir.path()).unwrap();
assert_eq!(db.get(b"a").unwrap(), Some(b"1".to_vec()));
assert_eq!(db.get(b"b").unwrap(), Some(b"2".to_vec()));
}
#[test]
fn test_auto_flush_at_capacity() {
let dir = tempfile::tempdir().unwrap();
let db = Lsm::open_with(dir.path(), LsmConfig::new().memtable_capacity(0)).unwrap();
db.put(b"k", b"v").unwrap(); assert!(dir.path().join(SSTABLE_FILE).exists());
assert_eq!(db.get(b"k").unwrap(), Some(b"v".to_vec()));
}
#[test]
fn test_scan_full_range() {
let (_d, db) = db();
db.put(b"c", b"3").unwrap();
db.put(b"a", b"1").unwrap();
db.put(b"b", b"2").unwrap();
let got: Vec<_> = db.scan(..).unwrap().collect();
assert_eq!(
got,
vec![
(b"a".to_vec(), b"1".to_vec()),
(b"b".to_vec(), b"2".to_vec()),
(b"c".to_vec(), b"3".to_vec()),
]
);
}
#[test]
fn test_scan_half_open_range() {
let (_d, db) = db();
for (k, v) in [("a", "1"), ("b", "2"), ("c", "3"), ("d", "4")] {
db.put(k.as_bytes(), v.as_bytes()).unwrap();
}
let got: Vec<_> = db.scan(b"b".to_vec()..b"d".to_vec()).unwrap().collect();
assert_eq!(
got,
vec![
(b"b".to_vec(), b"2".to_vec()),
(b"c".to_vec(), b"3".to_vec())
]
);
}
#[test]
fn test_scan_inclusive_range() {
let (_d, db) = db();
for (k, v) in [("a", "1"), ("b", "2"), ("c", "3")] {
db.put(k.as_bytes(), v.as_bytes()).unwrap();
}
let got: Vec<_> = db.scan(b"a".to_vec()..=b"b".to_vec()).unwrap().collect();
assert_eq!(
got,
vec![
(b"a".to_vec(), b"1".to_vec()),
(b"b".to_vec(), b"2".to_vec())
]
);
}
#[test]
fn test_scan_merges_buffer_and_run() {
let (_d, db) = db();
db.put(b"a", b"old-a").unwrap();
db.put(b"c", b"3").unwrap();
db.flush().unwrap(); db.put(b"a", b"new-a").unwrap(); db.put(b"b", b"2").unwrap();
db.delete(b"c").unwrap(); let got: Vec<_> = db.scan(..).unwrap().collect();
assert_eq!(
got,
vec![
(b"a".to_vec(), b"new-a".to_vec()),
(b"b".to_vec(), b"2".to_vec())
]
);
}
#[test]
fn test_batch_applies_all() {
let (_d, db) = db();
db.put(b"c", b"keep").unwrap();
let mut batch = Batch::new();
batch.put(b"a", b"1");
batch.put(b"b", b"2");
batch.delete(b"c");
db.write(batch).unwrap();
assert_eq!(db.get(b"a").unwrap(), Some(b"1".to_vec()));
assert_eq!(db.get(b"b").unwrap(), Some(b"2".to_vec()));
assert_eq!(db.get(b"c").unwrap(), None);
}
#[test]
fn test_empty_value_roundtrips() {
let (_d, db) = db();
db.put(b"k", b"").unwrap();
assert_eq!(db.get(b"k").unwrap(), Some(Vec::new()));
db.flush().unwrap();
assert_eq!(db.get(b"k").unwrap(), Some(Vec::new()));
}
#[test]
fn test_stale_tmp_is_discarded_on_open() {
let dir = tempfile::tempdir().unwrap();
{
let db = Lsm::open(dir.path()).unwrap();
db.put(b"k", b"v").unwrap();
db.flush().unwrap();
}
std::fs::write(dir.path().join(SSTABLE_TMP), b"garbage").unwrap();
let db = Lsm::open(dir.path()).unwrap();
assert_eq!(db.get(b"k").unwrap(), Some(b"v".to_vec()));
assert!(!dir.path().join(SSTABLE_TMP).exists());
}
#[test]
fn test_engine_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Lsm>();
}
}