1#![deny(unsafe_code)]
2
3use std::{
4 cmp::min,
5 fs::File,
6 io::{self, Read, Seek, SeekFrom, Write},
7 ops::Range,
8 path::Path,
9};
10
11use memmap2::MmapMut;
12
13const EMPTY_RANGE: &[u8] = &[];
14
15pub struct LazyCache<R>
16where
17 R: Read + Seek,
18{
19 source: R,
20 loaded: Vec<bool>,
21 hot_head: Vec<u8>,
22 hot_tail: Vec<u8>,
23 warm: Option<MmapMut>,
24 cold: Vec<u8>,
25 block_size: u64,
26 warm_size: Option<u64>,
27 stream_pos: u64,
28 pos_end: u64,
29}
30
31const BLOCK_SIZE: usize = 4096;
32
33impl<R> Seek for LazyCache<R>
34where
35 R: Read + Seek,
36{
37 #[inline(always)]
38 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
39 self.stream_pos = self.offset_from_start(pos);
40 Ok(self.stream_pos)
41 }
42}
43
44impl LazyCache<File> {
45 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
46 Self::from_read_seek(File::open(path)?)
47 }
48}
49
50impl<R> io::Read for LazyCache<R>
51where
52 R: Read + Seek,
53{
54 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
55 let r = self.inner_read_count(buf.len() as u64)?;
56 for (i, b) in r.iter().enumerate() {
57 buf[i] = *b;
58 }
59 Ok(r.len())
60 }
61}
62
63impl<R> LazyCache<R>
64where
65 R: Read + Seek,
66{
67 pub fn from_read_seek(mut rs: R) -> Result<Self, io::Error> {
68 let block_size = BLOCK_SIZE as u64;
69 let pos_end = rs.seek(SeekFrom::End(0))?;
70 let cache_cap = pos_end.div_ceil(BLOCK_SIZE as u64);
71
72 Ok(Self {
73 source: rs,
74 hot_head: vec![],
75 hot_tail: vec![],
76 warm: None,
77 cold: vec![0; block_size as usize],
78 loaded: vec![false; cache_cap as usize],
79 block_size,
80 warm_size: None,
81 stream_pos: 0,
82 pos_end,
83 })
84 }
85
86 pub fn with_hot_cache(mut self, size: usize) -> Result<Self, io::Error> {
87 let head_tail_size = size / 2;
88
89 self.source.seek(SeekFrom::Start(0))?;
90
91 if self.pos_end > size as u64 {
92 self.hot_head = vec![0u8; head_tail_size];
93 self.source.read_exact(self.hot_head.as_mut_slice())?;
94
95 self.source.seek(SeekFrom::End(-(size as i64)))?;
96 self.hot_tail = vec![0u8; head_tail_size];
97 self.source.read_exact(self.hot_tail.as_mut_slice())?;
98 } else {
99 self.hot_head = vec![0u8; self.pos_end as usize];
100 self.source.read_exact(self.hot_head.as_mut())?;
101 }
102
103 Ok(self)
104 }
105
106 pub fn with_warm_cache(mut self, warm_size: u64) -> Self {
107 self.warm_size = Some(warm_size);
108 self
109 }
110
111 #[inline(always)]
112 pub fn offset_from_start(&self, pos: SeekFrom) -> u64 {
113 match pos {
114 SeekFrom::Start(s) => s,
115 SeekFrom::Current(p) => (self.stream_pos as i128 + p as i128) as u64,
116 SeekFrom::End(e) => (self.pos_end as i128 + e as i128) as u64,
117 }
118 }
119
120 #[inline(always)]
121 pub fn lazy_stream_position(&self) -> u64 {
122 self.stream_pos
123 }
124
125 #[inline(always)]
126 fn warm(&mut self) -> Result<&mut MmapMut, io::Error> {
127 if self.warm.is_none() && self.warm_size.is_some() {
128 self.warm = Some(MmapMut::map_anon(
129 self.warm_size.unwrap_or_default() as usize
130 )?);
131 }
132 Ok(self.warm.as_mut().unwrap())
133 }
134
135 #[inline(always)]
136 fn range_warmup(&mut self, range: Range<u64>) -> Result<(), io::Error> {
137 let start_chunk_id = range.start / self.block_size;
138 let end_chunk_id = (range.end.saturating_sub(1)) / self.block_size;
139
140 if self.loaded.is_empty() {
141 return Ok(());
142 }
143
144 for chunk_id in start_chunk_id..=end_chunk_id {
145 if self.loaded[chunk_id as usize] {
146 continue;
147 }
148
149 let offset = chunk_id * self.block_size;
150 let buf_size = min(
151 self.block_size as usize,
152 (self.pos_end.saturating_sub(offset)) as usize,
153 );
154 let mut buf = vec![0u8; buf_size];
155 self.source.seek(SeekFrom::Start(offset))?;
156 self.source.read_exact(&mut buf)?;
157
158 (&mut self.warm()?[offset as usize..]).write_all(&buf)?;
159 self.loaded[chunk_id as usize] = true;
160 }
161
162 Ok(())
163 }
164
165 #[inline(always)]
166 fn get_range_u64(&mut self, range: Range<u64>) -> Result<&[u8], io::Error> {
167 let range = if range.end > self.pos_end {
169 range.start..self.pos_end
170 } else {
171 range
172 };
173
174 let range_len = range.end.saturating_sub(range.start);
175
176 if range.start > self.pos_end || range_len == 0 {
177 return Ok(EMPTY_RANGE);
178 } else if range.start < self.hot_head.len() as u64
179 && range.end <= self.hot_head.len() as u64
180 {
181 self.seek(SeekFrom::Start(range.end))?;
182 return Ok(&self.hot_head[range.start as usize..range.end as usize]);
183 } else if range.start > (self.pos_end - self.hot_tail.len() as u64) {
184 let start_from_end = self.pos_end.saturating_sub(1).saturating_sub(range.start);
185 self.seek(SeekFrom::Start(range.end))?;
186 return Ok(&self.hot_tail
187 [start_from_end as usize..start_from_end.saturating_add(range_len) as usize]);
188 } else if range.end < self.warm_size.unwrap_or_default() {
189 self.range_warmup(range.clone())?;
190 self.seek(SeekFrom::Start(range.end))?;
191 return Ok(&self.warm()?[range.start as usize..range.end as usize]);
192 } else if range_len > self.cold.len() as u64 {
193 self.cold.resize(range_len as usize, 0);
194 }
195
196 self.source.seek(SeekFrom::Start(range.start))?;
197 let n = self.source.read(self.cold[..range_len as usize].as_mut())?;
198 self.seek(SeekFrom::Start(range.end))?;
199 Ok(&self.cold[..n])
200 }
201
202 pub fn read_range(&mut self, range: Range<u64>) -> Result<&[u8], io::Error> {
203 let range = range.start..range.end;
204 self.get_range_u64(range)
205 }
206
207 #[inline(always)]
208 fn inner_read_count(&mut self, count: u64) -> Result<&[u8], io::Error> {
209 let pos = self.stream_pos;
210 let range = pos..(pos.saturating_add(count));
211 self.get_range_u64(range)
212 }
213
214 pub fn read_count(&mut self, count: u64) -> Result<&[u8], io::Error> {
216 self.inner_read_count(count)
217 }
218
219 pub fn read_exact_range(&mut self, range: Range<u64>) -> Result<&[u8], io::Error> {
220 let range_len = range.end - range.start;
221 let b = self.read_range(range)?;
222 if b.len() as u64 != range_len {
223 Err(io::Error::from(io::ErrorKind::UnexpectedEof))
224 } else {
225 Ok(b)
226 }
227 }
228
229 pub fn read_exact_count(&mut self, count: u64) -> Result<&[u8], io::Error> {
230 let b = self.read_count(count)?;
231 debug_assert!(b.len() <= count as usize);
232 if b.len() as u64 != count {
233 Err(io::ErrorKind::UnexpectedEof.into())
234 } else {
235 Ok(b)
236 }
237 }
238
239 pub fn read_exact_into(&mut self, buf: &mut [u8]) -> Result<(), io::Error> {
240 let read = self.read_exact_count(buf.len() as u64)?;
241 buf.copy_from_slice(read);
244 Ok(())
245 }
246
247 pub fn read_until_any_delim_or_limit(
248 &mut self,
249 delims: &[u8],
250 limit: u64,
251 ) -> Result<&[u8], io::Error> {
252 self._read_while_or_limit(|b| !delims.contains(&b), limit, true)
253 }
254
255 pub fn read_until_or_limit(&mut self, byte: u8, limit: u64) -> Result<&[u8], io::Error> {
256 self._read_while_or_limit(|b| b != byte, limit, true)
257 }
258
259 #[inline(always)]
261 fn _read_while_or_limit<F>(
262 &mut self,
263 f: F,
264 limit: u64,
265 include_last: bool,
266 ) -> Result<&[u8], io::Error>
267 where
268 F: Fn(u8) -> bool,
269 {
270 let start = self.stream_pos;
271 let mut end = 0;
272
273 'outer: while limit - end > 0 {
274 let buf = self.read_count(self.block_size)?;
275
276 for b in buf {
277 if limit - end == 0 {
278 break 'outer;
279 }
280
281 if !f(*b) {
282 if include_last {
283 end += 1;
284 }
285 break 'outer;
287 }
288
289 end += 1;
290 }
291
292 if buf.len() as u64 != self.block_size {
294 break;
295 }
296 }
297
298 self.read_exact_range(start..start + end)
299 }
300
301 pub fn read_while_or_limit<F>(&mut self, f: F, limit: u64) -> Result<&[u8], io::Error>
302 where
303 F: Fn(u8) -> bool,
304 {
305 self._read_while_or_limit(f, limit, false)
306 }
307
308 pub fn read_until_utf16_or_limit(
310 &mut self,
311 utf16_char: &[u8; 2],
312 limit: u64,
313 ) -> Result<&[u8], io::Error> {
314 let start = self.stream_pos;
315 let mut end = 0;
316
317 let even_bs = if self.block_size.is_multiple_of(2) {
318 self.block_size
319 } else {
320 self.block_size.saturating_add(1)
321 };
322
323 'outer: while limit.saturating_sub(end) > 0 {
324 let buf = self.read_count(even_bs)?;
325
326 let even = buf
327 .iter()
328 .enumerate()
329 .filter(|(i, _)| i % 2 == 0)
330 .map(|t| t.1);
331
332 let odd = buf
333 .iter()
334 .enumerate()
335 .filter(|(i, _)| i % 2 != 0)
336 .map(|t| t.1);
337
338 for t in even.zip(odd) {
339 if limit.saturating_sub(end) == 0 {
340 break 'outer;
341 }
342
343 end += 2;
344
345 if t.0 == &utf16_char[0] && t.1 == &utf16_char[1] {
347 break 'outer;
349 }
350 }
351
352 if buf.len() as u64 != even_bs {
354 if buf.len() % 2 != 0 {
356 end += 1
358 }
359 break;
360 }
361 }
362
363 self.read_exact_range(start..start + end)
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 macro_rules! lazy_cache {
372 ($content: literal) => {
373 LazyCache::from_read_seek(std::io::Cursor::new($content)).unwrap()
374 };
375 }
376
377 #[test]
378 fn test_get_single_block() {
379 let mut cache = lazy_cache!(b"hello world");
380 let data = cache.read_range(0..4).unwrap();
381 assert_eq!(data, b"hell");
382 }
383
384 #[test]
385 fn test_get_across_blocks() {
386 let mut cache = lazy_cache!(b"hello world");
387 let data = cache.read_range(2..7).unwrap();
388 assert_eq!(data, b"llo w");
389 }
390
391 #[test]
392 fn test_get_entire_file() {
393 let mut cache = lazy_cache!(b"hello world");
394 let data = cache.read_range(0..11).unwrap();
395 assert_eq!(data, b"hello world");
396 }
397
398 #[test]
399 fn test_get_empty_range() {
400 let mut cache = lazy_cache!(b"hello world");
401 let data = cache.read_range(0..0).unwrap();
402 assert!(data.is_empty());
403 }
404
405 #[test]
406 fn test_get_out_of_bounds() {
407 let mut cache = lazy_cache!(b"hello world");
408 assert!(cache.read_range(20..30).unwrap().is_empty());
412 }
413
414 #[test]
415 fn test_cache_eviction() {
416 let mut cache = lazy_cache!(b"0123456789abcdef");
417 let _ = cache.read_range(0..8).unwrap();
419 let _ = cache.read_range(8..12).unwrap();
421 let data = cache.read_range(8..12).unwrap();
423 assert_eq!(data, b"89ab");
424 }
425
426 #[test]
427 fn test_chunk_consolidation() {
428 let mut cache = lazy_cache!(b"0123456789abcdef");
429 let _ = cache.read_range(0..4).unwrap();
431 let _ = cache.read_range(4..8).unwrap();
432 let _ = cache.read_range(8..12).unwrap();
434 let _ = cache.read_range(2..6).unwrap();
436 let data = cache.read_range(0..8).unwrap();
438 assert_eq!(data, b"01234567");
439 }
440
441 #[test]
442 fn test_overlapping_ranges() {
443 let mut cache = lazy_cache!(b"0123456789abcdef");
444 let _ = cache.read_range(2..6).unwrap();
446 let _ = cache.read_range(4..10).unwrap();
447 let data = cache.read_range(2..10).unwrap();
449 assert_eq!(data, b"23456789");
450 }
451
452 #[test]
453 fn test_lru_behavior() {
454 let mut cache = lazy_cache!(b"0123456789abcdef");
455 let _ = cache.read_range(0..4).unwrap();
457 let _ = cache.read_range(4..8).unwrap();
459 let _ = cache.read_range(8..12).unwrap();
461 let data = cache.read_range(0..4).unwrap();
463 assert_eq!(data, b"0123");
464 }
465
466 #[test]
467 fn test_small_block_size() {
468 let mut cache = lazy_cache!(b"abc");
469 let data = cache.read_range(0..3).unwrap();
470 assert_eq!(data, b"abc");
471 }
472
473 #[test]
474 fn test_large_block_size() {
475 let mut cache = lazy_cache!(b"hello world");
476 let data = cache.read_range(0..11).unwrap();
477 assert_eq!(data, b"hello world");
478 }
479
480 #[test]
481 fn test_file_smaller_than_block() {
482 let mut cache = lazy_cache!(b"abc");
483 let data = cache.read_range(0..3).unwrap();
484 assert_eq!(data, b"abc");
485 }
486
487 #[test]
488 fn test_multiple_gets_same_block() {
489 let mut cache = lazy_cache!(b"0123456789abcdef");
490 let _ = cache.read_range(0..4).unwrap();
492 let _ = cache.read_range(0..4).unwrap();
493 let _ = cache.read_range(0..4).unwrap();
494 let data = cache.read_range(0..4).unwrap();
496 assert_eq!(data, b"0123");
497 }
498
499 #[test]
500 fn test_read_method() {
501 let mut cache = lazy_cache!(b"hello world");
502 let _ = cache.read_count(6).unwrap();
503 let data = cache.read_count(5).unwrap();
504 assert_eq!(data, b"world");
505 assert!(cache.read_count(1).unwrap().is_empty());
507 }
508
509 #[test]
510 fn test_read_empty() {
511 let mut cache = lazy_cache!(b"hello world");
512 let data = cache.read_count(0).unwrap();
513 assert!(data.is_empty());
514 }
515
516 #[test]
517 fn test_read_beyond_end() {
518 let mut cache = lazy_cache!(b"hello world");
519 let _ = cache.read_count(11).unwrap();
520 let data = cache.read_count(5).unwrap();
521 assert!(data.is_empty());
522 }
523
524 #[test]
525 fn test_read_exact_range() {
526 let mut cache = lazy_cache!(b"hello world");
527 let data = cache.read_exact_range(0..5).unwrap();
528 assert_eq!(data, b"hello");
529 assert_eq!(cache.read_exact_range(5..11).unwrap(), b" world");
530 assert!(cache.read_exact_range(12..13).is_err());
531 }
532
533 #[test]
534 fn test_read_exact_range_error() {
535 let mut cache = lazy_cache!(b"hello world");
536 let result = cache.read_exact_range(0..20);
537 assert!(result.is_err());
538 }
539
540 #[test]
541 fn test_read_exact() {
542 let mut cache = lazy_cache!(b"hello world");
543 let data = cache.read_exact_count(5).unwrap();
544 assert_eq!(data, b"hello");
545 assert_eq!(cache.read_exact_count(6).unwrap(), b" world");
546 assert!(cache.read_exact_count(0).is_ok());
547 assert!(cache.read_exact_count(1).is_err());
548 }
549
550 #[test]
551 fn test_read_exact_error() {
552 let mut cache = lazy_cache!(b"hello world");
553 let result = cache.read_exact_count(20);
554 assert!(result.is_err());
555 }
556
557 #[test]
558 fn test_read_until_limit() {
559 let mut cache = lazy_cache!(b"hello world");
560 let data = cache.read_until_or_limit(b' ', 10).unwrap();
561 assert_eq!(data, b"hello ");
562 assert_eq!(cache.read_exact_count(5).unwrap(), b"world");
563 }
564
565 #[test]
566 fn test_read_until_limit_not_found() {
567 let mut cache = lazy_cache!(b"hello world");
568 let data = cache.read_until_or_limit(b'\n', 11).unwrap();
569 assert_eq!(data, b"hello world");
570 assert!(cache.read_count(1).unwrap().is_empty());
571 }
572
573 #[test]
574 fn test_read_until_limit_beyond_stream() {
575 let mut cache = lazy_cache!(b"hello world");
576 let data = cache.read_until_or_limit(b'\n', 42).unwrap();
577 assert_eq!(data, b"hello world");
578 assert!(cache.read_count(1).unwrap().is_empty());
579 }
580
581 #[test]
582 fn test_read_until_limit_with_limit() {
583 let mut cache = lazy_cache!(b"hello world");
584 let data = cache.read_until_or_limit(b' ', 42).unwrap();
585 assert_eq!(data, b"hello ");
586
587 let data = cache.read_until_or_limit(b' ', 2).unwrap();
588 assert_eq!(data, b"wo");
589
590 let data = cache.read_until_or_limit(b' ', 42).unwrap();
591 assert_eq!(data, b"rld");
592 }
593
594 #[test]
595 fn test_read_until_utf16_limit() {
596 let mut cache = lazy_cache!(
597 b"\x61\x00\x62\x00\x63\x00\x64\x00\x00\x00\x61\x00\x62\x00\x63\x00\x64\x00\x00"
598 );
599 let data = cache.read_until_utf16_or_limit(b"\x00\x00", 512).unwrap();
600 assert_eq!(data, b"\x61\x00\x62\x00\x63\x00\x64\x00\x00\x00");
601
602 let data = cache.read_until_utf16_or_limit(b"\x00\x00", 1).unwrap();
603 assert_eq!(data, b"\x61\x00");
604
605 assert_eq!(
606 cache.read_until_utf16_or_limit(b"\xff\xff", 64).unwrap(),
607 b"\x62\x00\x63\x00\x64\x00\x00"
608 );
609 }
610}