use std::fmt;
use std::io;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use futures::{
channel::oneshot,
future::{self, FutureExt, TryFutureExt},
stream::{StreamExt, TryStreamExt},
};
use tokio::{self, sync::Mutex};
use crate::lru;
use super::{LFSObject, Storage, StorageKey, StorageStream};
type Cache = lru::Cache<StorageKey>;
#[derive(Debug)]
pub enum Error<C, S> {
Cache(C),
Storage(S),
Stream(io::Error),
}
impl<C, S> fmt::Display for Error<C, S>
where
C: fmt::Display,
S: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Cache(x) => fmt::Display::fmt(&x, f),
Error::Storage(x) => fmt::Display::fmt(&x, f),
Error::Stream(x) => fmt::Display::fmt(&x, f),
}
}
}
impl<C, S> Error<C, S> {
pub fn from_cache(error: C) -> Self {
Error::Cache(error)
}
pub fn from_storage(error: S) -> Self {
Error::Storage(error)
}
pub fn from_stream(error: io::Error) -> Self {
Error::Stream(error)
}
}
impl<C, S> std::error::Error for Error<C, S>
where
C: fmt::Debug + fmt::Display,
S: fmt::Debug + fmt::Display,
{
}
pub struct Backend<C, S> {
lru: Arc<Mutex<Cache>>,
max_size: u64,
cache: Arc<C>,
storage: Arc<S>,
}
impl<C, S> Backend<C, S>
where
C: Storage + Send + Sync,
S: Storage,
{
pub async fn new(
max_size: u64,
cache: C,
storage: S,
) -> Result<Self, C::Error> {
let lru = Cache::from_stream(cache.list()).await?;
log::info!(
"Prepopulated cache with {} entries ({})",
lru.len(),
humansize::format_size(lru.size(), humansize::DECIMAL),
);
let lru = Arc::new(Mutex::new(lru));
let cache = Arc::new(cache);
let count = prune_cache(lru.clone(), max_size, cache.clone()).await?;
if count > 0 {
log::info!("Pruned {} entries from the cache", count);
}
Ok(Backend {
lru,
max_size,
cache,
storage: Arc::new(storage),
})
}
}
async fn prune_cache<S>(
lru: Arc<Mutex<Cache>>,
max_size: u64,
storage: Arc<S>,
) -> Result<usize, S::Error>
where
S: Storage + Send + Sync,
{
if max_size == 0 {
return Ok(0);
}
let mut deleted = 0;
let mut lru = lru.lock().await;
while lru.size() > max_size {
if let Some((key, _)) = lru.pop() {
log::debug!("Pruning '{}' from cache", key);
let _ = storage.delete(&key).await;
deleted += 1;
}
}
Ok(deleted)
}
async fn cache_and_prune<C>(
cache: Arc<C>,
key: StorageKey,
obj: LFSObject,
lru: Arc<Mutex<Cache>>,
max_size: u64,
) -> Result<(), C::Error>
where
C: Storage + Send + Sync,
{
let len = obj.len();
let oid = *key.oid();
log::debug!("Caching {}", oid);
cache.put(key.clone(), obj).await?;
log::debug!("Finished caching {}", oid);
{
let mut lru = lru.lock().await;
lru.push(key, len);
}
match prune_cache(lru, max_size, cache).await {
Ok(count) => {
if count > 0 {
log::info!("Pruned {} entries from the cache", count);
}
Ok(())
}
Err(err) => {
log::error!("Error caching {} ({})", oid, err);
Err(err)
}
}
}
#[async_trait]
impl<C, S> Storage for Backend<C, S>
where
S: Storage + Send + Sync + 'static,
S::Error: 'static,
C: Storage + Send + Sync + 'static,
C::Error: 'static,
{
type Error = Error<C::Error, S::Error>;
async fn get(
&self,
key: &StorageKey,
) -> Result<Option<LFSObject>, Self::Error> {
if self.lru.lock().await.get_refresh(key).is_some() {
let obj = self.cache.get(key).await.map_err(Error::from_cache)?;
return match obj {
Some(obj) => Ok(Some(obj)),
None => {
let mut lru = self.lru.lock().await;
lru.remove(key);
self.storage.get(key).await.map_err(Error::from_storage)
}
};
}
let lru = self.lru.clone();
let max_size = self.max_size;
let cache = self.cache.clone();
let key = key.clone();
let obj = self.storage.get(&key).await.map_err(Error::from_storage)?;
match obj {
Some(obj) => {
let (f, a, b) = obj.fanout();
let cache =
cache_and_prune(cache, key.clone(), b, lru, max_size)
.map_err(Error::from_cache);
tokio::spawn(
future::try_join(f.map_err(Error::from_stream), cache)
.map_ok(|((), ())| ())
.map_err(move |err: Self::Error| {
log::error!("Error caching {} ({})", key, err);
}),
);
Ok(Some(a))
}
None => {
Ok(None)
}
}
}
async fn put(
&self,
key: StorageKey,
value: LFSObject,
) -> Result<(), Self::Error> {
let lru = self.lru.clone();
let max_size = self.max_size;
let cache = self.cache.clone();
let (f, a, b) = value.fanout();
let (signal_sender, signal_receiver) = oneshot::channel();
let store = self
.storage
.put(key.clone(), a)
.map_ok(move |()| {
log::debug!("Received last chunk from server.");
signal_sender.send(()).unwrap_or(())
})
.map_err(Error::from_storage);
let (len, stream) = b.into_parts();
let stream = stream.chain(
signal_receiver
.map_ok(|()| Bytes::new())
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
.into_stream(),
);
let cache = cache_and_prune(
cache,
key,
LFSObject::new(len, Box::pin(stream)),
lru,
max_size,
)
.map_err(Error::from_cache);
future::try_join3(f.map_err(Error::from_stream), cache, store).await?;
Ok(())
}
async fn size(&self, key: &StorageKey) -> Result<Option<u64>, Self::Error> {
let lru = self.lru.lock().await;
if let Some(size) = lru.get(key) {
Ok(Some(size))
} else {
self.storage.size(key).await.map_err(Error::from_storage)
}
}
async fn delete(&self, key: &StorageKey) -> Result<(), Self::Error> {
log::info!("Deleted {} from the cache", key);
self.cache.delete(key).await.map_err(Error::from_cache)
}
fn list(&self) -> StorageStream<(StorageKey, u64), Self::Error> {
Box::pin(self.cache.list().map_err(Error::from_cache))
}
async fn total_size(&self) -> Option<u64> {
Some(self.lru.lock().await.size())
}
async fn max_size(&self) -> Option<u64> {
if self.max_size == 0 {
None
} else {
Some(self.max_size)
}
}
fn public_url(&self, key: &StorageKey) -> Option<String> {
self.storage.public_url(key)
}
async fn upload_url(
&self,
key: &StorageKey,
expires_in: Duration,
) -> Option<String> {
self.storage.upload_url(key, expires_in).await
}
}