1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
//! Layer to ratelimit a hRPC service and return errors respecting the hRPC
//! protocol.
//!
//! This layer allows you to control what is used as a key for storing state.
//! For example, this can allow you to use seperate state for seperate
//! connections based on [`std::net::SocketAddr`] or similar.
//!
//! You can extract and bypass keys using `ExtractKey` and `BypassForKey`
//! functions which you can specify on [`RateLimitLayer`] and [`RateLimit`].
//! An example can be found on [`RateLimitLayer::set_key_fns`] documentation.
//!
//! Note that if you don't specify anything and use the default configuration,
//! this essentially acts as a [`tower::limit::RateLimit`].

use pin_project_lite::pin_project;
use std::{
    collections::HashMap,
    convert::Infallible,
    future::Future,
    hash::Hash,
    pin::Pin,
    task::{Context, Poll},
    time::{Duration, Instant},
};
use tower::{Layer, Service};

use crate::{
    encode,
    proto::{Error as HrpcError, RetryInfo},
    request::BoxRequest,
};

/// Enforces a rate limit on the number of requests the underlying
/// service can handle over a period of time.
///
/// Read module documentation for more information.
#[derive(Clone)]
pub struct RateLimitLayer<ExtractKey, BypassForKey> {
    rate: Rate,
    extract_key: ExtractKey,
    bypass_for_key: BypassForKey,
}

type ExtractKeyDefault = fn(&mut BoxRequest) -> Option<()>;
type BypassForKeyDefault = fn(&()) -> bool;

impl RateLimitLayer<ExtractKeyDefault, BypassForKeyDefault> {
    /// Create new rate limit layer.
    pub fn new(num: u64, per: Duration) -> Self {
        let rate = Rate::new(num, per);
        RateLimitLayer {
            rate,
            extract_key: |_| None,
            bypass_for_key: |_| false,
        }
    }
}

impl<ExtractKey, BypassForKey> RateLimitLayer<ExtractKey, BypassForKey> {
    /// Set the key extraction and bypass functions.
    ///
    /// ```
    /// # use hrpc::server::layer::ratelimit::RateLimitLayer;
    /// # use std::{time::Duration, net::SocketAddr};
    ///
    /// // create a rate limit layer that uses SocketAddr as keys
    /// // to distinguish connections and use seperate state for them
    /// let layer = RateLimitLayer::new(5, Duration::from_secs(10))
    ///     .set_key_fns(
    ///         // extract ip addr from request
    ///         |req| req.extensions().get::<SocketAddr>().map(|addr| addr.ip()),
    ///         // bypass ratelimit for loopback ips
    ///         |key| key.is_loopback(),
    ///     );
    /// ```
    pub fn set_key_fns<NewBypassForKey, NewExtractKeyFn, NewKey>(
        self,
        extract: NewExtractKeyFn,
        bypass: NewBypassForKey,
    ) -> RateLimitLayer<NewExtractKeyFn, NewBypassForKey>
    where
        NewBypassForKey: Fn(&NewKey) -> bool + Clone,
        NewExtractKeyFn: Fn(&mut BoxRequest) -> Option<NewKey> + Clone,
        NewKey: Eq + Hash,
    {
        RateLimitLayer {
            rate: self.rate,
            bypass_for_key: bypass,
            extract_key: extract,
        }
    }
}

impl<BypassForKey, ExtractKey, Key, S> Layer<S> for RateLimitLayer<ExtractKey, BypassForKey>
where
    ExtractKey: Fn(&mut BoxRequest) -> Option<Key> + Clone,
    BypassForKey: Fn(&Key) -> bool + Clone,
    Key: Eq + Hash,
{
    type Service = RateLimit<S, ExtractKey, BypassForKey, Key>;

    fn layer(&self, service: S) -> Self::Service {
        RateLimit::new(
            service,
            self.rate,
            self.extract_key.clone(),
            self.bypass_for_key.clone(),
        )
    }
}

/// Enforces a rate limit on the number of requests the underlying
/// service can handle over a period of time.
///
/// Read module documentation for more information.
pub struct RateLimit<T, ExtractKey, BypassForKey, Key> {
    inner: T,
    rate: Rate,
    global_state: State,
    keyed_states: HashMap<Key, State>,
    extract_key: ExtractKey,
    bypass_for_key: BypassForKey,
}

#[derive(Debug)]
enum State {
    // The service has hit its limit
    Limited { after: Instant },
    Ready { until: Instant, rem: u64 },
}

impl State {
    fn new_ready(rate: &Rate) -> Self {
        State::Ready {
            rem: rate.num(),
            until: Instant::now(),
        }
    }
}

impl<S, ExtractKey, BypassForKey, Key> RateLimit<S, ExtractKey, BypassForKey, Key>
where
    ExtractKey: Fn(&mut BoxRequest) -> Option<Key>,
    BypassForKey: Fn(&Key) -> bool,
    Key: Eq + Hash,
{
    /// Create a new rate limiter.
    pub fn new(
        inner: S,
        rate: Rate,
        extract_key: ExtractKey,
        bypass_for_key: BypassForKey,
    ) -> Self {
        RateLimit {
            inner,
            global_state: State::new_ready(&rate),
            rate,
            extract_key,
            bypass_for_key,
            keyed_states: HashMap::new(),
        }
    }

    /// Get a reference to the inner service.
    pub fn get_ref(&self) -> &S {
        &self.inner
    }

    /// Get a mutable reference to the inner service.
    pub fn get_mut(&mut self) -> &mut S {
        &mut self.inner
    }

    /// Consume `self`, returning the inner service.
    pub fn into_inner(self) -> S {
        self.inner
    }
}

impl<S, ExtractKey, BypassForKey, Key> Service<BoxRequest>
    for RateLimit<S, ExtractKey, BypassForKey, Key>
where
    S: Service<BoxRequest, Response = BoxResponse, Error = Infallible>,
    ExtractKey: Fn(&mut BoxRequest) -> Option<Key>,
    BypassForKey: Fn(&Key) -> bool,
    Key: Eq + Hash,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = RateLimitFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Service::poll_ready(&mut self.inner, cx)
    }

    fn call(&mut self, mut request: BoxRequest) -> Self::Future {
        let state = match (self.extract_key)(&mut request) {
            Some(key) => {
                if (self.bypass_for_key)(&key) {
                    let fut = Service::call(&mut self.inner, request);
                    return RateLimitFuture::ready(fut);
                }

                self.keyed_states
                    .entry(key)
                    .or_insert_with(|| State::new_ready(&self.rate))
            }
            None => &mut self.global_state,
        };

        match *state {
            State::Ready { mut until, mut rem } => {
                let now = Instant::now();

                // If the period has elapsed, reset it.
                if now >= until {
                    until = now + self.rate.per();
                    rem = self.rate.num();
                }

                if rem > 1 {
                    rem -= 1;
                    *state = State::Ready { until, rem };
                } else {
                    // The service is disabled until further notice
                    let after = Instant::now() + self.rate.per();
                    *state = State::Limited { after };
                }

                // Call the inner future
                let fut = Service::call(&mut self.inner, request);
                RateLimitFuture::ready(fut)
            }
            State::Limited { after } => {
                let now = Instant::now();
                if now < after {
                    tracing::trace!("rate limit exceeded.");
                    let after = after - now;
                    return RateLimitFuture::limited(after);
                }

                // Reset state
                *state = State::Ready {
                    until: now + self.rate.per(),
                    rem: self.rate.num(),
                };

                // Call the inner future
                let fut = Service::call(&mut self.inner, request);
                RateLimitFuture::ready(fut)
            }
        }
    }
}

pin_project! {
    #[project = EnumProj]
    enum RateLimitFutureInner<Fut> {
        Ready { #[pin] fut: Fut },
        Limited { after: Duration },
    }
}

pin_project! {
    /// Future for [`RateLimit`].
    pub struct RateLimitFuture<Fut> {
        #[pin]
        inner: RateLimitFutureInner<Fut>,
    }
}

impl<Fut> RateLimitFuture<Fut> {
    fn ready(fut: Fut) -> Self {
        Self {
            inner: RateLimitFutureInner::Ready { fut },
        }
    }

    fn limited(after: Duration) -> Self {
        Self {
            inner: RateLimitFutureInner::Limited { after },
        }
    }
}

impl<Fut> Future for RateLimitFuture<Fut>
where
    Fut: Future<Output = Result<BoxResponse, Infallible>>,
{
    type Output = Result<BoxResponse, Infallible>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        match self.project().inner.project() {
            EnumProj::Ready { fut } => fut.poll(cx),
            EnumProj::Limited { after } => {
                let retry_after = after.as_secs_f64().ceil() as u32;
                let retry_info = RetryInfo { retry_after };

                let err = HrpcError::new_resource_exhausted("rate limited; please try later")
                    .with_details(encode::encode_protobuf_message(&retry_info).freeze());

                Poll::Ready(Ok(err.into()))
            }
        }
    }
}

use crate::response::BoxResponse;

#[doc(inline)]
pub use self::rate::Rate;

mod rate {
    use std::time::Duration;

    /// A rate of requests per time period.
    #[derive(Debug, Copy, Clone)]
    pub struct Rate {
        num: u64,
        per: Duration,
    }

    impl Rate {
        /// Create a new rate.
        ///
        /// # Panics
        ///
        /// This function panics if `num` or `per` is 0.
        pub fn new(num: u64, per: Duration) -> Self {
            assert!(num > 0);
            assert!(per > Duration::from_millis(0));

            Rate { num, per }
        }

        pub(crate) fn num(&self) -> u64 {
            self.num
        }

        pub(crate) fn per(&self) -> Duration {
            self.per
        }
    }
}