use std::collections::BTreeMap;
use bytes::{Bytes, BytesMut};
#[derive(Default)]
pub struct WriteBackCache {
pieces: BTreeMap<usize, PieceCacheEntry>,
total_bytes: usize,
}
struct PieceCacheEntry {
ranges: Vec<(u64, BytesMut)>,
}
pub struct FlushBlock {
pub offset: u64,
pub data: Bytes,
}
impl WriteBackCache {
pub fn new() -> Self {
Self::default()
}
pub fn total_bytes(&self) -> usize {
self.total_bytes
}
pub fn insert(&mut self, piece_id: usize, offset: u64, data: Bytes) {
let len = data.len();
if len == 0 {
return;
}
let entry = self
.pieces
.entry(piece_id)
.or_insert_with(|| PieceCacheEntry { ranges: Vec::new() });
entry.merge(offset, data);
self.total_bytes += len;
}
pub fn drain_piece(&mut self, piece_id: usize) -> Vec<FlushBlock> {
match self.pieces.remove(&piece_id) {
Some(entry) => {
let blocks: Vec<FlushBlock> = entry
.ranges
.into_iter()
.map(|(offset, data)| {
self.total_bytes -= data.len();
FlushBlock {
offset,
data: data.freeze(),
}
})
.collect();
blocks
}
None => Vec::new(),
}
}
pub fn drain_all(&mut self) -> Vec<FlushBlock> {
let mut blocks = Vec::new();
let pieces = std::mem::take(&mut self.pieces);
for (_piece_id, entry) in pieces {
for (offset, data) in entry.ranges {
self.total_bytes -= data.len();
blocks.push(FlushBlock {
offset,
data: data.freeze(),
});
}
}
blocks.sort_by_key(|b| b.offset);
blocks
}
}
impl PieceCacheEntry {
fn merge(&mut self, offset: u64, data: Bytes) {
let end = offset + data.len() as u64;
let pos = self.ranges.partition_point(|(o, _)| *o < offset);
if pos > 0 {
let prev = &self.ranges[pos - 1];
let prev_end = prev.0 + prev.1.len() as u64;
if prev_end >= offset {
let idx = pos - 1;
if end > prev_end {
let skip = (prev_end - offset) as usize;
if skip < data.len() {
self.ranges[idx].1.extend_from_slice(&data[skip..]);
}
}
self.coalesce_from(idx);
return;
}
}
if pos < self.ranges.len() {
let next = &self.ranges[pos];
if end >= next.0 {
let mut buf = BytesMut::with_capacity(data.len());
buf.extend_from_slice(&data);
self.ranges.insert(pos, (offset, buf));
self.coalesce_from(pos);
return;
}
}
let mut buf = BytesMut::with_capacity(data.len());
buf.extend_from_slice(&data);
self.ranges.insert(pos, (offset, buf));
}
fn coalesce_from(&mut self, idx: usize) {
while idx + 1 < self.ranges.len() {
let cur_end = self.ranges[idx].0 + self.ranges[idx].1.len() as u64;
let next_start = self.ranges[idx + 1].0;
if cur_end >= next_start {
let next = self.ranges.remove(idx + 1);
let next_end = next.0 + next.1.len() as u64;
if next_end > cur_end {
let skip = (cur_end - next.0) as usize;
if skip < next.1.len() {
self.ranges[idx].1.extend_from_slice(&next.1[skip..]);
}
}
} else {
break;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_insert_and_drain() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 100]));
cache.insert(0, 100, Bytes::from(vec![2u8; 100]));
assert_eq!(cache.total_bytes(), 200);
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[0].data.len(), 200);
assert_eq!(cache.total_bytes(), 0);
}
#[test]
fn test_non_contiguous_ranges() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 100]));
cache.insert(0, 200, Bytes::from(vec![2u8; 100]));
assert_eq!(cache.total_bytes(), 200);
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[0].data.len(), 100);
assert_eq!(blocks[1].offset, 200);
assert_eq!(blocks[1].data.len(), 100);
}
#[test]
fn test_overlapping_merge() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 100]));
cache.insert(0, 50, Bytes::from(vec![2u8; 100]));
assert_eq!(cache.total_bytes(), 200);
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[0].data.len(), 150);
}
#[test]
fn test_drain_all() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 100]));
cache.insert(1, 1000, Bytes::from(vec![2u8; 100]));
let blocks = cache.drain_all();
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[1].offset, 1000);
assert_eq!(cache.total_bytes(), 0);
}
#[test]
fn test_multiple_pieces() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 100]));
cache.insert(1, 1000, Bytes::from(vec![2u8; 200]));
cache.insert(0, 100, Bytes::from(vec![3u8; 50]));
assert_eq!(cache.total_bytes(), 350);
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].data.len(), 150);
let blocks = cache.drain_piece(1);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].data.len(), 200);
}
#[test]
fn test_drain_piece_empty() {
let mut cache = WriteBackCache::new();
let blocks = cache.drain_piece(99);
assert!(blocks.is_empty());
}
#[test]
fn test_insert_empty_data() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::new());
assert_eq!(cache.total_bytes(), 0);
let blocks = cache.drain_piece(0);
assert!(blocks.is_empty());
}
#[test]
fn test_overlap_with_next_range() {
let mut cache = WriteBackCache::new();
cache.insert(0, 100, Bytes::from(vec![2u8; 50]));
cache.insert(0, 0, Bytes::from(vec![1u8; 120]));
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[0].data.len(), 150); }
#[test]
fn test_coalesce_multiple_ranges() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 50]));
cache.insert(0, 100, Bytes::from(vec![2u8; 50]));
cache.insert(0, 200, Bytes::from(vec![3u8; 50]));
cache.insert(0, 40, Bytes::from(vec![4u8; 170]));
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[0].data.len(), 250);
}
#[test]
fn test_insert_contiguous_extends_previous() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 100]));
cache.insert(0, 100, Bytes::from(vec![2u8; 100]));
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].data.len(), 200);
}
#[test]
fn test_fully_overlapping_insert() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 200]));
cache.insert(0, 50, Bytes::from(vec![2u8; 50]));
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[0].data.len(), 200);
}
#[test]
fn test_drain_all_sorts_by_offset() {
let mut cache = WriteBackCache::new();
cache.insert(5, 5000, Bytes::from(vec![1u8; 10]));
cache.insert(0, 0, Bytes::from(vec![2u8; 10]));
cache.insert(3, 3000, Bytes::from(vec![3u8; 10]));
let blocks = cache.drain_all();
assert_eq!(blocks.len(), 3);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[1].offset, 3000);
assert_eq!(blocks[2].offset, 5000);
assert_eq!(cache.total_bytes(), 0);
}
#[test]
fn test_coalesce_stops_at_gap() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 50]));
cache.insert(0, 50, Bytes::from(vec![2u8; 50]));
cache.insert(0, 200, Bytes::from(vec![3u8; 50]));
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[0].data.len(), 100);
assert_eq!(blocks[1].offset, 200);
assert_eq!(blocks[1].data.len(), 50);
}
#[test]
fn test_coalesce_breaks_when_next_range_is_still_gapped() {
let mut cache = WriteBackCache::new();
cache.insert(0, 0, Bytes::from(vec![1u8; 50]));
cache.insert(0, 100, Bytes::from(vec![2u8; 50]));
cache.insert(0, 200, Bytes::from(vec![3u8; 50]));
cache.insert(0, 40, Bytes::from(vec![4u8; 40]));
let blocks = cache.drain_piece(0);
assert_eq!(blocks.len(), 3);
assert_eq!(blocks[0].offset, 0);
assert_eq!(blocks[0].data.len(), 80);
assert_eq!(blocks[1].offset, 100);
assert_eq!(blocks[2].offset, 200);
}
}