1use async_trait::async_trait;
7use parking_lot::RwLock;
8use std::collections::BTreeMap;
9use std::io::{self, Read, Write};
10use std::ops::Range;
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13
14use super::{Directory, FileSlice, LazyFileHandle, OwnedBytes, RangeReadFn};
15
16pub const SLICE_CACHE_EXTENSION: &str = "slicecache";
18
19const SLICE_CACHE_MAGIC: &[u8; 8] = b"HRMSCACH";
21
22const SLICE_CACHE_VERSION: u32 = 2;
25
26#[derive(Debug, Clone)]
28struct CachedSlice {
29 range: Range<u64>,
31 data: Arc<Vec<u8>>,
33 access_count: u64,
35}
36
37struct FileSliceCache {
39 slices: BTreeMap<u64, CachedSlice>,
41 total_bytes: usize,
43}
44
45impl FileSliceCache {
46 fn new() -> Self {
47 Self {
48 slices: BTreeMap::new(),
49 total_bytes: 0,
50 }
51 }
52
53 fn serialize(&self) -> Vec<u8> {
55 let mut buf = Vec::new();
56 buf.extend_from_slice(&(self.slices.len() as u32).to_le_bytes());
58 for slice in self.slices.values() {
59 buf.extend_from_slice(&slice.range.start.to_le_bytes());
61 buf.extend_from_slice(&slice.range.end.to_le_bytes());
62 buf.extend_from_slice(&(slice.data.len() as u32).to_le_bytes());
64 buf.extend_from_slice(&slice.data);
65 }
66 buf
67 }
68
69 fn deserialize(data: &[u8], access_counter: u64) -> io::Result<(Self, usize)> {
71 let mut pos = 0;
72 if data.len() < 4 {
73 return Err(io::Error::new(
74 io::ErrorKind::InvalidData,
75 "truncated slice cache",
76 ));
77 }
78 let num_slices = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
79 pos += 4;
80
81 let mut cache = FileSliceCache::new();
82 for _ in 0..num_slices {
83 if pos + 20 > data.len() {
84 return Err(io::Error::new(
85 io::ErrorKind::InvalidData,
86 "truncated slice entry",
87 ));
88 }
89 let range_start = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
90 pos += 8;
91 let range_end = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
92 pos += 8;
93 let data_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
94 pos += 4;
95
96 if pos + data_len > data.len() {
97 return Err(io::Error::new(
98 io::ErrorKind::InvalidData,
99 "truncated slice data",
100 ));
101 }
102 let slice_data = data[pos..pos + data_len].to_vec();
103 pos += data_len;
104
105 cache.total_bytes += slice_data.len();
106 cache.slices.insert(
107 range_start,
108 CachedSlice {
109 range: range_start..range_end,
110 data: Arc::new(slice_data),
111 access_count: access_counter,
112 },
113 );
114 }
115 Ok((cache, pos))
116 }
117
118 #[allow(dead_code)]
120 fn iter_slices(&self) -> impl Iterator<Item = (&u64, &CachedSlice)> {
121 self.slices.iter()
122 }
123
124 fn try_read(&mut self, range: Range<u64>, access_counter: &mut u64) -> Option<Vec<u8>> {
126 let start = range.start;
128 let end = range.end;
129
130 let mut found_key = None;
132 for (&slice_start, slice) in self.slices.range(..=start).rev() {
133 if slice_start <= start && slice.range.end >= end {
134 found_key = Some((
135 slice_start,
136 (start - slice_start) as usize,
137 (end - start) as usize,
138 ));
139 break;
140 }
141 }
142
143 if let Some((key, offset, len)) = found_key {
144 *access_counter += 1;
146 if let Some(s) = self.slices.get_mut(&key) {
147 s.access_count = *access_counter;
148 return Some(s.data[offset..offset + len].to_vec());
149 }
150 }
151
152 None
153 }
154
155 fn insert(&mut self, range: Range<u64>, data: Vec<u8>, access_counter: u64) -> isize {
158 let start = range.start;
159 let end = range.end;
160 let data_len = data.len();
161
162 let mut to_remove = Vec::new();
164 let mut merged_start = start;
165 let mut merged_end = end;
166 let mut merged_data: Option<Vec<u8>> = None;
167 let mut bytes_removed: usize = 0;
168
169 for (&slice_start, slice) in &self.slices {
170 if slice_start < end && slice.range.end > start {
172 to_remove.push(slice_start);
173
174 merged_start = merged_start.min(slice_start);
176 merged_end = merged_end.max(slice.range.end);
177 }
178 }
179
180 if !to_remove.is_empty() {
182 let merged_len = (merged_end - merged_start) as usize;
183 let mut new_data = vec![0u8; merged_len];
184
185 for &slice_start in &to_remove {
187 if let Some(slice) = self.slices.get(&slice_start) {
188 let offset = (slice_start - merged_start) as usize;
189 new_data[offset..offset + slice.data.len()].copy_from_slice(&slice.data);
190 bytes_removed += slice.data.len();
191 self.total_bytes -= slice.data.len();
192 }
193 }
194
195 let offset = (start - merged_start) as usize;
197 new_data[offset..offset + data_len].copy_from_slice(&data);
198
199 for slice_start in to_remove {
201 self.slices.remove(&slice_start);
202 }
203
204 merged_data = Some(new_data);
205 }
206
207 let (final_start, final_data) = if let Some(md) = merged_data {
209 (merged_start, md)
210 } else {
211 (start, data)
212 };
213
214 let bytes_added = final_data.len();
215 self.total_bytes += bytes_added;
216
217 self.slices.insert(
218 final_start,
219 CachedSlice {
220 range: final_start..final_start + bytes_added as u64,
221 data: Arc::new(final_data),
222 access_count: access_counter,
223 },
224 );
225
226 bytes_added as isize - bytes_removed as isize
228 }
229
230 fn evict_lru(&mut self, bytes_to_free: usize) -> usize {
232 let mut freed = 0;
233
234 while freed < bytes_to_free && !self.slices.is_empty() {
235 let lru_key = self
237 .slices
238 .iter()
239 .min_by_key(|(_, s)| s.access_count)
240 .map(|(&k, _)| k);
241
242 if let Some(key) = lru_key {
243 if let Some(slice) = self.slices.remove(&key) {
244 freed += slice.data.len();
245 self.total_bytes -= slice.data.len();
246 }
247 } else {
248 break;
249 }
250 }
251
252 freed
253 }
254}
255
256pub struct SliceCachingDirectory<D: Directory> {
264 inner: Arc<D>,
265 caches: Arc<RwLock<std::collections::HashMap<PathBuf, FileSliceCache>>>,
267 file_sizes: Arc<RwLock<std::collections::HashMap<PathBuf, u64>>>,
269 max_bytes: usize,
271 current_bytes: Arc<RwLock<usize>>,
273 access_counter: Arc<RwLock<u64>>,
275}
276
277impl<D: Directory> SliceCachingDirectory<D> {
278 pub fn new(inner: D, max_bytes: usize) -> Self {
280 Self {
281 inner: Arc::new(inner),
282 caches: Arc::new(RwLock::new(std::collections::HashMap::new())),
283 file_sizes: Arc::new(RwLock::new(std::collections::HashMap::new())),
284 max_bytes,
285 current_bytes: Arc::new(RwLock::new(0)),
286 access_counter: Arc::new(RwLock::new(0)),
287 }
288 }
289
290 pub fn inner(&self) -> &D {
292 &self.inner
293 }
294
295 fn try_cache_read(&self, path: &Path, range: Range<u64>) -> Option<Vec<u8>> {
297 let mut caches = self.caches.write();
298 let mut counter = self.access_counter.write();
299
300 if let Some(file_cache) = caches.get_mut(path) {
301 file_cache.try_read(range, &mut counter)
302 } else {
303 None
304 }
305 }
306
307 fn cache_insert(&self, path: &Path, range: Range<u64>, data: Vec<u8>) {
309 let data_len = data.len();
310
311 {
313 let current = *self.current_bytes.read();
314 if current + data_len > self.max_bytes {
315 self.evict_to_fit(data_len);
316 }
317 }
318
319 let mut caches = self.caches.write();
320 let counter = *self.access_counter.read();
321
322 let file_cache = caches
323 .entry(path.to_path_buf())
324 .or_insert_with(FileSliceCache::new);
325
326 let net_change = file_cache.insert(range, data, counter);
327 let mut current = self.current_bytes.write();
328 if net_change >= 0 {
329 *current += net_change as usize;
330 } else {
331 *current = current.saturating_sub((-net_change) as usize);
332 }
333 }
334
335 fn evict_to_fit(&self, needed: usize) {
337 let mut caches = self.caches.write();
338 let mut current = self.current_bytes.write();
339
340 let target = if *current + needed > self.max_bytes {
341 (*current + needed) - self.max_bytes
342 } else {
343 return;
344 };
345
346 let mut freed = 0;
347
348 while freed < target {
350 let oldest_file = caches
352 .iter()
353 .filter(|(_, fc)| !fc.slices.is_empty())
354 .min_by_key(|(_, fc)| {
355 fc.slices
356 .values()
357 .map(|s| s.access_count)
358 .min()
359 .unwrap_or(u64::MAX)
360 })
361 .map(|(p, _)| p.clone());
362
363 if let Some(path) = oldest_file {
364 if let Some(file_cache) = caches.get_mut(&path) {
365 freed += file_cache.evict_lru(target - freed);
366 }
367 } else {
368 break;
369 }
370 }
371
372 *current = current.saturating_sub(freed);
373 }
374
375 pub fn stats(&self) -> SliceCacheStats {
377 let caches = self.caches.read();
378 let mut total_slices = 0;
379 let mut files_cached = 0;
380
381 for fc in caches.values() {
382 if !fc.slices.is_empty() {
383 files_cached += 1;
384 total_slices += fc.slices.len();
385 }
386 }
387
388 SliceCacheStats {
389 total_bytes: *self.current_bytes.read(),
390 max_bytes: self.max_bytes,
391 total_slices,
392 files_cached,
393 }
394 }
395
396 pub fn serialize(&self) -> Vec<u8> {
412 let caches = self.caches.read();
413 let file_sizes = self.file_sizes.read();
414 let mut buf = Vec::new();
415
416 buf.extend_from_slice(SLICE_CACHE_MAGIC);
418 buf.extend_from_slice(&SLICE_CACHE_VERSION.to_le_bytes());
419
420 let non_empty: Vec<_> = caches
422 .iter()
423 .filter(|(_, fc)| !fc.slices.is_empty())
424 .collect();
425 buf.extend_from_slice(&(non_empty.len() as u32).to_le_bytes());
426
427 for (path, file_cache) in non_empty {
428 let path_str = path.to_string_lossy();
430 let path_bytes = path_str.as_bytes();
431 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
432 buf.extend_from_slice(path_bytes);
433
434 let cache_data = file_cache.serialize();
436 buf.extend_from_slice(&cache_data);
437 }
438
439 buf.extend_from_slice(&(file_sizes.len() as u32).to_le_bytes());
441 for (path, &size) in file_sizes.iter() {
442 let path_str = path.to_string_lossy();
443 let path_bytes = path_str.as_bytes();
444 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
445 buf.extend_from_slice(path_bytes);
446 buf.extend_from_slice(&size.to_le_bytes());
447 }
448
449 buf
450 }
451
452 pub fn deserialize(&self, data: &[u8]) -> io::Result<()> {
457 let mut pos = 0;
458
459 if data.len() < 16 {
461 return Err(io::Error::new(
462 io::ErrorKind::InvalidData,
463 "slice cache too short",
464 ));
465 }
466 if &data[pos..pos + 8] != SLICE_CACHE_MAGIC {
467 return Err(io::Error::new(
468 io::ErrorKind::InvalidData,
469 "invalid slice cache magic",
470 ));
471 }
472 pos += 8;
473
474 let version = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
476 pos += 4;
477 if version != 2 {
478 return Err(io::Error::new(
479 io::ErrorKind::InvalidData,
480 format!("unsupported slice cache version: {} (expected 2)", version),
481 ));
482 }
483
484 let num_files = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
486 pos += 4;
487
488 let mut caches = self.caches.write();
489 let mut current_bytes = self.current_bytes.write();
490 let counter = *self.access_counter.read();
491
492 for _ in 0..num_files {
493 if pos + 4 > data.len() {
495 return Err(io::Error::new(
496 io::ErrorKind::InvalidData,
497 "truncated path length",
498 ));
499 }
500 let path_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
501 pos += 4;
502
503 if pos + path_len > data.len() {
505 return Err(io::Error::new(io::ErrorKind::InvalidData, "truncated path"));
506 }
507 let path_str = std::str::from_utf8(&data[pos..pos + path_len])
508 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
509 let path = PathBuf::from(path_str);
510 pos += path_len;
511
512 let (file_cache, consumed) = FileSliceCache::deserialize(&data[pos..], counter)?;
514 pos += consumed;
515
516 *current_bytes += file_cache.total_bytes;
518 caches.insert(path, file_cache);
519 }
520
521 if pos + 4 <= data.len() {
523 let num_sizes = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
524 pos += 4;
525
526 let mut file_sizes = self.file_sizes.write();
527 for _ in 0..num_sizes {
528 if pos + 4 > data.len() {
529 break;
530 }
531 let path_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
532 pos += 4;
533
534 if pos + path_len > data.len() {
535 break;
536 }
537 let path_str = match std::str::from_utf8(&data[pos..pos + path_len]) {
538 Ok(s) => s,
539 Err(_) => break,
540 };
541 let path = PathBuf::from(path_str);
542 pos += path_len;
543
544 if pos + 8 > data.len() {
545 break;
546 }
547 let size = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
548 pos += 8;
549
550 file_sizes.insert(path, size);
551 }
552 }
553
554 Ok(())
555 }
556
557 pub fn serialize_to_writer<W: Write>(&self, mut writer: W) -> io::Result<()> {
559 let data = self.serialize();
560 writer.write_all(&data)
561 }
562
563 pub fn deserialize_from_reader<R: Read>(&self, mut reader: R) -> io::Result<()> {
565 let mut data = Vec::new();
566 reader.read_to_end(&mut data)?;
567 self.deserialize(&data)
568 }
569
570 pub fn is_empty(&self) -> bool {
572 *self.current_bytes.read() == 0
573 }
574
575 pub fn clear(&self) {
577 let mut caches = self.caches.write();
578 let mut current_bytes = self.current_bytes.write();
579 caches.clear();
580 *current_bytes = 0;
581 }
582}
583
584#[derive(Debug, Clone)]
586pub struct SliceCacheStats {
587 pub total_bytes: usize,
588 pub max_bytes: usize,
589 pub total_slices: usize,
590 pub files_cached: usize,
591}
592
593#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
594#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
595impl<D: Directory> Directory for SliceCachingDirectory<D> {
596 async fn exists(&self, path: &Path) -> io::Result<bool> {
597 self.inner.exists(path).await
598 }
599
600 async fn file_size(&self, path: &Path) -> io::Result<u64> {
601 {
603 let file_sizes = self.file_sizes.read();
604 if let Some(&size) = file_sizes.get(path) {
605 return Ok(size);
606 }
607 }
608
609 let size = self.inner.file_size(path).await?;
611 {
612 let mut file_sizes = self.file_sizes.write();
613 file_sizes.insert(path.to_path_buf(), size);
614 }
615 Ok(size)
616 }
617
618 async fn open_read(&self, path: &Path) -> io::Result<FileSlice> {
619 let file_size = self.file_size(path).await?;
621 let full_range = 0..file_size;
622
623 if let Some(data) = self.try_cache_read(path, full_range.clone()) {
625 return Ok(FileSlice::new(OwnedBytes::new(data)));
626 }
627
628 let slice = self.inner.open_read(path).await?;
630 let bytes = slice.read_bytes().await?;
631
632 self.cache_insert(path, full_range, bytes.as_slice().to_vec());
634
635 Ok(FileSlice::new(bytes))
636 }
637
638 async fn read_range(&self, path: &Path, range: Range<u64>) -> io::Result<OwnedBytes> {
639 if let Some(data) = self.try_cache_read(path, range.clone()) {
641 return Ok(OwnedBytes::new(data));
642 }
643
644 let data = self.inner.read_range(path, range.clone()).await?;
646
647 self.cache_insert(path, range, data.as_slice().to_vec());
649
650 Ok(data)
651 }
652
653 async fn list_files(&self, prefix: &Path) -> io::Result<Vec<PathBuf>> {
654 self.inner.list_files(prefix).await
655 }
656
657 async fn open_lazy(&self, path: &Path) -> io::Result<LazyFileHandle> {
658 let file_size = self.file_size(path).await?;
660
661 let path_buf = path.to_path_buf();
663 let caches = Arc::clone(&self.caches);
664 let current_bytes = Arc::clone(&self.current_bytes);
665 let access_counter = Arc::clone(&self.access_counter);
666 let max_bytes = self.max_bytes;
667 let inner = Arc::clone(&self.inner);
668
669 let read_fn: RangeReadFn = Arc::new(move |range: Range<u64>| {
670 let path = path_buf.clone();
671 let caches = Arc::clone(&caches);
672 let current_bytes = Arc::clone(¤t_bytes);
673 let access_counter = Arc::clone(&access_counter);
674 let inner = Arc::clone(&inner);
675
676 Box::pin(async move {
677 {
679 let mut caches_guard = caches.write();
680 let mut counter = access_counter.write();
681 if let Some(file_cache) = caches_guard.get_mut(&path)
682 && let Some(data) = file_cache.try_read(range.clone(), &mut counter)
683 {
684 log::trace!("Cache HIT: {:?} [{}-{}]", path, range.start, range.end);
685 return Ok(OwnedBytes::new(data));
686 }
687 }
688
689 log::trace!("Cache MISS: {:?} [{}-{}]", path, range.start, range.end);
690
691 let data = inner.read_range(&path, range.clone()).await?;
693
694 let data_len = data.as_slice().len();
696 {
697 let current = *current_bytes.read();
699 if current + data_len > max_bytes {
700 } else {
703 let mut caches_guard = caches.write();
704 let counter = *access_counter.read();
705
706 let file_cache = caches_guard
707 .entry(path.clone())
708 .or_insert_with(FileSliceCache::new);
709
710 let net_change =
711 file_cache.insert(range, data.as_slice().to_vec(), counter);
712 let mut current = current_bytes.write();
713 if net_change >= 0 {
714 *current += net_change as usize;
715 } else {
716 *current = current.saturating_sub((-net_change) as usize);
717 }
718 }
719 }
720
721 Ok(data)
722 })
723 });
724
725 Ok(LazyFileHandle::new(file_size, read_fn))
726 }
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732 use crate::directories::{DirectoryWriter, RamDirectory};
733
734 #[tokio::test]
735 async fn test_slice_cache_basic() {
736 let ram = RamDirectory::new();
737 ram.write(Path::new("test.bin"), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
738 .await
739 .unwrap();
740
741 let cached = SliceCachingDirectory::new(ram, 1024);
742
743 let data = cached
745 .read_range(Path::new("test.bin"), 2..5)
746 .await
747 .unwrap();
748 assert_eq!(data.as_slice(), &[2, 3, 4]);
749
750 let data = cached
752 .read_range(Path::new("test.bin"), 2..5)
753 .await
754 .unwrap();
755 assert_eq!(data.as_slice(), &[2, 3, 4]);
756
757 let stats = cached.stats();
758 assert_eq!(stats.total_slices, 1);
759 assert_eq!(stats.total_bytes, 3);
760 }
761
762 #[tokio::test]
763 async fn test_slice_cache_overlap_merge() {
764 let ram = RamDirectory::new();
765 ram.write(Path::new("test.bin"), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
766 .await
767 .unwrap();
768
769 let cached = SliceCachingDirectory::new(ram, 1024);
770
771 cached
773 .read_range(Path::new("test.bin"), 2..5)
774 .await
775 .unwrap();
776
777 cached
779 .read_range(Path::new("test.bin"), 4..7)
780 .await
781 .unwrap();
782
783 let stats = cached.stats();
784 assert_eq!(stats.total_slices, 1);
786 assert_eq!(stats.total_bytes, 5); let data = cached
790 .read_range(Path::new("test.bin"), 3..6)
791 .await
792 .unwrap();
793 assert_eq!(data.as_slice(), &[3, 4, 5]);
794 }
795
796 #[tokio::test]
797 async fn test_slice_cache_eviction() {
798 let ram = RamDirectory::new();
799 ram.write(Path::new("test.bin"), &[0; 100]).await.unwrap();
800
801 let cached = SliceCachingDirectory::new(ram, 50);
803
804 cached
806 .read_range(Path::new("test.bin"), 0..30)
807 .await
808 .unwrap();
809
810 cached
812 .read_range(Path::new("test.bin"), 50..80)
813 .await
814 .unwrap();
815
816 let stats = cached.stats();
817 assert!(stats.total_bytes <= 50);
818 }
819
820 #[tokio::test]
821 async fn test_slice_cache_serialize_deserialize() {
822 let ram = RamDirectory::new();
823 ram.write(Path::new("file1.bin"), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
824 .await
825 .unwrap();
826 ram.write(Path::new("file2.bin"), &[10, 11, 12, 13, 14, 15])
827 .await
828 .unwrap();
829
830 let cached = SliceCachingDirectory::new(ram.clone(), 1024);
831
832 cached
834 .read_range(Path::new("file1.bin"), 2..6)
835 .await
836 .unwrap();
837 cached
838 .read_range(Path::new("file2.bin"), 1..4)
839 .await
840 .unwrap();
841
842 let stats = cached.stats();
843 assert_eq!(stats.files_cached, 2);
844 assert_eq!(stats.total_bytes, 7); let serialized = cached.serialize();
848 assert!(!serialized.is_empty());
849
850 let cached2 = SliceCachingDirectory::new(ram.clone(), 1024);
852 assert!(cached2.is_empty());
853
854 cached2.deserialize(&serialized).unwrap();
855
856 let stats2 = cached2.stats();
857 assert_eq!(stats2.files_cached, 2);
858 assert_eq!(stats2.total_bytes, 7);
859
860 let data = cached2
862 .read_range(Path::new("file1.bin"), 2..6)
863 .await
864 .unwrap();
865 assert_eq!(data.as_slice(), &[2, 3, 4, 5]);
866
867 let data = cached2
868 .read_range(Path::new("file2.bin"), 1..4)
869 .await
870 .unwrap();
871 assert_eq!(data.as_slice(), &[11, 12, 13]);
872 }
873
874 #[tokio::test]
875 async fn test_slice_cache_serialize_empty() {
876 let ram = RamDirectory::new();
877 let cached = SliceCachingDirectory::new(ram, 1024);
878
879 let serialized = cached.serialize();
881 assert!(!serialized.is_empty()); let cached2 = SliceCachingDirectory::new(RamDirectory::new(), 1024);
885 cached2.deserialize(&serialized).unwrap();
886 assert!(cached2.is_empty());
887 }
888}