use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use crate::block_storage::{
BlockCompression, BlockHeader, BlockRef, is_compressible, is_json_content, is_soch_content,
};
use crate::{Result, SochDBError};
const DEFAULT_SHARD_COUNT: usize = 8;
const DEFAULT_SEGMENT_SIZE: u64 = 64 * 1024 * 1024;
pub struct BlockShard {
id: usize,
data: RwLock<Vec<u8>>,
next_offset: AtomicU64,
index: RwLock<HashMap<u64, BlockRef>>,
ref_counts: RwLock<HashMap<u64, AtomicU32>>,
bytes_written: AtomicU64,
blocks_written: AtomicU64,
}
impl BlockShard {
pub fn new(id: usize) -> Self {
Self {
id,
data: RwLock::new(Vec::new()),
next_offset: AtomicU64::new(0),
index: RwLock::new(HashMap::new()),
ref_counts: RwLock::new(HashMap::new()),
bytes_written: AtomicU64::new(0),
blocks_written: AtomicU64::new(0),
}
}
pub fn write_block(&self, data: &[u8], compression: BlockCompression) -> Result<BlockRef> {
let compressed = self.compress(data, compression)?;
let checksum = crc32fast::hash(&compressed);
let header_size = BlockHeader::SIZE;
let total_size = header_size + compressed.len();
let local_offset = self
.next_offset
.fetch_add(total_size as u64, Ordering::SeqCst);
let header = BlockHeader {
magic: BlockHeader::MAGIC,
compression: compression as u8,
original_size: data.len() as u32,
compressed_size: compressed.len() as u32,
checksum,
};
{
let mut store = self.data.write();
let required_size = (local_offset + total_size as u64) as usize;
if store.len() < required_size {
store.resize(required_size, 0);
}
let header_bytes = header.to_bytes();
store[local_offset as usize..local_offset as usize + header_size]
.copy_from_slice(&header_bytes);
store[local_offset as usize + header_size..local_offset as usize + total_size]
.copy_from_slice(&compressed);
}
let block_ref = BlockRef {
store_offset: local_offset,
compressed_len: compressed.len() as u32,
original_len: data.len() as u32,
compression,
checksum,
};
self.index.write().insert(local_offset, block_ref.clone());
self.ref_counts
.write()
.insert(local_offset, AtomicU32::new(1));
self.bytes_written
.fetch_add(total_size as u64, Ordering::Relaxed);
self.blocks_written.fetch_add(1, Ordering::Relaxed);
Ok(block_ref)
}
pub fn read_block(&self, block_ref: &BlockRef) -> Result<Vec<u8>> {
let offset = block_ref.store_offset as usize;
let header_size = BlockHeader::SIZE;
let total_size = header_size + block_ref.compressed_len as usize;
let compressed = {
let store = self.data.read();
if offset + total_size > store.len() {
return Err(SochDBError::Corruption(format!(
"Block at offset {} extends beyond shard {} data (size {})",
offset,
self.id,
store.len()
)));
}
let header = BlockHeader::from_bytes(&store[offset..offset + header_size])?;
if header.checksum != block_ref.checksum {
return Err(SochDBError::Corruption(format!(
"Checksum mismatch in shard {}: expected {}, got {}",
self.id, block_ref.checksum, header.checksum
)));
}
store[offset + header_size..offset + total_size].to_vec()
};
let computed_checksum = crc32fast::hash(&compressed);
if computed_checksum != block_ref.checksum {
return Err(SochDBError::Corruption(format!(
"Data checksum mismatch in shard {}: expected {}, got {}",
self.id, block_ref.checksum, computed_checksum
)));
}
self.decompress(
&compressed,
block_ref.compression,
block_ref.original_len as usize,
)
}
pub fn add_ref(&self, offset: u64) {
let refs = self.ref_counts.read();
if let Some(count) = refs.get(&offset) {
count.fetch_add(1, Ordering::Relaxed);
}
}
pub fn release_ref(&self, offset: u64) -> bool {
let refs = self.ref_counts.read();
if let Some(count) = refs.get(&offset) {
let prev = count.fetch_sub(1, Ordering::Relaxed);
return prev == 1; }
false
}
pub fn stats(&self) -> ShardStats {
let index = self.index.read();
let mut total_original = 0u64;
let mut total_compressed = 0u64;
for block_ref in index.values() {
total_original += block_ref.original_len as u64;
total_compressed += block_ref.compressed_len as u64;
}
ShardStats {
shard_id: self.id,
block_count: index.len(),
bytes_written: self.bytes_written.load(Ordering::Relaxed),
total_original_bytes: total_original,
total_compressed_bytes: total_compressed,
}
}
fn compress(&self, data: &[u8], compression: BlockCompression) -> Result<Vec<u8>> {
match compression {
BlockCompression::None => Ok(data.to_vec()),
BlockCompression::Lz4 => match lz4::block::compress(data, None, false) {
Ok(compressed) if compressed.len() < data.len() => Ok(compressed),
_ => Ok(data.to_vec()),
},
BlockCompression::Zstd => match zstd::encode_all(data, 3) {
Ok(compressed) if compressed.len() < data.len() => Ok(compressed),
_ => Ok(data.to_vec()),
},
}
}
fn decompress(
&self,
data: &[u8],
compression: BlockCompression,
original_size: usize,
) -> Result<Vec<u8>> {
match compression {
BlockCompression::None => Ok(data.to_vec()),
BlockCompression::Lz4 => {
if data.len() == original_size {
return Ok(data.to_vec());
}
lz4::block::decompress(data, Some(original_size as i32))
.map_err(|e| SochDBError::Corruption(format!("LZ4 decompress failed: {}", e)))
}
BlockCompression::Zstd => {
if data.len() == original_size {
return Ok(data.to_vec());
}
zstd::decode_all(data)
.map_err(|e| SochDBError::Corruption(format!("ZSTD decompress failed: {}", e)))
}
}
}
}
#[derive(Debug, Clone)]
pub struct ShardStats {
pub shard_id: usize,
pub block_count: usize,
pub bytes_written: u64,
pub total_original_bytes: u64,
pub total_compressed_bytes: u64,
}
pub struct ShardedBlockStore {
shards: Vec<BlockShard>,
shard_count: usize,
#[allow(dead_code)]
segment_size: u64,
total_writes: AtomicU64,
}
impl ShardedBlockStore {
pub fn new() -> Self {
Self::with_shards(DEFAULT_SHARD_COUNT)
}
pub fn with_shards(shard_count: usize) -> Self {
let shards = (0..shard_count).map(BlockShard::new).collect();
Self {
shards,
shard_count,
segment_size: DEFAULT_SEGMENT_SIZE,
total_writes: AtomicU64::new(0),
}
}
#[inline]
fn shard_for_file(&self, file_id: u64) -> usize {
let mut h = file_id;
h ^= h >> 33;
h = h.wrapping_mul(0xff51afd7ed558ccd);
h ^= h >> 33;
h = h.wrapping_mul(0xc4ceb9fe1a85ec53);
h ^= h >> 33;
(h as usize) % self.shard_count
}
#[inline]
#[allow(dead_code)]
fn shard_for_offset(&self, offset: u64) -> usize {
((offset / self.segment_size) as usize) % self.shard_count
}
pub fn write_block(&self, file_id: u64, data: &[u8]) -> Result<ShardedBlockRef> {
let shard_id = self.shard_for_file(file_id);
let compression = self.select_compression(data);
let block_ref = self.shards[shard_id].write_block(data, compression)?;
self.total_writes.fetch_add(1, Ordering::Relaxed);
Ok(ShardedBlockRef {
shard_id,
block_ref,
})
}
pub fn write_block_with_compression(
&self,
file_id: u64,
data: &[u8],
compression: BlockCompression,
) -> Result<ShardedBlockRef> {
let shard_id = self.shard_for_file(file_id);
let block_ref = self.shards[shard_id].write_block(data, compression)?;
self.total_writes.fetch_add(1, Ordering::Relaxed);
Ok(ShardedBlockRef {
shard_id,
block_ref,
})
}
pub fn read_block(&self, shard_ref: &ShardedBlockRef) -> Result<Vec<u8>> {
if shard_ref.shard_id >= self.shard_count {
return Err(SochDBError::Corruption(format!(
"Invalid shard ID: {} (max {})",
shard_ref.shard_id,
self.shard_count - 1
)));
}
self.shards[shard_ref.shard_id].read_block(&shard_ref.block_ref)
}
pub fn add_ref(&self, shard_ref: &ShardedBlockRef) {
if shard_ref.shard_id < self.shard_count {
self.shards[shard_ref.shard_id].add_ref(shard_ref.block_ref.store_offset);
}
}
pub fn release_ref(&self, shard_ref: &ShardedBlockRef) -> bool {
if shard_ref.shard_id < self.shard_count {
self.shards[shard_ref.shard_id].release_ref(shard_ref.block_ref.store_offset)
} else {
false
}
}
pub fn stats(&self) -> ShardedBlockStoreStats {
let shard_stats: Vec<ShardStats> = self.shards.iter().map(|s| s.stats()).collect();
let total_blocks: usize = shard_stats.iter().map(|s| s.block_count).sum();
let total_bytes: u64 = shard_stats.iter().map(|s| s.bytes_written).sum();
let total_original: u64 = shard_stats.iter().map(|s| s.total_original_bytes).sum();
let total_compressed: u64 = shard_stats.iter().map(|s| s.total_compressed_bytes).sum();
ShardedBlockStoreStats {
shard_count: self.shard_count,
total_blocks,
total_bytes_written: total_bytes,
total_original_bytes: total_original,
total_compressed_bytes: total_compressed,
compression_ratio: if total_compressed > 0 {
total_original as f64 / total_compressed as f64
} else {
1.0
},
shard_stats,
}
}
fn select_compression(&self, data: &[u8]) -> BlockCompression {
if data.len() < 128 {
return BlockCompression::None;
}
if is_soch_content(data) {
BlockCompression::Zstd
} else if is_json_content(data) || is_compressible(data) {
BlockCompression::Lz4
} else {
BlockCompression::None
}
}
}
impl Default for ShardedBlockStore {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ShardedBlockRef {
pub shard_id: usize,
pub block_ref: BlockRef,
}
impl ShardedBlockRef {
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(4 + 21); buf.extend(&(self.shard_id as u32).to_le_bytes());
buf.extend(&self.block_ref.to_bytes().unwrap_or([0u8; 21]));
buf
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
if data.len() < 25 {
return Err(SochDBError::Corruption("ShardedBlockRef too short".into()));
}
let shard_id = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
let block_ref = BlockRef::from_bytes(&data[4..])?;
Ok(Self {
shard_id,
block_ref,
})
}
}
#[derive(Debug, Clone)]
pub struct ShardedBlockStoreStats {
pub shard_count: usize,
pub total_blocks: usize,
pub total_bytes_written: u64,
pub total_original_bytes: u64,
pub total_compressed_bytes: u64,
pub compression_ratio: f64,
pub shard_stats: Vec<ShardStats>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sharded_store_basic() {
let store = ShardedBlockStore::new();
let data = b"Hello, sharded world!";
let shard_ref = store.write_block(1, data).unwrap();
let recovered = store.read_block(&shard_ref).unwrap();
assert_eq!(recovered, data);
}
#[test]
fn test_sharded_store_multiple_files() {
let store = ShardedBlockStore::new();
let mut refs = Vec::new();
for file_id in 0..100u64 {
let data = format!("Data for file {}", file_id).into_bytes();
let shard_ref = store.write_block(file_id, &data).unwrap();
refs.push((file_id, shard_ref, data));
}
for (file_id, shard_ref, expected) in refs {
let recovered = store.read_block(&shard_ref).unwrap();
assert_eq!(recovered, expected, "File {} mismatch", file_id);
}
}
#[test]
fn test_sharded_store_distribution() {
let store = ShardedBlockStore::with_shards(4);
for i in 0..1000u64 {
let data = vec![i as u8; 64];
store.write_block(i, &data).unwrap();
}
let stats = store.stats();
for shard_stat in &stats.shard_stats {
assert!(
shard_stat.block_count > 0,
"Shard {} has no blocks",
shard_stat.shard_id
);
}
assert_eq!(stats.total_blocks, 1000);
}
#[test]
fn test_sharded_ref_serialization() {
let shard_ref = ShardedBlockRef {
shard_id: 3,
block_ref: BlockRef {
store_offset: 12345,
compressed_len: 100,
original_len: 200,
compression: BlockCompression::Lz4,
checksum: 0xDEADBEEF,
},
};
let bytes = shard_ref.to_bytes();
let recovered = ShardedBlockRef::from_bytes(&bytes).unwrap();
assert_eq!(recovered.shard_id, 3);
assert_eq!(recovered.block_ref.store_offset, 12345);
assert_eq!(recovered.block_ref.compression, BlockCompression::Lz4);
}
#[test]
fn test_sharded_store_compression() {
let store = ShardedBlockStore::new();
let data = vec![0u8; 4096];
let shard_ref = store.write_block(1, &data).unwrap();
assert!(shard_ref.block_ref.compressed_len < shard_ref.block_ref.original_len);
let recovered = store.read_block(&shard_ref).unwrap();
assert_eq!(recovered, data);
}
#[test]
fn test_ref_counting() {
let store = ShardedBlockStore::new();
let data = b"Reference counted block";
let shard_ref = store.write_block(1, data).unwrap();
store.add_ref(&shard_ref);
store.add_ref(&shard_ref);
assert!(!store.release_ref(&shard_ref)); assert!(!store.release_ref(&shard_ref)); assert!(store.release_ref(&shard_ref)); }
}