1use std::sync::Arc;
2
3use async_trait::async_trait;
4use futures::task::SpawnExt;
5use tracing::{trace, warn};
6
7use crate::{Runtime, object::ObjectId};
8
9use super::{BlockIndex, ChecksummedBytes, DataCache, DataCacheResult};
10
11pub struct MultilevelDataCache<DiskCache, ExpressCache> {
13 disk_cache: Arc<DiskCache>,
14 express_cache: ExpressCache,
15 runtime: Runtime,
16}
17
18impl<DiskCache: DataCache, ExpressCache: DataCache> MultilevelDataCache<DiskCache, ExpressCache> {
19 pub fn new(disk_cache: Arc<DiskCache>, express_cache: ExpressCache, runtime: Runtime) -> Self {
21 assert_eq!(
23 disk_cache.block_size(),
24 express_cache.block_size(),
25 "block sizes must be equal"
26 );
27 Self {
28 disk_cache,
29 express_cache,
30 runtime,
31 }
32 }
33}
34
35#[async_trait]
36impl<DiskCache, ExpressCache> DataCache for MultilevelDataCache<DiskCache, ExpressCache>
37where
38 DiskCache: DataCache + Sync + Send + 'static,
39 ExpressCache: DataCache + Sync,
40{
41 async fn get_block(
43 &self,
44 cache_key: &ObjectId,
45 block_idx: BlockIndex,
46 block_offset: u64,
47 object_size: usize,
48 ) -> DataCacheResult<Option<ChecksummedBytes>> {
49 match self
50 .disk_cache
51 .get_block(cache_key, block_idx, block_offset, object_size)
52 .await
53 {
54 Ok(Some(data)) => {
55 trace!(cache_key=?cache_key, block_idx=block_idx, "block served from the disk cache");
56 return DataCacheResult::Ok(Some(data));
57 }
58 Ok(None) => (),
59 Err(err) => warn!(cache_key=?cache_key, block_idx=block_idx, ?err, "error reading block from disk cache"),
60 }
61
62 if let Some(data) = self
63 .express_cache
64 .get_block(cache_key, block_idx, block_offset, object_size)
65 .await?
66 {
67 trace!(cache_key=?cache_key, block_idx=block_idx, "block served from the express cache");
68 let cache_key = cache_key.clone();
69 let disk_cache = self.disk_cache.clone();
70 let data_cloned = data.clone();
71 self.runtime
72 .spawn(async move {
73 if let Err(error) = disk_cache
74 .put_block(cache_key.clone(), block_idx, block_offset, data_cloned, object_size)
75 .await
76 {
77 warn!(cache_key=?cache_key, block_idx, ?error, "failed to update the local cache");
78 }
79 })
80 .unwrap();
81 return DataCacheResult::Ok(Some(data));
82 }
83
84 DataCacheResult::Ok(None)
85 }
86
87 async fn put_block(
89 &self,
90 cache_key: ObjectId,
91 block_idx: BlockIndex,
92 block_offset: u64,
93 bytes: ChecksummedBytes,
94 object_size: usize,
95 ) -> DataCacheResult<()> {
96 if let Err(error) = self
97 .disk_cache
98 .put_block(cache_key.clone(), block_idx, block_offset, bytes.clone(), object_size)
99 .await
100 {
101 warn!(cache_key=?cache_key, block_idx, ?error, "failed to update the local cache");
102 }
103
104 self.express_cache
105 .put_block(cache_key, block_idx, block_offset, bytes, object_size)
106 .await
107 }
108
109 fn block_size(&self) -> u64 {
110 self.disk_cache.block_size()
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::checksums::ChecksummedBytes;
118 use crate::data_cache::{CacheLimit, DiskDataCache, DiskDataCacheConfig, ExpressDataCache, ExpressDataCacheConfig};
119 use crate::memory::PagedPool;
120
121 use futures::executor::ThreadPool;
122 use mountpoint_s3_client::mock_client::MockClient;
123 use mountpoint_s3_client::types::ETag;
124 use tempfile::TempDir;
125 use test_case::test_case;
126
127 const PART_SIZE: usize = 8 * 1024 * 1024;
128 const BLOCK_SIZE: u64 = 1024 * 1024;
129
130 fn create_disk_cache() -> (TempDir, Arc<DiskDataCache>) {
131 let cache_directory = tempfile::tempdir().unwrap();
132 let pool = PagedPool::new_with_candidate_sizes([BLOCK_SIZE as usize, PART_SIZE]);
133 let cache = DiskDataCache::new(
134 DiskDataCacheConfig {
135 cache_directory: cache_directory.path().to_path_buf(),
136 block_size: BLOCK_SIZE,
137 limit: CacheLimit::Unbounded,
138 },
139 pool,
140 );
141 (cache_directory, Arc::new(cache))
142 }
143
144 fn create_express_cache() -> (MockClient, ExpressDataCache<MockClient>) {
145 let bucket = "test_bucket";
146 let client = MockClient::config()
147 .bucket(bucket.to_string())
148 .part_size(PART_SIZE)
149 .enable_backpressure(true)
150 .initial_read_window_size(PART_SIZE)
151 .build();
152 let cache = ExpressDataCache::new(
153 client.clone(),
154 ExpressDataCacheConfig::new(bucket, "unique source description"),
155 );
156 (client, cache)
157 }
158
159 #[test_case(false, true; "get from local")]
160 #[test_case(true, false; "get from express")]
161 #[test_case(true, true; "both empty")]
162 #[tokio::test]
163 async fn test_put_to_both_caches(cleanup_local: bool, cleanup_express: bool) {
164 let (cache_dir, disk_cache) = create_disk_cache();
165 let (client, express_cache) = create_express_cache();
166 let runtime = Runtime::new(ThreadPool::builder().pool_size(1).create().unwrap());
167 let cache = MultilevelDataCache::new(disk_cache, express_cache, runtime);
168
169 let data = ChecksummedBytes::new("Foo".into());
170 let object_size = data.len();
171 let cache_key = ObjectId::new("a".into(), ETag::for_tests());
172
173 cache
175 .put_block(cache_key.clone(), 0, 0, data.clone(), object_size)
176 .await
177 .expect("put should succeed");
178
179 if cleanup_local {
181 cache_dir.close().expect("should clean up local cache");
182 }
183 if cleanup_express {
184 client.remove_all_objects();
185 }
186
187 let entry = cache
189 .get_block(&cache_key, 0, 0, object_size)
190 .await
191 .expect("cache should be accessible");
192
193 if cleanup_local && cleanup_express {
194 assert!(entry.is_none());
195 } else {
196 assert_eq!(
197 entry.expect("cache entry should be returned"),
198 data,
199 "cache entry returned should match original bytes after put"
200 );
201 }
202 }
203
204 #[tokio::test]
205 async fn test_put_from_express_to_local() {
206 let (_cache_dir, disk_cache) = create_disk_cache();
207 let (client, express_cache) = create_express_cache();
208
209 let data = ChecksummedBytes::new("Foo".into());
210 let object_size = data.len();
211 let cache_key = ObjectId::new("a".into(), ETag::for_tests());
212 express_cache
213 .put_block(cache_key.clone(), 0, 0, data.clone(), object_size)
214 .await
215 .expect("put should succeed");
216
217 let runtime = Runtime::new(ThreadPool::builder().pool_size(1).create().unwrap());
218 let cache = MultilevelDataCache::new(disk_cache, express_cache, runtime);
219
220 let entry = cache
222 .get_block(&cache_key, 0, 0, object_size)
223 .await
224 .expect("cache should be accessible")
225 .expect("cache entry should be returned");
226 assert_eq!(
227 data, entry,
228 "cache entry returned should match original bytes after put"
229 );
230
231 client.remove_all_objects();
233
234 let mut retries = 10;
236 let entry = loop {
237 let entry = cache
238 .get_block(&cache_key, 0, 0, object_size)
239 .await
240 .expect("cache should be accessible");
241 if let Some(entry_data) = entry {
242 break entry_data;
243 }
244 retries -= 1;
245 if retries <= 0 {
246 panic!("entry was not found in the local cache");
247 }
248 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
249 };
250 assert_eq!(
251 data, entry,
252 "cache entry returned should match original bytes after put"
253 );
254 assert_eq!(client.object_count(), 0);
255 }
256
257 #[tokio::test]
258 async fn test_get_from_local() {
259 let (_cache_dir, disk_cache) = create_disk_cache();
260 let (_, express_cache) = create_express_cache();
261
262 let local_data_1 = ChecksummedBytes::new("key in local only".into());
263 let local_data_2 = ChecksummedBytes::new("key in both, right data".into());
264 let express_data = ChecksummedBytes::new("key in both, wrong data".into());
265 let cache_key_in_local = ObjectId::new("key_in_local".into(), ETag::for_tests());
266 let cache_key_in_both = ObjectId::new("key_in_both".into(), ETag::for_tests());
267 disk_cache
269 .put_block(
270 cache_key_in_local.clone(),
271 0,
272 0,
273 local_data_1.clone(),
274 local_data_1.len(),
275 )
276 .await
277 .expect("put should succeed");
278 disk_cache
280 .put_block(
281 cache_key_in_both.clone(),
282 0,
283 0,
284 local_data_2.clone(),
285 local_data_2.len(),
286 )
287 .await
288 .expect("put should succeed");
289 express_cache
290 .put_block(
291 cache_key_in_both.clone(),
292 0,
293 0,
294 express_data.clone(),
295 express_data.len(),
296 )
297 .await
298 .expect("put should succeed");
299
300 let runtime = Runtime::new(ThreadPool::builder().pool_size(1).create().unwrap());
301 let cache = MultilevelDataCache::new(disk_cache, express_cache, runtime);
302
303 let entry = cache
305 .get_block(&cache_key_in_local, 0, 0, local_data_1.len())
306 .await
307 .expect("cache should be accessible")
308 .expect("cache entry should be returned");
309 assert_eq!(
310 local_data_1, entry,
311 "cache entry returned should match original bytes after put"
312 );
313
314 let entry = cache
316 .get_block(&cache_key_in_both, 0, 0, local_data_2.len())
317 .await
318 .expect("cache should be accessible")
319 .expect("cache entry should be returned");
320 assert_eq!(
321 local_data_2, entry,
322 "cache entry returned should match original bytes after put"
323 );
324 }
325
326 #[tokio::test]
327 async fn test_get_from_express() {
328 let (_cache_dir, disk_cache) = create_disk_cache();
329 let (_, express_cache) = create_express_cache();
330
331 let data = ChecksummedBytes::new("Foo".into());
332 let object_size = data.len();
333 let cache_key = ObjectId::new("a".into(), ETag::for_tests());
334 express_cache
335 .put_block(cache_key.clone(), 0, 0, data.clone(), object_size)
336 .await
337 .expect("put should succeed");
338
339 let runtime = Runtime::new(ThreadPool::builder().pool_size(1).create().unwrap());
340 let cache = MultilevelDataCache::new(disk_cache, express_cache, runtime);
341
342 let entry = cache
343 .get_block(&cache_key, 0, 0, object_size)
344 .await
345 .expect("cache should be accessible")
346 .expect("cache entry should be returned");
347 assert_eq!(
348 data, entry,
349 "cache entry returned should match original bytes after put"
350 );
351 }
352
353 #[tokio::test]
354 async fn large_object_bypassed() {
355 let (cache_dir, disk_cache) = create_disk_cache();
356 let (client, express_cache) = create_express_cache();
357 let runtime = Runtime::new(ThreadPool::builder().pool_size(1).create().unwrap());
358 let cache = MultilevelDataCache::new(disk_cache, express_cache, runtime);
359
360 let data = vec![0u8; 1024 * 1024 + 1];
361 let data = ChecksummedBytes::new(data.into());
362 let object_size = data.len();
363 let cache_key = ObjectId::new("a".into(), ETag::for_tests());
364
365 cache
367 .put_block(cache_key.clone(), 0, 0, data.clone(), object_size)
368 .await
369 .expect("put should succeed");
370
371 assert_eq!(client.object_count(), 0, "cache must be empty");
372
373 cache_dir.close().expect("should clean up local cache");
375 let entry = cache
376 .get_block(&cache_key, 0, 0, object_size)
377 .await
378 .expect("cache should be accessible");
379 assert!(entry.is_none(), "cache miss is expected for a large object");
380 }
381}