use std::collections::HashMap;
use std::fs::File;
use std::future::Future;
use std::io::{self, Cursor, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use async_trait::async_trait;
use auto_impl::auto_impl;
use aws_sdk_s3::{
error::{GetObjectErrorKind, HeadObjectErrorKind},
types::SdkError,
};
use hashlink::LinkedHashMap;
use hyper::{body::Bytes, client::connect::Connect};
use parking_lot::Mutex;
use sha2::{Digest, Sha256};
use tempfile::tempfile;
use tokio::fs;
use tokio::io::AsyncReadExt;
use tokio::{task, time};
use tokio_stream::StreamExt;
use tokio_util::io::StreamReader;
use crate::client::FileClient;
use crate::fast_aio::file_reader;
use crate::utils::{atomic_copy, hash_path, stream_body};
use crate::{read_to_vec, BlobRange, Error, ReadStream};
#[async_trait]
#[auto_impl(&, Box, Arc)]
pub trait Provider: Send + Sync {
async fn head(&self, hash: &str) -> Result<u64, Error>;
async fn get(&self, hash: &str, range: BlobRange) -> Result<ReadStream<'static>, Error>;
async fn put(&self, data: ReadStream<'_>) -> Result<String, Error>;
}
#[derive(Default)]
pub struct Memory {
data: parking_lot::RwLock<HashMap<String, Bytes>>,
}
impl Memory {
pub fn new() -> Self {
Default::default()
}
}
#[async_trait]
impl Provider for Memory {
async fn head(&self, hash: &str) -> Result<u64, Error> {
let data = self.data.read();
let bytes = data.get(hash).ok_or(Error::NotFound)?;
Ok(bytes.len() as u64)
}
async fn get(&self, hash: &str, range: BlobRange) -> Result<ReadStream<'static>, Error> {
check_range(range)?;
let data = self.data.read();
let mut bytes = match data.get(hash) {
Some(bytes) => bytes.clone(),
None => return Err(Error::NotFound),
};
if let Some((start, end)) = range {
if start > bytes.len() as u64 {
return Ok(empty_stream());
}
bytes = bytes.slice(start as usize..bytes.len().min(end as usize));
}
Ok(Box::pin(Cursor::new(bytes)))
}
async fn put(&self, mut data: ReadStream<'_>) -> Result<String, Error> {
let mut buf = Vec::new();
data.read_to_end(&mut buf).await?;
let hash = format!("{:x}", Sha256::new().chain_update(&buf).finalize());
self.data.write().insert(hash.clone(), Bytes::from(buf));
Ok(hash)
}
}
pub struct S3 {
client: aws_sdk_s3::Client,
bucket: String,
}
impl S3 {
pub async fn new(client: aws_sdk_s3::Client, bucket: &str) -> anyhow::Result<Self> {
client
.head_bucket()
.bucket(bucket)
.send()
.await
.with_context(|| format!("unable to create provider for S3 bucket {bucket}"))?;
Ok(Self {
client,
bucket: bucket.into(),
})
}
}
#[async_trait]
impl Provider for S3 {
async fn head(&self, hash: &str) -> Result<u64, Error> {
let key = hash_path(hash)?;
let result = self
.client
.head_object()
.bucket(&self.bucket)
.key(key)
.send()
.await;
match result {
Ok(resp) => Ok(resp.content_length() as u64),
Err(SdkError::ServiceError { err, .. })
if matches!(err.kind, HeadObjectErrorKind::NotFound(_)) =>
{
Err(Error::NotFound)
}
Err(err) => Err(Error::Internal(err.into())),
}
}
async fn get(&self, hash: &str, range: BlobRange) -> Result<ReadStream<'static>, Error> {
check_range(range)?;
if matches!(range, Some((s, e)) if s == e) {
self.head(hash).await?;
return Ok(empty_stream());
}
let key = hash_path(hash)?;
let result = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.set_range(range.map(|(start, end)| format!("bytes={}-{}", start, end - 1)))
.send()
.await;
match result {
Ok(resp) => Ok(Box::pin(resp.body.into_async_read())),
Err(SdkError::ServiceError { err, .. })
if matches!(err.kind, GetObjectErrorKind::NoSuchKey(_)) =>
{
Err(Error::NotFound)
}
Err(SdkError::ServiceError { err, .. }) if err.code() == Some("InvalidRange") => {
Ok(empty_stream())
}
Err(err) => Err(Error::Internal(err.into())),
}
}
async fn put(&self, data: ReadStream<'_>) -> Result<String, Error> {
let (hash, file) = make_data_tempfile(data).await?;
let body = stream_body(file_reader(file, None));
self.client
.put_object()
.bucket(&self.bucket)
.key(hash_path(&hash)?)
.checksum_sha256(base64::encode(hex::decode(&hash).unwrap()))
.body(body.into())
.send()
.await
.map_err(anyhow::Error::from)?;
Ok(hash)
}
}
pub struct LocalDir {
dir: PathBuf,
}
impl LocalDir {
pub fn new(path: impl AsRef<Path>) -> Self {
Self {
dir: path.as_ref().to_owned(),
}
}
}
#[async_trait]
impl Provider for LocalDir {
async fn head(&self, hash: &str) -> Result<u64, Error> {
let key = hash_path(hash)?;
let path = self.dir.join(key);
match fs::metadata(&path).await {
Ok(metadata) => Ok(metadata.len()),
Err(err) if err.kind() == io::ErrorKind::NotFound => Err(Error::NotFound),
Err(err) => Err(err.into()),
}
}
async fn get(&self, hash: &str, range: BlobRange) -> Result<ReadStream<'static>, Error> {
check_range(range)?;
let key = hash_path(hash)?;
let path = self.dir.join(key);
let file = match File::open(path) {
Ok(file) => file,
Err(err) if err.kind() == io::ErrorKind::NotFound => return Err(Error::NotFound),
Err(err) => return Err(err.into()),
};
Ok(file_reader(file, range))
}
async fn put(&self, data: ReadStream<'_>) -> Result<String, Error> {
let (hash, file) = make_data_tempfile(data).await?;
let key = hash_path(&hash)?;
let path = self.dir.join(key);
task::spawn_blocking(move || atomic_copy(file, path))
.await
.map_err(anyhow::Error::from)??;
Ok(hash)
}
}
pub struct Remote<C> {
client: FileClient<C>,
}
impl<C> Remote<C> {
pub fn new(client: FileClient<C>) -> Self {
Self { client }
}
}
#[async_trait]
impl<C: Connect + Clone + Send + Sync + 'static> Provider for Remote<C> {
async fn head(&self, hash: &str) -> Result<u64, Error> {
self.client.head(hash).await
}
async fn get(&self, hash: &str, range: BlobRange) -> Result<ReadStream<'static>, Error> {
self.client.get(hash, range).await
}
async fn put(&self, data: ReadStream<'_>) -> Result<String, Error> {
let (_, file) = make_data_tempfile(data).await?;
let file = Arc::new(file);
self.client
.put(|| async { Ok(stream_body(file_reader(Arc::clone(&file), None))) })
.await
}
}
async fn make_data_tempfile(mut data: ReadStream<'_>) -> anyhow::Result<(String, File)> {
let mut file = task::spawn_blocking(tempfile).await??;
let mut hash = Sha256::new();
loop {
let mut buf = Vec::with_capacity(1 << 21);
let size = data.read_buf(&mut buf).await?;
if size == 0 {
break;
}
hash.update(&buf);
file = task::spawn_blocking(move || file.write_all(&buf).map(|_| file)).await??;
}
let hash = format!("{:x}", hash.finalize());
file = task::spawn_blocking(move || file.seek(SeekFrom::Start(0)).map(|_| file)).await??;
Ok((hash, file))
}
fn check_range(range: BlobRange) -> Result<(), Error> {
match range {
Some((start, end)) if start > end => Err(Error::BadRange),
_ => Ok(()),
}
}
fn empty_stream() -> ReadStream<'static> {
Box::pin(b"" as &[u8])
}
#[async_trait]
impl<P1: Provider, P2: Provider> Provider for (P1, P2) {
async fn head(&self, hash: &str) -> Result<u64, Error> {
match self.0.head(hash).await {
Ok(res) => Ok(res),
Err(_) => self.1.head(hash).await,
}
}
async fn get(&self, hash: &str, range: BlobRange) -> Result<ReadStream<'static>, Error> {
match self.0.get(hash, range).await {
Ok(res) => Ok(res),
Err(_) => self.1.get(hash, range).await,
}
}
async fn put(&self, data: ReadStream<'_>) -> Result<String, Error> {
self.0.put(data).await
}
}
pub struct Cached<P> {
state: Arc<CachedState<P>>,
}
const PAGE_CACHE_ENTRY_COST: u64 = 80;
struct PageCache {
mapping: LinkedHashMap<(String, u64), Bytes>,
total_cost: u64,
total_capacity: u64,
}
impl PageCache {
fn insert(&mut self, hash: String, n: u64, bytes: Bytes) {
use hashlink::linked_hash_map::Entry;
match self.mapping.entry((hash, n)) {
Entry::Occupied(mut o) => o.to_back(),
Entry::Vacant(v) => {
v.insert(bytes.clone());
self.total_cost += bytes.len() as u64 + PAGE_CACHE_ENTRY_COST;
while self.total_cost > self.total_capacity {
let (_, bytes) = self.mapping.pop_front().expect("cache with cost items");
self.total_cost -= bytes.len() as u64 + PAGE_CACHE_ENTRY_COST;
}
}
}
}
fn get(&mut self, hash: String, n: u64) -> Option<Bytes> {
use hashlink::linked_hash_map::Entry;
match self.mapping.entry((hash, n)) {
Entry::Occupied(mut o) => {
o.to_back();
Some(o.get().clone())
}
Entry::Vacant(_) => None,
}
}
}
impl Default for PageCache {
fn default() -> Self {
Self {
mapping: LinkedHashMap::new(),
total_cost: 0,
total_capacity: 1 << 26, }
}
}
struct CachedState<P> {
inner: P,
page_cache: Mutex<PageCache>,
dir: PathBuf,
pagesize: u64,
}
impl<P> Cached<P> {
pub fn new(inner: P, dir: impl AsRef<Path>, pagesize: u64) -> Self {
assert!(pagesize >= 4096, "pagesize must be at least 4096");
Self {
state: Arc::new(CachedState {
inner,
page_cache: Default::default(),
dir: dir.as_ref().to_owned(), pagesize,
}),
}
}
pub fn cleaner(&self) -> impl Future<Output = ()> {
let state = Arc::clone(&self.state);
async move { state.cleaner().await }
}
}
impl<P> CachedState<P> {
async fn with_cache<F, Out>(&self, hash: String, n: u64, func: F) -> Result<Bytes, Error>
where
F: FnOnce() -> Out,
Out: Future<Output = Result<Bytes, Error>>,
{
if let Some(bytes) = self.page_cache.lock().get(hash.clone(), n) {
return Ok(bytes);
}
let key = hash_path(&hash)?;
let path = self.dir.join(format!("{key}/{n}"));
if let Ok(data) = fs::read(&path).await {
let bytes = Bytes::from(data);
self.page_cache
.lock()
.insert(hash.clone(), n, bytes.clone());
return Ok(bytes);
}
let bytes = func().await?;
let read_buf = Cursor::new(bytes.clone());
task::spawn_blocking(move || {
if let Err(err) = atomic_copy(read_buf, &path) {
eprintln!("error writing {path:?} cache file: {err:?}");
}
});
self.page_cache.lock().insert(hash, n, bytes.clone());
Ok(bytes)
}
async fn cleaner(&self) {
const CLEAN_INTERVAL: Duration = Duration::from_secs(30);
loop {
time::sleep(CLEAN_INTERVAL).await;
let prefix = fastrand::u16(..);
let (d1, d2) = (prefix / 256, prefix % 256);
let subfolder = self.dir.join(format!("{d1:x}/{d2:x}"));
if fs::metadata(&subfolder).await.is_ok() {
println!("cleaning cache directory: {}", subfolder.display());
let subfolder_tmp = self.dir.join(format!("{d1:x}/.tmp-{d2:x}"));
fs::remove_dir_all(&subfolder_tmp).await.ok();
if fs::rename(&subfolder, &subfolder_tmp).await.is_ok() {
fs::remove_dir_all(&subfolder_tmp).await.ok();
}
}
}
}
}
impl<P: Provider> CachedState<P> {
async fn get_cached_size(&self, hash: &str) -> Result<u64, Error> {
let size = self
.with_cache(hash.into(), 0, || async {
let size = self.inner.head(hash).await?;
Ok(Bytes::from_iter(size.to_le_bytes()))
})
.await?;
Ok(u64::from_le_bytes(
size.as_ref().try_into().map_err(anyhow::Error::from)?,
))
}
async fn get_cached_chunk(&self, hash: &str, n: u64) -> Result<Bytes, Error> {
assert!(n > 0, "chunks of file data start at 1");
let lo = (n - 1) * self.pagesize;
let hi = n * self.pagesize;
self.with_cache(hash.into(), n, move || async move {
Ok(read_to_vec(self.inner.get(hash, Some((lo, hi))).await?)
.await?
.into())
})
.await
}
}
#[async_trait]
impl<P: Provider + 'static> Provider for Cached<P> {
async fn head(&self, hash: &str) -> Result<u64, Error> {
self.state.get_cached_size(hash).await
}
async fn get(&self, hash: &str, range: BlobRange) -> Result<ReadStream<'static>, Error> {
let (start, end) = range.unwrap_or((0, u64::MAX));
check_range(range)?;
if start == end {
self.head(hash).await?;
return Ok(empty_stream());
}
let chunk_begin: u64 = 1 + start / self.state.pagesize;
let chunk_end: u64 = 1 + (end - 1) / self.state.pagesize;
debug_assert!(chunk_begin >= 1);
debug_assert!(chunk_begin <= chunk_end);
let first_chunk = self.state.get_cached_chunk(hash, chunk_begin).await?;
let initial_offset = start - (chunk_begin - 1) * self.state.pagesize;
if initial_offset > first_chunk.len() as u64 {
return Ok(empty_stream());
}
let reached_end = (first_chunk.len() as u64) < self.state.pagesize;
let first_chunk = first_chunk.slice(initial_offset as usize..);
if reached_end || first_chunk.len() as u64 > end - start {
let total_len = first_chunk.len().min((end - start) as usize);
return Ok(Box::pin(Cursor::new(first_chunk.slice(..total_len))));
}
let remaining_bytes = Arc::new(Mutex::new(end - start - first_chunk.len() as u64));
let state = Arc::clone(&self.state);
let hash = hash.to_string();
let stream = tokio_stream::iter(chunk_begin..=chunk_end).then(move |chunk| {
let state = Arc::clone(&state);
let remaining_bytes = Arc::clone(&remaining_bytes);
let first_chunk = first_chunk.clone();
let hash = hash.clone();
async move {
if chunk == chunk_begin {
return Ok::<_, Error>(first_chunk);
}
let bytes = state.get_cached_chunk(&hash, chunk).await?;
let mut remaining_bytes = remaining_bytes.lock();
if bytes.len() as u64 > *remaining_bytes {
let result = bytes.slice(..*remaining_bytes as usize);
*remaining_bytes = 0;
Ok(result)
} else {
*remaining_bytes -= bytes.len() as u64;
Ok(bytes)
}
}
});
let stream = stream.take_while(|result| match result {
Ok(bytes) => !bytes.is_empty(), Err(_) => true,
});
Ok(Box::pin(StreamReader::new(stream)))
}
async fn put(&self, data: ReadStream<'_>) -> Result<String, Error> {
self.state.inner.put(data).await
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use hyper::body::Bytes;
use super::{Memory, PageCache, Provider};
use crate::Error;
#[test]
fn page_cache_eviction() {
let mut cache = PageCache::default();
let bigpage = Bytes::from(vec![42; 1 << 21]);
for i in 0..4096 {
cache.insert(String::new(), i, bigpage.clone());
}
assert_eq!(cache.get(String::new(), 0), None);
assert_eq!(cache.get(String::new(), 4095), Some(bigpage));
assert!(cache.mapping.len() < 2048);
}
#[test]
fn page_cache_duplicates() {
let mut cache = PageCache::default();
let page = Bytes::from(vec![42; 256]);
for _ in 0..4096 {
cache.insert(String::new(), 0, page.clone());
}
assert_eq!(cache.get(String::new(), 0), Some(page));
assert!(cache.mapping.len() == 1);
}
#[tokio::test]
async fn fallback_provider() {
let p = (Memory::new(), Memory::new());
let hash = p
.put(Box::pin(Cursor::new(vec![42; 1 << 21])))
.await
.unwrap();
assert!(matches!(p.get(&hash, None).await, Ok(_)));
assert!(matches!(p.0.get(&hash, None).await, Ok(_)));
assert!(matches!(p.1.get(&hash, None).await, Err(Error::NotFound)));
}
}