actix_web_lab/
load_shed.rs

1// Code mostly copied from `tower`:
2// https://github.com/tower-rs/tower/tree/5064987f/tower/src/load_shed
3
4//! Load-shedding middleware.
5
6use std::{
7    cell::Cell,
8    error::Error as StdError,
9    fmt,
10    pin::Pin,
11    task::{Context, Poll, ready},
12};
13
14use actix_service::{Service, Transform};
15use actix_utils::future::{Ready, ok};
16use actix_web::ResponseError;
17use pin_project_lite::pin_project;
18
19/// A middleware that sheds load when the inner service isn't ready.
20#[derive(Debug, Clone, Default)]
21#[non_exhaustive]
22pub struct LoadShed;
23
24impl LoadShed {
25    /// Creates a new load-shedding middleware.
26    pub fn new() -> Self {
27        LoadShed
28    }
29}
30
31impl<S: Service<Req>, Req> Transform<S, Req> for LoadShed {
32    type Response = S::Response;
33    type Error = Overloaded<S::Error>;
34    type Transform = LoadShedService<S>;
35    type InitError = ();
36    type Future = Ready<Result<Self::Transform, Self::InitError>>;
37
38    fn new_transform(&self, service: S) -> Self::Future {
39        ok(LoadShedService::new(service))
40    }
41}
42
43/// A service wrapper that sheds load when the inner service isn't ready.
44#[derive(Debug)]
45pub struct LoadShedService<S> {
46    inner: S,
47    is_ready: Cell<bool>,
48}
49
50impl<S> LoadShedService<S> {
51    /// Wraps a service in [`LoadShedService`] middleware.
52    pub(crate) fn new(inner: S) -> Self {
53        Self {
54            inner,
55            is_ready: Cell::new(false),
56        }
57    }
58}
59
60impl<S, Req> Service<Req> for LoadShedService<S>
61where
62    S: Service<Req>,
63{
64    type Response = S::Response;
65    type Error = Overloaded<S::Error>;
66    type Future = LoadShedFuture<S::Future>;
67
68    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
69        // We check for readiness here, so that we can know in `call` if
70        // the inner service is overloaded or not.
71        let is_ready = match self.inner.poll_ready(cx) {
72            Poll::Ready(Err(err)) => return Poll::Ready(Err(Overloaded::Service(err))),
73            res => res.is_ready(),
74        };
75
76        self.is_ready.set(is_ready);
77
78        // But we always report Ready, so that layers above don't wait until
79        // the inner service is ready (the entire point of this layer!)
80        Poll::Ready(Ok(()))
81    }
82
83    fn call(&self, req: Req) -> Self::Future {
84        if self.is_ready.get() {
85            // readiness only counts once, you need to check again!
86            self.is_ready.set(false);
87            LoadShedFuture::called(self.inner.call(req))
88        } else {
89            LoadShedFuture::overloaded()
90        }
91    }
92}
93
94pin_project! {
95    /// Future for [`LoadShedService`].
96    pub struct LoadShedFuture<F> {
97        #[pin]
98        state: LoadShedFutureState<F>,
99    }
100}
101
102pin_project! {
103    #[project = LoadShedFutureStateProj]
104    enum LoadShedFutureState<F> {
105        Called { #[pin] fut: F },
106        Overloaded,
107    }
108}
109
110impl<F> LoadShedFuture<F> {
111    pub(crate) fn called(fut: F) -> Self {
112        LoadShedFuture {
113            state: LoadShedFutureState::Called { fut },
114        }
115    }
116
117    pub(crate) fn overloaded() -> Self {
118        LoadShedFuture {
119            state: LoadShedFutureState::Overloaded,
120        }
121    }
122}
123
124impl<F, T, E> Future for LoadShedFuture<F>
125where
126    F: Future<Output = Result<T, E>>,
127{
128    type Output = Result<T, Overloaded<E>>;
129
130    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
131        match self.project().state.project() {
132            LoadShedFutureStateProj::Called { fut } => {
133                Poll::Ready(ready!(fut.poll(cx)).map_err(Overloaded::Service))
134            }
135            LoadShedFutureStateProj::Overloaded => Poll::Ready(Err(Overloaded::Overloaded)),
136        }
137    }
138}
139
140impl<F> fmt::Debug for LoadShedFuture<F>
141where
142    // bounds for future-proofing...
143    F: fmt::Debug,
144{
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        f.write_str("LoadShedFuture")
147    }
148}
149
150/// An error returned by [`LoadShed`] service when the inner service is not ready to handle any
151/// requests at the time of being called.
152#[derive(Debug)]
153#[non_exhaustive]
154pub enum Overloaded<E> {
155    /// Service error.
156    Service(E),
157
158    /// Service overloaded.
159    Overloaded,
160}
161
162impl<E: fmt::Display> fmt::Display for Overloaded<E> {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        match self {
165            Overloaded::Service(err) => write!(f, "{err}"),
166            Overloaded::Overloaded => f.write_str("service overloaded"),
167        }
168    }
169}
170
171impl<E: StdError + 'static> StdError for Overloaded<E> {
172    fn source(&self) -> Option<&(dyn StdError + 'static)> {
173        match self {
174            Overloaded::Service(err) => Some(err),
175            Overloaded::Overloaded => None,
176        }
177    }
178}
179
180impl<E> ResponseError for Overloaded<E>
181where
182    E: fmt::Debug + fmt::Display,
183{
184    fn status_code(&self) -> actix_http::StatusCode {
185        actix_web::http::StatusCode::SERVICE_UNAVAILABLE
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use actix_web::middleware::{Compat, Logger};
192
193    use super::*;
194
195    #[test]
196    fn integration() {
197        actix_web::App::new()
198            .wrap(Compat::new(LoadShed::new()))
199            .wrap(Logger::default());
200    }
201}