1use std::collections::HashMap;
12use std::time::Instant;
13
14pub trait BackingStore {
21 type Key: Eq + std::hash::Hash + Clone;
23 type Value: Clone;
25 type Error: std::fmt::Debug;
27
28 fn write(&mut self, key: &Self::Key, value: &Self::Value) -> Result<(), Self::Error>;
30
31 fn read(&self, key: &Self::Key) -> Result<Option<Self::Value>, Self::Error>;
33
34 fn delete(&mut self, key: &Self::Key) -> Result<(), Self::Error>;
36}
37
38struct CacheEntry<V> {
41 value: V,
42 dirty: bool,
43 last_modified: Instant,
44}
45
46pub struct WriteBehindCache<S: BackingStore> {
63 entries: HashMap<S::Key, CacheEntry<S::Value>>,
64 order: Vec<S::Key>,
66 capacity: usize,
67 store: S,
68 dirty_count: usize,
70 total_flushes: u64,
72 total_entries_flushed: u64,
74}
75
76#[derive(Debug, Clone)]
78pub struct WriteBehindStats {
79 pub entry_count: usize,
81 pub dirty_count: usize,
83 pub capacity: usize,
85 pub total_flushes: u64,
87 pub total_entries_flushed: u64,
89}
90
91#[derive(Debug)]
93pub enum WriteBehindError<E: std::fmt::Debug> {
94 StoreError(E),
96}
97
98impl<E: std::fmt::Debug> std::fmt::Display for WriteBehindError<E> {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 match self {
101 Self::StoreError(e) => write!(f, "backing store error: {e:?}"),
102 }
103 }
104}
105
106impl<S: BackingStore> WriteBehindCache<S> {
107 pub fn new(capacity: usize, store: S) -> Self {
110 Self {
111 entries: HashMap::new(),
112 order: Vec::new(),
113 capacity: capacity.max(1),
114 store,
115 dirty_count: 0,
116 total_flushes: 0,
117 total_entries_flushed: 0,
118 }
119 }
120
121 pub fn put(&mut self, key: S::Key, value: S::Value) -> Result<(), WriteBehindError<S::Error>> {
126 if self.entries.contains_key(&key) {
127 if let Some(entry) = self.entries.get_mut(&key) {
129 if !entry.dirty {
130 self.dirty_count += 1;
131 }
132 entry.value = value;
133 entry.dirty = true;
134 entry.last_modified = Instant::now();
135 }
136 return Ok(());
137 }
138
139 while self.entries.len() >= self.capacity {
141 self.evict_oldest()?;
142 }
143
144 self.order.push(key.clone());
145 self.entries.insert(
146 key,
147 CacheEntry {
148 value,
149 dirty: true,
150 last_modified: Instant::now(),
151 },
152 );
153 self.dirty_count += 1;
154 Ok(())
155 }
156
157 pub fn get(&mut self, key: &S::Key) -> Result<Option<&S::Value>, WriteBehindError<S::Error>> {
160 if self.entries.contains_key(key) {
161 return Ok(self.entries.get(key).map(|e| &e.value));
162 }
163 let value = self.store.read(key).map_err(WriteBehindError::StoreError)?;
165 if let Some(v) = value {
166 while self.entries.len() >= self.capacity {
168 self.evict_oldest().map_err(|e| match e {
169 WriteBehindError::StoreError(se) => WriteBehindError::StoreError(se),
170 })?;
171 }
172 self.order.push(key.clone());
173 self.entries.insert(
174 key.clone(),
175 CacheEntry {
176 value: v,
177 dirty: false,
178 last_modified: Instant::now(),
179 },
180 );
181 return Ok(self.entries.get(key).map(|e| &e.value));
182 }
183 Ok(None)
184 }
185
186 pub fn delete(&mut self, key: &S::Key) -> Result<bool, WriteBehindError<S::Error>> {
188 if let Some(entry) = self.entries.remove(key) {
189 self.order.retain(|k| k != key);
190 if entry.dirty {
191 self.dirty_count = self.dirty_count.saturating_sub(1);
192 }
193 self.store
194 .delete(key)
195 .map_err(WriteBehindError::StoreError)?;
196 return Ok(true);
197 }
198 Ok(false)
199 }
200
201 pub fn flush(&mut self) -> Result<usize, WriteBehindError<S::Error>> {
205 let dirty_keys: Vec<S::Key> = self
206 .entries
207 .iter()
208 .filter(|(_, e)| e.dirty)
209 .map(|(k, _)| k.clone())
210 .collect();
211 let count = dirty_keys.len();
212 for key in &dirty_keys {
213 if let Some(entry) = self.entries.get(key) {
214 self.store
215 .write(key, &entry.value)
216 .map_err(WriteBehindError::StoreError)?;
217 }
218 if let Some(entry) = self.entries.get_mut(key) {
219 entry.dirty = false;
220 }
221 }
222 self.dirty_count = 0;
223 self.total_flushes += 1;
224 self.total_entries_flushed += count as u64;
225 Ok(count)
226 }
227
228 pub fn flush_if_needed(
230 &mut self,
231 threshold: usize,
232 ) -> Result<usize, WriteBehindError<S::Error>> {
233 if self.dirty_count >= threshold {
234 self.flush()
235 } else {
236 Ok(0)
237 }
238 }
239
240 pub fn dirty_count(&self) -> usize {
242 self.dirty_count
243 }
244
245 pub fn stats(&self) -> WriteBehindStats {
247 WriteBehindStats {
248 entry_count: self.entries.len(),
249 dirty_count: self.dirty_count,
250 capacity: self.capacity,
251 total_flushes: self.total_flushes,
252 total_entries_flushed: self.total_entries_flushed,
253 }
254 }
255
256 pub fn store(&self) -> &S {
258 &self.store
259 }
260
261 pub fn store_mut(&mut self) -> &mut S {
263 &mut self.store
264 }
265
266 pub fn is_dirty(&self, key: &S::Key) -> bool {
268 self.entries.get(key).map(|e| e.dirty).unwrap_or(false)
269 }
270
271 pub fn contains(&self, key: &S::Key) -> bool {
273 self.entries.contains_key(key)
274 }
275
276 pub fn len(&self) -> usize {
278 self.entries.len()
279 }
280
281 pub fn is_empty(&self) -> bool {
283 self.entries.is_empty()
284 }
285
286 pub fn flush_older_than(
294 &mut self,
295 max_age: std::time::Duration,
296 ) -> Result<usize, WriteBehindError<S::Error>> {
297 let now = Instant::now();
298 let old_dirty_keys: Vec<S::Key> = self
299 .entries
300 .iter()
301 .filter(|(_, e)| e.dirty && now.duration_since(e.last_modified) >= max_age)
302 .map(|(k, _)| k.clone())
303 .collect();
304 let count = old_dirty_keys.len();
305 for key in &old_dirty_keys {
306 if let Some(entry) = self.entries.get(key) {
307 self.store
308 .write(key, &entry.value)
309 .map_err(WriteBehindError::StoreError)?;
310 }
311 if let Some(entry) = self.entries.get_mut(key) {
312 entry.dirty = false;
313 }
314 }
315 self.dirty_count = self.dirty_count.saturating_sub(count);
316 if count > 0 {
317 self.total_flushes += 1;
318 self.total_entries_flushed += count as u64;
319 }
320 Ok(count)
321 }
322
323 pub fn dirty_keys(&self) -> Vec<S::Key> {
325 self.entries
326 .iter()
327 .filter(|(_, e)| e.dirty)
328 .map(|(k, _)| k.clone())
329 .collect()
330 }
331
332 pub fn mark_clean(&mut self, key: &S::Key) -> bool {
338 if let Some(entry) = self.entries.get_mut(key) {
339 if entry.dirty {
340 entry.dirty = false;
341 self.dirty_count = self.dirty_count.saturating_sub(1);
342 return true;
343 }
344 }
345 false
346 }
347
348 pub fn capacity(&self) -> usize {
350 self.capacity
351 }
352
353 fn evict_oldest(&mut self) -> Result<(), WriteBehindError<S::Error>> {
355 if self.order.is_empty() {
356 return Ok(());
357 }
358 let key = self.order.remove(0);
359 if let Some(entry) = self.entries.remove(&key) {
360 if entry.dirty {
361 self.store
362 .write(&key, &entry.value)
363 .map_err(WriteBehindError::StoreError)?;
364 self.dirty_count = self.dirty_count.saturating_sub(1);
365 self.total_entries_flushed += 1;
366 }
367 }
368 Ok(())
369 }
370}
371
372#[cfg(test)]
375mod tests {
376 use super::*;
377 use std::collections::HashMap;
378 use std::sync::{Arc, Mutex};
379
380 #[derive(Clone)]
382 struct MemStore {
383 data: Arc<Mutex<HashMap<String, String>>>,
384 }
385
386 impl MemStore {
387 fn new() -> Self {
388 Self {
389 data: Arc::new(Mutex::new(HashMap::new())),
390 }
391 }
392
393 fn snapshot(&self) -> HashMap<String, String> {
394 let guard = self.data.lock().unwrap_or_else(|p| p.into_inner());
395 guard.clone()
396 }
397 }
398
399 impl BackingStore for MemStore {
400 type Key = String;
401 type Value = String;
402 type Error = String;
403
404 fn write(&mut self, key: &String, value: &String) -> Result<(), String> {
405 let mut guard = self.data.lock().unwrap_or_else(|p| p.into_inner());
406 guard.insert(key.clone(), value.clone());
407 Ok(())
408 }
409
410 fn read(&self, key: &String) -> Result<Option<String>, String> {
411 let guard = self.data.lock().unwrap_or_else(|p| p.into_inner());
412 Ok(guard.get(key).cloned())
413 }
414
415 fn delete(&mut self, key: &String) -> Result<(), String> {
416 let mut guard = self.data.lock().unwrap_or_else(|p| p.into_inner());
417 guard.remove(key);
418 Ok(())
419 }
420 }
421
422 #[test]
424 fn test_put_and_get() {
425 let store = MemStore::new();
426 let mut cache = WriteBehindCache::new(10, store);
427 cache.put("k1".to_string(), "v1".to_string()).ok();
428 let val = cache.get(&"k1".to_string()).ok().flatten();
429 assert_eq!(val, Some(&"v1".to_string()));
430 }
431
432 #[test]
434 fn test_dirty_tracking() {
435 let store = MemStore::new();
436 let mut cache = WriteBehindCache::new(10, store);
437 cache.put("a".to_string(), "1".to_string()).ok();
438 assert!(cache.is_dirty(&"a".to_string()));
439 assert_eq!(cache.dirty_count(), 1);
440 }
441
442 #[test]
444 fn test_flush_writes_to_store() {
445 let store = MemStore::new();
446 let mut cache = WriteBehindCache::new(10, store.clone());
447 cache.put("x".to_string(), "42".to_string()).ok();
448 let flushed = cache.flush().ok();
449 assert_eq!(flushed, Some(1));
450 assert!(!cache.is_dirty(&"x".to_string()));
451 let snap = store.snapshot();
452 assert_eq!(snap.get("x"), Some(&"42".to_string()));
453 }
454
455 #[test]
457 fn test_flush_clears_dirty() {
458 let store = MemStore::new();
459 let mut cache = WriteBehindCache::new(10, store);
460 cache.put("a".to_string(), "1".to_string()).ok();
461 cache.put("b".to_string(), "2".to_string()).ok();
462 cache.flush().ok();
463 assert_eq!(cache.dirty_count(), 0);
464 }
465
466 #[test]
468 fn test_flush_if_needed() {
469 let store = MemStore::new();
470 let mut cache = WriteBehindCache::new(10, store);
471 cache.put("a".to_string(), "1".to_string()).ok();
472 let flushed = cache.flush_if_needed(5).ok();
473 assert_eq!(flushed, Some(0)); cache.put("b".to_string(), "2".to_string()).ok();
475 cache.put("c".to_string(), "3".to_string()).ok();
476 let flushed = cache.flush_if_needed(2).ok();
477 assert_eq!(flushed, Some(3)); }
479
480 #[test]
482 fn test_eviction_flushes_dirty() {
483 let store = MemStore::new();
484 let mut cache = WriteBehindCache::new(2, store.clone());
485 cache.put("a".to_string(), "1".to_string()).ok();
486 cache.put("b".to_string(), "2".to_string()).ok();
487 cache.put("c".to_string(), "3".to_string()).ok();
489 let snap = store.snapshot();
490 assert_eq!(snap.get("a"), Some(&"1".to_string()));
491 }
492
493 #[test]
495 fn test_delete() {
496 let store = MemStore::new();
497 let mut cache = WriteBehindCache::new(10, store.clone());
498 cache.put("k".to_string(), "v".to_string()).ok();
499 cache.flush().ok();
500 let deleted = cache.delete(&"k".to_string()).ok();
501 assert_eq!(deleted, Some(true));
502 let snap = store.snapshot();
503 assert!(!snap.contains_key("k"));
504 }
505
506 #[test]
508 fn test_read_through() {
509 let store = MemStore::new();
510 {
511 let mut guard = store.data.lock().unwrap_or_else(|p| p.into_inner());
512 guard.insert("pre".to_string(), "existing".to_string());
513 }
514 let mut cache = WriteBehindCache::new(10, store);
515 let val = cache.get(&"pre".to_string()).ok().flatten();
516 assert_eq!(val, Some(&"existing".to_string()));
517 assert!(!cache.is_dirty(&"pre".to_string()));
519 }
520
521 #[test]
523 fn test_update_re_dirties() {
524 let store = MemStore::new();
525 let mut cache = WriteBehindCache::new(10, store);
526 cache.put("a".to_string(), "1".to_string()).ok();
527 cache.flush().ok();
528 assert!(!cache.is_dirty(&"a".to_string()));
529 cache.put("a".to_string(), "2".to_string()).ok();
530 assert!(cache.is_dirty(&"a".to_string()));
531 }
532
533 #[test]
535 fn test_stats() {
536 let store = MemStore::new();
537 let mut cache = WriteBehindCache::new(10, store);
538 cache.put("a".to_string(), "1".to_string()).ok();
539 cache.put("b".to_string(), "2".to_string()).ok();
540 cache.flush().ok();
541 let s = cache.stats();
542 assert_eq!(s.entry_count, 2);
543 assert_eq!(s.dirty_count, 0);
544 assert_eq!(s.total_flushes, 1);
545 assert_eq!(s.total_entries_flushed, 2);
546 }
547
548 #[test]
550 fn test_delete_absent() {
551 let store = MemStore::new();
552 let mut cache = WriteBehindCache::new(10, store);
553 let deleted = cache.delete(&"ghost".to_string()).ok();
554 assert_eq!(deleted, Some(false));
555 }
556
557 #[test]
559 fn test_get_absent() {
560 let store = MemStore::new();
561 let mut cache = WriteBehindCache::new(10, store);
562 let val = cache.get(&"nope".to_string()).ok().flatten();
563 assert!(val.is_none());
564 }
565
566 #[test]
570 fn test_contains() {
571 let store = MemStore::new();
572 let mut cache = WriteBehindCache::new(10, store);
573 cache.put("x".to_string(), "val".to_string()).ok();
574 assert!(cache.contains(&"x".to_string()));
575 assert!(!cache.contains(&"y".to_string()));
576 }
577
578 #[test]
580 fn test_len_and_is_empty() {
581 let store = MemStore::new();
582 let mut cache = WriteBehindCache::new(10, store);
583 assert!(cache.is_empty());
584 assert_eq!(cache.len(), 0);
585 cache.put("a".to_string(), "1".to_string()).ok();
586 cache.put("b".to_string(), "2".to_string()).ok();
587 assert_eq!(cache.len(), 2);
588 assert!(!cache.is_empty());
589 }
590
591 #[test]
593 fn test_flush_older_than() {
594 let store = MemStore::new();
595 let mut cache = WriteBehindCache::new(10, store.clone());
596 cache.put("old".to_string(), "old_val".to_string()).ok();
597 std::thread::sleep(std::time::Duration::from_millis(50));
599 cache.put("new".to_string(), "new_val".to_string()).ok();
600 let flushed = cache
602 .flush_older_than(std::time::Duration::from_millis(30))
603 .ok();
604 assert_eq!(flushed, Some(1));
605 assert!(!cache.is_dirty(&"old".to_string()));
607 assert!(cache.is_dirty(&"new".to_string()));
608 let snap = store.snapshot();
610 assert!(snap.contains_key("old"));
611 }
612
613 #[test]
615 fn test_flush_older_than_zero() {
616 let store = MemStore::new();
617 let mut cache = WriteBehindCache::new(10, store);
618 cache.put("a".to_string(), "1".to_string()).ok();
619 cache.put("b".to_string(), "2".to_string()).ok();
620 let flushed = cache
621 .flush_older_than(std::time::Duration::from_millis(0))
622 .ok();
623 assert_eq!(flushed, Some(2));
624 assert_eq!(cache.dirty_count(), 0);
625 }
626
627 #[test]
629 fn test_dirty_keys() {
630 let store = MemStore::new();
631 let mut cache = WriteBehindCache::new(10, store);
632 cache.put("a".to_string(), "1".to_string()).ok();
633 cache.put("b".to_string(), "2".to_string()).ok();
634 cache.put("c".to_string(), "3".to_string()).ok();
635 cache.flush().ok();
636 cache.put("b".to_string(), "updated".to_string()).ok();
638 let dirty = cache.dirty_keys();
639 assert_eq!(dirty.len(), 1);
640 assert_eq!(dirty[0], "b");
641 }
642
643 #[test]
645 fn test_mark_clean() {
646 let store = MemStore::new();
647 let mut cache = WriteBehindCache::new(10, store.clone());
648 cache.put("x".to_string(), "val".to_string()).ok();
649 assert!(cache.is_dirty(&"x".to_string()));
650 assert!(cache.mark_clean(&"x".to_string()));
651 assert!(!cache.is_dirty(&"x".to_string()));
652 assert_eq!(cache.dirty_count(), 0);
653 let snap = store.snapshot();
655 assert!(!snap.contains_key("x"));
656 }
657
658 #[test]
660 fn test_mark_clean_already_clean() {
661 let store = MemStore::new();
662 let mut cache = WriteBehindCache::new(10, store);
663 cache.put("a".to_string(), "1".to_string()).ok();
664 cache.flush().ok();
665 assert!(!cache.mark_clean(&"a".to_string()));
666 }
667
668 #[test]
670 fn test_mark_clean_absent() {
671 let store = MemStore::new();
672 let mut cache = WriteBehindCache::new(10, store);
673 assert!(!cache.mark_clean(&"ghost".to_string()));
674 }
675
676 #[test]
678 fn test_capacity() {
679 let store = MemStore::new();
680 let cache: WriteBehindCache<MemStore> = WriteBehindCache::new(42, store);
681 assert_eq!(cache.capacity(), 42);
682 }
683
684 #[test]
686 fn test_multiple_flushes_stats() {
687 let store = MemStore::new();
688 let mut cache = WriteBehindCache::new(10, store);
689 cache.put("a".to_string(), "1".to_string()).ok();
690 cache.flush().ok();
691 cache.put("b".to_string(), "2".to_string()).ok();
692 cache.flush().ok();
693 let s = cache.stats();
694 assert_eq!(s.total_flushes, 2);
695 assert_eq!(s.total_entries_flushed, 2);
696 }
697
698 #[test]
700 fn test_eviction_cascade() {
701 let store = MemStore::new();
702 let mut cache = WriteBehindCache::new(3, store.clone());
703 for i in 0..5 {
704 cache.put(format!("k{i}"), format!("v{i}")).ok();
705 }
706 let snap = store.snapshot();
708 assert!(snap.contains_key("k0"), "evicted k0 should be in store");
709 assert!(snap.contains_key("k1"), "evicted k1 should be in store");
710 }
711
712 #[test]
714 fn test_read_through_is_clean() {
715 let store = MemStore::new();
716 {
717 let mut guard = store.data.lock().unwrap_or_else(|p| p.into_inner());
718 guard.insert("existing".to_string(), "value".to_string());
719 }
720 let mut cache = WriteBehindCache::new(10, store);
721 cache.get(&"existing".to_string()).ok();
722 assert!(!cache.is_dirty(&"existing".to_string()));
723 assert_eq!(cache.dirty_count(), 0);
724 }
725
726 #[test]
728 fn test_store_accessors() {
729 let store = MemStore::new();
730 let cache = WriteBehindCache::new(10, store);
731 let _store_ref = cache.store();
732 }
734}