use std::{
collections::HashSet,
sync::atomic::{AtomicUsize, Ordering},
};
use futures::{Stream, TryStream, TryStreamExt as _};
use tokio::sync::Notify;
use tracing::trace;
pub(crate) fn try_unique_stream<S, T, E, F, V>(
f: F,
stream: S,
) -> impl TryStream<Ok = T, Error = E>
where
F: Fn(&S::Ok) -> V,
S: TryStream<Ok = T, Error = E>,
V: Eq + std::hash::Hash,
{
let mut seen = HashSet::new();
stream.try_filter(move |item| {
let v = f(item);
if seen.insert(v) {
futures::future::ready(true)
} else {
futures::future::ready(false)
}
})
}
#[derive(Debug)]
pub(crate) struct StreamLimiter {
name: String,
max_usage: usize,
current_usage: AtomicUsize,
waiter: Notify,
}
impl StreamLimiter {
pub(crate) fn new(name: String, max_usage: usize ) -> Self {
Self {
name,
max_usage,
current_usage: AtomicUsize::new(0),
waiter: Notify::new(),
}
}
pub(crate) fn current_usage(&self) -> usize {
self.current_usage.load(Ordering::Relaxed)
}
pub(crate) fn limit_stream<S, T, E, F>(
&self,
s: S,
get_usage: F,
) -> impl Stream<Item = Result<T, E>>
where
S: TryStream<Ok = T, Error = E>,
F: Fn(&T) -> usize + Copy,
{
s.and_then(move |elem| async move {
let usage = get_usage(&elem);
self.in_use(usage).await;
Ok(elem)
})
}
pub(crate) fn unlimit_stream<S, T, E, F>(
&self,
s: S,
get_usage: F,
) -> impl Stream<Item = Result<T, E>>
where
S: TryStream<Ok = T, Error = E>,
F: Fn(&T) -> usize + Copy,
{
s.and_then(move |elem| async move {
let usage = get_usage(&elem);
self.return_used(usage).await;
Ok(elem)
})
}
async fn in_use(&self, usage: usize) {
loop {
match self.current_usage.fetch_update(
Ordering::Relaxed,
Ordering::Relaxed,
|current| {
if current == 0 || current + usage <= self.max_usage {
Some(current + usage)
} else {
None
}
},
) {
Ok(old_usage) => {
trace!(
item_usage = usage,
current_usage = old_usage,
limiter = self.name,
"Stream rate limiter: Allowing item"
);
break;
}
Err(old_usage) => {
trace!(
item_usage = usage,
current_usage = old_usage,
limiter = self.name,
"Stream rate limiter: Blocking item"
);
self.waiter.notified().await;
}
}
}
}
async fn return_used(&self, usage: usize) {
let old_usage = self.current_usage.fetch_sub(usage, Ordering::Relaxed);
trace!(
item_usage = usage,
current_usage = old_usage,
limiter = self.name,
"Stream rate limiter: Freeing item"
);
self.waiter.notify_waiters();
}
}
#[cfg(test)]
mod tests {
use std::{convert::Infallible, error::Error, future::ready, sync::Arc};
use futures::{FutureExt as _, StreamExt as _, TryStreamExt as _, stream};
use icechunk_macros::tokio_test;
use tokio::pin;
use super::*;
#[tokio_test]
async fn test_limiter_lets_everything_pass_if_under_limit()
-> Result<(), Box<dyn Error>> {
let limiter = &Arc::new(StreamLimiter::new("test".to_string(), 1_000_000));
let stream = stream::iter(1..=100).map(Ok::<_, Infallible>);
let res = limiter.limit_stream(stream, |n| *n).try_collect::<Vec<_>>().await?;
assert_eq!(res, (1..=100).collect::<Vec<_>>());
Ok(())
}
#[tokio_test]
async fn test_limiter_limits() -> Result<(), Box<dyn Error>> {
let limiter = &Arc::new(StreamLimiter::new("test".to_string(), 3));
let stream = stream::repeat(1).map(Ok::<_, Infallible>);
let stream = limiter.limit_stream(stream, |n| *n);
pin!(stream);
assert_eq!(stream.try_next().now_or_never(), Some(Ok(Some(1))));
assert_eq!(stream.try_next().now_or_never(), Some(Ok(Some(1))));
assert_eq!(stream.try_next().now_or_never(), Some(Ok(Some(1))));
assert_eq!(stream.try_next().now_or_never(), None);
Ok(())
}
#[tokio_test]
async fn test_limiter_frees() -> Result<(), Box<dyn Error>> {
let limiter = &Arc::new(StreamLimiter::new("test".to_string(), 3));
let stream = stream::repeat(1).map(Ok::<_, Infallible>);
let stream = limiter.limit_stream(stream, |n| *n);
let stream = stream.map_ok(|n| n + 1);
let res = limiter
.unlimit_stream(stream, |n| n - 1)
.take(100)
.try_fold(0, |total, partial| ready(Ok(total + partial)))
.await?;
assert_eq!(res, 2 * 100);
assert_eq!(limiter.current_usage(), 0);
Ok(())
}
}