use crate::error::Result;
use crate::multi_tier::CacheKey;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
type WriteBufferQueue = Arc<RwLock<VecDeque<(CacheKey, Vec<u8>)>>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WritePolicyType {
WriteThrough,
WriteBack,
WriteBehind,
WriteAround,
}
#[derive(Debug, Clone)]
pub struct DirtyBlock {
pub key: CacheKey,
pub dirty_time: Instant,
pub size: usize,
pub write_count: u64,
}
impl DirtyBlock {
pub fn new(key: CacheKey, size: usize) -> Self {
Self {
key,
dirty_time: Instant::now(),
size,
write_count: 1,
}
}
pub fn record_write(&mut self) {
self.write_count += 1;
}
pub fn age(&self) -> Duration {
self.dirty_time.elapsed()
}
}
pub struct WriteBackManager {
dirty_blocks: Arc<RwLock<HashMap<CacheKey, DirtyBlock>>>,
max_dirty_blocks: usize,
max_dirty_age: Duration,
coalescing_enabled: bool,
}
impl WriteBackManager {
pub fn new(max_dirty_blocks: usize, max_dirty_age: Duration) -> Self {
Self {
dirty_blocks: Arc::new(RwLock::new(HashMap::new())),
max_dirty_blocks,
max_dirty_age,
coalescing_enabled: true,
}
}
pub fn set_coalescing(&mut self, enabled: bool) {
self.coalescing_enabled = enabled;
}
pub async fn mark_dirty(&self, key: CacheKey, size: usize) -> Result<bool> {
let mut dirty = self.dirty_blocks.write().await;
if let Some(block) = dirty.get_mut(&key) {
if self.coalescing_enabled {
block.record_write();
return Ok(false); }
} else {
dirty.insert(key.clone(), DirtyBlock::new(key, size));
}
let needs_flush = dirty.len() >= self.max_dirty_blocks;
Ok(needs_flush)
}
pub async fn get_flush_candidates(&self) -> Vec<DirtyBlock> {
let dirty = self.dirty_blocks.read().await;
let _now = Instant::now();
dirty
.values()
.filter(|block| {
block.age() >= self.max_dirty_age || dirty.len() >= self.max_dirty_blocks
})
.cloned()
.collect()
}
pub async fn mark_clean(&self, key: &CacheKey) {
self.dirty_blocks.write().await.remove(key);
}
pub async fn dirty_count(&self) -> usize {
self.dirty_blocks.read().await.len()
}
pub async fn dirty_bytes(&self) -> usize {
self.dirty_blocks
.read()
.await
.values()
.map(|b| b.size)
.sum()
}
pub async fn oldest_dirty_age(&self) -> Option<Duration> {
self.dirty_blocks
.read()
.await
.values()
.map(|b| b.age())
.max()
}
}
pub struct WriteBuffer {
buffer: WriteBufferQueue,
max_buffer_size: usize,
current_size: Arc<RwLock<usize>>,
}
impl WriteBuffer {
pub fn new(max_buffer_size: usize) -> Self {
Self {
buffer: Arc::new(RwLock::new(VecDeque::new())),
max_buffer_size,
current_size: Arc::new(RwLock::new(0)),
}
}
pub async fn add_write(&self, key: CacheKey, data: Vec<u8>) -> Result<bool> {
let data_size = data.len();
let mut size = self.current_size.write().await;
if *size + data_size >= self.max_buffer_size {
return Ok(true); }
let mut buffer = self.buffer.write().await;
buffer.push_back((key, data));
*size += data_size;
Ok(false)
}
pub async fn drain(&self) -> Vec<(CacheKey, Vec<u8>)> {
let mut buffer = self.buffer.write().await;
let mut size = self.current_size.write().await;
let writes: Vec<_> = buffer.drain(..).collect();
*size = 0;
writes
}
pub async fn size(&self) -> usize {
*self.current_size.read().await
}
pub async fn count(&self) -> usize {
self.buffer.read().await.len()
}
}
pub struct WriteAmplificationTracker {
cache_writes: Arc<RwLock<u64>>,
backing_writes: Arc<RwLock<u64>>,
}
impl WriteAmplificationTracker {
pub fn new() -> Self {
Self {
cache_writes: Arc::new(RwLock::new(0)),
backing_writes: Arc::new(RwLock::new(0)),
}
}
pub async fn record_cache_write(&self, bytes: u64) {
*self.cache_writes.write().await += bytes;
}
pub async fn record_backing_write(&self, bytes: u64) {
*self.backing_writes.write().await += bytes;
}
pub async fn amplification_factor(&self) -> f64 {
let cache = *self.cache_writes.read().await;
let backing = *self.backing_writes.read().await;
if cache == 0 {
0.0
} else {
backing as f64 / cache as f64
}
}
pub async fn cache_writes(&self) -> u64 {
*self.cache_writes.read().await
}
pub async fn backing_writes(&self) -> u64 {
*self.backing_writes.read().await
}
pub async fn reset(&self) {
*self.cache_writes.write().await = 0;
*self.backing_writes.write().await = 0;
}
}
impl Default for WriteAmplificationTracker {
fn default() -> Self {
Self::new()
}
}
pub struct WritePolicyManager {
policy_type: WritePolicyType,
write_back: WriteBackManager,
write_buffer: WriteBuffer,
amplification: WriteAmplificationTracker,
}
impl WritePolicyManager {
pub fn new(
policy_type: WritePolicyType,
max_dirty_blocks: usize,
max_dirty_age: Duration,
buffer_size: usize,
) -> Self {
Self {
policy_type,
write_back: WriteBackManager::new(max_dirty_blocks, max_dirty_age),
write_buffer: WriteBuffer::new(buffer_size),
amplification: WriteAmplificationTracker::new(),
}
}
pub fn policy_type(&self) -> WritePolicyType {
self.policy_type
}
pub fn set_policy_type(&mut self, policy_type: WritePolicyType) {
self.policy_type = policy_type;
}
pub async fn handle_write(&self, key: CacheKey, data: Vec<u8>) -> Result<WriteAction> {
let data_size = data.len();
match self.policy_type {
WritePolicyType::WriteThrough => {
let needs_flush = self.write_buffer.add_write(key, data).await?;
if needs_flush {
Ok(WriteAction::FlushBuffer)
} else {
Ok(WriteAction::Buffered)
}
}
WritePolicyType::WriteBack => {
let needs_flush = self.write_back.mark_dirty(key, data_size).await?;
if needs_flush {
Ok(WriteAction::FlushDirty)
} else {
Ok(WriteAction::Deferred)
}
}
WritePolicyType::WriteBehind => {
Ok(WriteAction::Async)
}
WritePolicyType::WriteAround => {
Ok(WriteAction::Direct)
}
}
}
pub fn write_back(&self) -> &WriteBackManager {
&self.write_back
}
pub fn write_buffer(&self) -> &WriteBuffer {
&self.write_buffer
}
pub fn amplification(&self) -> &WriteAmplificationTracker {
&self.amplification
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WriteAction {
Buffered,
FlushBuffer,
Deferred,
FlushDirty,
Async,
Direct,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dirty_block() {
let mut block = DirtyBlock::new("key1".to_string(), 1024);
assert_eq!(block.write_count, 1);
block.record_write();
assert_eq!(block.write_count, 2);
assert!(block.age().as_secs() < 1);
}
#[tokio::test]
async fn test_write_back_manager() {
let manager = WriteBackManager::new(10, Duration::from_secs(60));
let needs_flush = manager
.mark_dirty("key1".to_string(), 1024)
.await
.unwrap_or(false);
assert!(!needs_flush);
let count = manager.dirty_count().await;
assert_eq!(count, 1);
let bytes = manager.dirty_bytes().await;
assert_eq!(bytes, 1024);
}
#[tokio::test]
async fn test_write_buffer() {
let buffer = WriteBuffer::new(1024 * 10);
let data = vec![0u8; 1024];
let needs_flush = buffer
.add_write("key1".to_string(), data)
.await
.unwrap_or(false);
assert!(!needs_flush);
let size = buffer.size().await;
assert_eq!(size, 1024);
let writes = buffer.drain().await;
assert_eq!(writes.len(), 1);
let size = buffer.size().await;
assert_eq!(size, 0);
}
#[tokio::test]
async fn test_write_amplification() {
let tracker = WriteAmplificationTracker::new();
tracker.record_cache_write(1000).await;
tracker.record_backing_write(2000).await;
let amp = tracker.amplification_factor().await;
assert!((amp - 2.0).abs() < 0.01);
}
#[tokio::test]
async fn test_write_policy_manager() {
let manager = WritePolicyManager::new(
WritePolicyType::WriteBack,
10,
Duration::from_secs(60),
1024 * 10,
);
let data = vec![0u8; 1024];
let action = manager
.handle_write("key1".to_string(), data)
.await
.unwrap_or(WriteAction::Deferred);
assert_eq!(action, WriteAction::Deferred);
}
}