lazy_cache/
lib.rs

1#![deny(unsafe_code)]
2
3use std::{
4    cmp::min,
5    fs::File,
6    io::{self, Read, Seek, SeekFrom, Write},
7    ops::Range,
8    path::Path,
9};
10
11use memmap2::MmapMut;
12
13const EMPTY_RANGE: &[u8] = &[];
14
15pub struct LazyCache<R>
16where
17    R: Read + Seek,
18{
19    source: R,
20    loaded: Vec<bool>,
21    hot_head: Vec<u8>,
22    hot_tail: Vec<u8>,
23    warm: Option<MmapMut>,
24    cold: Vec<u8>,
25    block_size: u64,
26    warm_size: Option<u64>,
27    stream_pos: u64,
28    pos_end: u64,
29}
30
31const BLOCK_SIZE: usize = 4096;
32
33impl<R> Seek for LazyCache<R>
34where
35    R: Read + Seek,
36{
37    #[inline(always)]
38    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
39        self.stream_pos = self.offset_from_start(pos);
40        Ok(self.stream_pos)
41    }
42}
43
44impl LazyCache<File> {
45    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
46        Self::from_read_seek(File::open(path)?)
47    }
48}
49
50impl<R> io::Read for LazyCache<R>
51where
52    R: Read + Seek,
53{
54    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
55        let r = self.inner_read_count(buf.len() as u64)?;
56        for (i, b) in r.iter().enumerate() {
57            buf[i] = *b;
58        }
59        Ok(r.len())
60    }
61}
62
63impl<R> LazyCache<R>
64where
65    R: Read + Seek,
66{
67    pub fn from_read_seek(mut rs: R) -> Result<Self, io::Error> {
68        let block_size = BLOCK_SIZE as u64;
69        let pos_end = rs.seek(SeekFrom::End(0))?;
70        let cache_cap = pos_end.div_ceil(BLOCK_SIZE as u64);
71
72        Ok(Self {
73            source: rs,
74            hot_head: vec![],
75            hot_tail: vec![],
76            warm: None,
77            cold: vec![0; block_size as usize],
78            loaded: vec![false; cache_cap as usize],
79            block_size,
80            warm_size: None,
81            stream_pos: 0,
82            pos_end,
83        })
84    }
85
86    pub fn with_hot_cache(mut self, size: usize) -> Result<Self, io::Error> {
87        let head_tail_size = size / 2;
88
89        self.source.seek(SeekFrom::Start(0))?;
90
91        if self.pos_end > size as u64 {
92            self.hot_head = vec![0u8; head_tail_size];
93            self.source.read_exact(self.hot_head.as_mut_slice())?;
94
95            self.source.seek(SeekFrom::End(-(size as i64)))?;
96            self.hot_tail = vec![0u8; head_tail_size];
97            self.source.read_exact(self.hot_tail.as_mut_slice())?;
98        } else {
99            self.hot_head = vec![0u8; self.pos_end as usize];
100            self.source.read_exact(self.hot_head.as_mut())?;
101        }
102
103        Ok(self)
104    }
105
106    pub fn with_warm_cache(mut self, warm_size: u64) -> Self {
107        self.warm_size = Some(warm_size);
108        self
109    }
110
111    #[inline(always)]
112    pub fn offset_from_start(&self, pos: SeekFrom) -> u64 {
113        match pos {
114            SeekFrom::Start(s) => s,
115            SeekFrom::Current(p) => (self.stream_pos as i128 + p as i128) as u64,
116            SeekFrom::End(e) => (self.pos_end as i128 + e as i128) as u64,
117        }
118    }
119
120    #[inline(always)]
121    pub fn lazy_stream_position(&self) -> u64 {
122        self.stream_pos
123    }
124
125    #[inline(always)]
126    fn warm(&mut self) -> Result<&mut MmapMut, io::Error> {
127        if self.warm.is_none() && self.warm_size.is_some() {
128            self.warm = Some(MmapMut::map_anon(
129                self.warm_size.unwrap_or_default() as usize
130            )?);
131        }
132        Ok(self.warm.as_mut().unwrap())
133    }
134
135    #[inline(always)]
136    fn range_warmup(&mut self, range: Range<u64>) -> Result<(), io::Error> {
137        let start_chunk_id = range.start / self.block_size;
138        let end_chunk_id = (range.end.saturating_sub(1)) / self.block_size;
139
140        if self.loaded.is_empty() {
141            return Ok(());
142        }
143
144        for chunk_id in start_chunk_id..=end_chunk_id {
145            if self.loaded[chunk_id as usize] {
146                continue;
147            }
148
149            let offset = chunk_id * self.block_size;
150            let buf_size = min(
151                self.block_size as usize,
152                (self.pos_end.saturating_sub(offset)) as usize,
153            );
154            let mut buf = vec![0u8; buf_size];
155            self.source.seek(SeekFrom::Start(offset))?;
156            self.source.read_exact(&mut buf)?;
157
158            (&mut self.warm()?[offset as usize..]).write_all(&buf)?;
159            self.loaded[chunk_id as usize] = true;
160        }
161
162        Ok(())
163    }
164
165    #[inline(always)]
166    fn get_range_u64(&mut self, range: Range<u64>) -> Result<&[u8], io::Error> {
167        // we fix range in case we attempt at reading beyond end of file
168        let range = if range.end > self.pos_end {
169            range.start..self.pos_end
170        } else {
171            range
172        };
173
174        let range_len = range.end.saturating_sub(range.start);
175
176        if range.start > self.pos_end || range_len == 0 {
177            return Ok(EMPTY_RANGE);
178        } else if range.start < self.hot_head.len() as u64
179            && range.end <= self.hot_head.len() as u64
180        {
181            self.seek(SeekFrom::Start(range.end))?;
182            return Ok(&self.hot_head[range.start as usize..range.end as usize]);
183        } else if range.start > (self.pos_end - self.hot_tail.len() as u64) {
184            let start_from_end = self.pos_end.saturating_sub(1).saturating_sub(range.start);
185            self.seek(SeekFrom::Start(range.end))?;
186            return Ok(&self.hot_tail
187                [start_from_end as usize..start_from_end.saturating_add(range_len) as usize]);
188        } else if range.end < self.warm_size.unwrap_or_default() {
189            self.range_warmup(range.clone())?;
190            self.seek(SeekFrom::Start(range.end))?;
191            return Ok(&self.warm()?[range.start as usize..range.end as usize]);
192        } else if range_len > self.cold.len() as u64 {
193            self.cold.resize(range_len as usize, 0);
194        }
195
196        self.source.seek(SeekFrom::Start(range.start))?;
197        let n = self.source.read(self.cold[..range_len as usize].as_mut())?;
198        self.seek(SeekFrom::Start(range.end))?;
199        Ok(&self.cold[..n])
200    }
201
202    pub fn read_range(&mut self, range: Range<u64>) -> Result<&[u8], io::Error> {
203        let range = range.start..range.end;
204        self.get_range_u64(range)
205    }
206
207    #[inline(always)]
208    fn inner_read_count(&mut self, count: u64) -> Result<&[u8], io::Error> {
209        let pos = self.stream_pos;
210        let range = pos..(pos.saturating_add(count));
211        self.get_range_u64(range)
212    }
213
214    /// Read at current reader position and return byte slice
215    pub fn read_count(&mut self, count: u64) -> Result<&[u8], io::Error> {
216        self.inner_read_count(count)
217    }
218
219    pub fn read_exact_range(&mut self, range: Range<u64>) -> Result<&[u8], io::Error> {
220        let range_len = range.end - range.start;
221        let b = self.read_range(range)?;
222        if b.len() as u64 != range_len {
223            Err(io::Error::from(io::ErrorKind::UnexpectedEof))
224        } else {
225            Ok(b)
226        }
227    }
228
229    pub fn read_exact_count(&mut self, count: u64) -> Result<&[u8], io::Error> {
230        let b = self.read_count(count)?;
231        debug_assert!(b.len() <= count as usize);
232        if b.len() as u64 != count {
233            Err(io::ErrorKind::UnexpectedEof.into())
234        } else {
235            Ok(b)
236        }
237    }
238
239    pub fn read_exact_into(&mut self, buf: &mut [u8]) -> Result<(), io::Error> {
240        let read = self.read_exact_count(buf.len() as u64)?;
241        // this function call should not panic as read_exact
242        // guarantees we read exactly the length of buf
243        buf.copy_from_slice(read);
244        Ok(())
245    }
246
247    pub fn read_until_any_delim_or_limit(
248        &mut self,
249        delims: &[u8],
250        limit: u64,
251    ) -> Result<&[u8], io::Error> {
252        self._read_while_or_limit(|b| !delims.contains(&b), limit, true)
253    }
254
255    pub fn read_until_or_limit(&mut self, byte: u8, limit: u64) -> Result<&[u8], io::Error> {
256        self._read_while_or_limit(|b| b != byte, limit, true)
257    }
258
259    // reads while f returns true or we reach limit
260    #[inline(always)]
261    fn _read_while_or_limit<F>(
262        &mut self,
263        f: F,
264        limit: u64,
265        include_last: bool,
266    ) -> Result<&[u8], io::Error>
267    where
268        F: Fn(u8) -> bool,
269    {
270        let start = self.stream_pos;
271        let mut end = 0;
272
273        'outer: while limit - end > 0 {
274            let buf = self.read_count(self.block_size)?;
275
276            for b in buf {
277                if limit - end == 0 {
278                    break 'outer;
279                }
280
281                if !f(*b) {
282                    if include_last {
283                        end += 1;
284                    }
285                    // read_until includes delimiter
286                    break 'outer;
287                }
288
289                end += 1;
290            }
291
292            // we processed last chunk
293            if buf.len() as u64 != self.block_size {
294                break;
295            }
296        }
297
298        self.read_exact_range(start..start + end)
299    }
300
301    pub fn read_while_or_limit<F>(&mut self, f: F, limit: u64) -> Result<&[u8], io::Error>
302    where
303        F: Fn(u8) -> bool,
304    {
305        self._read_while_or_limit(f, limit, false)
306    }
307
308    // limit is expressed in numbers of utf16 chars
309    pub fn read_until_utf16_or_limit(
310        &mut self,
311        utf16_char: &[u8; 2],
312        limit: u64,
313    ) -> Result<&[u8], io::Error> {
314        let start = self.stream_pos;
315        let mut end = 0;
316
317        let even_bs = if self.block_size.is_multiple_of(2) {
318            self.block_size
319        } else {
320            self.block_size.saturating_add(1)
321        };
322
323        'outer: while limit.saturating_sub(end) > 0 {
324            let buf = self.read_count(even_bs)?;
325
326            let even = buf
327                .iter()
328                .enumerate()
329                .filter(|(i, _)| i % 2 == 0)
330                .map(|t| t.1);
331
332            let odd = buf
333                .iter()
334                .enumerate()
335                .filter(|(i, _)| i % 2 != 0)
336                .map(|t| t.1);
337
338            for t in even.zip(odd) {
339                if limit.saturating_sub(end) == 0 {
340                    break 'outer;
341                }
342
343                end += 2;
344
345                // tail check
346                if t.0 == &utf16_char[0] && t.1 == &utf16_char[1] {
347                    // we include char
348                    break 'outer;
349                }
350            }
351
352            // we processed the last chunk
353            if buf.len() as u64 != even_bs {
354                // if we arrive here we reached end of file
355                if buf.len() % 2 != 0 {
356                    // we include last byte missed by zip
357                    end += 1
358                }
359                break;
360            }
361        }
362
363        self.read_exact_range(start..start + end)
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    macro_rules! lazy_cache {
372        ($content: literal) => {
373            LazyCache::from_read_seek(std::io::Cursor::new($content)).unwrap()
374        };
375    }
376
377    #[test]
378    fn test_get_single_block() {
379        let mut cache = lazy_cache!(b"hello world");
380        let data = cache.read_range(0..4).unwrap();
381        assert_eq!(data, b"hell");
382    }
383
384    #[test]
385    fn test_get_across_blocks() {
386        let mut cache = lazy_cache!(b"hello world");
387        let data = cache.read_range(2..7).unwrap();
388        assert_eq!(data, b"llo w");
389    }
390
391    #[test]
392    fn test_get_entire_file() {
393        let mut cache = lazy_cache!(b"hello world");
394        let data = cache.read_range(0..11).unwrap();
395        assert_eq!(data, b"hello world");
396    }
397
398    #[test]
399    fn test_get_empty_range() {
400        let mut cache = lazy_cache!(b"hello world");
401        let data = cache.read_range(0..0).unwrap();
402        assert!(data.is_empty());
403    }
404
405    #[test]
406    fn test_get_out_of_bounds() {
407        let mut cache = lazy_cache!(b"hello world");
408        // This should not panic, but return an error or empty slice depending on your design
409        // Currently, your code will panic due to `unwrap()` on `None`
410        // You may want to handle this case more gracefully
411        assert!(cache.read_range(20..30).unwrap().is_empty());
412    }
413
414    #[test]
415    fn test_cache_eviction() {
416        let mut cache = lazy_cache!(b"0123456789abcdef");
417        // Load blocks 0 and 1
418        let _ = cache.read_range(0..8).unwrap();
419        // Load block 2, which should evict block 0 or 1 due to max_size=8
420        let _ = cache.read_range(8..12).unwrap();
421        // Check that the cache still works
422        let data = cache.read_range(8..12).unwrap();
423        assert_eq!(data, b"89ab");
424    }
425
426    #[test]
427    fn test_chunk_consolidation() {
428        let mut cache = lazy_cache!(b"0123456789abcdef");
429        // Load blocks 0 and 1 separately
430        let _ = cache.read_range(0..4).unwrap();
431        let _ = cache.read_range(4..8).unwrap();
432        // Load block 2, which should not consolidate with 0 or 1
433        let _ = cache.read_range(8..12).unwrap();
434        // Now load block 1 again, which should consolidate with block 0
435        let _ = cache.read_range(2..6).unwrap();
436        // Check that the consolidated chunk is correct
437        let data = cache.read_range(0..8).unwrap();
438        assert_eq!(data, b"01234567");
439    }
440
441    #[test]
442    fn test_overlapping_ranges() {
443        let mut cache = lazy_cache!(b"0123456789abcdef");
444        // Load overlapping ranges
445        let _ = cache.read_range(2..6).unwrap();
446        let _ = cache.read_range(4..10).unwrap();
447        // Check that the data is correct
448        let data = cache.read_range(2..10).unwrap();
449        assert_eq!(data, b"23456789");
450    }
451
452    #[test]
453    fn test_lru_behavior() {
454        let mut cache = lazy_cache!(b"0123456789abcdef");
455        // Load block 0
456        let _ = cache.read_range(0..4).unwrap();
457        // Load block 1
458        let _ = cache.read_range(4..8).unwrap();
459        // Load block 2, which should evict block 0
460        let _ = cache.read_range(8..12).unwrap();
461        // Block 0 should be evicted, so accessing it again should reload it
462        let data = cache.read_range(0..4).unwrap();
463        assert_eq!(data, b"0123");
464    }
465
466    #[test]
467    fn test_small_block_size() {
468        let mut cache = lazy_cache!(b"abc");
469        let data = cache.read_range(0..3).unwrap();
470        assert_eq!(data, b"abc");
471    }
472
473    #[test]
474    fn test_large_block_size() {
475        let mut cache = lazy_cache!(b"hello world");
476        let data = cache.read_range(0..11).unwrap();
477        assert_eq!(data, b"hello world");
478    }
479
480    #[test]
481    fn test_file_smaller_than_block() {
482        let mut cache = lazy_cache!(b"abc");
483        let data = cache.read_range(0..3).unwrap();
484        assert_eq!(data, b"abc");
485    }
486
487    #[test]
488    fn test_multiple_gets_same_block() {
489        let mut cache = lazy_cache!(b"0123456789abcdef");
490        // Get the same block multiple times
491        let _ = cache.read_range(0..4).unwrap();
492        let _ = cache.read_range(0..4).unwrap();
493        let _ = cache.read_range(0..4).unwrap();
494        // The block should still be in the cache
495        let data = cache.read_range(0..4).unwrap();
496        assert_eq!(data, b"0123");
497    }
498
499    #[test]
500    fn test_read_method() {
501        let mut cache = lazy_cache!(b"hello world");
502        let _ = cache.read_count(6).unwrap();
503        let data = cache.read_count(5).unwrap();
504        assert_eq!(data, b"world");
505        // We reached the end so next read should bring an empty slice
506        assert!(cache.read_count(1).unwrap().is_empty());
507    }
508
509    #[test]
510    fn test_read_empty() {
511        let mut cache = lazy_cache!(b"hello world");
512        let data = cache.read_count(0).unwrap();
513        assert!(data.is_empty());
514    }
515
516    #[test]
517    fn test_read_beyond_end() {
518        let mut cache = lazy_cache!(b"hello world");
519        let _ = cache.read_count(11).unwrap();
520        let data = cache.read_count(5).unwrap();
521        assert!(data.is_empty());
522    }
523
524    #[test]
525    fn test_read_exact_range() {
526        let mut cache = lazy_cache!(b"hello world");
527        let data = cache.read_exact_range(0..5).unwrap();
528        assert_eq!(data, b"hello");
529        assert_eq!(cache.read_exact_range(5..11).unwrap(), b" world");
530        assert!(cache.read_exact_range(12..13).is_err());
531    }
532
533    #[test]
534    fn test_read_exact_range_error() {
535        let mut cache = lazy_cache!(b"hello world");
536        let result = cache.read_exact_range(0..20);
537        assert!(result.is_err());
538    }
539
540    #[test]
541    fn test_read_exact() {
542        let mut cache = lazy_cache!(b"hello world");
543        let data = cache.read_exact_count(5).unwrap();
544        assert_eq!(data, b"hello");
545        assert_eq!(cache.read_exact_count(6).unwrap(), b" world");
546        assert!(cache.read_exact_count(0).is_ok());
547        assert!(cache.read_exact_count(1).is_err());
548    }
549
550    #[test]
551    fn test_read_exact_error() {
552        let mut cache = lazy_cache!(b"hello world");
553        let result = cache.read_exact_count(20);
554        assert!(result.is_err());
555    }
556
557    #[test]
558    fn test_read_until_limit() {
559        let mut cache = lazy_cache!(b"hello world");
560        let data = cache.read_until_or_limit(b' ', 10).unwrap();
561        assert_eq!(data, b"hello ");
562        assert_eq!(cache.read_exact_count(5).unwrap(), b"world");
563    }
564
565    #[test]
566    fn test_read_until_limit_not_found() {
567        let mut cache = lazy_cache!(b"hello world");
568        let data = cache.read_until_or_limit(b'\n', 11).unwrap();
569        assert_eq!(data, b"hello world");
570        assert!(cache.read_count(1).unwrap().is_empty());
571    }
572
573    #[test]
574    fn test_read_until_limit_beyond_stream() {
575        let mut cache = lazy_cache!(b"hello world");
576        let data = cache.read_until_or_limit(b'\n', 42).unwrap();
577        assert_eq!(data, b"hello world");
578        assert!(cache.read_count(1).unwrap().is_empty());
579    }
580
581    #[test]
582    fn test_read_until_limit_with_limit() {
583        let mut cache = lazy_cache!(b"hello world");
584        let data = cache.read_until_or_limit(b' ', 42).unwrap();
585        assert_eq!(data, b"hello ");
586
587        let data = cache.read_until_or_limit(b' ', 2).unwrap();
588        assert_eq!(data, b"wo");
589
590        let data = cache.read_until_or_limit(b' ', 42).unwrap();
591        assert_eq!(data, b"rld");
592    }
593
594    #[test]
595    fn test_read_until_utf16_limit() {
596        let mut cache = lazy_cache!(
597            b"\x61\x00\x62\x00\x63\x00\x64\x00\x00\x00\x61\x00\x62\x00\x63\x00\x64\x00\x00"
598        );
599        let data = cache.read_until_utf16_or_limit(b"\x00\x00", 512).unwrap();
600        assert_eq!(data, b"\x61\x00\x62\x00\x63\x00\x64\x00\x00\x00");
601
602        let data = cache.read_until_utf16_or_limit(b"\x00\x00", 1).unwrap();
603        assert_eq!(data, b"\x61\x00");
604
605        assert_eq!(
606            cache.read_until_utf16_or_limit(b"\xff\xff", 64).unwrap(),
607            b"\x62\x00\x63\x00\x64\x00\x00"
608        );
609    }
610}