lazy_cache/
lib.rs

1#![deny(unsafe_code)]
2
3use std::{
4    cmp::{max, min},
5    fs::File,
6    io::{self, Read, Seek, SeekFrom, Write},
7    ops::Range,
8    path::Path,
9};
10
11use memmap2::MmapMut;
12
13const EMPTY_RANGE: &[u8] = &[];
14
15pub struct LazyCache<R>
16where
17    R: Read + Seek,
18{
19    source: R,
20    loaded: Vec<bool>,
21    hot_head: Vec<u8>,
22    hot_tail: Vec<u8>,
23    warm: Option<MmapMut>,
24    cold: Vec<u8>,
25    block_size: u64,
26    warm_size: Option<u64>,
27    stream_pos: u64,
28    pos_end: u64,
29}
30
31const BLOCK_SIZE: usize = 4096;
32
33impl<R> Seek for LazyCache<R>
34where
35    R: Read + Seek,
36{
37    #[inline(always)]
38    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
39        self.stream_pos = self.offset_from_start(pos);
40        Ok(self.stream_pos)
41    }
42}
43
44impl LazyCache<File> {
45    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
46        Self::from_read_seek(File::open(path)?)
47    }
48}
49
50impl<R> io::Read for LazyCache<R>
51where
52    R: Read + Seek,
53{
54    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
55        let r = self.inner_read_count(buf.len() as u64)?;
56        for (i, b) in r.iter().enumerate() {
57            buf[i] = *b;
58        }
59        Ok(r.len())
60    }
61}
62
63impl<R> LazyCache<R>
64where
65    R: Read + Seek,
66{
67    pub fn from_read_seek(mut rs: R) -> Result<Self, io::Error> {
68        let block_size = BLOCK_SIZE as u64;
69        let pos_end = rs.seek(SeekFrom::End(0))?;
70        let cache_cap = pos_end.div_ceil(BLOCK_SIZE as u64);
71
72        Ok(Self {
73            source: rs,
74            hot_head: vec![],
75            hot_tail: vec![],
76            warm: None,
77            cold: vec![0; block_size as usize],
78            loaded: vec![false; cache_cap as usize],
79            block_size,
80            warm_size: None,
81            stream_pos: 0,
82            pos_end,
83        })
84    }
85
86    pub fn with_hot_cache(mut self, size: usize) -> Result<Self, io::Error> {
87        let head_tail_size = size / 2;
88
89        self.source.seek(SeekFrom::Start(0))?;
90
91        if self.pos_end > size as u64 {
92            self.hot_head = vec![0u8; head_tail_size];
93            self.source.read_exact(self.hot_head.as_mut_slice())?;
94
95            self.source.seek(SeekFrom::End(-(head_tail_size as i64)))?;
96            self.hot_tail = vec![0u8; head_tail_size];
97            self.source.read_exact(self.hot_tail.as_mut_slice())?;
98        } else {
99            self.hot_head = vec![0u8; self.pos_end as usize];
100            self.source.read_exact(self.hot_head.as_mut())?;
101        }
102
103        Ok(self)
104    }
105
106    pub fn with_warm_cache(mut self, mut warm_size: u64) -> Self {
107        // if warm_size is smaller than block_size we will not
108        // be able to write chunks into the warm cache
109        warm_size = max(warm_size, self.block_size);
110        self.warm_size = Some(warm_size);
111        self
112    }
113
114    #[inline(always)]
115    pub fn offset_from_start(&self, pos: SeekFrom) -> u64 {
116        match pos {
117            SeekFrom::Start(s) => s,
118            SeekFrom::Current(p) => (self.stream_pos as i128 + p as i128) as u64,
119            SeekFrom::End(e) => (self.pos_end as i128 + e as i128) as u64,
120        }
121    }
122
123    #[inline(always)]
124    pub fn lazy_stream_position(&self) -> u64 {
125        self.stream_pos
126    }
127
128    #[inline(always)]
129    fn warm(&mut self) -> Result<&mut MmapMut, io::Error> {
130        if self.warm.is_none() && self.warm_size.is_some() {
131            self.warm = Some(MmapMut::map_anon(
132                self.warm_size.unwrap_or_default() as usize
133            )?);
134        }
135        Ok(self.warm.as_mut().unwrap())
136    }
137
138    #[inline(always)]
139    fn range_warmup(&mut self, range: Range<u64>) -> Result<(), io::Error> {
140        let start_chunk_id = range.start / self.block_size;
141        let end_chunk_id = (range.end.saturating_sub(1)) / self.block_size;
142
143        if self.loaded.is_empty() {
144            return Ok(());
145        }
146
147        for chunk_id in start_chunk_id..=end_chunk_id {
148            if self.loaded[chunk_id as usize] {
149                continue;
150            }
151
152            let offset = chunk_id * self.block_size;
153            let buf_size = min(
154                self.block_size as usize,
155                (self.pos_end.saturating_sub(offset)) as usize,
156            );
157            let mut buf = vec![0u8; buf_size];
158            self.source.seek(SeekFrom::Start(offset))?;
159            self.source.read_exact(&mut buf)?;
160
161            (&mut self.warm()?[offset as usize..]).write_all(&buf)?;
162            self.loaded[chunk_id as usize] = true;
163        }
164
165        Ok(())
166    }
167
168    #[inline(always)]
169    fn get_range_u64(&mut self, range: Range<u64>) -> Result<&[u8], io::Error> {
170        // we fix range in case we attempt at reading beyond end of file
171        let range = if range.end > self.pos_end {
172            range.start..self.pos_end
173        } else {
174            range
175        };
176
177        let range_len = range.end.saturating_sub(range.start);
178
179        if range.start > self.pos_end || range_len == 0 {
180            Ok(EMPTY_RANGE)
181        } else if range.start < self.hot_head.len() as u64
182            && range.end <= self.hot_head.len() as u64
183        {
184            self.seek(SeekFrom::Start(range.end))?;
185
186            Ok(&self.hot_head[range.start as usize..range.end as usize])
187        } else if range.start >= (self.pos_end.saturating_sub(self.hot_tail.len() as u64)) {
188            let tail_base = self.pos_end.saturating_sub(self.hot_tail.len() as u64);
189
190            let start = range.start - tail_base;
191            let end = range.end - tail_base;
192
193            self.seek(SeekFrom::Start(range.end))?;
194
195            Ok(&self.hot_tail[start as usize..end as usize])
196        } else if range.end < self.warm_size.unwrap_or_default() {
197            self.range_warmup(range.clone())?;
198            self.seek(SeekFrom::Start(range.end))?;
199
200            Ok(&self.warm()?[range.start as usize..range.end as usize])
201        } else {
202            if range_len > self.cold.len() as u64 {
203                self.cold.resize(range_len as usize, 0);
204            }
205
206            self.source.seek(SeekFrom::Start(range.start))?;
207            let n = self.source.read(self.cold[..range_len as usize].as_mut())?;
208            self.seek(SeekFrom::Start(range.end))?;
209
210            Ok(&self.cold[..n])
211        }
212    }
213
214    pub fn read_range(&mut self, range: Range<u64>) -> Result<&[u8], io::Error> {
215        let range = range.start..range.end;
216        self.get_range_u64(range)
217    }
218
219    #[inline(always)]
220    fn inner_read_count(&mut self, count: u64) -> Result<&[u8], io::Error> {
221        let pos = self.stream_pos;
222        let range = pos..(pos.saturating_add(count));
223        self.get_range_u64(range)
224    }
225
226    /// Read at current reader position and return byte slice
227    pub fn read_count(&mut self, count: u64) -> Result<&[u8], io::Error> {
228        self.inner_read_count(count)
229    }
230
231    pub fn read_exact_range(&mut self, range: Range<u64>) -> Result<&[u8], io::Error> {
232        let range_len = range.end - range.start;
233        let b = self.read_range(range)?;
234        if b.len() as u64 != range_len {
235            Err(io::Error::from(io::ErrorKind::UnexpectedEof))
236        } else {
237            Ok(b)
238        }
239    }
240
241    pub fn read_exact_count(&mut self, count: u64) -> Result<&[u8], io::Error> {
242        let b = self.read_count(count)?;
243        debug_assert!(b.len() <= count as usize);
244        if b.len() as u64 != count {
245            Err(io::ErrorKind::UnexpectedEof.into())
246        } else {
247            Ok(b)
248        }
249    }
250
251    pub fn read_exact_into(&mut self, buf: &mut [u8]) -> Result<(), io::Error> {
252        let read = self.read_exact_count(buf.len() as u64)?;
253        // this function call should not panic as read_exact
254        // guarantees we read exactly the length of buf
255        buf.copy_from_slice(read);
256        Ok(())
257    }
258
259    pub fn read_until_any_delim_or_limit(
260        &mut self,
261        delims: &[u8],
262        limit: u64,
263    ) -> Result<&[u8], io::Error> {
264        self._read_while_or_limit(|b| !delims.contains(&b), limit, true)
265    }
266
267    pub fn read_until_or_limit(&mut self, byte: u8, limit: u64) -> Result<&[u8], io::Error> {
268        self._read_while_or_limit(|b| b != byte, limit, true)
269    }
270
271    // reads while f returns true or we reach limit
272    #[inline(always)]
273    fn _read_while_or_limit<F>(
274        &mut self,
275        f: F,
276        limit: u64,
277        include_last: bool,
278    ) -> Result<&[u8], io::Error>
279    where
280        F: Fn(u8) -> bool,
281    {
282        let start = self.stream_pos;
283        let mut end = 0;
284
285        'outer: while limit - end > 0 {
286            let buf = self.read_count(self.block_size)?;
287
288            for b in buf {
289                if limit - end == 0 {
290                    break 'outer;
291                }
292
293                if !f(*b) {
294                    if include_last {
295                        end += 1;
296                    }
297                    // read_until includes delimiter
298                    break 'outer;
299                }
300
301                end += 1;
302            }
303
304            // we processed last chunk
305            if buf.len() as u64 != self.block_size {
306                break;
307            }
308        }
309
310        self.read_exact_range(start..start + end)
311    }
312
313    pub fn read_while_or_limit<F>(&mut self, f: F, limit: u64) -> Result<&[u8], io::Error>
314    where
315        F: Fn(u8) -> bool,
316    {
317        self._read_while_or_limit(f, limit, false)
318    }
319
320    // limit is expressed in numbers of utf16 chars
321    pub fn read_until_utf16_or_limit(
322        &mut self,
323        utf16_char: &[u8; 2],
324        limit: u64,
325    ) -> Result<&[u8], io::Error> {
326        let start = self.stream_pos;
327        let mut end = 0;
328
329        let even_bs = if self.block_size.is_multiple_of(2) {
330            self.block_size
331        } else {
332            self.block_size.saturating_add(1)
333        };
334
335        'outer: while limit.saturating_sub(end) > 0 {
336            let buf = self.read_count(even_bs)?;
337
338            let even = buf
339                .iter()
340                .enumerate()
341                .filter(|(i, _)| i % 2 == 0)
342                .map(|t| t.1);
343
344            let odd = buf
345                .iter()
346                .enumerate()
347                .filter(|(i, _)| i % 2 != 0)
348                .map(|t| t.1);
349
350            for t in even.zip(odd) {
351                if limit.saturating_sub(end) == 0 {
352                    break 'outer;
353                }
354
355                end += 2;
356
357                // tail check
358                if t.0 == &utf16_char[0] && t.1 == &utf16_char[1] {
359                    // we include char
360                    break 'outer;
361                }
362            }
363
364            // we processed the last chunk
365            if buf.len() as u64 != even_bs {
366                // if we arrive here we reached end of file
367                if buf.len() % 2 != 0 {
368                    // we include last byte missed by zip
369                    end += 1
370                }
371                break;
372            }
373        }
374
375        self.read_exact_range(start..start + end)
376    }
377
378    /// Returns the size of the whole data. If the `LazyCache` is
379    /// used with a [`std::fs::File`] reader, this method will return
380    /// the file size.
381    #[inline(always)]
382    pub fn data_size(&self) -> u64 {
383        self.pos_end
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use std::os::unix::fs::MetadataExt;
390
391    use super::*;
392
393    macro_rules! lazy_cache {
394        ($content: literal) => {
395            LazyCache::from_read_seek(std::io::Cursor::new($content)).unwrap()
396        };
397    }
398
399    /// reads io::Reader `r` by chunks of size `cs` until the end
400    macro_rules! read_to_end {
401        ($r: expr, $cs: literal) => {{
402            let mut buf = [0u8; $cs];
403            let mut out: Vec<u8> = vec![];
404            while let Ok(n) = $r.read(&mut buf[..]) {
405                if n == 0 {
406                    break;
407                }
408                out.extend(&buf[..n]);
409            }
410            out
411        }};
412    }
413
414    #[test]
415    fn test_get_single_block() {
416        let mut cache = lazy_cache!(b"hello world");
417        let data = cache.read_range(0..4).unwrap();
418        assert_eq!(data, b"hell");
419    }
420
421    #[test]
422    fn test_get_across_blocks() {
423        let mut cache = lazy_cache!(b"hello world");
424        let data = cache.read_range(2..7).unwrap();
425        assert_eq!(data, b"llo w");
426    }
427
428    #[test]
429    fn test_get_entire_file() {
430        let mut cache = lazy_cache!(b"hello world");
431        let data = cache.read_range(0..11).unwrap();
432        assert_eq!(data, b"hello world");
433    }
434
435    #[test]
436    fn test_get_empty_range() {
437        let mut cache = lazy_cache!(b"hello world");
438        let data = cache.read_range(0..0).unwrap();
439        assert!(data.is_empty());
440    }
441
442    #[test]
443    fn test_get_out_of_bounds() {
444        let mut cache = lazy_cache!(b"hello world");
445        // This should not panic, but return an error or empty slice depending on your design
446        // Currently, your code will panic due to `unwrap()` on `None`
447        // You may want to handle this case more gracefully
448        assert!(cache.read_range(20..30).unwrap().is_empty());
449    }
450
451    #[test]
452    fn test_cache_eviction() {
453        let mut cache = lazy_cache!(b"0123456789abcdef");
454        // Load blocks 0 and 1
455        let _ = cache.read_range(0..8).unwrap();
456        // Load block 2, which should evict block 0 or 1 due to max_size=8
457        let _ = cache.read_range(8..12).unwrap();
458        // Check that the cache still works
459        let data = cache.read_range(8..12).unwrap();
460        assert_eq!(data, b"89ab");
461    }
462
463    #[test]
464    fn test_chunk_consolidation() {
465        let mut cache = lazy_cache!(b"0123456789abcdef");
466        // Load blocks 0 and 1 separately
467        let _ = cache.read_range(0..4).unwrap();
468        let _ = cache.read_range(4..8).unwrap();
469        // Load block 2, which should not consolidate with 0 or 1
470        let _ = cache.read_range(8..12).unwrap();
471        // Now load block 1 again, which should consolidate with block 0
472        let _ = cache.read_range(2..6).unwrap();
473        // Check that the consolidated chunk is correct
474        let data = cache.read_range(0..8).unwrap();
475        assert_eq!(data, b"01234567");
476    }
477
478    #[test]
479    fn test_overlapping_ranges() {
480        let mut cache = lazy_cache!(b"0123456789abcdef");
481        // Load overlapping ranges
482        let _ = cache.read_range(2..6).unwrap();
483        let _ = cache.read_range(4..10).unwrap();
484        // Check that the data is correct
485        let data = cache.read_range(2..10).unwrap();
486        assert_eq!(data, b"23456789");
487    }
488
489    #[test]
490    fn test_lru_behavior() {
491        let mut cache = lazy_cache!(b"0123456789abcdef");
492        // Load block 0
493        let _ = cache.read_range(0..4).unwrap();
494        // Load block 1
495        let _ = cache.read_range(4..8).unwrap();
496        // Load block 2, which should evict block 0
497        let _ = cache.read_range(8..12).unwrap();
498        // Block 0 should be evicted, so accessing it again should reload it
499        let data = cache.read_range(0..4).unwrap();
500        assert_eq!(data, b"0123");
501    }
502
503    #[test]
504    fn test_small_block_size() {
505        let mut cache = lazy_cache!(b"abc");
506        let data = cache.read_range(0..3).unwrap();
507        assert_eq!(data, b"abc");
508    }
509
510    #[test]
511    fn test_large_block_size() {
512        let mut cache = lazy_cache!(b"hello world");
513        let data = cache.read_range(0..11).unwrap();
514        assert_eq!(data, b"hello world");
515    }
516
517    #[test]
518    fn test_file_smaller_than_block() {
519        let mut cache = lazy_cache!(b"abc");
520        let data = cache.read_range(0..3).unwrap();
521        assert_eq!(data, b"abc");
522    }
523
524    #[test]
525    fn test_multiple_gets_same_block() {
526        let mut cache = lazy_cache!(b"0123456789abcdef");
527        // Get the same block multiple times
528        let _ = cache.read_range(0..4).unwrap();
529        let _ = cache.read_range(0..4).unwrap();
530        let _ = cache.read_range(0..4).unwrap();
531        // The block should still be in the cache
532        let data = cache.read_range(0..4).unwrap();
533        assert_eq!(data, b"0123");
534    }
535
536    #[test]
537    fn test_read_method() {
538        let mut cache = lazy_cache!(b"hello world");
539        let _ = cache.read_count(6).unwrap();
540        let data = cache.read_count(5).unwrap();
541        assert_eq!(data, b"world");
542        // We reached the end so next read should bring an empty slice
543        assert!(cache.read_count(1).unwrap().is_empty());
544    }
545
546    #[test]
547    fn test_read_empty() {
548        let mut cache = lazy_cache!(b"hello world");
549        let data = cache.read_count(0).unwrap();
550        assert!(data.is_empty());
551    }
552
553    #[test]
554    fn test_read_beyond_end() {
555        let mut cache = lazy_cache!(b"hello world");
556        let _ = cache.read_count(11).unwrap();
557        let data = cache.read_count(5).unwrap();
558        assert!(data.is_empty());
559    }
560
561    #[test]
562    fn test_read_exact_range() {
563        let mut cache = lazy_cache!(b"hello world");
564        let data = cache.read_exact_range(0..5).unwrap();
565        assert_eq!(data, b"hello");
566        assert_eq!(cache.read_exact_range(5..11).unwrap(), b" world");
567        assert!(cache.read_exact_range(12..13).is_err());
568    }
569
570    #[test]
571    fn test_read_exact_range_error() {
572        let mut cache = lazy_cache!(b"hello world");
573        let result = cache.read_exact_range(0..20);
574        assert!(result.is_err());
575    }
576
577    #[test]
578    fn test_read_exact() {
579        let mut cache = lazy_cache!(b"hello world");
580        let data = cache.read_exact_count(5).unwrap();
581        assert_eq!(data, b"hello");
582        assert_eq!(cache.read_exact_count(6).unwrap(), b" world");
583        assert!(cache.read_exact_count(0).is_ok());
584        assert!(cache.read_exact_count(1).is_err());
585    }
586
587    #[test]
588    fn test_read_exact_error() {
589        let mut cache = lazy_cache!(b"hello world");
590        let result = cache.read_exact_count(20);
591        assert!(result.is_err());
592    }
593
594    #[test]
595    fn test_read_until_limit() {
596        let mut cache = lazy_cache!(b"hello world");
597        let data = cache.read_until_or_limit(b' ', 10).unwrap();
598        assert_eq!(data, b"hello ");
599        assert_eq!(cache.read_exact_count(5).unwrap(), b"world");
600    }
601
602    #[test]
603    fn test_read_until_limit_not_found() {
604        let mut cache = lazy_cache!(b"hello world");
605        let data = cache.read_until_or_limit(b'\n', 11).unwrap();
606        assert_eq!(data, b"hello world");
607        assert!(cache.read_count(1).unwrap().is_empty());
608    }
609
610    #[test]
611    fn test_read_until_limit_beyond_stream() {
612        let mut cache = lazy_cache!(b"hello world");
613        let data = cache.read_until_or_limit(b'\n', 42).unwrap();
614        assert_eq!(data, b"hello world");
615        assert!(cache.read_count(1).unwrap().is_empty());
616    }
617
618    #[test]
619    fn test_read_until_limit_with_limit() {
620        let mut cache = lazy_cache!(b"hello world");
621        let data = cache.read_until_or_limit(b' ', 42).unwrap();
622        assert_eq!(data, b"hello ");
623
624        let data = cache.read_until_or_limit(b' ', 2).unwrap();
625        assert_eq!(data, b"wo");
626
627        let data = cache.read_until_or_limit(b' ', 42).unwrap();
628        assert_eq!(data, b"rld");
629    }
630
631    #[test]
632    fn test_read_until_utf16_limit() {
633        let mut cache = lazy_cache!(
634            b"\x61\x00\x62\x00\x63\x00\x64\x00\x00\x00\x61\x00\x62\x00\x63\x00\x64\x00\x00"
635        );
636        let data = cache.read_until_utf16_or_limit(b"\x00\x00", 512).unwrap();
637        assert_eq!(data, b"\x61\x00\x62\x00\x63\x00\x64\x00\x00\x00");
638
639        let data = cache.read_until_utf16_or_limit(b"\x00\x00", 1).unwrap();
640        assert_eq!(data, b"\x61\x00");
641
642        assert_eq!(
643            cache.read_until_utf16_or_limit(b"\xff\xff", 64).unwrap(),
644            b"\x62\x00\x63\x00\x64\x00\x00"
645        );
646    }
647
648    #[test]
649    fn test_io_read() {
650        let p = "./src/lib.rs";
651        let mut f = File::open(p).unwrap();
652        let mut lr = LazyCache::from_read_seek(File::open(p).unwrap())
653            .unwrap()
654            .with_hot_cache(512)
655            .unwrap()
656            .with_warm_cache(1024);
657
658        let fb = read_to_end!(f, 32);
659        let lcb = read_to_end!(lr, 16);
660
661        assert_eq!(lcb, fb);
662    }
663
664    #[test]
665    fn test_data_size() {
666        let f = File::open("./src/lib.rs").unwrap();
667        let size = f.metadata().unwrap().size();
668
669        let c = LazyCache::from_read_seek(f).unwrap();
670        assert_eq!(size, c.data_size());
671
672        assert_eq!(
673            LazyCache::from_read_seek(io::Cursor::new(&[]))
674                .unwrap()
675                .data_size(),
676            0
677        );
678    }
679}