octocrab_rate_limiter/
lib.rs1use std::{pin::Pin, sync::Arc, time::Duration};
2
3use http::{HeaderValue, header::AUTHORIZATION};
4use moka::future::Cache;
5use tokio::sync::Semaphore;
6use tower::{Layer, Service};
7
8#[derive(Clone, Debug)]
9pub struct AccessTokenRateLimitLayer {
10 per_token_semaphores: Cache<HeaderValue, Arc<tokio::sync::Semaphore>>,
11}
12
13impl AccessTokenRateLimitLayer {
14 pub fn new(idle_ttl: Duration) -> AccessTokenRateLimitLayer {
15 let per_token_semaphores = Cache::builder().time_to_idle(idle_ttl).build();
16 AccessTokenRateLimitLayer {
17 per_token_semaphores,
18 }
19 }
20}
21
22impl<S> Layer<S> for AccessTokenRateLimitLayer {
23 type Service = AccessTokenRateLimit<S>;
24
25 fn layer(&self, service: S) -> Self::Service {
26 AccessTokenRateLimit {
27 inner: service,
28 per_token_semaphores: self.per_token_semaphores.clone(),
29 }
30 }
31}
32
33#[derive(Clone)]
34pub struct AccessTokenRateLimit<T> {
35 inner: T,
36 per_token_semaphores: Cache<HeaderValue, Arc<tokio::sync::Semaphore>>,
37}
38
39impl<Request, S> Service<http::Request<Request>> for AccessTokenRateLimit<S>
40where
41 S: Service<http::Request<Request>> + Clone + Send + 'static,
42 S::Response: Send,
43 S::Error: Send,
44 S::Future: Send,
45 Request: Send + 'static,
46{
47 type Response = S::Response;
48 type Error = S::Error;
49
50 type Future = Pin<Box<dyn Future<Output = <<S as tower::Service<http::Request<Request>>>::Future as std::future::IntoFuture>::Output> + Send + 'static>>;
51
52 fn poll_ready(
53 &mut self,
54 cx: &mut std::task::Context<'_>,
55 ) -> std::task::Poll<Result<(), Self::Error>> {
56 self.inner.poll_ready(cx)
57 }
58
59 fn call(&mut self, request: http::Request<Request>) -> Self::Future {
60 let header_value = request.headers().get(AUTHORIZATION).map(|v| v.to_owned());
61
62 let clone = self.inner.clone();
63 let mut inner = std::mem::replace(&mut self.inner, clone);
64
65 let per_token_semaphores = self.per_token_semaphores.clone();
66 Box::pin(async move {
67 let semaphore;
68 let mut permit = None;
69 if let Some(header_value) = header_value {
70 semaphore = per_token_semaphores
71 .get_with(header_value, async {
72 Arc::new(Semaphore::new(99))
74 })
75 .await;
76 permit = Some(semaphore.acquire().await.unwrap());
78 }
79 let result = inner.call(request).await;
80 drop(permit);
81 result
82 })
83 }
84}