use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures_util::Stream;
use crate::error::Error;
use crate::rate::bandwidth::TokenBucket;
pub struct ThrottledStream<S> {
inner: S,
limiter: Arc<TokenBucket>,
pending_chunk: Option<Bytes>,
}
impl<S> ThrottledStream<S>
where
S: Stream<Item = std::result::Result<Bytes, Box<dyn std::error::Error + Send>>> + Unpin,
{
pub fn new(inner: S, bytes_per_second: u64) -> Self {
let capacity = bytes_per_second;
let refill_rate = bytes_per_second;
Self {
inner,
limiter: Arc::new(TokenBucket::new(capacity, refill_rate)),
pending_chunk: None,
}
}
pub fn with_bucket(inner: S, bucket: Arc<TokenBucket>) -> Self {
Self {
inner,
limiter: bucket,
pending_chunk: None,
}
}
}
impl<S> Stream for ThrottledStream<S>
where
S: Stream<Item = std::result::Result<Bytes, Box<dyn std::error::Error + Send>>> + Unpin,
{
type Item = std::result::Result<Bytes, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(chunk) = this.pending_chunk.take() {
let chunk_size = chunk.len();
let limiter = Arc::clone(&this.limiter);
if limiter.try_acquire(chunk_size) {
return Poll::Ready(Some(Ok(chunk)));
} else {
this.pending_chunk = Some(chunk.clone());
let limiter_clone = Arc::clone(&this.limiter);
let waker = cx.waker().clone();
tokio::spawn(async move {
limiter_clone.acquire(chunk_size).await;
waker.wake();
});
return Poll::Pending;
}
}
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
this.pending_chunk = Some(chunk.clone());
let chunk_size = chunk.len();
let limiter = Arc::clone(&this.limiter);
if limiter.try_acquire(chunk_size) {
this.pending_chunk = None;
Poll::Ready(Some(Ok(chunk)))
} else {
let limiter_clone = Arc::clone(&this.limiter);
let waker = cx.waker().clone();
tokio::spawn(async move {
limiter_clone.acquire(chunk_size).await;
waker.wake();
});
Poll::Pending
}
}
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Some(Err(Error::Network(e.to_string()))))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
pub struct AsyncThrottledStream<S> {
inner: S,
limiter: Arc<TokenBucket>,
pending_chunk: Option<Bytes>,
}
impl<S> AsyncThrottledStream<S>
where
S: Stream<Item = std::result::Result<Bytes, Box<dyn std::error::Error + Send>>> + Unpin,
{
pub fn new(inner: S, bytes_per_second: u64) -> Self {
let capacity = bytes_per_second;
let refill_rate = bytes_per_second;
Self {
inner,
limiter: Arc::new(TokenBucket::new(capacity, refill_rate)),
pending_chunk: None,
}
}
}
impl<S> Stream for AsyncThrottledStream<S>
where
S: Stream<Item = std::result::Result<Bytes, Box<dyn std::error::Error + Send>>> + Unpin,
{
type Item = std::result::Result<Bytes, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(chunk) = this.pending_chunk.take() {
return Poll::Ready(Some(Ok(chunk)));
}
match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(chunk))) => {
this.pending_chunk = Some(chunk.clone());
let limiter = Arc::clone(&this.limiter);
let chunk_size = chunk.len();
let waker = cx.waker().clone();
tokio::spawn(async move {
limiter.acquire(chunk_size).await;
waker.wake();
});
Poll::Pending
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(Error::Network(e.to_string())))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::StreamExt;
use futures_util::stream;
#[tokio::test]
async fn test_throttled_stream_basic() {
let chunks = vec![
Ok::<_, Box<dyn std::error::Error + Send>>(Bytes::from("hi")),
Ok(Bytes::from("hi")),
Ok(Bytes::from("hi")),
];
let stream = stream::iter(chunks);
let throttled = ThrottledStream::new(stream, 100);
let results: Vec<_> = throttled.collect().await;
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.is_ok()));
}
#[tokio::test]
async fn test_throttled_stream_rate_limit() {
let chunks = vec![
Ok::<_, Box<dyn std::error::Error + Send>>(Bytes::from(vec![0u8; 50])),
Ok(Bytes::from(vec![0u8; 50])),
];
let stream = stream::iter(chunks);
let throttled = ThrottledStream::new(stream, 100);
let results: Vec<_> = throttled.collect().await;
assert_eq!(results.len(), 2);
assert!(results.iter().all(|r| r.is_ok()));
}
}