use crate::buffer::manager::FrameRef;
use bytes::{Bytes, BytesMut};
use lz4_flex::{compress_prepend_size, decompress_size_prepended};
use std::io::{self};
use std::sync::{Arc, OnceLock};
use thiserror::Error;
#[cfg(not(feature = "simd"))]
use varint_rs::VarintReader;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompressionType {
None,
#[default]
Lz4,
Zstd,
}
impl CompressionType {
pub(crate) const fn to_byte(self) -> u8 {
match self {
Self::None => 0,
Self::Lz4 => 1,
Self::Zstd => 2,
}
}
pub(crate) const fn from_byte(b: u8) -> Option<Self> {
match b {
0 => Some(Self::None),
1 => Some(Self::Lz4),
2 => Some(Self::Zstd),
_ => None,
}
}
}
#[inline]
fn write_varint(buf: &mut BytesMut, value: u64) {
let mut temp = [0u8; 10];
let mut n = value;
let mut i = 0;
while n >= 0x80 {
temp[i] = (n as u8) | 0x80;
n >>= 7;
i += 1;
}
temp[i] = n as u8;
buf.extend_from_slice(&temp[..=i]);
}
#[inline]
fn read_varint(data: &[u8], offset: &mut usize) -> Option<u64> {
#[cfg(feature = "simd")]
{
if let Some((val, len)) = crate::simd::decode_varint(&data[*offset..]) {
*offset += len;
return Some(val);
}
None
}
#[cfg(not(feature = "simd"))]
{
let mut slice = &data[*offset..];
match slice.read_u64_varint() {
Ok(val) => {
let read = data.len() - *offset - slice.len();
*offset += read;
Some(val)
}
Err(_) => None,
}
}
}
use crate::simd;
#[derive(Debug, Error)]
pub enum BlockError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Block corrupted: checksum mismatch")]
Corruption,
#[error("Invalid block format")]
InvalidFormat,
#[error("Block full")]
BlockFull,
}
pub type Result<T> = std::result::Result<T, BlockError>;
pub const DEFAULT_BLOCK_SIZE: usize = 4096;
const RESTART_INTERVAL: usize = 16;
pub struct BlockBuilder {
buffer: BytesMut,
restart_points: Vec<u32>,
counter: usize,
last_key: Bytes,
max_size: usize,
compression_type: CompressionType,
}
impl BlockBuilder {
#[must_use]
pub fn new() -> Self {
Self::with_capacity(DEFAULT_BLOCK_SIZE)
}
#[must_use]
pub fn with_capacity(max_size: usize) -> Self {
Self {
buffer: BytesMut::with_capacity(max_size),
restart_points: vec![0], counter: 0,
last_key: Bytes::new(),
max_size,
compression_type: CompressionType::Lz4,
}
}
pub const fn set_compression_type(&mut self, compression_type: CompressionType) {
self.compression_type = compression_type;
}
#[deprecated(since = "0.1.0", note = "use set_compression_type instead")]
pub const fn set_compression(&mut self, enabled: bool) {
self.compression_type = if enabled {
CompressionType::Lz4
} else {
CompressionType::None
};
}
#[inline]
pub fn add(&mut self, key: &[u8], value: &[u8]) -> bool {
let prefix_len = if self.counter > 0 && !self.last_key.is_empty() {
simd::shared_prefix_len(key, &self.last_key)
} else {
0
};
let suffix_len = key.len() - prefix_len;
let entry_size = 10 + 10 + suffix_len + 10 + value.len();
let footer_size = (self.restart_points.len() + 1) * 10 + 14;
if self.buffer.len() + entry_size + footer_size > self.max_size {
return false;
}
if self.counter >= RESTART_INTERVAL {
self.restart_points.push(self.buffer.len() as u32);
self.counter = 0;
return self.add(key, value);
}
write_varint(&mut self.buffer, prefix_len as u64);
write_varint(&mut self.buffer, suffix_len as u64);
self.buffer.extend_from_slice(&key[prefix_len..]);
write_varint(&mut self.buffer, value.len() as u64);
self.buffer.extend_from_slice(value);
self.last_key = Bytes::copy_from_slice(key);
self.counter += 1;
true
}
pub fn current_size(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn last_key(&self) -> &[u8] {
&self.last_key
}
#[inline]
pub fn finish(mut self) -> Bytes {
let restart_offset = self.buffer.len() as u32;
for offset in &self.restart_points {
write_varint(&mut self.buffer, *offset as u64);
}
write_varint(&mut self.buffer, self.restart_points.len() as u64);
let uncompressed_size = self.buffer.len() as u32;
match self.compression_type {
CompressionType::None => {
self.buffer
.extend_from_slice(&uncompressed_size.to_le_bytes()); self.buffer
.extend_from_slice(&[CompressionType::None.to_byte()]); self.buffer.extend_from_slice(&restart_offset.to_le_bytes());
let checksum = crc32c::crc32c(&self.buffer);
self.buffer.extend_from_slice(&checksum.to_le_bytes());
self.buffer.freeze()
}
CompressionType::Lz4 => {
let uncompressed_data = self.buffer.to_vec();
let compressed_data = compress_prepend_size(&uncompressed_data);
let mut final_buffer = BytesMut::with_capacity(compressed_data.len() + 13);
final_buffer.extend_from_slice(&compressed_data);
final_buffer.extend_from_slice(&uncompressed_size.to_le_bytes()); final_buffer.extend_from_slice(&[CompressionType::Lz4.to_byte()]); final_buffer.extend_from_slice(&restart_offset.to_le_bytes());
let checksum = crc32c::crc32c(&final_buffer);
final_buffer.extend_from_slice(&checksum.to_le_bytes());
final_buffer.freeze()
}
CompressionType::Zstd => {
let uncompressed_data = self.buffer.to_vec();
let compressed_data =
zstd::encode_all(uncompressed_data.as_slice(), 3).unwrap_or(uncompressed_data);
let mut final_buffer = BytesMut::with_capacity(compressed_data.len() + 13);
final_buffer.extend_from_slice(&compressed_data);
final_buffer.extend_from_slice(&uncompressed_size.to_le_bytes()); final_buffer.extend_from_slice(&[CompressionType::Zstd.to_byte()]); final_buffer.extend_from_slice(&restart_offset.to_le_bytes());
let checksum = crc32c::crc32c(&final_buffer);
final_buffer.extend_from_slice(&checksum.to_le_bytes());
final_buffer.freeze()
}
}
}
pub fn reset(&mut self) {
self.buffer.clear();
self.restart_points.clear();
self.restart_points.push(0);
self.counter = 0;
self.last_key = Bytes::new();
}
}
impl Default for BlockBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub enum BlockData {
Owned(Bytes),
Borrowed(FrameRef),
}
impl BlockData {
pub fn as_slice(&self) -> &[u8] {
match self {
Self::Owned(bytes) => bytes.as_ref(),
Self::Borrowed(frame) => unsafe { frame.data_unchecked() },
}
}
pub fn slice(&self, range: std::ops::Range<usize>) -> Bytes {
match self {
Self::Owned(bytes) => bytes.slice(range),
Self::Borrowed(frame) => unsafe {
let data = frame.data_unchecked();
Bytes::copy_from_slice(&data[range])
},
}
}
}
#[derive(Clone)]
pub struct Block {
data: BlockData,
restart_offset: usize,
num_restarts: usize,
decompressed_cache: Arc<OnceLock<Vec<(Bytes, Bytes)>>>,
}
impl Block {
pub fn new(data: BlockData) -> Result<Self> {
let (restart_offset, num_restarts) = {
let raw_data = data.as_slice();
if raw_data.len() < 13 {
return Err(BlockError::InvalidFormat);
}
let stored_checksum = u32::from_le_bytes([
raw_data[raw_data.len() - 4],
raw_data[raw_data.len() - 3],
raw_data[raw_data.len() - 2],
raw_data[raw_data.len() - 1],
]);
let computed_checksum = crc32c::crc32c(&raw_data[..raw_data.len() - 4]);
if stored_checksum != computed_checksum {
return Err(BlockError::Corruption);
}
let restart_offset = u32::from_le_bytes([
raw_data[raw_data.len() - 8],
raw_data[raw_data.len() - 7],
raw_data[raw_data.len() - 6],
raw_data[raw_data.len() - 5],
]) as usize;
let compression_byte = raw_data[raw_data.len() - 9];
let compression_type =
CompressionType::from_byte(compression_byte).ok_or(BlockError::InvalidFormat)?;
if compression_type != CompressionType::None {
let compressed_slice = &raw_data[..raw_data.len() - 13];
let uncompressed_data = match compression_type {
CompressionType::Lz4 => decompress_size_prepended(compressed_slice)
.map_err(|_| BlockError::InvalidFormat)?,
CompressionType::Zstd => {
zstd::decode_all(compressed_slice).map_err(|_| BlockError::InvalidFormat)?
}
CompressionType::None => unreachable!(),
};
let data = Bytes::from(uncompressed_data);
if restart_offset >= data.len() {
return Err(BlockError::InvalidFormat);
}
let mut offset = restart_offset;
let mut num_restarts = 0;
while offset < data.len() {
if let Some(_offset_val) = read_varint(&data, &mut offset) {
num_restarts += 1;
let pos_after = offset;
if let Some(count) = read_varint(&data, &mut offset) {
if count as usize == num_restarts {
num_restarts = count as usize;
break;
}
offset = pos_after;
} else {
break;
}
} else {
break;
}
}
return Ok(Self {
data: BlockData::Owned(data),
restart_offset,
num_restarts,
decompressed_cache: Arc::new(OnceLock::new()),
});
}
if restart_offset >= raw_data.len() - 13 {
return Err(BlockError::InvalidFormat);
}
let content_limit = raw_data.len() - 13;
let mut offset = restart_offset;
let mut num_restarts = 0;
while offset < content_limit {
if let Some(_offset_val) = read_varint(raw_data, &mut offset) {
num_restarts += 1;
let pos_after = offset;
if let Some(count) = read_varint(raw_data, &mut offset) {
if count as usize == num_restarts {
num_restarts = count as usize;
break;
}
offset = pos_after;
} else {
break;
}
} else {
break;
}
}
(restart_offset, num_restarts)
};
Ok(Self {
data,
restart_offset,
num_restarts,
decompressed_cache: Arc::new(OnceLock::new()),
})
}
pub fn from_bytes(data: Bytes) -> Result<Self> {
Self::new(BlockData::Owned(data))
}
pub fn iter(&self) -> BlockIterator<'_> {
let entries = self
.decompressed_cache
.get_or_init(|| self.decompress_all_entries());
BlockIterator::new_cached(entries)
}
#[inline]
pub fn find_exact(&self, key: &[u8]) -> Option<(Bytes, Bytes)> {
let entries = self
.decompressed_cache
.get_or_init(|| self.decompress_all_entries());
match entries.binary_search_by(|(k, _)| simd::compare_keys(k.as_ref(), key)) {
Ok(idx) => Some(entries[idx].clone()),
Err(_) => None,
}
}
#[inline]
pub fn find_lower_bound(&self, key: &[u8]) -> Option<(Bytes, Bytes)> {
let entries = self
.decompressed_cache
.get_or_init(|| self.decompress_all_entries());
let idx = entries.partition_point(|(k, _)| simd::compare_keys(k.as_ref(), key).is_lt());
entries.get(idx).cloned()
}
#[inline]
pub fn find_lower_bound_by_user_key(&self, user_key: &[u8]) -> Option<(Bytes, Bytes)> {
let entries = self
.decompressed_cache
.get_or_init(|| self.decompress_all_entries());
let idx = entries.partition_point(|(k, _)| {
simd::compare_internal_to_user_key(k.as_ref(), user_key).is_lt()
});
entries.get(idx).cloned()
}
#[inline]
pub fn find_mvcc(&self, encoded_search_key: &[u8], user_key: &[u8]) -> Option<(Bytes, Bytes)> {
let entries = self
.decompressed_cache
.get_or_init(|| self.decompress_all_entries());
let start_idx = entries
.partition_point(|(k, _)| simd::compare_keys(k.as_ref(), encoded_search_key).is_lt());
for (entry_key, entry_value) in entries.iter().skip(start_idx) {
if entry_key.len() < 8 {
continue;
}
let entry_user_key = &entry_key[..entry_key.len() - 8];
if entry_user_key == user_key {
return Some((entry_key.clone(), entry_value.clone()));
}
if !entry_user_key.starts_with(user_key) && entry_user_key > user_key {
return None;
}
}
None
}
pub const fn num_entries_approx(&self) -> usize {
self.num_restarts * RESTART_INTERVAL
}
fn decompress_all_entries(&self) -> Vec<(Bytes, Bytes)> {
let mut entries = Vec::with_capacity(self.num_entries_approx());
let raw_data = self.data.as_slice();
let data = &raw_data[..self.restart_offset];
let mut offset = 0;
let mut key_buffer = BytesMut::with_capacity(256);
while offset < data.len() {
let prefix_len = match read_varint(data, &mut offset) {
Some(len) => len as usize,
None => break,
};
let suffix_len = match read_varint(data, &mut offset) {
Some(len) => len as usize,
None => break,
};
if offset + suffix_len > data.len() {
break;
}
let suffix_start = offset;
let suffix_end = offset + suffix_len;
offset = suffix_end;
let key = if prefix_len == 0 {
key_buffer.clear();
key_buffer.extend_from_slice(&data[suffix_start..suffix_end]);
key_buffer.clone().freeze()
} else {
if prefix_len > key_buffer.len() {
break; }
key_buffer.truncate(prefix_len);
key_buffer.extend_from_slice(&data[suffix_start..suffix_end]);
key_buffer.clone().freeze()
};
let value_len = match read_varint(data, &mut offset) {
Some(len) => len as usize,
None => break,
};
if offset + value_len > data.len() {
break;
}
let value = self.data.slice(offset..offset + value_len);
offset += value_len;
entries.push((key, value));
}
entries
}
}
pub struct BlockIterator<'a> {
iter: std::slice::Iter<'a, (Bytes, Bytes)>,
}
impl<'a> BlockIterator<'a> {
fn new_cached(entries: &'a [(Bytes, Bytes)]) -> Self {
Self {
iter: entries.iter(),
}
}
}
impl Iterator for BlockIterator<'_> {
type Item = Result<(Bytes, Bytes)>;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(k, v)| Ok((k.clone(), v.clone())))
}
}
impl DoubleEndedIterator for BlockIterator<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
self.iter
.next_back()
.map(|(k, v)| Ok((k.clone(), v.clone())))
}
}
impl<'a> IntoIterator for &'a Block {
type Item = Result<(Bytes, Bytes)>;
type IntoIter = BlockIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_block_builder_single_entry() {
let mut builder = BlockBuilder::new();
assert!(builder.add(b"key1", b"value1"));
let block_data = builder.finish();
let block = Block::from_bytes(block_data).unwrap();
let entries: Vec<_> = block.iter().collect();
assert_eq!(entries.len(), 1);
let (key, value) = entries[0].as_ref().unwrap();
assert_eq!(key, &Bytes::from("key1"));
assert_eq!(value, &Bytes::from("value1"));
}
#[test]
fn test_block_builder_multiple_entries() {
let mut builder = BlockBuilder::new();
assert!(builder.add(b"key1", b"value1"));
assert!(builder.add(b"key2", b"value2"));
assert!(builder.add(b"key3", b"value3"));
let block_data = builder.finish();
let block = Block::from_bytes(block_data).unwrap();
let entries: Vec<_> = block.iter().map(|r| r.unwrap()).collect();
assert_eq!(entries.len(), 3);
assert_eq!(entries[0].0, Bytes::from("key1"));
assert_eq!(entries[1].0, Bytes::from("key2"));
assert_eq!(entries[2].0, Bytes::from("key3"));
}
#[test]
fn test_block_builder_full() {
let mut builder = BlockBuilder::with_capacity(256);
let mut count = 0;
for i in 0..100 {
let key = format!("key{:04}", i);
let value = format!("value{:04}", i);
if !builder.add(key.as_bytes(), value.as_bytes()) {
break;
}
count += 1;
}
assert!(
count > 0 && count < 100,
"Block should fill before 100 entries"
);
let block_data = builder.finish();
let block = Block::from_bytes(block_data).unwrap();
let entries: Vec<_> = block.iter().collect();
assert_eq!(entries.len(), count);
}
#[test]
fn test_block_checksum_validation() {
let mut builder = BlockBuilder::new();
builder.add(b"key1", b"value1");
let mut block_data = builder.finish().to_vec();
block_data[0] ^= 0xFF;
let result = Block::from_bytes(Bytes::from(block_data));
assert!(matches!(result, Err(BlockError::Corruption)));
}
#[test]
fn test_block_restart_points() {
let mut builder = BlockBuilder::new();
for i in 0..40 {
let key = format!("key{:04}", i);
let value = format!("value{:04}", i);
assert!(builder.add(key.as_bytes(), value.as_bytes()));
}
assert!(builder.restart_points.len() > 1);
let block_data = builder.finish();
let block = Block::from_bytes(block_data).unwrap();
let entries: Vec<_> = block.iter().map(|r| r.unwrap()).collect();
assert_eq!(entries.len(), 40);
}
#[test]
fn test_block_large_values() {
let mut builder = BlockBuilder::new();
let large_value = vec![b'x'; 2000];
assert!(builder.add(b"key1", &large_value));
let block_data = builder.finish();
let block = Block::from_bytes(block_data).unwrap();
let entries: Vec<_> = block.iter().collect();
assert_eq!(entries.len(), 1);
let (_, value) = entries[0].as_ref().unwrap();
assert_eq!(value.len(), 2000);
}
#[test]
fn test_block_zstd_compression() {
let mut builder = BlockBuilder::new();
builder.set_compression_type(CompressionType::Zstd);
for i in 0..20 {
let key = format!("key{:04}", i);
let value = format!("value{:04}", i);
assert!(builder.add(key.as_bytes(), value.as_bytes()));
}
let block_data = builder.finish();
let block = Block::from_bytes(block_data).unwrap();
let entries: Vec<_> = block.iter().map(|r| r.unwrap()).collect();
assert_eq!(entries.len(), 20);
assert_eq!(entries[0].0, Bytes::from("key0000"));
assert_eq!(entries[19].0, Bytes::from("key0019"));
}
#[test]
fn test_block_no_compression() {
let mut builder = BlockBuilder::new();
builder.set_compression_type(CompressionType::None);
assert!(builder.add(b"key1", b"value1"));
assert!(builder.add(b"key2", b"value2"));
let block_data = builder.finish();
let block = Block::from_bytes(block_data).unwrap();
let entries: Vec<_> = block.iter().map(|r| r.unwrap()).collect();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].0, Bytes::from("key1"));
assert_eq!(entries[1].0, Bytes::from("key2"));
}
#[test]
fn test_compression_ratio_comparison() {
let test_data: Vec<(Vec<u8>, Vec<u8>)> = (0..50)
.map(|i| {
let key = format!("user_profile_{:08}", i).into_bytes();
let value = format!("{{\"name\":\"user{}\",\"email\":\"user{}@example.com\",\"bio\":\"This is a sample biography that contains repetitive text patterns for testing compression.\"}}", i, i).into_bytes();
(key, value)
})
.collect();
let mut none_builder = BlockBuilder::with_capacity(16384);
none_builder.set_compression_type(CompressionType::None);
for (k, v) in &test_data {
none_builder.add(k, v);
}
let none_size = none_builder.finish().len();
let mut lz4_builder = BlockBuilder::with_capacity(16384);
lz4_builder.set_compression_type(CompressionType::Lz4);
for (k, v) in &test_data {
lz4_builder.add(k, v);
}
let lz4_size = lz4_builder.finish().len();
let mut zstd_builder = BlockBuilder::with_capacity(16384);
zstd_builder.set_compression_type(CompressionType::Zstd);
for (k, v) in &test_data {
zstd_builder.add(k, v);
}
let zstd_size = zstd_builder.finish().len();
assert!(lz4_size < none_size, "LZ4 should compress data");
assert!(zstd_size < none_size, "ZSTD should compress data");
assert!(
zstd_size <= lz4_size,
"ZSTD ({}) should compress at least as well as LZ4 ({})",
zstd_size,
lz4_size
);
}
#[test]
fn test_zstd_large_values() {
let mut builder = BlockBuilder::with_capacity(16384);
builder.set_compression_type(CompressionType::Zstd);
let embedding: Vec<u8> = (0..768 * 4).map(|i| (i % 256) as u8).collect();
assert!(builder.add(b"embedding_key", &embedding));
let block_data = builder.finish();
let block = Block::from_bytes(block_data).unwrap();
let entries: Vec<_> = block.iter().collect();
assert_eq!(entries.len(), 1);
let (key, value) = entries[0].as_ref().unwrap();
assert_eq!(key, &Bytes::from("embedding_key"));
assert_eq!(value.len(), 768 * 4);
}
}