Skip to main content

async_tiff/metadata/
cache.rs

1//! Caching strategies for metadata fetching.
2
3use std::ops::Range;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use bytes::{Bytes, BytesMut};
8use tokio::sync::Mutex;
9
10use crate::error::AsyncTiffResult;
11use crate::metadata::MetadataFetch;
12
13/// Logic for managing a cache of sequential buffers
14#[derive(Debug)]
15struct SequentialBlockCache {
16    /// Contiguous blocks from offset 0
17    ///
18    /// # Invariant
19    /// - Buffers are contiguous from offset 0
20    buffers: Vec<Bytes>,
21
22    /// Total length cached (== sum of buffers lengths)
23    len: u64,
24}
25
26impl SequentialBlockCache {
27    /// Create a new, empty SequentialBlockCache
28    fn new() -> Self {
29        Self {
30            buffers: vec![],
31            len: 0,
32        }
33    }
34
35    /// Check if the given range is fully contained within the cached buffers
36    fn contains(&self, range: Range<u64>) -> bool {
37        range.end <= self.len
38    }
39
40    /// Slice out the given range from the cached buffers
41    fn slice(&self, range: Range<u64>) -> Bytes {
42        // The size of the output buffer
43        let out_len = (range.end - range.start) as usize;
44
45        // The remaining range of bytes required. This range is updated as we traverse buffers, so
46        // the indexes are relative to the current buffer.
47        let mut remaining = range;
48        let mut out_buffers: Vec<Bytes> = vec![];
49
50        for buf in &self.buffers {
51            let current_buf_len = buf.len() as u64;
52
53            // this block falls entirely before the desired range start
54            if remaining.start >= current_buf_len {
55                remaining.start -= current_buf_len;
56                remaining.end -= current_buf_len;
57                continue;
58            }
59
60            // we slice bytes out of *this* block
61            let start = remaining.start as usize;
62            let length =
63                (remaining.end - remaining.start).min(current_buf_len - remaining.start) as usize;
64            let end = start + length;
65
66            // nothing to take from this block
67            if start == end {
68                continue;
69            }
70
71            let chunk = buf.slice(start..end);
72            out_buffers.push(chunk);
73
74            // consumed some portion; update and potentially break
75            remaining.start = 0;
76            if remaining.end <= current_buf_len {
77                break;
78            }
79            remaining.end -= current_buf_len;
80        }
81
82        if out_buffers.len() == 1 {
83            out_buffers.into_iter().next().unwrap()
84        } else {
85            let mut out = BytesMut::with_capacity(out_len);
86            for b in out_buffers {
87                out.extend_from_slice(&b);
88            }
89            out.into()
90        }
91    }
92
93    fn append_buffer(&mut self, buffer: Bytes) {
94        self.len += buffer.len() as u64;
95        self.buffers.push(buffer);
96    }
97}
98
99/// A MetadataFetch implementation that caches fetched data in exponentially growing chunks,
100/// sequentially from the beginning of the file.
101#[derive(Debug)]
102pub struct ReadaheadMetadataCache<F: MetadataFetch> {
103    inner: F,
104    cache: Arc<Mutex<SequentialBlockCache>>,
105    initial: u64,
106    multiplier: f64,
107}
108
109impl<F: MetadataFetch> ReadaheadMetadataCache<F> {
110    /// Create a new ReadaheadMetadataCache wrapping the given MetadataFetch
111    pub fn new(inner: F) -> Self {
112        Self {
113            inner,
114            cache: Arc::new(Mutex::new(SequentialBlockCache::new())),
115            initial: 32 * 1024,
116            multiplier: 2.0,
117        }
118    }
119
120    /// Access the inner MetadataFetch
121    pub fn inner(&self) -> &F {
122        &self.inner
123    }
124
125    /// Set the initial fetch size in bytes, otherwise defaults to 32 KiB
126    pub fn with_initial_size(mut self, initial: u64) -> Self {
127        self.initial = initial;
128        self
129    }
130
131    /// Set the multiplier for subsequent fetch sizes, otherwise defaults to 2.0
132    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
133        self.multiplier = multiplier;
134        self
135    }
136
137    fn next_fetch_size(&self, existing_len: u64) -> u64 {
138        if existing_len == 0 {
139            self.initial
140        } else {
141            (existing_len as f64 * self.multiplier).round() as u64
142        }
143    }
144}
145
146#[async_trait]
147impl<F: MetadataFetch + Send + Sync> MetadataFetch for ReadaheadMetadataCache<F> {
148    async fn fetch(&self, range: Range<u64>) -> AsyncTiffResult<Bytes> {
149        let mut cache = self.cache.lock().await;
150
151        // First check if we already have the range cached
152        if cache.contains(range.start..range.end) {
153            return Ok(cache.slice(range));
154        }
155
156        // Compute the correct fetch range
157        let start_len = cache.len;
158        let needed = range.end.saturating_sub(start_len);
159        let fetch_size = self.next_fetch_size(start_len).max(needed);
160        let fetch_range = start_len..start_len + fetch_size;
161
162        // Perform the fetch while holding mutex
163        // (this is OK because the mutex is async)
164        let bytes = self.inner.fetch(fetch_range).await?;
165
166        // Now append safely
167        cache.append_buffer(bytes);
168
169        Ok(cache.slice(range))
170    }
171}
172
173#[cfg(test)]
174mod test {
175    use super::*;
176
177    #[derive(Debug)]
178    struct TestFetch {
179        data: Bytes,
180        /// The number of fetches that actually reach the raw Fetch implementation
181        num_fetches: Arc<Mutex<u64>>,
182    }
183
184    impl TestFetch {
185        fn new(data: Bytes) -> Self {
186            Self {
187                data,
188                num_fetches: Arc::new(Mutex::new(0)),
189            }
190        }
191    }
192
193    #[async_trait]
194    impl MetadataFetch for TestFetch {
195        async fn fetch(&self, range: Range<u64>) -> crate::error::AsyncTiffResult<Bytes> {
196            if range.start as usize >= self.data.len() {
197                return Ok(Bytes::new());
198            }
199
200            let end = (range.end as usize).min(self.data.len());
201            let slice = self.data.slice(range.start as _..end);
202            let mut g = self.num_fetches.lock().await;
203            *g += 1;
204            Ok(slice)
205        }
206    }
207
208    #[tokio::test]
209    async fn test_readahead_cache() {
210        let data = Bytes::from_static(b"abcdefghijklmnopqrstuvwxyz");
211        let fetch = TestFetch::new(data.clone());
212        let cache = ReadaheadMetadataCache::new(fetch)
213            .with_initial_size(2)
214            .with_multiplier(3.0);
215
216        // Make initial request
217        let result = cache.fetch(0..2).await.unwrap();
218        assert_eq!(result.as_ref(), b"ab");
219        assert_eq!(*cache.inner.num_fetches.lock().await, 1);
220
221        // Making a request within the cached range should not trigger a new fetch
222        let result = cache.fetch(1..2).await.unwrap();
223        assert_eq!(result.as_ref(), b"b");
224        assert_eq!(*cache.inner.num_fetches.lock().await, 1);
225
226        // Making a request that exceeds the cached range should trigger a new fetch
227        let result = cache.fetch(2..5).await.unwrap();
228        assert_eq!(result.as_ref(), b"cde");
229        assert_eq!(*cache.inner.num_fetches.lock().await, 2);
230
231        // Multiplier should be accurate: initial was 2, next was 6 (2*3), so total cached is now 8
232        let result = cache.fetch(5..8).await.unwrap();
233        assert_eq!(result.as_ref(), b"fgh");
234        assert_eq!(*cache.inner.num_fetches.lock().await, 2);
235
236        // Should work even for fetch range larger than underlying buffer
237        let result = cache.fetch(8..20).await.unwrap();
238        assert_eq!(result.as_ref(), b"ijklmnopqrst");
239        assert_eq!(*cache.inner.num_fetches.lock().await, 3);
240    }
241
242    #[test]
243    fn test_sequential_block_cache_empty_buffers() {
244        let mut cache = SequentialBlockCache::new();
245        cache.append_buffer(Bytes::from_static(b"012"));
246        cache.append_buffer(Bytes::from_static(b""));
247        cache.append_buffer(Bytes::from_static(b"34"));
248        cache.append_buffer(Bytes::from_static(b""));
249        cache.append_buffer(Bytes::from_static(b"5"));
250        cache.append_buffer(Bytes::from_static(b""));
251        cache.append_buffer(Bytes::from_static(b"67"));
252
253        // Range, does it exist, expected slice
254        let test_cases = [
255            (0..3, true, Bytes::from_static(b"012")),
256            (4..7, true, Bytes::from_static(b"456")),
257            (0..8, true, Bytes::from_static(b"01234567")),
258            (6..6, true, Bytes::from_static(b"")),
259            (6..9, false, Bytes::from_static(b"")),
260            (9..9, false, Bytes::from_static(b"")),
261            (8..10, false, Bytes::from_static(b"")),
262        ];
263
264        for (range, exists, expected) in test_cases {
265            assert_eq!(cache.contains(range.clone()), exists);
266            if exists {
267                assert_eq!(cache.slice(range.clone()), expected);
268            }
269        }
270    }
271}