use tower_layer::Layer;
use crate::Algorithm;
use crate::classifier::DefaultClassifier;
use crate::service::ConcurrencyLimit;
#[derive(Debug, Clone)]
pub struct ConcurrencyLimitLayer<A, C = DefaultClassifier> {
algorithm: A,
classifier: C,
}
impl<A> ConcurrencyLimitLayer<A> {
pub fn new(algorithm: A) -> Self {
Self {
algorithm,
classifier: DefaultClassifier,
}
}
}
impl<A, C> ConcurrencyLimitLayer<A, C> {
pub fn with_classifier(algorithm: A, classifier: C) -> Self {
Self {
algorithm,
classifier,
}
}
}
impl<S, A, C> Layer<S> for ConcurrencyLimitLayer<A, C>
where
A: Algorithm + Clone,
C: Clone,
{
type Service = ConcurrencyLimit<S, A, C>;
fn layer(&self, service: S) -> Self::Service {
ConcurrencyLimit::with_classifier(service, self.algorithm.clone(), self.classifier.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::convert::Infallible;
use std::future::{Ready, ready};
use std::task::{Context, Poll};
use std::time::Duration;
use tower_service::Service;
#[derive(Clone, Debug)]
struct FixedAlgorithm(usize);
impl Algorithm for FixedAlgorithm {
fn max_concurrency(&self) -> usize {
self.0
}
fn update(
&mut self,
_rtt: Duration,
_num_inflight: usize,
_is_error: bool,
_is_canceled: bool,
) {
}
}
#[derive(Clone, Debug)]
struct EchoService;
impl Service<&'static str> for EchoService {
type Response = &'static str;
type Error = Infallible;
type Future = Ready<Result<&'static str, Infallible>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: &'static str) -> Self::Future {
ready(Ok(req))
}
}
#[test]
fn layer_produces_concurrency_limit_service() {
let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(10));
let svc = layer.layer(EchoService);
let inner: &EchoService = svc.get_ref();
assert!(format!("{:?}", inner).contains("EchoService"));
}
#[tokio::test]
async fn layered_service_forwards_requests() {
let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(10));
let mut svc = layer.layer(EchoService);
std::future::poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
let resp = svc.call("hello").await.unwrap();
assert_eq!(resp, "hello");
}
#[test]
fn layer_is_clone() {
let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(5));
let layer2 = layer.clone();
let _ = layer.layer(EchoService);
let _ = layer2.layer(EchoService);
}
#[test]
fn layer_is_debug() {
let layer = ConcurrencyLimitLayer::new(FixedAlgorithm(5));
let debug = format!("{:?}", layer);
assert!(debug.contains("ConcurrencyLimitLayer"));
}
#[tokio::test]
async fn layer_with_custom_classifier() {
let classifier = |_result: &Result<&str, Infallible>| false;
let layer = ConcurrencyLimitLayer::with_classifier(FixedAlgorithm(10), classifier);
let mut svc = layer.layer(EchoService);
std::future::poll_fn(|cx| svc.poll_ready(cx)).await.unwrap();
let resp = svc.call("hello").await.unwrap();
assert_eq!(resp, "hello");
}
}