Skip to main content

octocrab_rate_limiter/
lib.rs

1use std::{pin::Pin, sync::Arc, time::Duration};
2
3use http::{HeaderValue, header::AUTHORIZATION};
4use moka::future::Cache;
5use tokio::sync::Semaphore;
6use tower::{Layer, Service};
7
8#[derive(Clone, Debug)]
9pub struct AccessTokenRateLimitLayer {
10    per_token_semaphores: Cache<HeaderValue, Arc<tokio::sync::Semaphore>>,
11}
12
13impl AccessTokenRateLimitLayer {
14    pub fn new(idle_ttl: Duration) -> AccessTokenRateLimitLayer {
15        let per_token_semaphores = Cache::builder().time_to_idle(idle_ttl).build();
16        AccessTokenRateLimitLayer {
17            per_token_semaphores,
18        }
19    }
20}
21
22impl<S> Layer<S> for AccessTokenRateLimitLayer {
23    type Service = AccessTokenRateLimit<S>;
24
25    fn layer(&self, service: S) -> Self::Service {
26        AccessTokenRateLimit {
27            inner: service,
28            per_token_semaphores: self.per_token_semaphores.clone(),
29        }
30    }
31}
32
33#[derive(Clone)]
34pub struct AccessTokenRateLimit<T> {
35    inner: T,
36    per_token_semaphores: Cache<HeaderValue, Arc<tokio::sync::Semaphore>>,
37}
38
39impl<Request, S> Service<http::Request<Request>> for AccessTokenRateLimit<S>
40where
41    S: Service<http::Request<Request>> + Clone + Send + 'static,
42    S::Response: Send,
43    S::Error: Send,
44    S::Future: Send,
45    Request: Send + 'static,
46{
47    type Response = S::Response;
48    type Error = S::Error;
49
50    type Future = Pin<Box<dyn Future<Output = <<S as tower::Service<http::Request<Request>>>::Future as std::future::IntoFuture>::Output> + Send + 'static>>;
51
52    fn poll_ready(
53        &mut self,
54        cx: &mut std::task::Context<'_>,
55    ) -> std::task::Poll<Result<(), Self::Error>> {
56        self.inner.poll_ready(cx)
57    }
58
59    fn call(&mut self, request: http::Request<Request>) -> Self::Future {
60        let header_value = request.headers().get(AUTHORIZATION).map(|v| v.to_owned());
61
62        let clone = self.inner.clone();
63        let mut inner = std::mem::replace(&mut self.inner, clone);
64
65        let per_token_semaphores = self.per_token_semaphores.clone();
66        Box::pin(async move {
67            let semaphore;
68            let mut permit = None;
69            if let Some(header_value) = header_value {
70                semaphore = per_token_semaphores
71                    .get_with(header_value, async {
72                        // GitHub's secondary rate limits start kicking in at 100 parallel requests.
73                        Arc::new(Semaphore::new(99))
74                    })
75                    .await;
76                // UNWRAP: We never close these semaphores, and never leak them outside of the struct so no one else can, so acquire can't fail.
77                permit = Some(semaphore.acquire().await.unwrap());
78            }
79            let result = inner.call(request).await;
80            drop(permit);
81            result
82        })
83    }
84}