1use std::collections::HashMap;
2use std::num::NonZeroUsize;
3use std::sync::Arc;
4
5use lru::LruCache;
6use parking_lot::{Condvar, Mutex};
7use smallvec::SmallVec;
8
9use crate::error::Result;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct ChunkKey {
17 pub dataset_addr: u64,
18 pub chunk_offsets: SmallVec<[u64; 4]>,
19}
20
21pub struct ChunkCache {
26 inner: Mutex<ChunkCacheState>,
27 max_bytes: usize,
28}
29
30struct ChunkCacheState {
31 cache: LruCache<ChunkKey, Arc<Vec<u8>>>,
32 current_bytes: usize,
33 in_flight: HashMap<ChunkKey, Arc<InFlightLoad>>,
34}
35
36struct InFlightLoad {
37 completed: Mutex<bool>,
38 ready: Condvar,
39}
40
41impl ChunkCache {
42 pub fn new(max_bytes: usize, max_slots: usize) -> Self {
47 let slots = NonZeroUsize::new(max_slots).unwrap_or(NonZeroUsize::new(521).unwrap());
48 ChunkCache {
49 inner: Mutex::new(ChunkCacheState {
50 cache: LruCache::new(slots),
51 current_bytes: 0,
52 in_flight: HashMap::new(),
53 }),
54 max_bytes,
55 }
56 }
57
58 pub fn get(&self, key: &ChunkKey) -> Option<Arc<Vec<u8>>> {
60 let mut cache = self.inner.lock();
61 cache.cache.get(key).cloned()
62 }
63
64 pub fn insert(&self, key: ChunkKey, data: Vec<u8>) -> Arc<Vec<u8>> {
66 let data_len = data.len();
67 let arc = Arc::new(data);
68
69 if self.max_bytes == 0 || data_len > self.max_bytes {
70 return arc;
71 }
72
73 let mut state = self.inner.lock();
74 while state.current_bytes + data_len > self.max_bytes && !state.cache.is_empty() {
76 if let Some((_, evicted)) = state.cache.pop_lru() {
77 state.current_bytes = state.current_bytes.saturating_sub(evicted.len());
78 }
79 }
80
81 if let Some(replaced) = state.cache.peek(&key) {
82 state.current_bytes = state.current_bytes.saturating_sub(replaced.len());
83 }
84 state.current_bytes += data_len;
85 state.cache.put(key, arc.clone());
86
87 arc
88 }
89
90 pub fn get_or_insert_with<F>(&self, key: ChunkKey, load: F) -> Result<Arc<Vec<u8>>>
92 where
93 F: FnOnce() -> Result<Vec<u8>>,
94 {
95 loop {
96 let in_flight = {
97 let mut state = self.inner.lock();
98 if let Some(cached) = state.cache.get(&key).cloned() {
99 return Ok(cached);
100 }
101
102 if let Some(in_flight) = state.in_flight.get(&key) {
103 Arc::clone(in_flight)
104 } else {
105 let in_flight = Arc::new(InFlightLoad {
106 completed: Mutex::new(false),
107 ready: Condvar::new(),
108 });
109 state.in_flight.insert(key.clone(), Arc::clone(&in_flight));
110 drop(state);
111
112 let result = load().map(|data| self.insert(key.clone(), data));
113
114 let mut state = self.inner.lock();
115 state.in_flight.remove(&key);
116 let mut completed = in_flight.completed.lock();
117 *completed = true;
118 in_flight.ready.notify_all();
119
120 return result;
121 }
122 };
123
124 let mut completed = in_flight.completed.lock();
125 while !*completed {
126 in_flight.ready.wait(&mut completed);
127 }
128 }
129 }
130}
131
132impl Default for ChunkCache {
133 fn default() -> Self {
134 Self::new(64 * 1024 * 1024, 521)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_cache_insert_and_get() {
144 let cache = ChunkCache::new(1024, 10);
145 let key = ChunkKey {
146 dataset_addr: 100,
147 chunk_offsets: SmallVec::from_vec(vec![0, 0]),
148 };
149 cache.insert(key.clone(), vec![1, 2, 3]);
150 let val = cache.get(&key).unwrap();
151 assert_eq!(&**val, &[1, 2, 3]);
152 }
153
154 #[test]
155 fn test_cache_eviction() {
156 let cache = ChunkCache::new(10, 10); for i in 0..5 {
158 let key = ChunkKey {
159 dataset_addr: 100,
160 chunk_offsets: SmallVec::from_vec(vec![i]),
161 };
162 cache.insert(key, vec![0; 4]); }
164 let first_key = ChunkKey {
167 dataset_addr: 100,
168 chunk_offsets: SmallVec::from_vec(vec![0]),
169 };
170 assert!(cache.get(&first_key).is_none()); }
172
173 #[test]
174 fn test_cache_disabled_bypasses_storage() {
175 let cache = ChunkCache::new(0, 10);
176 let key = ChunkKey {
177 dataset_addr: 100,
178 chunk_offsets: SmallVec::from_vec(vec![0]),
179 };
180 cache.insert(key.clone(), vec![1, 2, 3]);
181 assert!(cache.get(&key).is_none());
182 }
183
184 #[test]
185 fn test_cache_promotes_on_get() {
186 let cache = ChunkCache::new(12, 10); let key_a = ChunkKey {
189 dataset_addr: 1,
190 chunk_offsets: SmallVec::from_vec(vec![0]),
191 };
192 let key_b = ChunkKey {
193 dataset_addr: 2,
194 chunk_offsets: SmallVec::from_vec(vec![0]),
195 };
196 let key_c = ChunkKey {
197 dataset_addr: 3,
198 chunk_offsets: SmallVec::from_vec(vec![0]),
199 };
200
201 cache.insert(key_a.clone(), vec![0; 4]); cache.insert(key_b.clone(), vec![0; 4]); cache.insert(key_c.clone(), vec![0; 4]); assert!(cache.get(&key_a).is_some()); let key_d = ChunkKey {
210 dataset_addr: 4,
211 chunk_offsets: SmallVec::from_vec(vec![0]),
212 };
213 cache.insert(key_d, vec![0; 4]); assert!(cache.get(&key_a).is_some()); assert!(cache.get(&key_b).is_none()); }
218
219 #[test]
220 fn test_cache_replacement_updates_accounting() {
221 let cache = ChunkCache::new(8, 10);
222 let key = ChunkKey {
223 dataset_addr: 100,
224 chunk_offsets: SmallVec::from_vec(vec![0]),
225 };
226
227 cache.insert(key.clone(), vec![1, 2, 3, 4]);
228 cache.insert(key.clone(), vec![5, 6]);
229
230 let other = ChunkKey {
231 dataset_addr: 100,
232 chunk_offsets: SmallVec::from_vec(vec![1]),
233 };
234 cache.insert(other.clone(), vec![7, 8, 9, 10]);
235
236 assert_eq!(&**cache.get(&key).unwrap(), &[5, 6]);
237 assert!(cache.get(&other).is_some());
238 }
239
240 #[test]
241 fn test_cache_get_or_insert_with_deduplicates_concurrent_loads() {
242 use std::sync::atomic::{AtomicUsize, Ordering};
243
244 let cache = Arc::new(ChunkCache::new(1024, 10));
245 let key = ChunkKey {
246 dataset_addr: 100,
247 chunk_offsets: SmallVec::from_vec(vec![0, 0]),
248 };
249 let load_count = Arc::new(AtomicUsize::new(0));
250
251 std::thread::scope(|scope| {
252 for _ in 0..8 {
253 let cache = Arc::clone(&cache);
254 let key = key.clone();
255 let load_count = Arc::clone(&load_count);
256 scope.spawn(move || {
257 let value = cache
258 .get_or_insert_with(key, || {
259 load_count.fetch_add(1, Ordering::SeqCst);
260 std::thread::sleep(std::time::Duration::from_millis(10));
261 Ok(vec![1, 2, 3, 4])
262 })
263 .unwrap();
264 assert_eq!(&*value, &[1, 2, 3, 4]);
265 });
266 }
267 });
268
269 assert_eq!(load_count.load(Ordering::SeqCst), 1);
270 }
271}