use std::time::Instant;
use std::{ops::Range, sync::Arc};
use futures::task::{Spawn, SpawnExt};
use futures::{Stream, StreamExt, pin_mut};
use mountpoint_s3_client::ObjectClient;
use mountpoint_s3_client::types::GetBodyPart;
use tracing::{Instrument, debug_span, trace, warn};
use crate::async_util::Runtime;
use crate::checksums::ChecksummedBytes;
use crate::data_cache::{BlockIndex, DataCache};
use crate::mem_limiter::MemoryLimiter;
use crate::object::ObjectId;
use crate::prefetch::backpressure_controller::ReadWindowAlignmentConfig;
use super::PrefetchReadError;
use super::backpressure_controller::{BackpressureConfig, BackpressureLimiter, new_backpressure_controller};
use super::part::{Part, PartSource};
use super::part_queue::{PartQueueProducer, unbounded_part_queue};
use super::part_stream::{
ObjectPartStream, RequestRange, RequestReaderOutput, RequestTaskConfig, read_from_client_stream,
};
use super::task::RequestTask;
#[derive(Debug)]
pub struct CachingPartStream<Cache, Client: ObjectClient + Clone + Send + Sync + 'static> {
cache: Arc<Cache>,
runtime: Runtime,
client: Client,
mem_limiter: Arc<MemoryLimiter>,
}
impl<Cache, Client: ObjectClient + Clone + Send + Sync + 'static> CachingPartStream<Cache, Client> {
pub fn new(runtime: Runtime, client: Client, mem_limiter: Arc<MemoryLimiter>, cache: Cache) -> Self {
Self {
cache: Arc::new(cache),
runtime,
client,
mem_limiter,
}
}
}
impl<Cache, Client> ObjectPartStream<Client> for CachingPartStream<Cache, Client>
where
Cache: DataCache + Send + Sync + 'static,
Client: ObjectClient + Clone + Send + Sync + 'static,
{
fn spawn_get_object_request(&self, config: RequestTaskConfig) -> RequestTask<Client> {
let range = config.range;
let backpressure_config = BackpressureConfig {
initial_read_window_size: config.initial_read_window_size(),
min_read_window_size: config.read_part_size,
max_read_window_size: config.max_read_window_size,
read_window_size_multiplier: config.read_window_size_multiplier,
request_range: range.into(),
read_window_alignment_config: ReadWindowAlignmentConfig::Disable, };
let (backpressure_controller, backpressure_limiter) =
new_backpressure_controller(backpressure_config, self.mem_limiter.clone());
let (part_queue, part_queue_producer) = unbounded_part_queue(self.mem_limiter.clone());
trace!(?range, "spawning request");
let request_task = {
let request = CachingRequest::new(
self.client.clone(),
self.cache.clone(),
self.runtime.clone(),
backpressure_limiter,
config,
);
let span = debug_span!("prefetch", ?range);
request.get_from_cache(range, part_queue_producer).instrument(span)
};
let task_handle = self.runtime.spawn_with_handle(request_task).unwrap();
RequestTask::from_handle(task_handle, range, part_queue, backpressure_controller)
}
fn client(&self) -> &Client {
&self.client
}
}
#[derive(Debug)]
struct CachingRequest<Client: ObjectClient, Cache> {
client: Client,
cache: Arc<Cache>,
runtime: Runtime,
backpressure_limiter: BackpressureLimiter,
config: RequestTaskConfig,
}
impl<Client, Cache> CachingRequest<Client, Cache>
where
Client: ObjectClient + Clone + Send + Sync + 'static,
Cache: DataCache + Send + Sync + 'static,
{
fn new(
client: Client,
cache: Arc<Cache>,
runtime: Runtime,
backpressure_limiter: BackpressureLimiter,
config: RequestTaskConfig,
) -> Self {
Self {
client,
cache,
runtime,
backpressure_limiter,
config,
}
}
async fn get_from_cache(
mut self,
range: RequestRange,
part_queue_producer: PartQueueProducer<Client::ClientError>,
) {
let cache_key = &self.config.object_id;
let block_size = self.cache.block_size();
let block_range = self.block_indices_for_byte_range(&range);
let mut block_offset = block_range.start * block_size;
for block_index in block_range.clone() {
match self
.cache
.get_block(cache_key, block_index, block_offset, range.object_size())
.await
{
Ok(Some(block)) => {
trace!(?cache_key, ?range, block_index, "cache hit");
let part = try_make_part(&block, block_offset, cache_key, &range, PartSource::Cache).unwrap();
part_queue_producer.push(Ok(part));
block_offset += block_size;
if let Err(e) = self
.backpressure_limiter
.wait_for_read_window_increment(block_offset)
.await
{
part_queue_producer.push(Err(e));
break;
}
continue;
}
Ok(None) => trace!(?cache_key, block_index, ?range, "cache miss - no data for block"),
Err(error) => warn!(
?cache_key,
block_index,
?range,
?error,
"error reading block from cache, falling back to S3",
),
}
return self
.get_from_client(
range.trim_start(block_offset),
block_index..block_range.end,
part_queue_producer,
)
.await;
}
}
async fn get_from_client(
&mut self,
range: RequestRange,
block_range: Range<u64>,
part_queue_producer: PartQueueProducer<Client::ClientError>,
) {
let bucket = &self.config.bucket;
let cache_key = &self.config.object_id;
let initial_request_end_offset = self.config.range.start() + self.config.initial_request_size as u64;
let block_size = self.cache.block_size();
assert!(block_size > 0);
let start_offset = block_range.start * block_size;
let end_offset = (block_range.end * block_size).min(range.object_size() as u64);
let request_len = (end_offset - start_offset) as usize;
let block_aligned_byte_range = RequestRange::new(range.object_size(), start_offset, request_len);
trace!(
key = cache_key.key(),
range =? block_aligned_byte_range,
original_range =? range,
"fetching data from client"
);
let request_stream = read_from_client_stream(
&mut self.backpressure_limiter,
&self.client,
bucket.clone(),
cache_key.clone(),
initial_request_end_offset,
block_aligned_byte_range,
self.config.handle_id,
);
let mut part_composer = CachingPartComposer {
part_queue_producer,
cache_key: cache_key.clone(),
original_range: range,
block_index: block_range.start,
block_offset: block_range.start * block_size,
cache: self.cache.clone(),
runtime: self.runtime.clone(),
};
part_composer.try_compose_parts(request_stream, range).await;
}
fn block_indices_for_byte_range(&self, range: &RequestRange) -> Range<BlockIndex> {
let block_size = self.cache.block_size();
let start_block = range.start() / block_size;
let mut end_block = range.end() / block_size;
if !range.is_empty() && !range.end().is_multiple_of(block_size) {
end_block += 1;
}
start_block..end_block
}
}
struct CachingPartComposer<E: std::error::Error, Cache, Runtime: Spawn> {
part_queue_producer: PartQueueProducer<E>,
cache_key: ObjectId,
original_range: RequestRange,
block_index: u64,
block_offset: u64,
cache: Arc<Cache>,
runtime: Runtime,
}
impl<E, Cache, Runtime> CachingPartComposer<E, Cache, Runtime>
where
E: std::error::Error + Send + Sync,
Cache: DataCache + Send + Sync + 'static,
Runtime: Spawn,
{
async fn try_compose_parts(
&mut self,
request_stream: impl Stream<Item = RequestReaderOutput<E>>,
range: RequestRange,
) {
if let Err(e) = self.compose_parts(request_stream, range).await {
trace!(error=?e, "part stream task failed");
self.part_queue_producer.push(Err(e));
}
trace!("part composer finished");
}
async fn compose_parts(
&mut self,
request_stream: impl Stream<Item = RequestReaderOutput<E>>,
range: RequestRange,
) -> Result<(), PrefetchReadError<E>> {
let key = self.cache_key.key();
let block_size = self.cache.block_size();
let mut buffer = ChecksummedBytes::default();
pin_mut!(request_stream);
while let Some(next) = request_stream.next().await {
assert!(
buffer.len() < block_size as usize,
"buffer should be flushed when we get a full block"
);
let GetBodyPart { offset, data: mut body } = next?;
let expected_offset = self.block_offset + buffer.len() as u64;
if offset != expected_offset {
warn!(key, offset, expected_offset, "wrong offset for GetObject body part");
return Err(PrefetchReadError::GetRequestReturnedWrongOffset {
offset,
expected_offset,
});
}
let mut offset = offset;
while !body.is_empty() {
let remaining = (block_size as usize).saturating_sub(buffer.len()).min(body.len());
let chunk: ChecksummedBytes = body.split_to(remaining).into();
if let Some(part) = try_make_part(&chunk, offset, &self.cache_key, &self.original_range, PartSource::S3)
{
self.part_queue_producer.push(Ok(part));
}
offset += chunk.len() as u64;
buffer
.extend(chunk)
.inspect_err(|e| warn!(key, error=?e, "integrity check for body part failed"))?;
if buffer.len() < block_size as usize {
break;
}
self.update_cache(buffer, self.block_index, self.block_offset, &self.cache_key, range);
self.block_index += 1;
self.block_offset += block_size;
buffer = ChecksummedBytes::default();
}
}
if !buffer.is_empty() {
assert_eq!(
self.block_offset as usize + buffer.len(),
self.original_range.object_size(),
"a partial block is only allowed at the end of the object"
);
self.update_cache(buffer, self.block_index, self.block_offset, &self.cache_key, range);
}
Ok(())
}
fn update_cache(
&self,
block: ChecksummedBytes,
block_index: u64,
block_offset: u64,
object_id: &ObjectId,
range: RequestRange,
) {
let object_id = object_id.clone();
let cache = self.cache.clone();
self.runtime
.spawn(async move {
let start = Instant::now();
if let Err(error) = cache
.put_block(object_id.clone(), block_index, block_offset, block, range.object_size())
.await
{
warn!(key=?object_id, block_index, ?error, "failed to update cache");
}
metrics::histogram!("prefetch.cache_update_duration_us").record(start.elapsed().as_micros() as f64);
})
.unwrap();
}
}
fn try_make_part(
bytes: &ChecksummedBytes,
offset: u64,
object_id: &ObjectId,
range: &RequestRange,
source: PartSource,
) -> Option<Part> {
let part_range = range.trim_start(offset).trim_end(offset + bytes.len() as u64);
if part_range.is_empty() {
return None;
}
trace!(?part_range, "creating part trimmed to the request range");
let trim_start = (part_range.start().saturating_sub(offset)) as usize;
let trim_end = (part_range.end().saturating_sub(offset)) as usize;
Some(Part::new(
object_id.clone(),
part_range.start(),
bytes.slice(trim_start..trim_end),
source,
))
}
#[cfg(test)]
mod tests {
#![allow(clippy::identity_op)]
use std::{thread, time::Duration};
use futures::executor::{ThreadPool, block_on};
use mountpoint_s3_client::{
mock_client::{MockClient, MockObject, Operation},
types::ETag,
};
use test_case::test_case;
use crate::{
data_cache::InMemoryDataCache,
mem_limiter::{MINIMUM_MEM_LIMIT, MemoryLimiter},
memory::PagedPool,
object::ObjectId,
prefetch::HandleId,
};
use super::*;
const KB: usize = 1024;
const MB: usize = 1024 * 1024;
#[test_case(1 * MB, 8 * MB, 16 * MB, 0, 16 * MB; "whole object")]
#[test_case(1 * MB, 8 * MB, 16 * MB, 1 * MB, 3 * MB + 512 * KB; "aligned offset")]
#[test_case(1 * MB, 8 * MB, 16 * MB, 512 * KB, 3 * MB; "non-aligned range")]
#[test_case(3 * MB, 8 * MB, 14 * MB, 0, 14 * MB; "whole object, size not aligned to parts or blocks")]
#[test_case(3 * MB, 8 * MB, 14 * MB, 9 * MB, 100 * MB; "aligned offset, size not aligned to parts or blocks")]
#[test_case(1 * MB, 8 * MB, 100 * KB, 0, 100 * KB; "small object")]
#[test_case(8 * MB, 5 * MB, 16 * MB, 0, 16 * MB; "cache blocks larger than client parts")]
fn test_read_from_cache(
block_size: usize,
client_part_size: usize,
object_size: usize,
offset: usize,
preferred_size: usize,
) {
let key = "object";
let seed = 0xaa;
let object = MockObject::ramp(seed, object_size, ETag::for_tests());
let id = ObjectId::new(key.to_owned(), object.etag());
let handle_id = HandleId::new(1);
let initial_request_size = 1 * MB;
let max_read_window_size = 64 * MB;
let read_window_size_multiplier = 2;
let cache = InMemoryDataCache::new(block_size as u64);
let bucket = "test-bucket";
let mock_client = Arc::new(
MockClient::config()
.bucket(bucket)
.part_size(client_part_size)
.enable_backpressure(true)
.initial_read_window_size(client_part_size)
.build(),
);
let pool = PagedPool::new_with_candidate_sizes([block_size, client_part_size]);
let mem_limiter = Arc::new(MemoryLimiter::new(pool, MINIMUM_MEM_LIMIT));
mock_client.add_object(key, object.clone());
let runtime = Runtime::new(ThreadPool::builder().pool_size(1).create().unwrap());
let stream = CachingPartStream::new(runtime, mock_client.clone(), mem_limiter.clone(), cache);
let range = RequestRange::new(object_size, offset as u64, preferred_size);
let expected_start_block = (range.start() as usize).div_euclid(block_size);
let expected_end_block = (range.end() as usize).div_ceil(block_size);
let first_read_count = {
let get_object_counter = mock_client.new_counter(Operation::GetObject);
let config = RequestTaskConfig {
bucket: bucket.to_owned(),
object_id: id.clone(),
handle_id,
range,
read_part_size: client_part_size,
preferred_part_size: 256 * KB,
initial_request_size,
max_read_window_size,
read_window_size_multiplier,
};
let request_task = stream.spawn_get_object_request(config);
compare_read(&id, &object, request_task);
get_object_counter.count()
};
assert!(first_read_count > 0);
let expected_block_count = expected_end_block - expected_start_block;
while stream.cache.block_count(&id) < expected_block_count {
thread::sleep(Duration::from_millis(10));
}
assert_eq!(expected_block_count, stream.cache.block_count(&id));
let second_read_count = {
let get_object_counter = mock_client.new_counter(Operation::GetObject);
let config = RequestTaskConfig {
bucket: bucket.to_owned(),
object_id: id.clone(),
handle_id,
range,
read_part_size: client_part_size,
preferred_part_size: 256 * KB,
initial_request_size,
max_read_window_size,
read_window_size_multiplier,
};
let request_task = stream.spawn_get_object_request(config);
compare_read(&id, &object, request_task);
get_object_counter.count()
};
assert_eq!(second_read_count, 0);
}
#[test_case(1 * MB, 8 * MB)]
#[test_case(8 * MB, 8 * MB)]
#[test_case(1 * MB, 5 * MB + 1)]
#[test_case(1 * MB + 1, 5 * MB)]
fn test_get_object_parts(block_size: usize, client_part_size: usize) {
let key = "object";
let object_size = 16 * MB;
let seed = 0xaa;
let object = MockObject::ramp(seed, object_size, ETag::for_tests());
let id = ObjectId::new(key.to_owned(), object.etag());
let initial_request_size = 1 * MB;
let max_read_window_size = 64 * MB;
let read_window_size_multiplier = 2;
let cache = InMemoryDataCache::new(block_size as u64);
let bucket = "test-bucket";
let mock_client = Arc::new(
MockClient::config()
.bucket(bucket)
.part_size(client_part_size)
.enable_backpressure(true)
.initial_read_window_size(client_part_size)
.build(),
);
let pool = PagedPool::new_with_candidate_sizes([block_size, client_part_size]);
let mem_limiter = Arc::new(MemoryLimiter::new(pool, MINIMUM_MEM_LIMIT));
mock_client.add_object(key, object.clone());
let runtime = Runtime::new(ThreadPool::builder().pool_size(1).create().unwrap());
let stream = CachingPartStream::new(runtime, mock_client, mem_limiter.clone(), cache);
for offset in [0, 512 * KB, 1 * MB, 4 * MB, 9 * MB] {
for preferred_size in [1 * KB, 512 * KB, 4 * MB, 12 * MB, 16 * MB] {
let config = RequestTaskConfig {
bucket: bucket.to_owned(),
object_id: id.clone(),
handle_id: HandleId::new(1),
range: RequestRange::new(object_size, offset as u64, preferred_size),
read_part_size: client_part_size,
preferred_part_size: 256 * KB,
initial_request_size,
max_read_window_size,
read_window_size_multiplier,
};
let request_task = stream.spawn_get_object_request(config);
compare_read(&id, &object, request_task);
}
}
}
fn compare_read<Client: ObjectClient>(id: &ObjectId, object: &MockObject, mut request_task: RequestTask<Client>) {
let mut offset = request_task.start_offset();
let mut remaining = request_task.total_size();
while remaining > 0 {
let part = block_on(request_task.read(remaining)).unwrap();
let bytes = part.into_bytes(id, offset).unwrap();
let expected = object.read(offset, bytes.len());
let bytes = bytes.into_bytes().unwrap();
assert_eq!(bytes, *expected);
offset += bytes.len() as u64;
remaining -= bytes.len();
}
}
}