little_loadshedder/
lib.rs

1//! A load-shedding middleware based on [Little's law].
2//!
3//! This provides middleware for shedding load to maintain a target average
4//! latency, see the documentation on the [`LoadShed`] service for more detail.
5//!
6//! The `metrics` feature uses the [metrics] crate to provide insight into the
7//! current queue sizes and measured latency.
8//!
9//! [Little's law]: https://en.wikipedia.org/wiki/Little%27s_law
10//! [metrics]: https://docs.rs/metrics/latest/metrics
11
12#![warn(missing_debug_implementations, missing_docs, non_ascii_idents)]
13#![forbid(unsafe_code)]
14
15use std::{
16    cmp::Ordering,
17    future::Future,
18    pin::Pin,
19    sync::{Arc, Mutex},
20    task::{Context, Poll},
21    time::{Duration, Instant},
22};
23
24#[cfg(feature = "metrics")]
25use metrics::{decrement_gauge, gauge, histogram, increment_counter, increment_gauge};
26use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
27use tower::{Layer, Service, ServiceExt};
28
29/// Load Shed service's current state of the world
30#[derive(Debug, Clone)]
31struct LoadShedConf {
32    /// The target average latency in seconds.
33    target: f64,
34    /// The exponentially weighted moving average parameter.
35    /// Must be in the range (0, 1), `0.25` means new value accounts for 25% of
36    /// the moving average.
37    ewma_param: f64,
38    /// Semaphore controlling the waiting queue of requests.
39    available_queue: Arc<Semaphore>,
40    /// Semaphore controlling concurrency to the inner service.
41    available_concurrency: Arc<Semaphore>,
42    /// Stats about the latency that change with each completed request.
43    stats: Arc<Mutex<ConfStats>>,
44}
45
46#[derive(Debug)]
47struct ConfStats {
48    /// The current average latency in seconds.
49    average_latency: f64,
50    /// The average of the latency measured when
51    /// `available_concurrent.available_permits() == 0`.
52    average_latency_at_capacity: f64,
53    /// The number of available permits in the queue semaphore
54    /// (the current capacity of the queue).
55    queue_capacity: usize,
56    /// The number of permits in the available_concurrency semaphore.
57    concurrency: usize,
58    /// The value of `self.concurrency` before it was last changed.
59    previous_concurrency: usize,
60    /// The time that the concurrency was last adjusted, to rate limit changing it.
61    last_changed: Instant,
62    /// Average throughput when at the previous concurrency value.
63    previous_throughput: f64,
64}
65
66// size of system [req] = target latency [s] * throughput [r/s]
67// size of queue [req] = size of system [req] - concurrency [req]
68// throughput [req/s] = concurrency [req] / average latency of service [s]
69// => (size of queue [req] + concurrency[req]) = target latency [s] * concurrency[req] / latency [s]
70// => size of queue [req] = concurrency [req] * (target latency [s] / latency [s] - 1)
71//
72// Control the concurrency:
73// increase concurrency but not beyond target latency
74//
75// Control queue length:
76// queue capacity = concurrency * ((target latency / average latency of service) - 1)
77
78impl LoadShedConf {
79    fn new(ewma_param: f64, target: f64) -> Self {
80        #[cfg(feature = "metrics")]
81        {
82            gauge!("loadshedder.capacity", 1.0, "component" => "service");
83            gauge!("loadshedder.capacity", 1.0, "component" => "queue");
84            gauge!("loadshedder.size", 0.0, "component" => "service");
85            gauge!("loadshedder.size", 0.0, "component" => "queue");
86            gauge!("loadshedder.average_latency", target);
87        }
88        Self {
89            target,
90            ewma_param,
91            available_concurrency: Arc::new(Semaphore::new(1)),
92            available_queue: Arc::new(Semaphore::new(1)),
93            stats: Arc::new(Mutex::new(ConfStats {
94                average_latency: target,
95                average_latency_at_capacity: target,
96                queue_capacity: 1,
97                concurrency: 1,
98                previous_concurrency: 0,
99                last_changed: Instant::now(),
100                previous_throughput: 0.0,
101            })),
102        }
103    }
104
105    /// Add ourselves to the queue and wait until we've made it through and have
106    /// obtained a permit to send the request.
107    async fn start(&self) -> Result<Permit, ()> {
108        {
109            // Work inside a block so we drop the stats lock asap.
110            let mut stats = self.stats.lock().unwrap();
111            let desired_queue_capacity = usize::max(
112                1, // The queue must always be at least 1 request long.
113                // Use average latency at (concurrency) capacity so that this doesn't
114                // grow too large while the system is under-utilised.
115                (stats.concurrency as f64
116                    * ((self.target / stats.average_latency_at_capacity) - 1.0))
117                    .floor() as usize,
118            );
119            #[cfg(feature = "metrics")]
120            gauge!("loadshedder.capacity", desired_queue_capacity as f64, "component" => "queue");
121
122            // Adjust the semaphore capacity by adding or acquiring many permits.
123            // If acquiring permits fails we can return overload and let the next
124            // request recompute the queue capacity.
125            match desired_queue_capacity.cmp(&stats.queue_capacity) {
126                Ordering::Less => {
127                    match self
128                        .available_queue
129                        .try_acquire_many((stats.queue_capacity - desired_queue_capacity) as u32)
130                    {
131                        Ok(permits) => permits.forget(),
132                        Err(TryAcquireError::NoPermits) => return Err(()),
133                        Err(TryAcquireError::Closed) => panic!(),
134                    }
135                }
136                Ordering::Equal => {}
137                Ordering::Greater => self
138                    .available_queue
139                    .add_permits(desired_queue_capacity - stats.queue_capacity),
140            }
141            stats.queue_capacity = desired_queue_capacity;
142        }
143
144        // Finally get our queue permit, if this fails then the queue is full
145        // and we need to bail out.
146        let queue_permit = match self.available_queue.clone().try_acquire_owned() {
147            Ok(queue_permit) => Permit::new(queue_permit, "queue"),
148            Err(TryAcquireError::NoPermits) => return Err(()),
149            Err(TryAcquireError::Closed) => panic!("queue semaphore closed?"),
150        };
151        // We're in the queue now so wait until we get ourselves a concurrency permit.
152        let concurrency_permit = self
153            .available_concurrency
154            .clone()
155            .acquire_owned()
156            .await
157            .unwrap();
158        // Now we've got the permit required to send the request we can leave the queue.
159        drop(queue_permit);
160        Ok(Permit::new(concurrency_permit, "service"))
161    }
162
163    /// Register a completed call of the inner service, providing the latency to
164    /// update the statistics.
165    fn stop(&mut self, elapsed: Duration, concurrency_permit: Permit) {
166        let elapsed = elapsed.as_secs_f64();
167        #[cfg(feature = "metrics")]
168        histogram!("loadshedder.latency", elapsed);
169
170        // This function solely updates the stats (and is not async) so hold the
171        // lock for the entire function.
172        let mut stats = self.stats.lock().expect("To be able to lock stats");
173
174        let available_permits = self.available_concurrency.available_permits();
175        // Have some leeway on what "at max concurrency" means as you might
176        // otherwise never see this condition at large concurrency values.
177        let at_max_concurrency = available_permits <= usize::max(1, stats.concurrency / 10);
178
179        // Update the average latency using the EWMA algorithm.
180        stats.average_latency =
181            (stats.average_latency * (1.0 - self.ewma_param)) + (self.ewma_param * elapsed);
182        #[cfg(feature = "metrics")]
183        gauge!("loadshedder.average_latency", stats.average_latency);
184        if at_max_concurrency {
185            stats.average_latency_at_capacity = (stats.average_latency_at_capacity
186                * (1.0 - self.ewma_param))
187                + (self.ewma_param * elapsed);
188        }
189
190        // Only ever change max concurrency if we're at the limit as we need
191        // measurements to have happened at the current limit.
192        // Also, introduce a max rate of change that's somewhat magically
193        // related to the latency and ewma parameter to prevent this from
194        // changing too quickly.
195        if stats.last_changed.elapsed().as_secs_f64()
196            > (stats.average_latency / self.ewma_param) / 10.0
197            && at_max_concurrency
198        {
199            // Plausibly should be using average latency at capacity here and
200            // stats.concurrency but this appears to work. It might do weird
201            // things if it's been running under capacity for a while then spikes.
202            let current_concurrency = stats.concurrency - available_permits;
203            let throughput = current_concurrency as f64 / stats.average_latency;
204            // Was the throughput better or worse than it was previously.
205            let negative_gradient = (throughput > stats.previous_throughput)
206                ^ (current_concurrency > stats.previous_concurrency);
207            if negative_gradient || (stats.average_latency > self.target) {
208                // Don't reduce concurrency below 1 or everything stops.
209                if stats.concurrency > 1 {
210                    // negative gradient so decrease concurrency
211                    concurrency_permit.forget();
212                    stats.concurrency -= 1;
213                    #[cfg(feature = "metrics")]
214                    gauge!("loadshedder.capacity", stats.concurrency as f64, "component" => "service");
215
216                    // Adjust the average latency assuming that the change in
217                    // concurrency doesn't affect the service latency, which is
218                    // closer to the truth than the latency not changing.
219                    let latency_factor =
220                        stats.concurrency as f64 / (stats.concurrency as f64 + 1.0);
221                    stats.average_latency *= latency_factor;
222                    stats.average_latency_at_capacity *= latency_factor;
223                }
224            } else {
225                self.available_concurrency.add_permits(1);
226                stats.concurrency += 1;
227                #[cfg(feature = "metrics")]
228                gauge!("loadshedder.capacity", stats.concurrency as f64, "component" => "service");
229
230                // Adjust the average latency assuming that the change in
231                // concurrency doesn't affect the service latency, which is
232                // closer to the truth than the latency not changing.
233                let latency_factor = stats.concurrency as f64 / (stats.concurrency as f64 - 1.0);
234                stats.average_latency *= latency_factor;
235                stats.average_latency_at_capacity *= latency_factor;
236            }
237
238            stats.previous_throughput = throughput;
239            stats.previous_concurrency = current_concurrency;
240            stats.last_changed = Instant::now()
241        }
242    }
243}
244
245/// A permit for something, this is used for updating metrics.
246#[derive(Debug)]
247struct Permit {
248    /// The permit, this is only optional to enable the forget function.
249    permit: Option<OwnedSemaphorePermit>,
250    /// The name of the component this permit is for, used as a metric label.
251    #[allow(unused)]
252    component: &'static str,
253}
254
255impl Permit {
256    /// Create a new permit for the given component.
257    fn new(permit: OwnedSemaphorePermit, component: &'static str) -> Self {
258        #[cfg(feature = "metrics")]
259        increment_gauge!("loadshedder.size", 1.0, "component" => component);
260        Self {
261            permit: Some(permit),
262            component,
263        }
264    }
265
266    /// Forget the permit, essentially reducing the size of the semaphore by one.
267    /// Note this does still decrement the size metric.
268    fn forget(mut self) {
269        self.permit.take().unwrap().forget()
270    }
271}
272
273impl Drop for Permit {
274    fn drop(&mut self) {
275        #[cfg(feature = "metrics")]
276        decrement_gauge!("loadshedder.size", 1.0, "component" => self.component);
277    }
278}
279
280/// A [`Service`] that attempts to hold the average latency at a given target.
281///
282/// It does this by placing a queue in front of the service and rejecting
283/// requests when that queue is full (this means requests are either immediately
284/// rejected or will be processed by the inner service). It calculates the size
285/// of that queue using [Little's Law] which states that the average number of
286/// items in a system is equal to the average throughput multiplied by the
287/// average latency.
288///
289/// This service therefore measures the average latency and sets the queue size
290/// such that when the queue is full a request will on average take the target
291/// latency time to be responded to. Note that if the queue is not full the
292/// latency will be below the target.
293///
294/// This service also optimises the number concurrent requests to the service.
295/// This will usually be the same as the queue size, unless the target latency
296/// (and hence the queue) is very large, or the underlying service can cope with
297/// very few concurrent requests.
298///
299/// This service is reactive, if the underlying service degrades then the queues
300/// will shorten, if it improves they will lengthen. The queue lengths will be
301/// underestimates at startup and will only increase while the service is near
302/// its concurrency limit. Be wary of using the queue length as a measure of
303/// system capacity unless the queues have been at or above the concurrency for
304/// a while.
305///
306/// [Little's law]: https://en.wikipedia.org/wiki/Little%27s_law
307#[derive(Debug, Clone)]
308pub struct LoadShed<Inner> {
309    conf: LoadShedConf,
310    inner: Inner,
311}
312
313impl<Inner> LoadShed<Inner> {
314    /// Wrap a service with this middleware, using the given target average
315    /// latency and computing the current average latency using an exponentially
316    /// weighted moving average with the given parameter.
317    pub fn new(inner: Inner, ewma_param: f64, target: Duration) -> Self {
318        Self {
319            inner,
320            conf: LoadShedConf::new(ewma_param, target.as_secs_f64()),
321        }
322    }
323
324    /// The current average latency of requests through the inner service,
325    /// that is ignoring the queue this service adds.
326    pub fn average_latency(&self) -> Duration {
327        Duration::from_secs_f64(self.conf.stats.lock().unwrap().average_latency)
328    }
329
330    /// The current maximum concurrency of requests to the inner service.
331    pub fn concurrency(&self) -> usize {
332        self.conf.stats.lock().unwrap().concurrency
333    }
334
335    /// The current maximum capacity of this service (including the queue).
336    pub fn queue_capacity(&self) -> usize {
337        let stats = self.conf.stats.lock().unwrap();
338        stats.concurrency + stats.queue_capacity
339    }
340
341    /// The current number of requests that have been accepted by this service.
342    pub fn queue_len(&self) -> usize {
343        let stats = self.conf.stats.lock().unwrap();
344        let current_concurrency =
345            stats.concurrency - self.conf.available_concurrency.available_permits();
346        let current_queue = stats.queue_capacity - self.conf.available_queue.available_permits();
347        current_concurrency + current_queue
348    }
349}
350
351/// Either an error from the wrapped service or message that the request was shed
352#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
353pub enum LoadShedResponse<T> {
354    /// A response from the inner service.
355    Inner(T),
356    /// The request was shed due to overload.
357    Overload,
358}
359
360type BoxFuture<Output> = Pin<Box<dyn Future<Output = Output> + Send>>;
361
362impl<Request, Inner> Service<Request> for LoadShed<Inner>
363where
364    Request: Send + 'static,
365    Inner: Service<Request> + Clone + Send + 'static,
366    Inner::Future: Send,
367{
368    type Response = LoadShedResponse<Inner::Response>;
369    type Error = Inner::Error;
370    type Future = BoxFuture<Result<Self::Response, Self::Error>>;
371
372    /// Always ready because there's a queue between this service and the inner one.
373    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
374        Poll::Ready(Ok(()))
375    }
376
377    fn call(&mut self, req: Request) -> Self::Future {
378        // We're fine to use the clone because inner hasn't been polled to
379        // readiness yet.
380        let inner = self.inner.clone();
381        let mut conf = self.conf.clone();
382        Box::pin(async move {
383            let permit = match conf.start().await {
384                Ok(permit) => {
385                    #[cfg(feature = "metrics")]
386                    increment_counter!("loadshedder.request", "status" => "accepted");
387                    permit
388                }
389                Err(()) => {
390                    #[cfg(feature = "metrics")]
391                    increment_counter!("loadshedder.request", "status" => "rejected");
392                    return Ok(LoadShedResponse::Overload);
393                }
394            };
395            let start = Instant::now();
396            // The elapsed time includes waiting for readiness which should help
397            // us stay under any upstream concurrency limiters.
398            let response = inner.oneshot(req).await;
399            conf.stop(start.elapsed(), permit);
400            Ok(LoadShedResponse::Inner(response?))
401        })
402    }
403}
404
405/// A [`Layer`] to wrap services in a [`LoadShed`] middleware.
406///
407/// See [`LoadShed`] for details of the load shedding algorithm.
408#[derive(Debug, Clone)]
409pub struct LoadShedLayer {
410    ewma_param: f64,
411    target: Duration,
412}
413
414impl LoadShedLayer {
415    /// Create a new layer with the given target average latency and
416    /// computing the current average latency using an exponentially weighted
417    /// moving average with the given parameter.
418    pub fn new(ewma_param: f64, target: Duration) -> Self {
419        Self { ewma_param, target }
420    }
421}
422
423impl<Inner> Layer<Inner> for LoadShedLayer {
424    type Service = LoadShed<Inner>;
425
426    fn layer(&self, inner: Inner) -> Self::Service {
427        LoadShed::new(inner, self.ewma_param, self.target)
428    }
429}