mod bloom_filter;
pub mod constants;
mod memtable;
mod sstable;
mod wal;
use crate::constants::{
BLOCK_CACHE_CAPACITY, BLOOM_FILTER_FPR, L0_COMPACTION_TRIGGER, L1_MAX_BYTES,
LEVEL_SIZE_MULTIPLIER, MAX_LEVELS, MEMTABLE_CAPACITY_BYTES,
};
use crate::memtable::MemTable;
use crate::sstable::{Manifest, SSTableBuilder, SSTableReader, VersionEdit, compaction::compact};
use crate::wal::Wal;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
pub type BlockCache = Arc<RwLock<lru::LruCache<(u64, u64), Arc<Vec<u8>>>>>;
pub struct StorageEngine {
active_memtable: Arc<Mutex<MemTable>>,
immutable_memtable: Arc<Mutex<Option<Arc<MemTable>>>>,
wal: Arc<Mutex<Wal>>,
sstables: Arc<RwLock<Vec<Vec<SSTableReader>>>>,
manifest: Arc<RwLock<Manifest>>,
memtable_capacity: usize,
next_seq_num: Arc<AtomicU64>,
db_path: Arc<PathBuf>,
block_cache: BlockCache,
flush_condvar: Arc<(Mutex<bool>, Condvar)>,
}
impl StorageEngine {
pub fn open(path: impl Into<PathBuf>) -> Result<Self, anyhow::Error> {
let db_path = path.into();
std::fs::create_dir_all(&db_path)?;
let wal_dir = db_path.join("wal");
let sst_dir = db_path.join("sst");
std::fs::create_dir_all(&wal_dir)?;
std::fs::create_dir_all(&sst_dir)?;
let mut wal = Wal::new(wal_dir)?;
let memtable_capacity = MEMTABLE_CAPACITY_BYTES;
let mut memtable = MemTable::new(memtable_capacity, BLOOM_FILTER_FPR);
let mut max_seq = 0;
if let Ok(records) = wal.recover() {
for record in records {
max_seq = max_seq.max(record.seq_num);
memtable.set(record.key, record.val);
}
}
let manifest_path = db_path.join("MANIFEST");
let active_ssts = Manifest::recover(&manifest_path)?;
let manifest = Manifest::open(&manifest_path)?;
let mut sstables: Vec<Vec<SSTableReader>> = Vec::new();
for (level, sst_ids) in active_ssts.iter().enumerate() {
let mut level_readers = Vec::new();
for sst_id in sst_ids {
let path = sst_dir.join(format!("{}.sst", sst_id));
if path.exists() {
level_readers.push(SSTableReader::new(path));
}
}
if level == 0 {
level_readers.sort_by(|a, b| b.id.cmp(&a.id));
}
sstables.push(level_readers);
}
if sstables.is_empty() {
sstables.push(Vec::new());
}
Ok(Self {
active_memtable: Arc::new(Mutex::new(memtable)),
immutable_memtable: Arc::new(Mutex::new(None)),
wal: Arc::new(Mutex::new(wal)),
sstables: Arc::new(RwLock::new(sstables)),
manifest: Arc::new(RwLock::new(manifest)),
memtable_capacity,
next_seq_num: Arc::new(AtomicU64::new(max_seq + 1)),
db_path: Arc::new(db_path),
block_cache: Arc::new(RwLock::new(lru::LruCache::new(
std::num::NonZeroUsize::new(BLOCK_CACHE_CAPACITY).unwrap(),
))),
flush_condvar: Arc::new((Mutex::new(false), Condvar::new())),
})
}
pub fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
) -> Result<(), anyhow::Error> {
let key = key.as_ref().to_vec();
let value = value.as_ref().to_vec();
let seq = self.next_seq_num.fetch_add(1, Ordering::SeqCst);
self.wal
.lock()
.map_err(|_| anyhow::anyhow!("WAL lock poisoned"))?
.add(seq, key.clone(), value.clone())?;
let needs_flush = {
let mut memtable = self
.active_memtable
.lock()
.map_err(|_| anyhow::anyhow!("MemTable lock poisoned"))?;
memtable.set(key, value);
memtable.needs_flush()
};
if needs_flush {
self.trigger_background_flush()?;
}
Ok(())
}
pub fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, anyhow::Error> {
let key = key.as_ref();
{
let memtable = self
.active_memtable
.lock()
.map_err(|_| anyhow::anyhow!("MemTable lock poisoned"))?;
if let Some(val) = memtable.get(key) {
if val.is_empty() {
return Ok(None);
}
return Ok(Some(val.clone()));
}
}
{
let imm = self
.immutable_memtable
.lock()
.map_err(|_| anyhow::anyhow!("Immutable MemTable lock poisoned"))?;
if let Some(imm_memtable) = imm.as_ref()
&& let Some(val) = imm_memtable.get(key)
{
if val.is_empty() {
return Ok(None);
}
return Ok(Some(val.clone()));
}
}
{
let sstables = self
.sstables
.read()
.map_err(|_| anyhow::anyhow!("SSTables read lock poisoned"))?;
for level in sstables.iter() {
for reader in level.iter() {
if let Some(val) = reader.get(key, Some(&self.block_cache)) {
if val.is_empty() {
return Ok(None);
}
return Ok(Some(val));
}
}
}
}
Ok(None)
}
pub fn update<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
) -> Result<(), anyhow::Error> {
self.put(key, value)
}
pub fn remove<K: AsRef<[u8]>>(&self, key: K) -> Result<(), anyhow::Error> {
let key = key.as_ref().to_vec();
let seq = self.next_seq_num.fetch_add(1, Ordering::SeqCst);
let tombstone_val: Vec<u8> = vec![];
self.wal
.lock()
.map_err(|_| anyhow::anyhow!("WAL lock poisoned"))?
.remove(seq, key.clone())?;
let needs_flush = {
let mut memtable = self
.active_memtable
.lock()
.map_err(|_| anyhow::anyhow!("MemTable lock poisoned"))?;
memtable.set(key, tombstone_val);
memtable.needs_flush()
};
if needs_flush {
self.trigger_background_flush()?;
}
Ok(())
}
pub fn clear(&self) -> Result<(), anyhow::Error> {
{
let mut memtable = self
.active_memtable
.lock()
.map_err(|_| anyhow::anyhow!("MemTable lock poisoned"))?;
memtable.clear();
}
{
let mut imm = self
.immutable_memtable
.lock()
.map_err(|_| anyhow::anyhow!("Immutable MemTable lock poisoned"))?;
*imm = None;
}
let mut sstables = self
.sstables
.write()
.map_err(|_| anyhow::anyhow!("SSTables lock poisoned"))?;
sstables.clear();
let sst_dir = self.db_path.join("sst");
let wal_dir = self.db_path.join("wal");
let manifest_path = self.db_path.join("MANIFEST");
let _ = std::fs::remove_dir_all(&sst_dir);
let _ = std::fs::remove_dir_all(&wal_dir);
let _ = std::fs::remove_file(&manifest_path);
std::fs::create_dir_all(&wal_dir)?;
std::fs::create_dir_all(&sst_dir)?;
{
let mut wal_writer = self
.wal
.lock()
.map_err(|_| anyhow::anyhow!("WAL lock poisoned"))?;
*wal_writer = Wal::new(wal_dir)?;
}
let manifest = crate::sstable::Manifest::open(&manifest_path)?;
let mut manifest_lock = self
.manifest
.write()
.map_err(|_| anyhow::anyhow!("Manifest lock poisoned"))?;
*manifest_lock = manifest;
sstables.push(Vec::new());
let mut cache = self
.block_cache
.write()
.map_err(|_| anyhow::anyhow!("Block Cache lock poisoned"))?;
cache.clear();
Ok(())
}
fn trigger_background_flush(&self) -> Result<(), anyhow::Error> {
let (flush_mutex, flush_condvar) = &*self.flush_condvar;
{
let mut flushing = flush_mutex.lock().unwrap();
while *flushing {
flushing = flush_condvar.wait(flushing).unwrap();
}
*flushing = true;
}
{
let mut active = self
.active_memtable
.lock()
.map_err(|_| anyhow::anyhow!("MemTable lock poisoned"))?;
let mut imm = self
.immutable_memtable
.lock()
.map_err(|_| anyhow::anyhow!("Immutable MemTable lock poisoned"))?;
let empty_memtable =
MemTable::new(self.memtable_capacity, crate::constants::BLOOM_FILTER_FPR);
let memtable_to_flush = std::mem::replace(&mut *active, empty_memtable);
*imm = Some(Arc::new(memtable_to_flush));
let wal_dir = self.db_path.join("wal");
if let Ok(mut wal) = Wal::new(wal_dir) {
wal.add(0, b"".to_vec(), b"".to_vec()).unwrap_or(());
let mut wal_writer = self
.wal
.lock()
.map_err(|_| anyhow::anyhow!("WAL lock poisoned"))?;
*wal_writer = wal;
}
};
let imm_memtable_arc = Arc::clone(&self.immutable_memtable);
let sstables_arc = Arc::clone(&self.sstables);
let manifest_arc = Arc::clone(&self.manifest);
let db_path_arc = Arc::clone(&self.db_path);
let wal_arc = Arc::clone(&self.wal);
let condvar_arc = Arc::clone(&self.flush_condvar);
std::thread::spawn(move || {
if let Err(e) = Self::flush_immutable_memtable(
imm_memtable_arc,
sstables_arc,
manifest_arc,
db_path_arc,
wal_arc,
) {
eprintln!("Background flush failed: {}", e);
}
let (mutex, condvar) = &*condvar_arc;
let mut flushing = mutex.lock().unwrap();
*flushing = false;
condvar.notify_all();
});
Ok(())
}
fn flush_immutable_memtable(
immutable_memtable: Arc<Mutex<Option<Arc<MemTable>>>>,
sstables: Arc<RwLock<Vec<Vec<SSTableReader>>>>,
manifest: Arc<RwLock<crate::sstable::Manifest>>,
db_path: Arc<PathBuf>,
wal: Arc<Mutex<Wal>>,
) -> Result<(), anyhow::Error> {
let memtable_arc = {
let imm = immutable_memtable
.lock()
.map_err(|_| anyhow::anyhow!("Immutable MemTable lock poisoned"))?;
match imm.as_ref() {
Some(m) => Arc::clone(m),
None => return Ok(()),
}
};
if memtable_arc.approximate_memory_usage() == 0 {
return Ok(());
}
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
let sst_path = db_path.join(format!("sst/{}.sst", timestamp));
let mut sst_builder = SSTableBuilder::new(sst_path.clone());
for (k, v) in memtable_arc.entries() {
sst_builder.add(k, v);
}
sst_builder.finish()?;
{
let mut sstables_write = sstables
.write()
.map_err(|_| anyhow::anyhow!("SSTables lock poisoned"))?;
sstables_write[0].insert(0, SSTableReader::new(sst_path));
}
let sst_id = timestamp as u64;
if let Ok(mut m_lock) = manifest.write() {
let _ = m_lock.log_edit(&VersionEdit::AddTable { level: 0, sst_id });
}
let safe_to_delete_wal_num = {
let wal_lock = wal.lock().unwrap();
wal_lock.current_file_num().saturating_sub(1)
};
{
let mut imm = immutable_memtable.lock().unwrap();
*imm = None;
}
if safe_to_delete_wal_num > 0 {
let wal_lock = wal.lock().unwrap();
let _ = wal_lock.delete_old_files(safe_to_delete_wal_num);
}
let _ = Self::run_compaction(sstables, manifest, db_path);
Ok(())
}
fn run_compaction(
sstables: Arc<RwLock<Vec<Vec<SSTableReader>>>>,
manifest: Arc<RwLock<Manifest>>,
db_path: Arc<PathBuf>,
) -> Result<(), anyhow::Error> {
let max_levels = MAX_LEVELS;
for level in 0..max_levels.saturating_sub(1) {
let next_level = level + 1;
let (needs_compact, input_ids, input_paths) = {
let sst_read = sstables.read().unwrap();
if sst_read.len() <= level || sst_read[level].is_empty() {
break;
}
let should_compact = if level == 0 {
sst_read[level].len() >= L0_COMPACTION_TRIGGER
} else {
let level_budget = L1_MAX_BYTES
* (LEVEL_SIZE_MULTIPLIER as u64).pow(level.saturating_sub(1) as u32);
let total_bytes: u64 = sst_read[level]
.iter()
.map(|r| {
let path = db_path.join(format!("sst/{}.sst", r.id));
std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0)
})
.sum();
total_bytes > level_budget
};
if !should_compact {
break;
}
let ids: Vec<u64> = sst_read[level].iter().map(|r| r.id).collect();
let paths: Vec<std::path::PathBuf> = ids
.iter()
.map(|id| db_path.join(format!("sst/{}.sst", id)))
.collect();
(true, ids, paths)
};
if !needs_compact {
break;
}
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let output_path = db_path.join(format!("sst/{}.sst", timestamp));
compact(input_paths, output_path.clone())?;
{
let mut m_lock = manifest.write().unwrap();
m_lock.log_edit(&VersionEdit::AddTable {
level: next_level as u32,
sst_id: timestamp,
})?;
for id in &input_ids {
m_lock.log_edit(&VersionEdit::RemoveTable {
level: level as u32,
sst_id: *id,
})?;
}
}
{
let mut sst_write = sstables.write().unwrap();
while sst_write.len() <= next_level {
sst_write.push(Vec::new());
}
sst_write[next_level].insert(0, SSTableReader::new(output_path));
sst_write[level].retain(|r| !input_ids.contains(&r.id));
}
for id in &input_ids {
let _ = std::fs::remove_file(db_path.join(format!("sst/{}.sst", id)));
}
}
Ok(())
}
}