use std::{
cmp::min,
collections::BTreeMap,
io::{BufReader, ErrorKind, Read, Seek, SeekFrom},
ops::Range,
sync::{Arc, Condvar, Mutex},
};
use ranges::Ranges;
use reqwest::blocking::Response;
use thiserror::Error;
use super::http_range_reader::{self, RangeFetcher};
const DEFAULT_MAX_BLOCK: usize = 1024 * 1024;
const DEFAULT_SKIP_AHEAD_THRESHOLD: u64 = 2 * 1024 * 1024;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub(crate) enum AccessPattern {
RandomAccess,
SequentialIsh,
}
impl Default for AccessPattern {
fn default() -> Self {
Self::RandomAccess
}
}
#[derive(Error, Debug)]
pub(crate) enum Error {
#[error(
"This HTTP resource did not advertise that it accepts ranges via the Accept-Ranges header"
)]
AcceptRangesNotSupported,
#[error(transparent)]
RangeFetcherError(http_range_reader::Error),
}
struct CacheCell {
data: Vec<u8>,
bytes_read: Ranges<usize>,
}
impl CacheCell {
fn new(data: Vec<u8>) -> Self {
Self {
data,
bytes_read: Ranges::new(),
}
}
fn read(&mut self, range: Range<usize>) -> &[u8] {
let new_range = self.bytes_read.clone().union(Ranges::from(range.clone()));
self.bytes_read = new_range;
&self.data[range]
}
fn len(&self) -> usize {
self.data.len()
}
fn entirely_consumed(&self) -> bool {
let data_left_to_read =
Ranges::from(0..self.data.len()).difference(self.bytes_read.clone());
data_left_to_read.is_empty()
}
}
#[derive(Default)]
struct State {
access_pattern: AccessPattern,
readahead_limit: Option<usize>,
current_size: usize,
cache: BTreeMap<u64, CacheCell>,
expect_skip_ahead: bool,
skip_ahead_threshold: u64,
max_block: usize,
stats: SeekableHttpReaderStatistics,
reader: Option<Box<ReadingMaterials>>,
read_failed_somewhere: bool,
}
impl State {
fn new(
readahead_limit: Option<usize>,
access_pattern: AccessPattern,
skip_ahead_threshold: u64,
max_block: usize,
reader: Box<ReadingMaterials>,
) -> Self {
let readahead_limit = match readahead_limit {
Some(readahead_limit) if readahead_limit > max_block => Some(readahead_limit),
Some(_) => Some(max_block),
_ => None,
};
Self {
readahead_limit,
access_pattern,
skip_ahead_threshold,
max_block,
reader: Some(reader),
..Default::default()
}
}
fn insert(&mut self, pos: u64, block: Vec<u8>) {
log::debug!(
"Inserting into cache, block is 0x{:x}-0x{:x}",
pos,
pos + block.len() as u64
);
let extra_size = block.len();
self.cache.insert(pos, CacheCell::new(block));
self.current_size += extra_size;
if let Some(readahead_limit) = self.readahead_limit {
while self.current_size > readahead_limit {
self.stats.cache_shrinks += 1;
let first_block = self.cache.iter().next().map(|(pos, _)| pos).cloned();
if let Some(pos) = first_block {
let block = self.cache.remove(&pos).unwrap();
self.current_size -= block.len();
}
}
}
}
fn read_from_cache(&mut self, pos: u64, buf: &mut [u8]) -> Option<usize> {
let discard_read_data = matches!(self.access_pattern, AccessPattern::SequentialIsh);
let mut block_to_discard = None;
let mut return_value = None;
for (possible_block_start, block) in self
.cache
.range_mut(pos - min(pos, self.max_block as u64)..=pos)
{
let block_offset = pos as usize - *possible_block_start as usize;
let block_len = block.len();
if block_offset >= block_len {
continue;
}
let block_len = block.len();
let block_offset = pos as usize - *possible_block_start as usize;
let to_read = min(buf.len(), block_len - block_offset);
buf[..to_read].copy_from_slice(block.read(block_offset..to_read + block_offset));
self.stats.cache_hits += 1;
if discard_read_data && block.entirely_consumed() {
block_to_discard = Some(*possible_block_start);
self.current_size -= block.len();
}
return_value = Some(to_read);
break;
}
if let Some(block_to_discard) = block_to_discard {
self.cache.remove(&block_to_discard);
}
return_value
}
}
impl std::fmt::Debug for State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Cache")
.field("max_size", &self.readahead_limit)
.field("current_size", &self.current_size)
.finish()
}
}
struct ReadingMaterials {
range_fetcher: RangeFetcher,
reader: Option<(BufReader<Response>, u64)>, }
pub(crate) struct SeekableHttpReaderEngine {
len: u64,
state: Mutex<State>,
read_completed: Condvar,
}
#[derive(Default, Debug, Clone)]
pub(crate) struct SeekableHttpReaderStatistics {
pub(crate) num_http_streams: usize,
pub(crate) cache_hits: usize,
pub(crate) cache_misses: usize,
pub(crate) cache_shrinks: usize,
}
impl SeekableHttpReaderEngine {
pub(crate) fn new(
uri: String,
readahead_limit: Option<usize>,
access_pattern: AccessPattern,
) -> Result<Arc<Self>, Error> {
Self::with_configuration(
uri,
readahead_limit,
access_pattern,
DEFAULT_SKIP_AHEAD_THRESHOLD,
DEFAULT_MAX_BLOCK,
)
}
fn with_configuration(
uri: String,
readahead_limit: Option<usize>,
access_pattern: AccessPattern,
skip_ahead_threshold: u64,
max_block: usize,
) -> Result<Arc<Self>, Error> {
let range_fetcher = RangeFetcher::new(uri).map_err(Error::RangeFetcherError)?;
if !range_fetcher.accepts_ranges() {
return Err(Error::AcceptRangesNotSupported);
}
let len = range_fetcher.len();
Ok(Arc::new(Self {
len,
state: Mutex::new(State::new(
readahead_limit,
access_pattern,
skip_ahead_threshold,
max_block,
Box::new(ReadingMaterials {
range_fetcher,
reader: None,
}),
)),
read_completed: Condvar::new(),
}))
}
pub(crate) fn create_reader(self: Arc<Self>) -> SeekableHttpReader {
SeekableHttpReader {
engine: self,
pos: 0u64,
}
}
fn read(&self, buf: &mut [u8], pos: u64) -> std::io::Result<usize> {
log::debug!("Read: requested position 0x{:x}.", pos);
if pos == self.len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"read beyond end of stream",
));
}
let mut state = self.state.lock().unwrap();
if let Some(bytes_read_from_cache) = state.read_from_cache(pos, buf) {
log::debug!("Immediate cache success");
return Ok(bytes_read_from_cache);
}
let mut reading_stuff = state.reader.take();
while reading_stuff.is_none() {
state = self.read_completed.wait(state).unwrap();
if state.read_failed_somewhere {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"another thread experienced a problem creating a reader",
));
}
if let Some(bytes_read_from_cache) = state.read_from_cache(pos, buf) {
log::debug!("Deferred cache success");
return Ok(bytes_read_from_cache);
}
reading_stuff = state.reader.take();
}
let reading_stuff = reading_stuff.unwrap(); state.stats.cache_misses += 1;
let expect_skip_ahead = state.expect_skip_ahead;
state.expect_skip_ahead = false;
let skip_ahead_threshold = state.skip_ahead_threshold;
let max_block = state.max_block;
drop(state);
let read_result = self.perform_read_using_reader(
buf,
pos,
reading_stuff,
expect_skip_ahead,
skip_ahead_threshold,
max_block,
);
if read_result.is_err() {
let mut state = self.state.lock().unwrap();
state.read_failed_somewhere = true;
}
self.read_completed.notify_all();
read_result
}
#[allow(clippy::comparison_chain)]
fn perform_read_using_reader(
&self,
buf: &mut [u8],
pos: u64,
mut reading_stuff: Box<ReadingMaterials>,
expect_skip_ahead: bool,
skip_ahead_threshold: u64,
max_block: usize,
) -> std::io::Result<usize> {
if let Some((_, readerpos)) = reading_stuff.reader.as_ref() {
if pos < *readerpos {
log::debug!(
"Rewinding: New reader will be required at 0x{:x} - old reader pos was 0x{:x}",
pos,
*readerpos
);
reading_stuff.reader = None;
} else if pos > *readerpos {
let delta = pos - *readerpos;
if delta > skip_ahead_threshold && expect_skip_ahead {
log::debug!("Fast forwarding expected skip: New reader will be required at 0x{:x} - old reader pos was 0x{:x}",
pos,
*readerpos
);
reading_stuff.reader = None;
}
}
}
let mut reader_created = false;
if reading_stuff.reader.is_none() {
log::debug!("create_reader");
reading_stuff.reader = Some((
BufReader::new(
reading_stuff
.range_fetcher
.fetch_range(pos)
.map_err(|e| std::io::Error::new(ErrorKind::Unsupported, e.to_string()))?,
),
pos,
));
reader_created = true;
};
let (reader, reader_pos) = reading_stuff.reader.as_mut().unwrap();
if pos > *reader_pos {
log::debug!(
"Read: reading ahead from 0x{:x} to 0x{:x} without skipping",
*reader_pos,
pos
);
}
while pos >= *reader_pos {
let to_read = min(max_block, self.len as usize - *reader_pos as usize);
let mut new_block = vec![0u8; to_read];
reader.read_exact(&mut new_block)?;
let mut state = self.state.lock().unwrap();
state.insert(*reader_pos, new_block);
self.read_completed.notify_all();
*reader_pos += to_read as u64;
}
let mut state = self.state.lock().unwrap();
let bytes_read = state
.read_from_cache(pos, buf)
.expect("Cache still couldn't satisfy request event after reading beyond read pos");
log::debug!("Cache success after read");
if reader_created {
state.stats.num_http_streams += 1;
}
state.reader = Some(reading_stuff);
Ok(bytes_read)
}
pub(crate) fn len(&self) -> u64 {
self.len
}
pub(crate) fn set_expected_access_pattern(&self, access_pattern: AccessPattern) {
let mut state = self.state.lock().unwrap();
let old_access_pattern = state.access_pattern;
if old_access_pattern == access_pattern {
return;
}
log::debug!(
"Changing access pattern - current stats are {:?}",
state.stats
);
if matches!(access_pattern, AccessPattern::SequentialIsh) {
log::debug!("create_reader_at_zero");
{
let reading_materials = state.reader.as_mut().expect(
"Must not call set_expected_access_pattern while a read is in progress",
);
let new_reader = reading_materials.range_fetcher.fetch_range(0);
if let Ok(new_reader) = new_reader {
reading_materials.reader = Some((BufReader::new(new_reader), 0));
}
}
state.stats.num_http_streams += 1;
}
state.access_pattern = access_pattern;
}
pub(crate) fn read_skip_expected(&self) {
let mut state = self.state.lock().unwrap();
state.expect_skip_ahead = true;
}
pub(crate) fn get_stats(&self) -> SeekableHttpReaderStatistics {
self.state.lock().unwrap().stats.clone()
}
}
impl Drop for SeekableHttpReaderEngine {
fn drop(&mut self) {
log::debug!("Dropping: stats are {:?}", self.state.lock().unwrap().stats)
}
}
#[derive(Clone)]
pub(crate) struct SeekableHttpReader {
engine: Arc<SeekableHttpReaderEngine>,
pos: u64,
}
impl Seek for SeekableHttpReader {
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
self.pos = match pos {
SeekFrom::Start(pos) => pos,
SeekFrom::End(pos) => {
let positive_pos: u64 = (-pos).try_into().map_err(|_| {
std::io::Error::new(std::io::ErrorKind::Unsupported, "Seeked beyond end")
})?;
self.engine
.len()
.checked_sub(positive_pos)
.ok_or(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Rewind too far",
))?
}
SeekFrom::Current(offset_from_pos) => {
let offset_from_pos_u64: Result<u64, _> = offset_from_pos.try_into();
match offset_from_pos_u64 {
Ok(positive_offset) => self.pos + positive_offset,
Err(_) => {
let negative_offset = -offset_from_pos as u64;
self.pos
.checked_sub(negative_offset)
.ok_or(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Rewind too far",
))?
}
}
}
};
Ok(self.pos)
}
}
impl Read for SeekableHttpReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let bytes_read = self.engine.read(buf, self.pos)?;
self.pos += bytes_read as u64;
Ok(bytes_read)
}
}
#[cfg(test)]
mod tests {
use ripunzip_test_utils::{ExpectedRange, RangeAwareResponse, RangeAwareResponseType};
use std::io::{Read, Seek, SeekFrom};
use test_log::test;
use httptest::{matchers::*, Expectation, Server};
use crate::unzip::seekable_http_reader::DEFAULT_MAX_BLOCK;
use super::{AccessPattern, CacheCell, SeekableHttpReaderEngine};
#[test]
fn test_cachecell() {
let mut cell = CacheCell::new(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
assert_eq!(cell.len(), 10);
assert!(!cell.entirely_consumed());
assert_eq!(cell.read(0..2), &[0, 1]);
assert!(!cell.entirely_consumed());
assert_eq!(cell.read(3..10), &[3, 4, 5, 6, 7, 8, 9]);
assert!(!cell.entirely_consumed());
assert_eq!(cell.read(0..2), &[0, 1]);
assert!(!cell.entirely_consumed());
assert_eq!(cell.read(1..4), &[1, 2, 3]);
assert!(cell.entirely_consumed());
}
#[test]
fn test_unlimited_readahead() {
do_test(None, AccessPattern::SequentialIsh)
}
#[test]
fn test_big_readahead() {
const ONE_HUNDRED_MB: usize = 1024usize * 1024usize * 100usize;
do_test(Some(ONE_HUNDRED_MB), AccessPattern::SequentialIsh)
}
#[test]
fn test_small_readahead() {
do_test(Some(4), AccessPattern::SequentialIsh)
}
#[test]
fn test_random_access() {
do_test(None, AccessPattern::RandomAccess)
}
fn get_head_expectation() -> Expectation {
Expectation::matching(request::method_path("HEAD", "/foo")).respond_with(
RangeAwareResponse::new(200, RangeAwareResponseType::LengthOnly(12)),
)
}
const TEST_BODY: &[u8] = "0123456789AB".as_bytes();
fn get_range_expectation(expected_start: u64, expected_end: u64) -> Expectation {
Expectation::matching(request::method_path("GET", "/foo"))
.times(1..)
.respond_with(RangeAwareResponse::new(
206,
RangeAwareResponseType::Body {
body: TEST_BODY.into(),
expected_range: Some(ExpectedRange {
expected_start,
expected_end,
}),
},
))
}
fn do_test(readahead_limit: Option<usize>, access_pattern: AccessPattern) {
let mut server = Server::run();
server.expect(get_head_expectation());
let seekable_http_reader_engine = SeekableHttpReaderEngine::with_configuration(
server.url("/foo").to_string(),
readahead_limit,
access_pattern,
4,
DEFAULT_MAX_BLOCK,
)
.unwrap();
let mut seekable_http_reader = seekable_http_reader_engine.create_reader();
server.verify_and_clear();
let mut throwaway = [0u8; 4];
server.expect(get_range_expectation(0, 12));
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "0123");
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "4567");
seekable_http_reader.stream_position().unwrap();
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "89AB");
server.verify_and_clear();
seekable_http_reader.rewind().unwrap();
if matches!(access_pattern, AccessPattern::SequentialIsh) {
server.expect(get_range_expectation(0, 12));
}
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "0123");
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "4567");
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "89AB");
if matches!(access_pattern, AccessPattern::SequentialIsh) {
server.verify_and_clear();
}
seekable_http_reader.seek(SeekFrom::Start(4)).unwrap();
if matches!(access_pattern, AccessPattern::SequentialIsh) {
server.expect(get_range_expectation(4, 12));
}
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "4567");
if matches!(access_pattern, AccessPattern::SequentialIsh) {
server.verify_and_clear();
}
if matches!(access_pattern, AccessPattern::SequentialIsh) {
seekable_http_reader.rewind().unwrap();
server.expect(get_range_expectation(0, 12));
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "0123");
server.verify_and_clear();
seekable_http_reader.seek(SeekFrom::Start(8)).unwrap();
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "89AB");
server.verify_and_clear();
server.expect(get_head_expectation());
let seekable_http_reader_engine = SeekableHttpReaderEngine::with_configuration(
server.url("/foo").to_string(),
readahead_limit,
access_pattern,
4,
4,
)
.unwrap();
let mut seekable_http_reader = seekable_http_reader_engine.clone().create_reader();
seekable_http_reader.rewind().unwrap();
server.expect(get_range_expectation(0, 12));
seekable_http_reader.read_exact(&mut throwaway).unwrap();
assert_eq!(std::str::from_utf8(&throwaway).unwrap(), "0123");
server.verify_and_clear();
seekable_http_reader_engine.read_skip_expected();
seekable_http_reader.seek(SeekFrom::Start(10)).unwrap();
server.expect(get_range_expectation(10, 12));
seekable_http_reader
.read_exact(&mut throwaway[0..2])
.unwrap();
assert_eq!(std::str::from_utf8(&throwaway[0..2]).unwrap(), "AB");
server.verify_and_clear();
}
}
}