mas_tower/metrics/
in_flight.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16
17use opentelemetry::{metrics::UpDownCounter, KeyValue};
18use pin_project_lite::pin_project;
19use tower::{Layer, Service};
20
21use crate::MetricsAttributes;
22
23/// A [`Layer`] that records the number of in-flight requests.
24///
25/// # Generic Parameters
26///
27/// * `OnRequest`: A type that can extract attributes from a request.
28#[derive(Clone, Debug)]
29pub struct InFlightCounterLayer<OnRequest = ()> {
30    counter: UpDownCounter<i64>,
31    on_request: OnRequest,
32}
33
34impl InFlightCounterLayer {
35    /// Create a new [`InFlightCounterLayer`].
36    #[must_use]
37    pub fn new(name: &'static str) -> Self {
38        let counter = crate::meter()
39            .i64_up_down_counter(name)
40            .with_unit("{request}")
41            .with_description("The number of in-flight requests")
42            .init();
43
44        Self {
45            counter,
46            on_request: (),
47        }
48    }
49}
50
51impl<F> InFlightCounterLayer<F> {
52    /// Set the [`MetricsAttributes`] to use.
53    #[must_use]
54    pub fn on_request<OnRequest>(self, on_request: OnRequest) -> InFlightCounterLayer<OnRequest> {
55        InFlightCounterLayer {
56            counter: self.counter,
57            on_request,
58        }
59    }
60}
61
62impl<S, OnRequest> Layer<S> for InFlightCounterLayer<OnRequest>
63where
64    OnRequest: Clone,
65{
66    type Service = InFlightCounterService<S, OnRequest>;
67
68    fn layer(&self, inner: S) -> Self::Service {
69        InFlightCounterService {
70            inner,
71            counter: self.counter.clone(),
72            on_request: self.on_request.clone(),
73        }
74    }
75}
76
77/// A middleware that records the number of in-flight requests.
78///
79/// # Generic Parameters
80///
81/// * `S`: The type of the inner service.
82/// * `OnRequest`: A type that can extract attributes from a request.
83#[derive(Clone, Debug)]
84pub struct InFlightCounterService<S, OnRequest = ()> {
85    inner: S,
86    counter: UpDownCounter<i64>,
87    on_request: OnRequest,
88}
89
90/// A guard that decrements the in-flight request count when dropped.
91struct InFlightGuard {
92    counter: UpDownCounter<i64>,
93    attributes: Vec<KeyValue>,
94}
95
96impl InFlightGuard {
97    fn new(counter: UpDownCounter<i64>, attributes: Vec<KeyValue>) -> Self {
98        counter.add(1, &attributes);
99
100        Self {
101            counter,
102            attributes,
103        }
104    }
105}
106
107impl Drop for InFlightGuard {
108    fn drop(&mut self) {
109        self.counter.add(-1, &self.attributes);
110    }
111}
112
113pin_project! {
114    /// The future returned by [`InFlightCounterService`]
115    pub struct InFlightFuture<F> {
116        guard: InFlightGuard,
117
118        #[pin]
119        inner: F,
120    }
121}
122
123impl<F> Future for InFlightFuture<F>
124where
125    F: Future,
126{
127    type Output = F::Output;
128
129    fn poll(
130        self: std::pin::Pin<&mut Self>,
131        cx: &mut std::task::Context<'_>,
132    ) -> std::task::Poll<Self::Output> {
133        self.project().inner.poll(cx)
134    }
135}
136
137impl<R, S, OnRequest> Service<R> for InFlightCounterService<S, OnRequest>
138where
139    S: Service<R>,
140    OnRequest: MetricsAttributes<R>,
141{
142    type Response = S::Response;
143    type Error = S::Error;
144    type Future = InFlightFuture<S::Future>;
145
146    fn poll_ready(
147        &mut self,
148        cx: &mut std::task::Context<'_>,
149    ) -> std::task::Poll<Result<(), Self::Error>> {
150        self.inner.poll_ready(cx)
151    }
152
153    fn call(&mut self, req: R) -> Self::Future {
154        // Extract attributes from the request.
155        let attributes = self.on_request.attributes(&req).collect();
156
157        // Increment the in-flight request count.
158        let guard = InFlightGuard::new(self.counter.clone(), attributes);
159
160        // Call the inner service, and return a future that decrements the in-flight
161        // when dropped.
162        let inner = self.inner.call(req);
163        InFlightFuture { guard, inner }
164    }
165}