use std::ops::Range;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use tokio::sync::Mutex;
use crate::error::AsyncTiffResult;
use crate::metadata::MetadataFetch;
#[derive(Debug)]
struct SequentialBlockCache {
buffers: Vec<Bytes>,
len: u64,
}
impl SequentialBlockCache {
fn new() -> Self {
Self {
buffers: vec![],
len: 0,
}
}
fn contains(&self, range: Range<u64>) -> bool {
range.end <= self.len
}
fn slice(&self, range: Range<u64>) -> Bytes {
let out_len = (range.end - range.start) as usize;
let mut remaining = range;
let mut out_buffers: Vec<Bytes> = vec![];
for buf in &self.buffers {
let current_buf_len = buf.len() as u64;
if remaining.start >= current_buf_len {
remaining.start -= current_buf_len;
remaining.end -= current_buf_len;
continue;
}
let start = remaining.start as usize;
let length =
(remaining.end - remaining.start).min(current_buf_len - remaining.start) as usize;
let end = start + length;
if start == end {
continue;
}
let chunk = buf.slice(start..end);
out_buffers.push(chunk);
remaining.start = 0;
if remaining.end <= current_buf_len {
break;
}
remaining.end -= current_buf_len;
}
if out_buffers.len() == 1 {
out_buffers.into_iter().next().unwrap()
} else {
let mut out = BytesMut::with_capacity(out_len);
for b in out_buffers {
out.extend_from_slice(&b);
}
out.into()
}
}
fn append_buffer(&mut self, buffer: Bytes) {
self.len += buffer.len() as u64;
self.buffers.push(buffer);
}
}
#[derive(Debug)]
pub struct ReadaheadMetadataCache<F: MetadataFetch> {
inner: F,
cache: Arc<Mutex<SequentialBlockCache>>,
initial: u64,
multiplier: f64,
}
impl<F: MetadataFetch> ReadaheadMetadataCache<F> {
pub fn new(inner: F) -> Self {
Self {
inner,
cache: Arc::new(Mutex::new(SequentialBlockCache::new())),
initial: 32 * 1024,
multiplier: 2.0,
}
}
pub fn inner(&self) -> &F {
&self.inner
}
pub fn with_initial_size(mut self, initial: u64) -> Self {
self.initial = initial;
self
}
pub fn with_multiplier(mut self, multiplier: f64) -> Self {
self.multiplier = multiplier;
self
}
fn next_fetch_size(&self, existing_len: u64) -> u64 {
if existing_len == 0 {
self.initial
} else {
(existing_len as f64 * self.multiplier).round() as u64
}
}
}
#[async_trait]
impl<F: MetadataFetch + Send + Sync> MetadataFetch for ReadaheadMetadataCache<F> {
async fn fetch(&self, range: Range<u64>) -> AsyncTiffResult<Bytes> {
let mut cache = self.cache.lock().await;
if cache.contains(range.start..range.end) {
return Ok(cache.slice(range));
}
let start_len = cache.len;
let needed = range.end.saturating_sub(start_len);
let fetch_size = self.next_fetch_size(start_len).max(needed);
let fetch_range = start_len..start_len + fetch_size;
let bytes = self.inner.fetch(fetch_range).await?;
cache.append_buffer(bytes);
Ok(cache.slice(range))
}
}
#[cfg(test)]
mod test {
use super::*;
#[derive(Debug)]
struct TestFetch {
data: Bytes,
num_fetches: Arc<Mutex<u64>>,
}
impl TestFetch {
fn new(data: Bytes) -> Self {
Self {
data,
num_fetches: Arc::new(Mutex::new(0)),
}
}
}
#[async_trait]
impl MetadataFetch for TestFetch {
async fn fetch(&self, range: Range<u64>) -> crate::error::AsyncTiffResult<Bytes> {
if range.start as usize >= self.data.len() {
return Ok(Bytes::new());
}
let end = (range.end as usize).min(self.data.len());
let slice = self.data.slice(range.start as _..end);
let mut g = self.num_fetches.lock().await;
*g += 1;
Ok(slice)
}
}
#[tokio::test]
async fn test_readahead_cache() {
let data = Bytes::from_static(b"abcdefghijklmnopqrstuvwxyz");
let fetch = TestFetch::new(data.clone());
let cache = ReadaheadMetadataCache::new(fetch)
.with_initial_size(2)
.with_multiplier(3.0);
let result = cache.fetch(0..2).await.unwrap();
assert_eq!(result.as_ref(), b"ab");
assert_eq!(*cache.inner.num_fetches.lock().await, 1);
let result = cache.fetch(1..2).await.unwrap();
assert_eq!(result.as_ref(), b"b");
assert_eq!(*cache.inner.num_fetches.lock().await, 1);
let result = cache.fetch(2..5).await.unwrap();
assert_eq!(result.as_ref(), b"cde");
assert_eq!(*cache.inner.num_fetches.lock().await, 2);
let result = cache.fetch(5..8).await.unwrap();
assert_eq!(result.as_ref(), b"fgh");
assert_eq!(*cache.inner.num_fetches.lock().await, 2);
let result = cache.fetch(8..20).await.unwrap();
assert_eq!(result.as_ref(), b"ijklmnopqrst");
assert_eq!(*cache.inner.num_fetches.lock().await, 3);
}
#[test]
fn test_sequential_block_cache_empty_buffers() {
let mut cache = SequentialBlockCache::new();
cache.append_buffer(Bytes::from_static(b"012"));
cache.append_buffer(Bytes::from_static(b""));
cache.append_buffer(Bytes::from_static(b"34"));
cache.append_buffer(Bytes::from_static(b""));
cache.append_buffer(Bytes::from_static(b"5"));
cache.append_buffer(Bytes::from_static(b""));
cache.append_buffer(Bytes::from_static(b"67"));
let test_cases = [
(0..3, true, Bytes::from_static(b"012")),
(4..7, true, Bytes::from_static(b"456")),
(0..8, true, Bytes::from_static(b"01234567")),
(6..6, true, Bytes::from_static(b"")),
(6..9, false, Bytes::from_static(b"")),
(9..9, false, Bytes::from_static(b"")),
(8..10, false, Bytes::from_static(b"")),
];
for (range, exists, expected) in test_cases {
assert_eq!(cache.contains(range.clone()), exists);
if exists {
assert_eq!(cache.slice(range.clone()), expected);
}
}
}
}