use std::{
collections::BTreeMap,
fmt::Debug,
sync::Arc,
task::{Context, Poll},
};
use ic_bn_lib_common::{
traits::shed::TypeExtractor,
types::shed::{ShardedOptions, ShedReason, ShedResponse},
};
use tower::{Layer, Service, ServiceExt};
use super::{
BoxFuture,
little::{LoadShedLayer, LoadShedResponse},
};
#[derive(Debug, Clone)]
pub struct ShardedLittleLoadShedder<T: TypeExtractor, I> {
extractor: T,
inner: I,
shards: Arc<BTreeMap<T::Type, LoadShedLayer>>,
}
impl<T: TypeExtractor, I: Send + Sync + Clone> ShardedLittleLoadShedder<T, I> {
pub const fn new(
inner: I,
extractor: T,
shards: Arc<BTreeMap<T::Type, LoadShedLayer>>,
) -> Self {
Self {
extractor,
inner,
shards,
}
}
fn get_shard(&self, req: &T::Request) -> Option<LoadShedLayer> {
let req_type = self.extractor.extract(req)?;
self.shards.get(&req_type).cloned()
}
}
impl<T: TypeExtractor, I> Service<T::Request> for ShardedLittleLoadShedder<T, I>
where
I: Service<T::Request> + Clone + Send + Sync + 'static,
I::Future: Send,
{
type Response = ShedResponse<I::Response>;
type Error = I::Error;
type Future = BoxFuture<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: T::Request) -> Self::Future {
let Some(shard) = self.get_shard(&req) else {
let inner = self.inner.clone();
return Box::pin(async move { Ok(ShedResponse::Inner(inner.oneshot(req).await?)) });
};
let svc = shard.layer(self.inner.clone());
Box::pin(async move {
svc.oneshot(req).await.map(|x| match x {
LoadShedResponse::Overload => ShedResponse::Overload(ShedReason::Latency),
LoadShedResponse::Inner(i) => ShedResponse::Inner(i),
})
})
}
}
#[derive(Debug, Clone)]
pub struct ShardedLittleLoadShedderLayer<T: TypeExtractor>(
ShardedOptions<T>,
Arc<BTreeMap<T::Type, LoadShedLayer>>,
);
impl<T: TypeExtractor> ShardedLittleLoadShedderLayer<T> {
pub fn new(opts: ShardedOptions<T>) -> Self {
let shards = Arc::new(BTreeMap::from_iter(opts.latencies.iter().map(|x| {
(
x.0.clone(),
LoadShedLayer::new(opts.ewma_alpha, x.1, opts.passthrough_count),
)
})));
Self(opts, shards)
}
}
impl<T: TypeExtractor, I: Send + Sync + Clone> Layer<I> for ShardedLittleLoadShedderLayer<T> {
type Service = ShardedLittleLoadShedder<T, I>;
fn layer(&self, inner: I) -> Self::Service {
ShardedLittleLoadShedder::new(inner, self.0.extractor.clone(), self.1.clone())
}
}
#[cfg(test)]
mod test {
use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
time::Duration,
};
use ic_bn_lib_common::types::shed::TypeLatency;
use tokio_util::task::TaskTracker;
use super::*;
use crate::Error;
#[derive(Debug, Clone)]
struct StubService;
impl Service<Duration> for StubService {
type Response = ();
type Error = Error;
type Future = BoxFuture<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Duration) -> Self::Future {
let fut = async move {
tokio::time::sleep(req).await;
Ok(())
};
Box::pin(fut)
}
}
#[derive(Debug, Clone)]
struct StubExtractor(u8);
impl TypeExtractor for StubExtractor {
type Type = u8;
type Request = Duration;
fn extract(&self, _req: &Self::Request) -> Option<Self::Type> {
Some(self.0)
}
}
#[tokio::test]
async fn test_sharded_shedder() {
let opts = ShardedOptions {
extractor: StubExtractor(0),
passthrough_count: 100,
ewma_alpha: 0.9,
latencies: vec![TypeLatency(0, Duration::from_millis(1))],
};
let inner = StubService;
let layer = ShardedLittleLoadShedderLayer::new(opts);
let mut shedder = layer.layer(inner);
let shedded = Arc::new(AtomicUsize::new(0));
let tracker = TaskTracker::new();
for _ in 0..100 {
let shedder = shedder.clone();
let shedded = shedded.clone();
tracker.spawn(async move {
let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap();
if matches!(resp, ShedResponse::Overload(ShedReason::Latency)) {
shedded.fetch_add(1, Ordering::SeqCst);
}
});
}
tracker.close();
tracker.wait().await;
assert_eq!(shedded.load(Ordering::SeqCst), 0);
for _ in 0..10 {
let resp = shedder.call(Duration::from_millis(10)).await.unwrap();
assert_eq!(resp, ShedResponse::Inner(()));
}
let shedded = Arc::new(AtomicUsize::new(0));
let tracker = TaskTracker::new();
for _ in 0..10 {
let shedder = shedder.clone();
let shedded = shedded.clone();
tracker.spawn(async move {
let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap();
if matches!(resp, ShedResponse::Overload(ShedReason::Latency)) {
shedded.fetch_add(1, Ordering::SeqCst);
}
});
}
tracker.close();
tracker.wait().await;
assert_eq!(shedded.load(Ordering::SeqCst), 8);
let shedded = Arc::new(AtomicUsize::new(0));
let tracker = TaskTracker::new();
let sem = Arc::new(tokio::sync::Semaphore::new(2));
for _ in 0..10 {
let shedder = shedder.clone();
let shedded = shedded.clone();
let sem = sem.clone();
tracker.spawn(async move {
let _permit = sem.acquire().await.unwrap();
let resp = shedder.oneshot(Duration::from_millis(1)).await.unwrap();
if matches!(resp, ShedResponse::Overload(ShedReason::Latency)) {
shedded.fetch_add(1, Ordering::SeqCst);
}
});
}
tracker.close();
tracker.wait().await;
assert_eq!(shedded.load(Ordering::SeqCst), 0);
let resp = shedder.oneshot(Duration::from_millis(10)).await.unwrap();
assert_eq!(resp, ShedResponse::Inner(()));
let opts = ShardedOptions {
extractor: StubExtractor(1),
ewma_alpha: 0.9,
passthrough_count: 0,
latencies: vec![TypeLatency(0, Duration::from_millis(1))],
};
let inner = StubService;
let layer = ShardedLittleLoadShedderLayer::new(opts);
let mut shedder = layer.layer(inner);
let resp = shedder.call(Duration::from_millis(50)).await.unwrap();
assert_eq!(resp, ShedResponse::Inner(()));
}
}