google_authz/
service.rs

1use std::{
2    fmt,
3    future::Future,
4    pin::Pin,
5    sync::Arc,
6    task::{self, Poll},
7    time::{Duration, Instant},
8};
9
10use hyper::{
11    header::{HeaderValue, AUTHORIZATION},
12    Request,
13};
14use parking_lot::RwLock;
15use tracing::{info, trace};
16
17use crate::{token, Credentials, Token, TokenSource};
18
19pub struct AddAuthorization<S> {
20    inner: Arc<RwLock<Inner>>,
21    service: S,
22}
23
24impl AddAuthorization<()> {
25    pub async fn init<S>(service: S) -> AddAuthorization<S> {
26        AddAuthorization { inner: Arc::new(RwLock::new(Inner::init().await)), service }
27    }
28
29    pub fn init_with<S>(source: impl Into<TokenSource>, service: S) -> AddAuthorization<S> {
30        AddAuthorization { inner: Arc::new(RwLock::new(Inner::init_with(source))), service }
31    }
32}
33
34enum State {
35    Uninitialized,
36    Fetching {
37        retry: u8,
38        fut: RefGuard<Pin<Box<dyn Future<Output = token::Result<Token>> + Send + 'static>>>,
39    },
40    Fetched,
41}
42
43/// RefGuard wraps a `Send` type to make it `Sync`, by ensuring that it is only
44/// ever accessed through a &mut pointer.
45struct RefGuard<T: Send>(T);
46
47impl<T: Send> RefGuard<T> {
48    pub fn new(value: T) -> Self {
49        RefGuard(value)
50    }
51
52    pub fn get_mut(&mut self) -> &mut T {
53        &mut self.0
54    }
55}
56
57unsafe impl<T: Send> Sync for RefGuard<T> {}
58
59struct Inner {
60    state: State,
61    cache: Option<Cache>,
62    source: TokenSource,
63    max_retry: u8,
64}
65
66impl Inner {
67    async fn init() -> Self {
68        Self::init_with(Credentials::default().await)
69    }
70
71    fn init_with(s: impl Into<TokenSource>) -> Self {
72        Self { state: State::Uninitialized, cache: None, source: s.into(), max_retry: 5 }
73    }
74
75    #[inline]
76    fn cache_ref(&self) -> &Cache {
77        self.cache.as_ref().unwrap()
78    }
79
80    fn can_skip_poll_ready(&self) -> bool {
81        matches!(self.state, State::Fetched) && !self.cache_ref().expired(Instant::now())
82    }
83
84    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<()> {
85        loop {
86            match self.state {
87                State::Uninitialized => {
88                    trace!("token is uninitialized");
89                    self.state =
90                        State::Fetching { retry: 0, fut: RefGuard::new(self.source.token()) };
91                    continue;
92                }
93                State::Fetching { ref retry, ref mut fut } => match fut.get_mut().as_mut().poll(cx)
94                {
95                    Poll::Ready(r) => match r.and_then(|t| t.into_pairs()) {
96                        Ok((value, expiry)) => {
97                            self.cache = Some(Cache::new(value, expiry));
98                            self.state = State::Fetched;
99                            trace!("token updated: expiry={:?}", expiry);
100                            return Poll::Ready(());
101                        }
102                        Err(err) => {
103                            if *retry < self.max_retry {
104                                info!("an error occurred: retry={}, err={:?}", retry, err);
105                            } else {
106                                panic!("max retry exceeded: retry={}, last error={:?}", retry, err);
107                            }
108                            self.state = State::Fetching {
109                                retry: retry + 1,
110                                fut: RefGuard::new(self.source.token()),
111                            };
112                            continue;
113                        }
114                    },
115                    Poll::Pending => return Poll::Pending,
116                },
117                State::Fetched => {
118                    let cache = self.cache_ref();
119                    if !cache.expired(Instant::now()) {
120                        return Poll::Ready(());
121                    }
122                    trace!("token will expire: expiry={:?}", cache.expiry);
123                    self.state =
124                        State::Fetching { retry: 0, fut: RefGuard::new(self.source.token()) };
125                    continue;
126                }
127            }
128        }
129    }
130}
131
132#[derive(Clone)]
133struct Cache {
134    value: HeaderValue,
135    expiry: Instant,
136}
137
138impl Cache {
139    fn new(value: HeaderValue, expiry: Instant) -> Self {
140        Self { value, expiry }
141    }
142
143    fn expired(&self, at: Instant) -> bool {
144        const EXPIRY_DELTA: Duration = Duration::from_secs(10);
145        self.expiry.checked_duration_since(at).map(|dur| dur < EXPIRY_DELTA).unwrap_or(true)
146    }
147
148    fn value(&self) -> HeaderValue {
149        self.value.clone()
150    }
151}
152
153impl<S, B> tower_service::Service<Request<B>> for AddAuthorization<S>
154where
155    S: tower_service::Service<Request<B>>,
156{
157    type Response = S::Response;
158    type Error = S::Error;
159    type Future = S::Future;
160
161    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
162        if self.inner.read().can_skip_poll_ready() {
163            return self.service.poll_ready(cx);
164        }
165        match self.inner.write().poll_ready(cx) {
166            Poll::Ready(()) => self.service.poll_ready(cx),
167            Poll::Pending => Poll::Pending,
168        }
169    }
170
171    fn call(&mut self, mut req: Request<B>) -> Self::Future {
172        req.headers_mut().insert(AUTHORIZATION, self.inner.read().cache_ref().value());
173        self.service.call(req)
174    }
175}
176
177impl<S: Clone> Clone for AddAuthorization<S> {
178    fn clone(&self) -> Self {
179        Self { inner: self.inner.clone(), service: self.service.clone() }
180    }
181}
182
183impl<S: fmt::Debug> fmt::Debug for AddAuthorization<S> {
184    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185        f.debug_struct("AddAuthorization").field("service", &self.service).finish()
186    }
187}
188
189#[cfg(test)]
190mod test {
191    use super::*;
192
193    #[test]
194    fn compile_test() {
195        #[derive(Clone)]
196        struct Counter {
197            cur: i32,
198        }
199
200        impl Counter {
201            fn new() -> Self {
202                Counter { cur: 0 }
203            }
204        }
205
206        impl<B> tower_service::Service<Request<B>> for Counter {
207            type Response = i32;
208            type Error = i32;
209            type Future = Pin<Box<dyn Future<Output = Result<i32, i32>> + Send + 'static>>;
210
211            fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
212                Poll::Ready(Ok(()))
213            }
214
215            fn call(&mut self, _: Request<B>) -> Self::Future {
216                self.cur += 1;
217                let current = self.cur;
218                Box::pin(async move { Ok(current) })
219            }
220        }
221
222        fn assert_send<T: Send>(_: T) {}
223        fn assert_sync<T: Sync>(_: T) {}
224
225        let svc = AddAuthorization::init_with(
226            Credentials::from_json(
227                br#"{
228  "client_id": "xxx.apps.googleusercontent.com",
229  "client_secret": "secret-xxx",
230  "refresh_token": "refresh-xxx",
231  "type": "authorized_user"
232}"#,
233                &[],
234            ),
235            Counter::new(),
236        );
237        assert_send(svc.clone());
238        assert_sync(svc);
239    }
240
241    #[test]
242    fn cache_expiry() {
243        let now = Instant::now();
244        let c = Cache::new(HeaderValue::from_static("value"), now);
245        assert!(c.expired(now - Duration::from_secs(5)));
246        assert!(!c.expired(now - Duration::from_secs(30)))
247    }
248}