use std::cmp::Ordering;
use std::collections::HashMap;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use memmap2::{Mmap, MmapOptions};
use parking_lot::RwLock;
use super::block::{Block, BlockHandle, BlockIterator, BlockType};
use super::filter::FilterReader;
use super::format::{Footer, Header, Section, SectionType, SSTableFormat, HEADER_SIZE};
pub struct CachedBlock {
pub data: Vec<u8>,
pub block_type: BlockType,
pub decompressed: Vec<u8>,
}
pub struct BlockCache {
entries: RwLock<HashMap<(u64, u64), Arc<CachedBlock>>>,
capacity: usize,
}
impl BlockCache {
pub fn new(capacity: usize) -> Self {
Self {
entries: RwLock::new(HashMap::with_capacity(capacity)),
capacity,
}
}
pub fn get(&self, file_id: u64, offset: u64) -> Option<Arc<CachedBlock>> {
self.entries.read().get(&(file_id, offset)).cloned()
}
pub fn insert(&self, file_id: u64, offset: u64, block: CachedBlock) -> Arc<CachedBlock> {
let block = Arc::new(block);
let mut entries = self.entries.write();
if entries.len() >= self.capacity {
entries.clear();
}
entries.insert((file_id, offset), block.clone());
block
}
}
#[derive(Debug, Clone)]
pub struct ReadOptions {
pub verify_checksums: bool,
pub fill_cache: bool,
pub use_filter: bool,
}
impl Default for ReadOptions {
fn default() -> Self {
Self {
verify_checksums: true,
fill_cache: true,
use_filter: true,
}
}
}
pub struct SSTable {
path: PathBuf,
file_id: u64,
mmap: Mmap,
header: Header,
footer: Footer,
index: Vec<u8>,
index_entries: Vec<IndexEntry>,
filter: Option<FilterReader>,
metadata: TableMetadata,
cache: Option<Arc<BlockCache>>,
}
#[derive(Debug, Clone)]
struct IndexEntry {
largest_key: Vec<u8>,
handle: BlockHandle,
}
#[derive(Debug, Clone)]
pub struct TableMetadata {
pub file_size: u64,
pub num_data_blocks: usize,
pub smallest_key: Option<Vec<u8>>,
pub largest_key: Option<Vec<u8>>,
}
impl SSTable {
pub fn open<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
Self::open_with_cache(path, None)
}
pub fn open_with_cache<P: AsRef<Path>>(
path: P,
cache: Option<Arc<BlockCache>>,
) -> std::io::Result<Self> {
let path = path.as_ref();
let file = File::open(path)?;
let file_size = file.metadata()?.len();
let mmap = unsafe { MmapOptions::new().map(&file)? };
let file_id = {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
path.hash(&mut hasher);
hasher.finish()
};
if mmap.len() < HEADER_SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"File too small for SSTable header",
));
}
let header = Header::decode(&mmap[..HEADER_SIZE]).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid SSTable header")
})?;
let footer_offset = header.footer_offset as usize;
if footer_offset >= mmap.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Footer offset beyond file",
));
}
let footer = Footer::decode(&mmap[footer_offset..], header.num_sections).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid SSTable footer")
})?;
let index_section = footer
.sections
.iter()
.find(|s| s.section_type == SectionType::Index)
.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Missing index section")
})?;
let index_start = index_section.offset as usize;
let index_end = index_start + index_section.size as usize;
let index = mmap[index_start..index_end].to_vec();
let index_entries = Self::parse_index(&index)?;
let filter = footer
.sections
.iter()
.find(|s| s.section_type == SectionType::Filter)
.and_then(|section| {
let start = section.offset as usize;
let end = start + section.size as usize;
FilterReader::from_bytes(&mmap[start..end])
});
let metadata = TableMetadata {
file_size,
num_data_blocks: index_entries.len(),
smallest_key: index_entries.first().map(|e| e.largest_key.clone()),
largest_key: index_entries.last().map(|e| e.largest_key.clone()),
};
Ok(Self {
path: path.to_path_buf(),
file_id,
mmap,
header,
footer,
index,
index_entries,
filter,
metadata,
cache,
})
}
fn parse_index(data: &[u8]) -> std::io::Result<Vec<IndexEntry>> {
let mut entries = Vec::new();
let block = Block::new(data.to_vec()).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid index block")
})?;
let mut iter = block.iter();
while iter.valid() {
let key = iter.key().to_vec();
let value = iter.value();
let (handle, _bytes_read) = BlockHandle::decode(value).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid block handle")
})?;
entries.push(IndexEntry {
largest_key: key,
handle,
});
iter.next();
}
Ok(entries)
}
pub fn get(&self, key: &[u8], options: &ReadOptions) -> std::io::Result<Option<Vec<u8>>> {
if options.use_filter {
if let Some(ref filter) = self.filter {
if !filter.may_contain(key) {
return Ok(None);
}
}
}
let block_idx = self.find_block_for_key(key);
if block_idx >= self.index_entries.len() {
return Ok(None);
}
let block_data = self.read_block(&self.index_entries[block_idx].handle, options)?;
let block = Block::new(block_data).ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid data block")
})?;
let iter = block.seek(key);
if iter.valid() && iter.key() == key {
Ok(Some(iter.value().to_vec()))
} else {
Ok(None)
}
}
fn find_block_for_key(&self, key: &[u8]) -> usize {
self.index_entries
.binary_search_by(|entry| {
if entry.largest_key.as_slice() < key {
Ordering::Less
} else {
Ordering::Greater
}
})
.unwrap_or_else(|i| i)
}
fn read_block(
&self,
handle: &BlockHandle,
options: &ReadOptions,
) -> std::io::Result<Vec<u8>> {
let offset = handle.offset();
let size = handle.size();
if let Some(ref cache) = self.cache {
if let Some(block) = cache.get(self.file_id, offset) {
return Ok(block.decompressed.clone());
}
}
let start = offset as usize;
let end = start + size as usize;
if end + 5 > self.mmap.len() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Block extends beyond file",
));
}
let block_data = &self.mmap[start..end];
let block_type = BlockType::from_u8(self.mmap[end]);
let stored_checksum = u32::from_le_bytes([
self.mmap[end + 1],
self.mmap[end + 2],
self.mmap[end + 3],
self.mmap[end + 4],
]);
if options.verify_checksums {
let computed_checksum = crc32fast::hash(block_data);
if computed_checksum != stored_checksum {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Block checksum mismatch",
));
}
}
let decompressed = match block_type {
BlockType::Uncompressed => block_data.to_vec(),
BlockType::Lz4 => lz4_flex::decompress_size_prepended(block_data).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, format!("LZ4 error: {}", e))
})?,
BlockType::Zstd => zstd::decode_all(block_data).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Zstd error: {}", e))
})?,
BlockType::Snappy => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Snappy not supported",
))
}
};
if options.fill_cache {
if let Some(ref cache) = self.cache {
cache.insert(
self.file_id,
offset,
CachedBlock {
data: block_data.to_vec(),
block_type,
decompressed: decompressed.clone(),
},
);
}
}
Ok(decompressed)
}
pub fn iter(&self) -> SSTableIterator {
SSTableIterator::new(self)
}
pub fn range(
&self,
start: Option<&[u8]>,
end: Option<&[u8]>,
) -> RangeIterator {
RangeIterator::new(self, start, end)
}
pub fn metadata(&self) -> &TableMetadata {
&self.metadata
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn num_blocks(&self) -> usize {
self.index_entries.len()
}
pub fn may_contain(&self, key: &[u8]) -> bool {
self.filter
.as_ref()
.map(|f| f.may_contain(key))
.unwrap_or(true)
}
}
pub struct SSTableIterator<'a> {
table: &'a SSTable,
block_idx: usize,
block_data: Option<Vec<u8>>,
block_iter: Option<BlockIterator<'a>>,
options: ReadOptions,
valid: bool,
}
impl<'a> SSTableIterator<'a> {
fn new(table: &'a SSTable) -> Self {
let mut iter = Self {
table,
block_idx: 0,
block_data: None,
block_iter: None,
options: ReadOptions::default(),
valid: false,
};
iter.load_block();
iter
}
fn load_block(&mut self) {
if self.block_idx >= self.table.index_entries.len() {
self.valid = false;
return;
}
let handle = &self.table.index_entries[self.block_idx].handle;
match self.table.read_block(handle, &self.options) {
Ok(data) => {
self.block_data = Some(data);
self.valid = true;
}
Err(_) => {
self.valid = false;
}
}
}
pub fn valid(&self) -> bool {
self.valid
}
pub fn key(&self) -> Option<&[u8]> {
if !self.valid {
return None;
}
self.block_data.as_ref().map(|_| &b""[..])
}
pub fn value(&self) -> Option<&[u8]> {
if !self.valid {
return None;
}
self.block_data.as_ref().map(|_| &b""[..])
}
pub fn next(&mut self) {
self.block_idx += 1;
self.load_block();
}
pub fn seek(&mut self, target: &[u8]) {
self.block_idx = self.table.find_block_for_key(target);
self.load_block();
}
}
pub struct RangeIterator<'a> {
table: &'a SSTable,
start: Option<Vec<u8>>,
end: Option<Vec<u8>>,
current_block: usize,
exhausted: bool,
}
impl<'a> RangeIterator<'a> {
fn new(table: &'a SSTable, start: Option<&[u8]>, end: Option<&[u8]>) -> Self {
let start_block = start
.map(|k| table.find_block_for_key(k))
.unwrap_or(0);
Self {
table,
start: start.map(|s| s.to_vec()),
end: end.map(|e| e.to_vec()),
current_block: start_block,
exhausted: false,
}
}
pub fn exhausted(&self) -> bool {
self.exhausted
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sstable::builder::{SSTableBuilder, SSTableBuilderOptions};
use tempfile::tempdir;
#[test]
fn test_roundtrip() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.sst");
let options = SSTableBuilderOptions {
block_size: 256,
filter_policy: None,
..Default::default()
};
let mut builder = SSTableBuilder::new(&path, options).unwrap();
for i in 0..100 {
let key = format!("key{:05}", i);
let value = format!("value{:05}", i);
builder.add(key.as_bytes(), value.as_bytes()).unwrap();
}
builder.finish().unwrap();
let table = SSTable::open(&path).unwrap();
assert_eq!(table.num_blocks(), table.metadata.num_data_blocks);
}
#[test]
fn test_get() {
let dir = tempdir().unwrap();
let path = dir.path().join("test_get.sst");
let options = SSTableBuilderOptions {
block_size: 256,
filter_policy: None,
..Default::default()
};
let mut builder = SSTableBuilder::new(&path, options).unwrap();
for i in 0..100 {
let key = format!("key{:05}", i);
let value = format!("value{:05}", i);
builder.add(key.as_bytes(), value.as_bytes()).unwrap();
}
builder.finish().unwrap();
let table = SSTable::open(&path).unwrap();
let read_opts = ReadOptions::default();
let result = table.get(b"key00050", &read_opts).unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap(), b"value00050");
let result = table.get(b"nonexistent", &read_opts).unwrap();
assert!(result.is_none());
}
#[test]
fn test_block_cache() {
let cache = BlockCache::new(100);
let block = CachedBlock {
data: vec![1, 2, 3],
block_type: BlockType::Uncompressed,
decompressed: vec![1, 2, 3],
};
cache.insert(1, 0, block);
let cached = cache.get(1, 0);
assert!(cached.is_some());
assert_eq!(cached.unwrap().data, vec![1, 2, 3]);
let missing = cache.get(1, 100);
assert!(missing.is_none());
}
}