1use async_trait::async_trait;
7use parking_lot::RwLock;
8use std::collections::BTreeMap;
9use std::io::{self, Read, Write};
10use std::ops::Range;
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13
14use super::{Directory, FileSlice, LazyFileHandle, OwnedBytes, RangeReadFn};
15
16pub const SLICE_CACHE_EXTENSION: &str = "slicecache";
18
19const SLICE_CACHE_MAGIC: &[u8; 8] = b"HRMSCACH";
21
22const SLICE_CACHE_VERSION: u32 = 1;
24
25#[derive(Debug, Clone)]
27struct CachedSlice {
28 range: Range<u64>,
30 data: Arc<Vec<u8>>,
32 access_count: u64,
34}
35
36struct FileSliceCache {
38 slices: BTreeMap<u64, CachedSlice>,
40 total_bytes: usize,
42}
43
44impl FileSliceCache {
45 fn new() -> Self {
46 Self {
47 slices: BTreeMap::new(),
48 total_bytes: 0,
49 }
50 }
51
52 fn serialize(&self) -> Vec<u8> {
54 let mut buf = Vec::new();
55 buf.extend_from_slice(&(self.slices.len() as u32).to_le_bytes());
57 for slice in self.slices.values() {
58 buf.extend_from_slice(&slice.range.start.to_le_bytes());
60 buf.extend_from_slice(&slice.range.end.to_le_bytes());
61 buf.extend_from_slice(&(slice.data.len() as u32).to_le_bytes());
63 buf.extend_from_slice(&slice.data);
64 }
65 buf
66 }
67
68 fn deserialize(data: &[u8], access_counter: u64) -> io::Result<(Self, usize)> {
70 let mut pos = 0;
71 if data.len() < 4 {
72 return Err(io::Error::new(
73 io::ErrorKind::InvalidData,
74 "truncated slice cache",
75 ));
76 }
77 let num_slices = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
78 pos += 4;
79
80 let mut cache = FileSliceCache::new();
81 for _ in 0..num_slices {
82 if pos + 20 > data.len() {
83 return Err(io::Error::new(
84 io::ErrorKind::InvalidData,
85 "truncated slice entry",
86 ));
87 }
88 let range_start = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
89 pos += 8;
90 let range_end = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
91 pos += 8;
92 let data_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
93 pos += 4;
94
95 if pos + data_len > data.len() {
96 return Err(io::Error::new(
97 io::ErrorKind::InvalidData,
98 "truncated slice data",
99 ));
100 }
101 let slice_data = data[pos..pos + data_len].to_vec();
102 pos += data_len;
103
104 cache.total_bytes += slice_data.len();
105 cache.slices.insert(
106 range_start,
107 CachedSlice {
108 range: range_start..range_end,
109 data: Arc::new(slice_data),
110 access_count: access_counter,
111 },
112 );
113 }
114 Ok((cache, pos))
115 }
116
117 #[allow(dead_code)]
119 fn iter_slices(&self) -> impl Iterator<Item = (&u64, &CachedSlice)> {
120 self.slices.iter()
121 }
122
123 fn try_read(&mut self, range: Range<u64>, access_counter: &mut u64) -> Option<Vec<u8>> {
125 let start = range.start;
127 let end = range.end;
128
129 let mut found_key = None;
131 for (&slice_start, slice) in self.slices.range(..=start).rev() {
132 if slice_start <= start && slice.range.end >= end {
133 found_key = Some((
134 slice_start,
135 (start - slice_start) as usize,
136 (end - start) as usize,
137 ));
138 break;
139 }
140 }
141
142 if let Some((key, offset, len)) = found_key {
143 *access_counter += 1;
145 if let Some(s) = self.slices.get_mut(&key) {
146 s.access_count = *access_counter;
147 return Some(s.data[offset..offset + len].to_vec());
148 }
149 }
150
151 None
152 }
153
154 fn insert(&mut self, range: Range<u64>, data: Vec<u8>, access_counter: u64) -> isize {
157 let start = range.start;
158 let end = range.end;
159 let data_len = data.len();
160
161 let mut to_remove = Vec::new();
163 let mut merged_start = start;
164 let mut merged_end = end;
165 let mut merged_data: Option<Vec<u8>> = None;
166 let mut bytes_removed: usize = 0;
167
168 for (&slice_start, slice) in &self.slices {
169 if slice_start < end && slice.range.end > start {
171 to_remove.push(slice_start);
172
173 merged_start = merged_start.min(slice_start);
175 merged_end = merged_end.max(slice.range.end);
176 }
177 }
178
179 if !to_remove.is_empty() {
181 let merged_len = (merged_end - merged_start) as usize;
182 let mut new_data = vec![0u8; merged_len];
183
184 for &slice_start in &to_remove {
186 if let Some(slice) = self.slices.get(&slice_start) {
187 let offset = (slice_start - merged_start) as usize;
188 new_data[offset..offset + slice.data.len()].copy_from_slice(&slice.data);
189 bytes_removed += slice.data.len();
190 self.total_bytes -= slice.data.len();
191 }
192 }
193
194 let offset = (start - merged_start) as usize;
196 new_data[offset..offset + data_len].copy_from_slice(&data);
197
198 for slice_start in to_remove {
200 self.slices.remove(&slice_start);
201 }
202
203 merged_data = Some(new_data);
204 }
205
206 let (final_start, final_data) = if let Some(md) = merged_data {
208 (merged_start, md)
209 } else {
210 (start, data)
211 };
212
213 let bytes_added = final_data.len();
214 self.total_bytes += bytes_added;
215
216 self.slices.insert(
217 final_start,
218 CachedSlice {
219 range: final_start..final_start + bytes_added as u64,
220 data: Arc::new(final_data),
221 access_count: access_counter,
222 },
223 );
224
225 bytes_added as isize - bytes_removed as isize
227 }
228
229 fn evict_lru(&mut self, bytes_to_free: usize) -> usize {
231 let mut freed = 0;
232
233 while freed < bytes_to_free && !self.slices.is_empty() {
234 let lru_key = self
236 .slices
237 .iter()
238 .min_by_key(|(_, s)| s.access_count)
239 .map(|(&k, _)| k);
240
241 if let Some(key) = lru_key {
242 if let Some(slice) = self.slices.remove(&key) {
243 freed += slice.data.len();
244 self.total_bytes -= slice.data.len();
245 }
246 } else {
247 break;
248 }
249 }
250
251 freed
252 }
253}
254
255pub struct SliceCachingDirectory<D: Directory> {
262 inner: Arc<D>,
263 caches: Arc<RwLock<std::collections::HashMap<PathBuf, FileSliceCache>>>,
265 max_bytes: usize,
267 current_bytes: Arc<RwLock<usize>>,
269 access_counter: Arc<RwLock<u64>>,
271}
272
273impl<D: Directory> SliceCachingDirectory<D> {
274 pub fn new(inner: D, max_bytes: usize) -> Self {
276 Self {
277 inner: Arc::new(inner),
278 caches: Arc::new(RwLock::new(std::collections::HashMap::new())),
279 max_bytes,
280 current_bytes: Arc::new(RwLock::new(0)),
281 access_counter: Arc::new(RwLock::new(0)),
282 }
283 }
284
285 pub fn inner(&self) -> &D {
287 &self.inner
288 }
289
290 fn try_cache_read(&self, path: &Path, range: Range<u64>) -> Option<Vec<u8>> {
292 let mut caches = self.caches.write();
293 let mut counter = self.access_counter.write();
294
295 if let Some(file_cache) = caches.get_mut(path) {
296 file_cache.try_read(range, &mut counter)
297 } else {
298 None
299 }
300 }
301
302 fn cache_insert(&self, path: &Path, range: Range<u64>, data: Vec<u8>) {
304 let data_len = data.len();
305
306 {
308 let current = *self.current_bytes.read();
309 if current + data_len > self.max_bytes {
310 self.evict_to_fit(data_len);
311 }
312 }
313
314 let mut caches = self.caches.write();
315 let counter = *self.access_counter.read();
316
317 let file_cache = caches
318 .entry(path.to_path_buf())
319 .or_insert_with(FileSliceCache::new);
320
321 let net_change = file_cache.insert(range, data, counter);
322 let mut current = self.current_bytes.write();
323 if net_change >= 0 {
324 *current += net_change as usize;
325 } else {
326 *current = current.saturating_sub((-net_change) as usize);
327 }
328 }
329
330 fn evict_to_fit(&self, needed: usize) {
332 let mut caches = self.caches.write();
333 let mut current = self.current_bytes.write();
334
335 let target = if *current + needed > self.max_bytes {
336 (*current + needed) - self.max_bytes
337 } else {
338 return;
339 };
340
341 let mut freed = 0;
342
343 while freed < target {
345 let oldest_file = caches
347 .iter()
348 .filter(|(_, fc)| !fc.slices.is_empty())
349 .min_by_key(|(_, fc)| {
350 fc.slices
351 .values()
352 .map(|s| s.access_count)
353 .min()
354 .unwrap_or(u64::MAX)
355 })
356 .map(|(p, _)| p.clone());
357
358 if let Some(path) = oldest_file {
359 if let Some(file_cache) = caches.get_mut(&path) {
360 freed += file_cache.evict_lru(target - freed);
361 }
362 } else {
363 break;
364 }
365 }
366
367 *current = current.saturating_sub(freed);
368 }
369
370 pub fn stats(&self) -> SliceCacheStats {
372 let caches = self.caches.read();
373 let mut total_slices = 0;
374 let mut files_cached = 0;
375
376 for fc in caches.values() {
377 if !fc.slices.is_empty() {
378 files_cached += 1;
379 total_slices += fc.slices.len();
380 }
381 }
382
383 SliceCacheStats {
384 total_bytes: *self.current_bytes.read(),
385 max_bytes: self.max_bytes,
386 total_slices,
387 files_cached,
388 }
389 }
390
391 pub fn serialize(&self) -> Vec<u8> {
402 let caches = self.caches.read();
403 let mut buf = Vec::new();
404
405 buf.extend_from_slice(SLICE_CACHE_MAGIC);
407 buf.extend_from_slice(&SLICE_CACHE_VERSION.to_le_bytes());
408
409 let non_empty: Vec<_> = caches
411 .iter()
412 .filter(|(_, fc)| !fc.slices.is_empty())
413 .collect();
414 buf.extend_from_slice(&(non_empty.len() as u32).to_le_bytes());
415
416 for (path, file_cache) in non_empty {
417 let path_str = path.to_string_lossy();
419 let path_bytes = path_str.as_bytes();
420 buf.extend_from_slice(&(path_bytes.len() as u32).to_le_bytes());
421 buf.extend_from_slice(path_bytes);
422
423 let cache_data = file_cache.serialize();
425 buf.extend_from_slice(&cache_data);
426 }
427
428 buf
429 }
430
431 pub fn deserialize(&self, data: &[u8]) -> io::Result<()> {
436 let mut pos = 0;
437
438 if data.len() < 16 {
440 return Err(io::Error::new(
441 io::ErrorKind::InvalidData,
442 "slice cache too short",
443 ));
444 }
445 if &data[pos..pos + 8] != SLICE_CACHE_MAGIC {
446 return Err(io::Error::new(
447 io::ErrorKind::InvalidData,
448 "invalid slice cache magic",
449 ));
450 }
451 pos += 8;
452
453 let version = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap());
455 if version != SLICE_CACHE_VERSION {
456 return Err(io::Error::new(
457 io::ErrorKind::InvalidData,
458 format!("unsupported slice cache version: {}", version),
459 ));
460 }
461 pos += 4;
462
463 let num_files = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
465 pos += 4;
466
467 let mut caches = self.caches.write();
468 let mut current_bytes = self.current_bytes.write();
469 let counter = *self.access_counter.read();
470
471 for _ in 0..num_files {
472 if pos + 4 > data.len() {
474 return Err(io::Error::new(
475 io::ErrorKind::InvalidData,
476 "truncated path length",
477 ));
478 }
479 let path_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
480 pos += 4;
481
482 if pos + path_len > data.len() {
484 return Err(io::Error::new(io::ErrorKind::InvalidData, "truncated path"));
485 }
486 let path_str = std::str::from_utf8(&data[pos..pos + path_len])
487 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
488 let path = PathBuf::from(path_str);
489 pos += path_len;
490
491 let (file_cache, consumed) = FileSliceCache::deserialize(&data[pos..], counter)?;
493 pos += consumed;
494
495 *current_bytes += file_cache.total_bytes;
497 caches.insert(path, file_cache);
498 }
499
500 Ok(())
501 }
502
503 pub fn serialize_to_writer<W: Write>(&self, mut writer: W) -> io::Result<()> {
505 let data = self.serialize();
506 writer.write_all(&data)
507 }
508
509 pub fn deserialize_from_reader<R: Read>(&self, mut reader: R) -> io::Result<()> {
511 let mut data = Vec::new();
512 reader.read_to_end(&mut data)?;
513 self.deserialize(&data)
514 }
515
516 pub fn is_empty(&self) -> bool {
518 *self.current_bytes.read() == 0
519 }
520
521 pub fn clear(&self) {
523 let mut caches = self.caches.write();
524 let mut current_bytes = self.current_bytes.write();
525 caches.clear();
526 *current_bytes = 0;
527 }
528}
529
530#[derive(Debug, Clone)]
532pub struct SliceCacheStats {
533 pub total_bytes: usize,
534 pub max_bytes: usize,
535 pub total_slices: usize,
536 pub files_cached: usize,
537}
538
539#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
540#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
541impl<D: Directory> Directory for SliceCachingDirectory<D> {
542 async fn exists(&self, path: &Path) -> io::Result<bool> {
543 self.inner.exists(path).await
544 }
545
546 async fn file_size(&self, path: &Path) -> io::Result<u64> {
547 self.inner.file_size(path).await
548 }
549
550 async fn open_read(&self, path: &Path) -> io::Result<FileSlice> {
551 let file_size = self.inner.file_size(path).await?;
553 let full_range = 0..file_size;
554
555 if let Some(data) = self.try_cache_read(path, full_range.clone()) {
557 return Ok(FileSlice::new(OwnedBytes::new(data)));
558 }
559
560 let slice = self.inner.open_read(path).await?;
562 let bytes = slice.read_bytes().await?;
563
564 self.cache_insert(path, full_range, bytes.as_slice().to_vec());
566
567 Ok(FileSlice::new(bytes))
568 }
569
570 async fn read_range(&self, path: &Path, range: Range<u64>) -> io::Result<OwnedBytes> {
571 if let Some(data) = self.try_cache_read(path, range.clone()) {
573 return Ok(OwnedBytes::new(data));
574 }
575
576 let data = self.inner.read_range(path, range.clone()).await?;
578
579 self.cache_insert(path, range, data.as_slice().to_vec());
581
582 Ok(data)
583 }
584
585 async fn list_files(&self, prefix: &Path) -> io::Result<Vec<PathBuf>> {
586 self.inner.list_files(prefix).await
587 }
588
589 async fn open_lazy(&self, path: &Path) -> io::Result<LazyFileHandle> {
590 let file_size = self.inner.file_size(path).await? as usize;
592
593 let path_buf = path.to_path_buf();
595 let caches = Arc::clone(&self.caches);
596 let current_bytes = Arc::clone(&self.current_bytes);
597 let access_counter = Arc::clone(&self.access_counter);
598 let max_bytes = self.max_bytes;
599 let inner = Arc::clone(&self.inner);
600
601 let read_fn: RangeReadFn = Arc::new(move |range: Range<u64>| {
602 let path = path_buf.clone();
603 let caches = Arc::clone(&caches);
604 let current_bytes = Arc::clone(¤t_bytes);
605 let access_counter = Arc::clone(&access_counter);
606 let inner = Arc::clone(&inner);
607
608 Box::pin(async move {
609 {
611 let mut caches_guard = caches.write();
612 let mut counter = access_counter.write();
613 if let Some(file_cache) = caches_guard.get_mut(&path)
614 && let Some(data) = file_cache.try_read(range.clone(), &mut counter)
615 {
616 return Ok(OwnedBytes::new(data));
617 }
618 }
619
620 let data = inner.read_range(&path, range.clone()).await?;
622
623 let data_len = data.as_slice().len();
625 {
626 let current = *current_bytes.read();
628 if current + data_len > max_bytes {
629 } else {
632 let mut caches_guard = caches.write();
633 let counter = *access_counter.read();
634
635 let file_cache = caches_guard
636 .entry(path.clone())
637 .or_insert_with(FileSliceCache::new);
638
639 let net_change =
640 file_cache.insert(range, data.as_slice().to_vec(), counter);
641 let mut current = current_bytes.write();
642 if net_change >= 0 {
643 *current += net_change as usize;
644 } else {
645 *current = current.saturating_sub((-net_change) as usize);
646 }
647 }
648 }
649
650 Ok(data)
651 })
652 });
653
654 Ok(LazyFileHandle::new(file_size, read_fn))
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661 use crate::directories::{DirectoryWriter, RamDirectory};
662
663 #[tokio::test]
664 async fn test_slice_cache_basic() {
665 let ram = RamDirectory::new();
666 ram.write(Path::new("test.bin"), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
667 .await
668 .unwrap();
669
670 let cached = SliceCachingDirectory::new(ram, 1024);
671
672 let data = cached
674 .read_range(Path::new("test.bin"), 2..5)
675 .await
676 .unwrap();
677 assert_eq!(data.as_slice(), &[2, 3, 4]);
678
679 let data = cached
681 .read_range(Path::new("test.bin"), 2..5)
682 .await
683 .unwrap();
684 assert_eq!(data.as_slice(), &[2, 3, 4]);
685
686 let stats = cached.stats();
687 assert_eq!(stats.total_slices, 1);
688 assert_eq!(stats.total_bytes, 3);
689 }
690
691 #[tokio::test]
692 async fn test_slice_cache_overlap_merge() {
693 let ram = RamDirectory::new();
694 ram.write(Path::new("test.bin"), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
695 .await
696 .unwrap();
697
698 let cached = SliceCachingDirectory::new(ram, 1024);
699
700 cached
702 .read_range(Path::new("test.bin"), 2..5)
703 .await
704 .unwrap();
705
706 cached
708 .read_range(Path::new("test.bin"), 4..7)
709 .await
710 .unwrap();
711
712 let stats = cached.stats();
713 assert_eq!(stats.total_slices, 1);
715 assert_eq!(stats.total_bytes, 5); let data = cached
719 .read_range(Path::new("test.bin"), 3..6)
720 .await
721 .unwrap();
722 assert_eq!(data.as_slice(), &[3, 4, 5]);
723 }
724
725 #[tokio::test]
726 async fn test_slice_cache_eviction() {
727 let ram = RamDirectory::new();
728 ram.write(Path::new("test.bin"), &[0; 100]).await.unwrap();
729
730 let cached = SliceCachingDirectory::new(ram, 50);
732
733 cached
735 .read_range(Path::new("test.bin"), 0..30)
736 .await
737 .unwrap();
738
739 cached
741 .read_range(Path::new("test.bin"), 50..80)
742 .await
743 .unwrap();
744
745 let stats = cached.stats();
746 assert!(stats.total_bytes <= 50);
747 }
748
749 #[tokio::test]
750 async fn test_slice_cache_serialize_deserialize() {
751 let ram = RamDirectory::new();
752 ram.write(Path::new("file1.bin"), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
753 .await
754 .unwrap();
755 ram.write(Path::new("file2.bin"), &[10, 11, 12, 13, 14, 15])
756 .await
757 .unwrap();
758
759 let cached = SliceCachingDirectory::new(ram.clone(), 1024);
760
761 cached
763 .read_range(Path::new("file1.bin"), 2..6)
764 .await
765 .unwrap();
766 cached
767 .read_range(Path::new("file2.bin"), 1..4)
768 .await
769 .unwrap();
770
771 let stats = cached.stats();
772 assert_eq!(stats.files_cached, 2);
773 assert_eq!(stats.total_bytes, 7); let serialized = cached.serialize();
777 assert!(!serialized.is_empty());
778
779 let cached2 = SliceCachingDirectory::new(ram.clone(), 1024);
781 assert!(cached2.is_empty());
782
783 cached2.deserialize(&serialized).unwrap();
784
785 let stats2 = cached2.stats();
786 assert_eq!(stats2.files_cached, 2);
787 assert_eq!(stats2.total_bytes, 7);
788
789 let data = cached2
791 .read_range(Path::new("file1.bin"), 2..6)
792 .await
793 .unwrap();
794 assert_eq!(data.as_slice(), &[2, 3, 4, 5]);
795
796 let data = cached2
797 .read_range(Path::new("file2.bin"), 1..4)
798 .await
799 .unwrap();
800 assert_eq!(data.as_slice(), &[11, 12, 13]);
801 }
802
803 #[tokio::test]
804 async fn test_slice_cache_serialize_empty() {
805 let ram = RamDirectory::new();
806 let cached = SliceCachingDirectory::new(ram, 1024);
807
808 let serialized = cached.serialize();
810 assert!(!serialized.is_empty()); let cached2 = SliceCachingDirectory::new(RamDirectory::new(), 1024);
814 cached2.deserialize(&serialized).unwrap();
815 assert!(cached2.is_empty());
816 }
817}