#![warn(missing_debug_implementations)]
#![allow(clippy::significant_drop_tightening)]
#![allow(clippy::significant_drop_in_scrutinee)]
#![forbid(unsafe_code)]
use std::{
cmp::Ordering,
future::Future,
pin::Pin,
sync::{Arc, Mutex, atomic::AtomicU64},
task::{Context, Poll},
time::{Duration, Instant},
};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tower::{Layer, Service, ServiceExt};
#[derive(Debug)]
pub struct LoadShedConf {
passthrough_count: u64,
target: f64,
ewma_param: f64,
available_queue: Arc<Semaphore>,
available_concurrency: Arc<Semaphore>,
stats: Mutex<ConfStats>,
requests: AtomicU64,
}
#[derive(Debug)]
struct ConfStats {
average_latency: f64,
average_latency_at_capacity: f64,
queue_capacity: usize,
concurrency: usize,
previous_concurrency: usize,
last_changed: Instant,
previous_throughput: f64,
}
impl LoadShedConf {
pub fn new(ewma_param: f64, target: f64, passthrough_count: u64) -> Self {
Self {
passthrough_count,
target,
ewma_param,
available_concurrency: Arc::new(Semaphore::new(1)),
available_queue: Arc::new(Semaphore::new(1)),
stats: Mutex::new(ConfStats {
average_latency: target,
average_latency_at_capacity: target,
queue_capacity: 1,
concurrency: 1,
previous_concurrency: 0,
last_changed: Instant::now(),
previous_throughput: 0.0,
}),
requests: AtomicU64::new(0),
}
}
async fn start(&self) -> Option<OwnedSemaphorePermit> {
{
let mut stats = self.stats.lock().unwrap();
let desired_queue_capacity = usize::max(
1, (stats.concurrency as f64
* ((self.target / stats.average_latency_at_capacity) - 1.0))
.floor() as usize,
);
match desired_queue_capacity.cmp(&stats.queue_capacity) {
Ordering::Less => {
match self
.available_queue
.try_acquire_many((stats.queue_capacity - desired_queue_capacity) as u32)
{
Ok(permits) => permits.forget(),
Err(TryAcquireError::NoPermits) => return None,
Err(TryAcquireError::Closed) => panic!(),
}
}
Ordering::Equal => {}
Ordering::Greater => self
.available_queue
.add_permits(desired_queue_capacity - stats.queue_capacity),
}
stats.queue_capacity = desired_queue_capacity;
}
let _queue_permit = match self.available_queue.clone().try_acquire_owned() {
Ok(queue_permit) => queue_permit,
Err(TryAcquireError::NoPermits) => return None,
Err(TryAcquireError::Closed) => panic!("queue semaphore closed?"),
};
let concurrency_permit = self
.available_concurrency
.clone()
.acquire_owned()
.await
.unwrap();
Some(concurrency_permit)
}
fn stop(&self, elapsed: Duration) {
let elapsed = elapsed.as_secs_f64();
let mut stats = self.stats.lock().expect("To be able to lock stats");
let available_permits = self.available_concurrency.available_permits();
let at_max_concurrency = available_permits <= usize::max(1, stats.concurrency / 10);
stats.average_latency = stats
.average_latency
.mul_add(1.0 - self.ewma_param, self.ewma_param * elapsed);
if at_max_concurrency {
stats.average_latency_at_capacity = stats
.average_latency_at_capacity
.mul_add(1.0 - self.ewma_param, self.ewma_param * elapsed);
}
if stats.last_changed.elapsed().as_secs_f64()
> (stats.average_latency / self.ewma_param) / 10.0
&& at_max_concurrency
{
let current_concurrency = stats.concurrency - available_permits;
let throughput = current_concurrency as f64 / stats.average_latency;
let negative_gradient = (throughput > stats.previous_throughput)
^ (current_concurrency > stats.previous_concurrency);
if negative_gradient || (stats.average_latency > self.target) {
if stats.concurrency > 1 {
self.available_concurrency.forget_permits(1);
stats.concurrency -= 1;
let latency_factor =
stats.concurrency as f64 / (stats.concurrency as f64 + 1.0);
stats.average_latency *= latency_factor;
stats.average_latency_at_capacity *= latency_factor;
}
} else {
self.available_concurrency.add_permits(1);
stats.concurrency += 1;
let latency_factor = stats.concurrency as f64 / (stats.concurrency as f64 - 1.0);
stats.average_latency *= latency_factor;
stats.average_latency_at_capacity *= latency_factor;
}
stats.previous_throughput = throughput;
stats.previous_concurrency = current_concurrency;
stats.last_changed = Instant::now()
}
}
}
#[derive(Debug, Clone)]
pub struct LoadShed<Inner> {
conf: Arc<LoadShedConf>,
inner: Inner,
}
impl<Inner> LoadShed<Inner> {
pub const fn new(inner: Inner, conf: Arc<LoadShedConf>) -> Self {
Self { inner, conf }
}
pub fn average_latency(&self) -> Duration {
Duration::from_secs_f64(self.conf.stats.lock().unwrap().average_latency)
}
pub fn concurrency(&self) -> usize {
self.conf.stats.lock().unwrap().concurrency
}
pub fn queue_capacity(&self) -> usize {
let stats = self.conf.stats.lock().unwrap();
stats.concurrency + stats.queue_capacity
}
pub fn queue_len(&self) -> usize {
let stats = self.conf.stats.lock().unwrap();
let current_concurrency =
stats.concurrency - self.conf.available_concurrency.available_permits();
let current_queue = stats.queue_capacity - self.conf.available_queue.available_permits();
current_concurrency + current_queue
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum LoadShedResponse<T> {
Inner(T),
Overload,
}
type BoxFuture<Output> = Pin<Box<dyn Future<Output = Output> + Send>>;
impl<Request, Inner> Service<Request> for LoadShed<Inner>
where
Request: Send + 'static,
Inner: Service<Request> + Clone + Send + 'static,
Inner::Future: Send,
{
type Response = LoadShedResponse<Inner::Response>;
type Error = Inner::Error;
type Future = BoxFuture<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request) -> Self::Future {
let inner = self.inner.clone();
let conf = self.conf.clone();
let requests = conf
.requests
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Box::pin(async move {
let permit = conf.start().await;
if permit.is_none() && requests >= conf.passthrough_count {
return Ok(LoadShedResponse::Overload);
}
let start = Instant::now();
let response = inner.oneshot(req).await;
conf.stop(start.elapsed());
Ok(LoadShedResponse::Inner(response?))
})
}
}
#[derive(Debug, Clone)]
pub struct LoadShedLayer(Arc<LoadShedConf>);
impl LoadShedLayer {
pub fn new(ewma_param: f64, target: Duration, passthrough_count: u64) -> Self {
let conf = Arc::new(LoadShedConf::new(
ewma_param,
target.as_secs_f64(),
passthrough_count,
));
Self(conf)
}
}
impl<Inner> Layer<Inner> for LoadShedLayer {
type Service = LoadShed<Inner>;
fn layer(&self, inner: Inner) -> Self::Service {
LoadShed::new(inner, self.0.clone())
}
}
#[cfg(test)]
mod test {
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use tokio_util::task::TaskTracker;
use super::*;
use crate::Error;
#[derive(Debug, Clone)]
struct StubService;
impl Service<Duration> for StubService {
type Response = ();
type Error = Error;
type Future = BoxFuture<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Duration) -> Self::Future {
let fut = async move {
tokio::time::sleep(req).await;
Ok(())
};
Box::pin(fut)
}
}
#[tokio::test]
async fn test_little_loadshedder() {
let layer = LoadShedLayer::new(0.9, Duration::from_millis(1), 100);
let inner = StubService;
let mut shedder = layer.layer(inner);
let shedded = Arc::new(AtomicUsize::new(0));
let tracker = TaskTracker::new();
for _ in 0..100 {
let shedder = shedder.clone();
let shedded = shedded.clone();
tracker.spawn(async move {
let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap();
if matches!(resp, LoadShedResponse::Overload) {
shedded.fetch_add(1, Ordering::SeqCst);
}
});
}
tracker.close();
tracker.wait().await;
assert_eq!(shedded.load(Ordering::SeqCst), 0);
for _ in 0..10 {
let resp = shedder.call(Duration::from_millis(10)).await.unwrap();
assert_eq!(resp, LoadShedResponse::Inner(()));
}
let shedded = Arc::new(AtomicUsize::new(0));
let tracker = TaskTracker::new();
for _ in 0..10 {
let shedder = shedder.clone();
let shedded = shedded.clone();
tracker.spawn(async move {
let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap();
if matches!(resp, LoadShedResponse::Overload) {
shedded.fetch_add(1, Ordering::SeqCst);
}
});
}
tracker.close();
tracker.wait().await;
assert_eq!(shedded.load(Ordering::SeqCst), 8);
}
}