use std::{
fmt::{Debug, Display},
num::NonZero,
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use async_trait::async_trait;
use cached::stores::ExpiringSizedCache;
use futures::stream::BoxStream;
use governor::{DefaultDirectRateLimiter, Jitter, Quota, RateLimiter};
use object_store::{
CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore,
PutMultipartOptions, PutOptions, PutPayload, PutResult, path::Path,
};
use tracing::{debug, error, instrument, warn};
use crate::Result;
const DEFAULT_JITTER: Duration = Duration::from_millis(0);
#[derive(Clone)]
pub(crate) struct PutRateLimiter<O> {
entries: Arc<Mutex<ExpiringSizedCache<Path, Arc<DefaultDirectRateLimiter>>>>,
rate_per_second: Option<NonZero<u32>>,
jitter: Option<Duration>,
object_store: O,
}
impl<O> Debug for PutRateLimiter<O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PutRateLimiter").finish()
}
}
impl<O> Display for PutRateLimiter<O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PutRateLimiter").finish()
}
}
impl<O> PutRateLimiter<O> {
pub(crate) fn new(object_store: O, ttl: Duration) -> Self {
Self {
object_store,
entries: Arc::new(Mutex::new(ExpiringSizedCache::new(ttl))),
rate_per_second: Default::default(),
jitter: Default::default(),
}
}
pub(crate) fn with_rate_per_second(self, rate_per_second: Option<NonZero<u32>>) -> Self {
Self {
rate_per_second,
..self
}
}
pub(crate) fn with_jitter(self, jitter: Option<Duration>) -> Self {
Self { jitter, ..self }
}
fn rate_limiter(&self) -> Option<Arc<DefaultDirectRateLimiter>> {
self.rate_per_second
.map(Quota::per_second)
.map(RateLimiter::direct)
.map(Arc::new)
}
#[instrument(skip_all, fields(location = %location))]
fn location_rate_limiter(&self, location: &Path) -> Option<Arc<DefaultDirectRateLimiter>> {
self.entries.lock().ok().and_then(|mut entries| {
entries
.get(location)
.cloned()
.or_else(|| self.rate_limiter())
.and_then(|rate_limiter| {
entries
.insert_evict(location.to_owned(), rate_limiter.clone(), true)
.ok()
.map(|_| rate_limiter)
})
})
}
#[instrument(skip_all, fields(location = %location))]
async fn rate_limit(&self, location: &Path) {
if let Some(rate_per_second) = self.rate_per_second
&& let Some(rate_limiter) = self.location_rate_limiter(location)
{
let rate_limit_start = SystemTime::now();
_ = rate_limiter
.until_n_ready_with_jitter(
rate_per_second,
Jitter::up_to(self.jitter.unwrap_or(DEFAULT_JITTER)),
)
.await
.inspect_err(|err| error!(%location, ?err))
.inspect(|_| {
let rate_limited_ms = rate_limit_start
.elapsed()
.map_or(0, |duration| duration.as_millis() as u64);
debug!(rate_limited_ms);
})
.ok();
} else {
warn!("no_rate_limit");
}
}
}
#[async_trait]
impl<O> ObjectStore for PutRateLimiter<O>
where
O: ObjectStore,
{
#[instrument(skip_all, fields(location = %location))]
async fn put_opts(
&self,
location: &Path,
payload: PutPayload,
opts: PutOptions,
) -> Result<PutResult, object_store::Error> {
self.rate_limit(location).await;
self.object_store.put_opts(location, payload, opts).await
}
#[instrument(skip_all, fields(location = %location))]
async fn put_multipart_opts(
&self,
location: &Path,
opts: PutMultipartOptions,
) -> Result<Box<dyn MultipartUpload>, object_store::Error> {
self.rate_limit(location).await;
self.object_store.put_multipart_opts(location, opts).await
}
#[instrument(skip_all, fields(%location, if_none_match = options.if_none_match), ret)]
async fn get_opts(
&self,
location: &Path,
options: GetOptions,
) -> Result<GetResult, object_store::Error> {
self.object_store.get_opts(location, options.clone()).await
}
fn delete_stream(
&self,
locations: BoxStream<'static, Result<Path, object_store::Error>>,
) -> BoxStream<'static, Result<Path, object_store::Error>> {
self.object_store.delete_stream(locations)
}
#[instrument(skip_all, fields(prefix))]
fn list(
&self,
prefix: Option<&Path>,
) -> BoxStream<'static, Result<ObjectMeta, object_store::Error>> {
self.object_store.list(prefix)
}
#[instrument(skip_all, fields(prefix))]
async fn list_with_delimiter(
&self,
prefix: Option<&Path>,
) -> Result<ListResult, object_store::Error> {
self.object_store.list_with_delimiter(prefix).await
}
#[instrument(skip_all, fields(from = %from, to = %to))]
async fn copy_opts(
&self,
from: &Path,
to: &Path,
opts: CopyOptions,
) -> Result<(), object_store::Error> {
self.object_store.copy_opts(from, to, opts).await
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroU32;
use bytes::Bytes;
use object_store::memory::InMemory;
use tracing::subscriber::DefaultGuard;
use tracing_subscriber::EnvFilter;
use crate::Error;
use super::*;
fn init_tracing() -> Result<DefaultGuard> {
use std::{fs::File, sync::Arc, thread};
Ok(tracing::subscriber::set_default(
tracing_subscriber::fmt()
.with_level(true)
.with_line_number(true)
.with_thread_names(false)
.with_env_filter(EnvFilter::from_default_env().add_directive(
format!("{}=debug", env!("CARGO_PKG_NAME").replace("-", "_")).parse()?,
))
.with_writer(
thread::current()
.name()
.ok_or(Error::Message(String::from("unnamed thread")))
.and_then(|name| {
File::create(format!("../logs/{}/{name}.log", env!("CARGO_PKG_NAME"),))
.map_err(Into::into)
})
.map(Arc::new)?,
)
.finish(),
))
}
#[tokio::test]
async fn test() -> Result<()> {
let _guard = init_tracing()?;
const EXPECTED_DELAY: u64 = 900;
let prl = PutRateLimiter::new(InMemory::new(), Duration::from_mins(5))
.with_rate_per_second(NonZeroU32::new(1));
let location = Path::from("a");
let delay = {
let now = SystemTime::now();
_ = prl
.put_opts(
&location,
PutPayload::from(Bytes::from_static(b"12321")),
PutOptions::default(),
)
.await?;
now.elapsed()
.map_or(0, |duration| duration.as_millis() as u64)
};
assert!(delay < EXPECTED_DELAY, "{delay}");
let delay = {
let now = SystemTime::now();
_ = prl
.put_opts(
&location,
PutPayload::from(Bytes::from_static(b"12321")),
PutOptions::default(),
)
.await?;
now.elapsed()
.map_or(0, |duration| duration.as_millis() as u64)
};
assert!(delay >= EXPECTED_DELAY, "{delay}");
let location = Path::from("b");
let delay = {
let now = SystemTime::now();
_ = prl
.put_opts(
&location,
PutPayload::from(Bytes::from_static(b"12321")),
PutOptions::default(),
)
.await?;
now.elapsed()
.map_or(0, |duration| duration.as_millis() as u64)
};
assert!(delay < EXPECTED_DELAY, "{delay}");
let location = Path::from("a");
let delay = {
let now = SystemTime::now();
_ = prl
.put_opts(
&location,
PutPayload::from(Bytes::from_static(b"12321")),
PutOptions::default(),
)
.await?;
now.elapsed()
.map_or(0, |duration| duration.as_millis() as u64)
};
assert!(delay >= EXPECTED_DELAY, "{delay}");
Ok(())
}
}