use super::chunk::VersionedChunk;
use super::types::{ChunkId, VersionMismatch, VersionedResult};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
pub struct VersionedChunkStore {
chunks: Arc<RwLock<HashMap<ChunkId, Arc<VersionedChunk>>>>,
global_version: Arc<AtomicU64>,
hash_index: Arc<RwLock<HashMap<[u8; 8], ChunkId>>>,
}
impl VersionedChunkStore {
pub fn new() -> Self {
Self {
chunks: Arc::new(RwLock::new(HashMap::new())),
global_version: Arc::new(AtomicU64::new(0)),
hash_index: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn version(&self) -> u64 {
self.global_version.load(Ordering::Acquire)
}
pub fn get(&self, chunk_id: ChunkId) -> Option<(Arc<VersionedChunk>, u64)> {
let chunks = self.chunks.read().unwrap();
let version = self.version();
chunks
.get(&chunk_id)
.map(|chunk| (Arc::clone(chunk), version))
}
pub fn find_by_hash(&self, content_hash: &[u8; 8]) -> Option<(ChunkId, Arc<VersionedChunk>)> {
let hash_index = self.hash_index.read().unwrap();
let chunks = self.chunks.read().unwrap();
hash_index.get(content_hash).and_then(|&chunk_id| {
chunks
.get(&chunk_id)
.map(|chunk| (chunk_id, Arc::clone(chunk)))
})
}
pub fn insert(
&self,
chunk_id: ChunkId,
chunk: VersionedChunk,
expected_version: u64,
) -> VersionedResult<u64> {
let mut chunks = self.chunks.write().unwrap();
let mut hash_index = self.hash_index.write().unwrap();
let current_version = self.version();
if current_version != expected_version {
return Err(VersionMismatch {
expected: expected_version,
actual: current_version,
});
}
hash_index.insert(chunk.content_hash, chunk_id);
chunks.insert(chunk_id, Arc::new(chunk));
let new_version = self.global_version.fetch_add(1, Ordering::AcqRel) + 1;
Ok(new_version)
}
pub fn batch_insert(
&self,
updates: Vec<(ChunkId, VersionedChunk)>,
expected_version: u64,
) -> VersionedResult<u64> {
let mut chunks = self.chunks.write().unwrap();
let mut hash_index = self.hash_index.write().unwrap();
let current_version = self.version();
if current_version != expected_version {
return Err(VersionMismatch {
expected: expected_version,
actual: current_version,
});
}
for (chunk_id, chunk) in updates {
hash_index.insert(chunk.content_hash, chunk_id);
chunks.insert(chunk_id, Arc::new(chunk));
}
let new_version = self.global_version.fetch_add(1, Ordering::AcqRel) + 1;
Ok(new_version)
}
pub fn batch_insert_new(
&self,
updates: Vec<(ChunkId, VersionedChunk)>,
) -> VersionedResult<u64> {
let mut chunks = self.chunks.write().unwrap();
let mut hash_index = self.hash_index.write().unwrap();
for (chunk_id, chunk) in updates {
hash_index.insert(chunk.content_hash, chunk_id);
chunks.insert(chunk_id, Arc::new(chunk));
}
let new_version = self.global_version.fetch_add(1, Ordering::AcqRel) + 1;
Ok(new_version)
}
pub fn remove(
&self,
chunk_id: ChunkId,
expected_version: u64,
) -> VersionedResult<Option<Arc<VersionedChunk>>> {
let mut chunks = self.chunks.write().unwrap();
let mut hash_index = self.hash_index.write().unwrap();
let current_version = self.version();
if current_version != expected_version {
return Err(VersionMismatch {
expected: expected_version,
actual: current_version,
});
}
let removed = chunks.remove(&chunk_id);
if let Some(ref chunk) = removed {
hash_index.remove(&chunk.content_hash);
}
self.global_version.fetch_add(1, Ordering::AcqRel);
Ok(removed)
}
pub fn len(&self) -> usize {
self.chunks.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.chunks.read().unwrap().is_empty()
}
pub fn chunk_ids(&self) -> Vec<ChunkId> {
self.chunks.read().unwrap().keys().copied().collect()
}
pub fn iter(&self) -> Vec<(ChunkId, Arc<VersionedChunk>)> {
self.chunks
.read()
.unwrap()
.iter()
.map(|(&id, chunk)| (id, Arc::clone(chunk)))
.collect()
}
pub fn gc(&self, expected_version: u64) -> VersionedResult<usize> {
let mut chunks = self.chunks.write().unwrap();
let mut hash_index = self.hash_index.write().unwrap();
let current_version = self.version();
if current_version != expected_version {
return Err(VersionMismatch {
expected: expected_version,
actual: current_version,
});
}
let to_remove: Vec<ChunkId> = chunks
.iter()
.filter(|(_, chunk)| chunk.is_unreferenced())
.map(|(id, _)| *id)
.collect();
let count = to_remove.len();
for chunk_id in to_remove {
if let Some(chunk) = chunks.remove(&chunk_id) {
hash_index.remove(&chunk.content_hash);
}
}
if count > 0 {
self.global_version.fetch_add(1, Ordering::AcqRel);
}
Ok(count)
}
pub fn stats(&self) -> CodebookStats {
let chunks = self.chunks.read().unwrap();
let total_chunks = chunks.len();
let mut total_refs = 0u64;
let mut unreferenced = 0;
let mut total_size = 0usize;
for chunk in chunks.values() {
let refs = chunk.ref_count();
total_refs += refs as u64;
if refs == 0 {
unreferenced += 1;
}
total_size += chunk.original_size;
}
CodebookStats {
total_chunks,
unreferenced_chunks: unreferenced,
total_references: total_refs,
avg_references: if total_chunks > 0 {
total_refs as f64 / total_chunks as f64
} else {
0.0
},
total_size_bytes: total_size,
version: self.version(),
}
}
}
impl Default for VersionedChunkStore {
fn default() -> Self {
Self::new()
}
}
impl Clone for VersionedChunkStore {
fn clone(&self) -> Self {
Self {
chunks: Arc::clone(&self.chunks),
global_version: Arc::clone(&self.global_version),
hash_index: Arc::clone(&self.hash_index),
}
}
}
#[derive(Debug, Clone)]
pub struct CodebookStats {
pub total_chunks: usize,
pub unreferenced_chunks: usize,
pub total_references: u64,
pub avg_references: f64,
pub total_size_bytes: usize,
pub version: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SparseVec;
fn make_test_chunk(id: usize) -> VersionedChunk {
let vec = SparseVec::new();
let hash = [(id & 0xFF) as u8; 8];
VersionedChunk::new(vec, 4096, hash)
}
#[test]
fn test_codebook_creation() {
let codebook = VersionedChunkStore::new();
assert_eq!(codebook.version(), 0);
assert!(codebook.is_empty());
}
#[test]
fn test_insert_and_get() {
let codebook = VersionedChunkStore::new();
let chunk = make_test_chunk(1);
let version = codebook.insert(1, chunk, 0).unwrap();
assert_eq!(version, 1);
let (retrieved, ver) = codebook.get(1).unwrap();
assert_eq!(retrieved.version, 0);
assert_eq!(ver, 1);
}
#[test]
fn test_version_mismatch() {
let codebook = VersionedChunkStore::new();
let chunk1 = make_test_chunk(1);
let chunk2 = make_test_chunk(2);
codebook.insert(1, chunk1, 0).unwrap();
let result = codebook.insert(2, chunk2, 0);
assert!(result.is_err());
match result {
Err(VersionMismatch { expected, actual }) => {
assert_eq!(expected, 0);
assert_eq!(actual, 1);
}
_ => panic!("Expected VersionMismatch"),
}
}
#[test]
fn test_batch_insert() {
let codebook = VersionedChunkStore::new();
let updates = vec![
(1, make_test_chunk(1)),
(2, make_test_chunk(2)),
(3, make_test_chunk(3)),
];
let version = codebook.batch_insert(updates, 0).unwrap();
assert_eq!(version, 1);
assert_eq!(codebook.len(), 3);
}
#[test]
fn test_deduplication() {
let codebook = VersionedChunkStore::new();
let chunk1 = make_test_chunk(1);
let hash = chunk1.content_hash;
codebook.insert(1, chunk1, 0).unwrap();
let found = codebook.find_by_hash(&hash);
assert!(found.is_some());
let (id, _) = found.unwrap();
assert_eq!(id, 1);
}
#[test]
fn test_garbage_collection() {
let codebook = VersionedChunkStore::new();
let chunk = make_test_chunk(1);
codebook.insert(1, chunk.clone(), 0).unwrap();
chunk.dec_ref();
assert_eq!(codebook.len(), 1);
let removed = codebook.gc(1).unwrap();
assert_eq!(removed, 1);
assert_eq!(codebook.len(), 0);
}
#[test]
fn test_stats() {
let codebook = VersionedChunkStore::new();
for i in 0..10 {
codebook.insert(i, make_test_chunk(i), i as u64).unwrap();
}
let stats = codebook.stats();
assert_eq!(stats.total_chunks, 10);
assert_eq!(stats.total_size_bytes, 10 * 4096);
assert_eq!(stats.version, 10);
}
}