1use crate::error::Result;
12use crate::multi_tier::CacheKey;
13use std::collections::{HashMap, VecDeque};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use tokio::sync::RwLock;
17
18type WriteBufferQueue = Arc<RwLock<VecDeque<(CacheKey, Vec<u8>)>>>;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum WritePolicyType {
24 WriteThrough,
26 WriteBack,
28 WriteBehind,
30 WriteAround,
32}
33
34#[derive(Debug, Clone)]
36pub struct DirtyBlock {
37 pub key: CacheKey,
39 pub dirty_time: Instant,
41 pub size: usize,
43 pub write_count: u64,
45}
46
47impl DirtyBlock {
48 pub fn new(key: CacheKey, size: usize) -> Self {
50 Self {
51 key,
52 dirty_time: Instant::now(),
53 size,
54 write_count: 1,
55 }
56 }
57
58 pub fn record_write(&mut self) {
60 self.write_count += 1;
61 }
62
63 pub fn age(&self) -> Duration {
65 self.dirty_time.elapsed()
66 }
67}
68
69pub struct WriteBackManager {
71 dirty_blocks: Arc<RwLock<HashMap<CacheKey, DirtyBlock>>>,
73 max_dirty_blocks: usize,
75 max_dirty_age: Duration,
77 coalescing_enabled: bool,
79}
80
81impl WriteBackManager {
82 pub fn new(max_dirty_blocks: usize, max_dirty_age: Duration) -> Self {
84 Self {
85 dirty_blocks: Arc::new(RwLock::new(HashMap::new())),
86 max_dirty_blocks,
87 max_dirty_age,
88 coalescing_enabled: true,
89 }
90 }
91
92 pub fn set_coalescing(&mut self, enabled: bool) {
94 self.coalescing_enabled = enabled;
95 }
96
97 pub async fn mark_dirty(&self, key: CacheKey, size: usize) -> Result<bool> {
99 let mut dirty = self.dirty_blocks.write().await;
100
101 if let Some(block) = dirty.get_mut(&key) {
102 if self.coalescing_enabled {
104 block.record_write();
105 return Ok(false); }
107 } else {
108 dirty.insert(key.clone(), DirtyBlock::new(key, size));
109 }
110
111 let needs_flush = dirty.len() >= self.max_dirty_blocks;
113 Ok(needs_flush)
114 }
115
116 pub async fn get_flush_candidates(&self) -> Vec<DirtyBlock> {
118 let dirty = self.dirty_blocks.read().await;
119 let _now = Instant::now();
120
121 dirty
122 .values()
123 .filter(|block| {
124 block.age() >= self.max_dirty_age || dirty.len() >= self.max_dirty_blocks
125 })
126 .cloned()
127 .collect()
128 }
129
130 pub async fn mark_clean(&self, key: &CacheKey) {
132 self.dirty_blocks.write().await.remove(key);
133 }
134
135 pub async fn dirty_count(&self) -> usize {
137 self.dirty_blocks.read().await.len()
138 }
139
140 pub async fn dirty_bytes(&self) -> usize {
142 self.dirty_blocks
143 .read()
144 .await
145 .values()
146 .map(|b| b.size)
147 .sum()
148 }
149
150 pub async fn oldest_dirty_age(&self) -> Option<Duration> {
152 self.dirty_blocks
153 .read()
154 .await
155 .values()
156 .map(|b| b.age())
157 .max()
158 }
159}
160
161pub struct WriteBuffer {
163 buffer: WriteBufferQueue,
165 max_buffer_size: usize,
167 current_size: Arc<RwLock<usize>>,
169}
170
171impl WriteBuffer {
172 pub fn new(max_buffer_size: usize) -> Self {
174 Self {
175 buffer: Arc::new(RwLock::new(VecDeque::new())),
176 max_buffer_size,
177 current_size: Arc::new(RwLock::new(0)),
178 }
179 }
180
181 pub async fn add_write(&self, key: CacheKey, data: Vec<u8>) -> Result<bool> {
183 let data_size = data.len();
184 let mut size = self.current_size.write().await;
185
186 if *size + data_size >= self.max_buffer_size {
188 return Ok(true); }
190
191 let mut buffer = self.buffer.write().await;
192 buffer.push_back((key, data));
193 *size += data_size;
194
195 Ok(false)
196 }
197
198 pub async fn drain(&self) -> Vec<(CacheKey, Vec<u8>)> {
200 let mut buffer = self.buffer.write().await;
201 let mut size = self.current_size.write().await;
202
203 let writes: Vec<_> = buffer.drain(..).collect();
204 *size = 0;
205
206 writes
207 }
208
209 pub async fn size(&self) -> usize {
211 *self.current_size.read().await
212 }
213
214 pub async fn count(&self) -> usize {
216 self.buffer.read().await.len()
217 }
218}
219
220pub struct WriteAmplificationTracker {
222 cache_writes: Arc<RwLock<u64>>,
224 backing_writes: Arc<RwLock<u64>>,
226}
227
228impl WriteAmplificationTracker {
229 pub fn new() -> Self {
231 Self {
232 cache_writes: Arc::new(RwLock::new(0)),
233 backing_writes: Arc::new(RwLock::new(0)),
234 }
235 }
236
237 pub async fn record_cache_write(&self, bytes: u64) {
239 *self.cache_writes.write().await += bytes;
240 }
241
242 pub async fn record_backing_write(&self, bytes: u64) {
244 *self.backing_writes.write().await += bytes;
245 }
246
247 pub async fn amplification_factor(&self) -> f64 {
249 let cache = *self.cache_writes.read().await;
250 let backing = *self.backing_writes.read().await;
251
252 if cache == 0 {
253 0.0
254 } else {
255 backing as f64 / cache as f64
256 }
257 }
258
259 pub async fn cache_writes(&self) -> u64 {
261 *self.cache_writes.read().await
262 }
263
264 pub async fn backing_writes(&self) -> u64 {
266 *self.backing_writes.read().await
267 }
268
269 pub async fn reset(&self) {
271 *self.cache_writes.write().await = 0;
272 *self.backing_writes.write().await = 0;
273 }
274}
275
276impl Default for WriteAmplificationTracker {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281
282pub struct WritePolicyManager {
284 policy_type: WritePolicyType,
286 write_back: WriteBackManager,
288 write_buffer: WriteBuffer,
290 amplification: WriteAmplificationTracker,
292}
293
294impl WritePolicyManager {
295 pub fn new(
297 policy_type: WritePolicyType,
298 max_dirty_blocks: usize,
299 max_dirty_age: Duration,
300 buffer_size: usize,
301 ) -> Self {
302 Self {
303 policy_type,
304 write_back: WriteBackManager::new(max_dirty_blocks, max_dirty_age),
305 write_buffer: WriteBuffer::new(buffer_size),
306 amplification: WriteAmplificationTracker::new(),
307 }
308 }
309
310 pub fn policy_type(&self) -> WritePolicyType {
312 self.policy_type
313 }
314
315 pub fn set_policy_type(&mut self, policy_type: WritePolicyType) {
317 self.policy_type = policy_type;
318 }
319
320 pub async fn handle_write(&self, key: CacheKey, data: Vec<u8>) -> Result<WriteAction> {
322 let data_size = data.len();
323
324 match self.policy_type {
325 WritePolicyType::WriteThrough => {
326 let needs_flush = self.write_buffer.add_write(key, data).await?;
328
329 if needs_flush {
330 Ok(WriteAction::FlushBuffer)
331 } else {
332 Ok(WriteAction::Buffered)
333 }
334 }
335 WritePolicyType::WriteBack => {
336 let needs_flush = self.write_back.mark_dirty(key, data_size).await?;
338
339 if needs_flush {
340 Ok(WriteAction::FlushDirty)
341 } else {
342 Ok(WriteAction::Deferred)
343 }
344 }
345 WritePolicyType::WriteBehind => {
346 Ok(WriteAction::Async)
348 }
349 WritePolicyType::WriteAround => {
350 Ok(WriteAction::Direct)
352 }
353 }
354 }
355
356 pub fn write_back(&self) -> &WriteBackManager {
358 &self.write_back
359 }
360
361 pub fn write_buffer(&self) -> &WriteBuffer {
363 &self.write_buffer
364 }
365
366 pub fn amplification(&self) -> &WriteAmplificationTracker {
368 &self.amplification
369 }
370}
371
372#[derive(Debug, Clone, PartialEq, Eq)]
374pub enum WriteAction {
375 Buffered,
377 FlushBuffer,
379 Deferred,
381 FlushDirty,
383 Async,
385 Direct,
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_dirty_block() {
395 let mut block = DirtyBlock::new("key1".to_string(), 1024);
396 assert_eq!(block.write_count, 1);
397
398 block.record_write();
399 assert_eq!(block.write_count, 2);
400
401 assert!(block.age().as_secs() < 1);
402 }
403
404 #[tokio::test]
405 async fn test_write_back_manager() {
406 let manager = WriteBackManager::new(10, Duration::from_secs(60));
407
408 let needs_flush = manager
409 .mark_dirty("key1".to_string(), 1024)
410 .await
411 .unwrap_or(false);
412 assert!(!needs_flush);
413
414 let count = manager.dirty_count().await;
415 assert_eq!(count, 1);
416
417 let bytes = manager.dirty_bytes().await;
418 assert_eq!(bytes, 1024);
419 }
420
421 #[tokio::test]
422 async fn test_write_buffer() {
423 let buffer = WriteBuffer::new(1024 * 10);
424
425 let data = vec![0u8; 1024];
426 let needs_flush = buffer
427 .add_write("key1".to_string(), data)
428 .await
429 .unwrap_or(false);
430 assert!(!needs_flush);
431
432 let size = buffer.size().await;
433 assert_eq!(size, 1024);
434
435 let writes = buffer.drain().await;
436 assert_eq!(writes.len(), 1);
437
438 let size = buffer.size().await;
439 assert_eq!(size, 0);
440 }
441
442 #[tokio::test]
443 async fn test_write_amplification() {
444 let tracker = WriteAmplificationTracker::new();
445
446 tracker.record_cache_write(1000).await;
447 tracker.record_backing_write(2000).await;
448
449 let amp = tracker.amplification_factor().await;
450 assert!((amp - 2.0).abs() < 0.01);
451 }
452
453 #[tokio::test]
454 async fn test_write_policy_manager() {
455 let manager = WritePolicyManager::new(
456 WritePolicyType::WriteBack,
457 10,
458 Duration::from_secs(60),
459 1024 * 10,
460 );
461
462 let data = vec![0u8; 1024];
463 let action = manager
464 .handle_write("key1".to_string(), data)
465 .await
466 .unwrap_or(WriteAction::Deferred);
467
468 assert_eq!(action, WriteAction::Deferred);
469 }
470}