1use std::{
2 fmt,
3 future::Future,
4 pin::Pin,
5 sync::Arc,
6 task::{self, Poll},
7 time::{Duration, Instant},
8};
9
10use hyper::{
11 header::{HeaderValue, AUTHORIZATION},
12 Request,
13};
14use parking_lot::RwLock;
15use tracing::{info, trace};
16
17use crate::{token, Credentials, Token, TokenSource};
18
19pub struct AddAuthorization<S> {
20 inner: Arc<RwLock<Inner>>,
21 service: S,
22}
23
24impl AddAuthorization<()> {
25 pub async fn init<S>(service: S) -> AddAuthorization<S> {
26 AddAuthorization { inner: Arc::new(RwLock::new(Inner::init().await)), service }
27 }
28
29 pub fn init_with<S>(source: impl Into<TokenSource>, service: S) -> AddAuthorization<S> {
30 AddAuthorization { inner: Arc::new(RwLock::new(Inner::init_with(source))), service }
31 }
32}
33
34enum State {
35 Uninitialized,
36 Fetching {
37 retry: u8,
38 fut: RefGuard<Pin<Box<dyn Future<Output = token::Result<Token>> + Send + 'static>>>,
39 },
40 Fetched,
41}
42
43struct RefGuard<T: Send>(T);
46
47impl<T: Send> RefGuard<T> {
48 pub fn new(value: T) -> Self {
49 RefGuard(value)
50 }
51
52 pub fn get_mut(&mut self) -> &mut T {
53 &mut self.0
54 }
55}
56
57unsafe impl<T: Send> Sync for RefGuard<T> {}
58
59struct Inner {
60 state: State,
61 cache: Option<Cache>,
62 source: TokenSource,
63 max_retry: u8,
64}
65
66impl Inner {
67 async fn init() -> Self {
68 Self::init_with(Credentials::default().await)
69 }
70
71 fn init_with(s: impl Into<TokenSource>) -> Self {
72 Self { state: State::Uninitialized, cache: None, source: s.into(), max_retry: 5 }
73 }
74
75 #[inline]
76 fn cache_ref(&self) -> &Cache {
77 self.cache.as_ref().unwrap()
78 }
79
80 fn can_skip_poll_ready(&self) -> bool {
81 matches!(self.state, State::Fetched) && !self.cache_ref().expired(Instant::now())
82 }
83
84 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<()> {
85 loop {
86 match self.state {
87 State::Uninitialized => {
88 trace!("token is uninitialized");
89 self.state =
90 State::Fetching { retry: 0, fut: RefGuard::new(self.source.token()) };
91 continue;
92 }
93 State::Fetching { ref retry, ref mut fut } => match fut.get_mut().as_mut().poll(cx)
94 {
95 Poll::Ready(r) => match r.and_then(|t| t.into_pairs()) {
96 Ok((value, expiry)) => {
97 self.cache = Some(Cache::new(value, expiry));
98 self.state = State::Fetched;
99 trace!("token updated: expiry={:?}", expiry);
100 return Poll::Ready(());
101 }
102 Err(err) => {
103 if *retry < self.max_retry {
104 info!("an error occurred: retry={}, err={:?}", retry, err);
105 } else {
106 panic!("max retry exceeded: retry={}, last error={:?}", retry, err);
107 }
108 self.state = State::Fetching {
109 retry: retry + 1,
110 fut: RefGuard::new(self.source.token()),
111 };
112 continue;
113 }
114 },
115 Poll::Pending => return Poll::Pending,
116 },
117 State::Fetched => {
118 let cache = self.cache_ref();
119 if !cache.expired(Instant::now()) {
120 return Poll::Ready(());
121 }
122 trace!("token will expire: expiry={:?}", cache.expiry);
123 self.state =
124 State::Fetching { retry: 0, fut: RefGuard::new(self.source.token()) };
125 continue;
126 }
127 }
128 }
129 }
130}
131
132#[derive(Clone)]
133struct Cache {
134 value: HeaderValue,
135 expiry: Instant,
136}
137
138impl Cache {
139 fn new(value: HeaderValue, expiry: Instant) -> Self {
140 Self { value, expiry }
141 }
142
143 fn expired(&self, at: Instant) -> bool {
144 const EXPIRY_DELTA: Duration = Duration::from_secs(10);
145 self.expiry.checked_duration_since(at).map(|dur| dur < EXPIRY_DELTA).unwrap_or(true)
146 }
147
148 fn value(&self) -> HeaderValue {
149 self.value.clone()
150 }
151}
152
153impl<S, B> tower_service::Service<Request<B>> for AddAuthorization<S>
154where
155 S: tower_service::Service<Request<B>>,
156{
157 type Response = S::Response;
158 type Error = S::Error;
159 type Future = S::Future;
160
161 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
162 if self.inner.read().can_skip_poll_ready() {
163 return self.service.poll_ready(cx);
164 }
165 match self.inner.write().poll_ready(cx) {
166 Poll::Ready(()) => self.service.poll_ready(cx),
167 Poll::Pending => Poll::Pending,
168 }
169 }
170
171 fn call(&mut self, mut req: Request<B>) -> Self::Future {
172 req.headers_mut().insert(AUTHORIZATION, self.inner.read().cache_ref().value());
173 self.service.call(req)
174 }
175}
176
177impl<S: Clone> Clone for AddAuthorization<S> {
178 fn clone(&self) -> Self {
179 Self { inner: self.inner.clone(), service: self.service.clone() }
180 }
181}
182
183impl<S: fmt::Debug> fmt::Debug for AddAuthorization<S> {
184 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
185 f.debug_struct("AddAuthorization").field("service", &self.service).finish()
186 }
187}
188
189#[cfg(test)]
190mod test {
191 use super::*;
192
193 #[test]
194 fn compile_test() {
195 #[derive(Clone)]
196 struct Counter {
197 cur: i32,
198 }
199
200 impl Counter {
201 fn new() -> Self {
202 Counter { cur: 0 }
203 }
204 }
205
206 impl<B> tower_service::Service<Request<B>> for Counter {
207 type Response = i32;
208 type Error = i32;
209 type Future = Pin<Box<dyn Future<Output = Result<i32, i32>> + Send + 'static>>;
210
211 fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
212 Poll::Ready(Ok(()))
213 }
214
215 fn call(&mut self, _: Request<B>) -> Self::Future {
216 self.cur += 1;
217 let current = self.cur;
218 Box::pin(async move { Ok(current) })
219 }
220 }
221
222 fn assert_send<T: Send>(_: T) {}
223 fn assert_sync<T: Sync>(_: T) {}
224
225 let svc = AddAuthorization::init_with(
226 Credentials::from_json(
227 br#"{
228 "client_id": "xxx.apps.googleusercontent.com",
229 "client_secret": "secret-xxx",
230 "refresh_token": "refresh-xxx",
231 "type": "authorized_user"
232}"#,
233 &[],
234 ),
235 Counter::new(),
236 );
237 assert_send(svc.clone());
238 assert_sync(svc);
239 }
240
241 #[test]
242 fn cache_expiry() {
243 let now = Instant::now();
244 let c = Cache::new(HeaderValue::from_static("value"), now);
245 assert!(c.expired(now - Duration::from_secs(5)));
246 assert!(!c.expired(now - Duration::from_secs(30)))
247 }
248}