google_authz/
token.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    time::{Duration, Instant},
5};
6
7use hyper::{
8    body::{aggregate, Body},
9    client::HttpConnector,
10    header::{HeaderValue, CONTENT_TYPE, USER_AGENT},
11    Method, Request, StatusCode, Uri,
12};
13use hyper_rustls::HttpsConnector;
14
15use crate::{credentials, token, Credentials};
16
17// === error ===
18
19/// Represents errors that can occur during getting token.
20#[derive(thiserror::Error, Debug)]
21pub enum Error {
22    // internal
23    #[error("http client error: {0}")]
24    Http(#[from] hyper::Error),
25    #[error("gcemeta client error: {0}")]
26    Gcemeta(#[from] gcemeta::Error),
27    // server
28    #[error("response status code error: {0}")]
29    StatusCode(StatusCode),
30    #[error("json deserialization error: {0:?}")]
31    InvalidJson(serde_json::Error),
32    #[error("invalid token: {0:?}")]
33    InvalidToken(Token),
34    #[error("invalid header value: {0:?}")]
35    InvalidHeaderValue(hyper::header::InvalidHeaderValue),
36}
37
38/// Wrapper for the `Result` type with an [`Error`](Error).
39pub type Result<T> = std::result::Result<T, Error>;
40
41// === http ===
42
43struct Client {
44    inner: hyper::Client<HttpsConnector<HttpConnector>, Body>,
45    user_agent: HeaderValue,
46    content_type: HeaderValue,
47}
48
49impl Client {
50    fn new() -> Client {
51        #[allow(unused_variables)]
52        #[cfg(feature = "native-certs")]
53        let https = HttpsConnector::with_native_roots();
54        #[cfg(feature = "webpki-roots")]
55        let https = HttpsConnector::with_webpki_roots();
56
57        Client {
58            inner: hyper::Client::builder().build(https),
59            user_agent: HeaderValue::from_static(concat!(
60                "github.com/mechiru/",
61                env!("CARGO_PKG_NAME"),
62                " v",
63                env!("CARGO_PKG_VERSION")
64            )),
65            content_type: HeaderValue::from_static("application/x-www-form-urlencoded"),
66        }
67    }
68
69    fn request<T>(&self, uri: &Uri, body: &T) -> Request<Body>
70    where
71        T: serde::Serialize,
72    {
73        let mut req = Request::builder().uri(uri).method(Method::POST);
74        let headers = req.headers_mut().unwrap();
75        headers.insert(USER_AGENT, self.user_agent.clone());
76        headers.insert(CONTENT_TYPE, self.content_type.clone());
77        let body = Body::from(serde_urlencoded::to_string(body).unwrap());
78        req.body(body).unwrap()
79    }
80
81    fn send<T>(&self, req: Request<Body>) -> impl Future<Output = Result<T>> + Send + 'static
82    where
83        T: serde::de::DeserializeOwned,
84    {
85        let fut = self.inner.request(req);
86        async {
87            use bytes::Buf as _;
88
89            let (parts, body) = fut.await?.into_parts();
90            match parts.status {
91                StatusCode::OK => {
92                    let buf = aggregate(body).await?;
93                    serde_json::from_reader(buf.reader()).map_err(Error::InvalidJson)
94                }
95                code => Err(Error::StatusCode(code)),
96            }
97        }
98    }
99}
100
101// === token ===
102
103#[derive(Debug, serde::Deserialize)]
104pub struct Token {
105    pub token_type: String,
106    pub access_token: String,
107    pub expires_in: u64,
108}
109
110impl Token {
111    pub fn into_pairs(self) -> Result<(HeaderValue, Instant)> {
112        if self.token_type.is_empty() || self.access_token.is_empty() || self.expires_in == 0 {
113            Err(Error::InvalidToken(self))
114        } else {
115            match HeaderValue::from_str(&format!("{} {}", self.token_type, self.access_token)) {
116                Ok(value) => Ok((value, Instant::now() + Duration::from_secs(self.expires_in))),
117                Err(err) => Err(Error::InvalidHeaderValue(err)),
118            }
119        }
120    }
121}
122
123// === token source ===
124
125pub enum TokenSource {
126    User(user::User),
127    ServiceAccount(service_account::ServiceAccount),
128    Metadata(metadata::Metadata),
129}
130
131impl TokenSource {
132    pub fn token(&self) -> Pin<Box<dyn Future<Output = token::Result<Token>> + Send + 'static>> {
133        match self {
134            TokenSource::User(user) => Box::pin(user.token()),
135            TokenSource::ServiceAccount(sa) => Box::pin(sa.token()),
136            TokenSource::Metadata(meta) => Box::pin(meta.token()),
137        }
138    }
139}
140
141impl From<Credentials> for TokenSource {
142    fn from(c: Credentials) -> Self {
143        use crate::{
144            credentials::Kind,
145            token::{service_account as sa, TokenSource::*},
146        };
147        match c.into_parts() {
148            (s, Kind::User(user)) => User(user::User::new(user, s)),
149            (s, Kind::ServiceAccount(sa)) => ServiceAccount(sa::ServiceAccount::new(sa, s)),
150            (s, Kind::Metadata(meta)) => Metadata(metadata::Metadata::new(meta, s)),
151        }
152    }
153}
154
155pub(super) mod user {
156    use super::*;
157
158    #[derive(serde::Serialize)]
159    struct Payload<'a> {
160        client_id: &'a str,
161        client_secret: &'a str,
162        grant_type: &'a str,
163        refresh_token: &'a str,
164    }
165
166    pub struct User {
167        inner: Client,
168        token_uri: Uri,
169        creds: credentials::User,
170    }
171
172    impl User {
173        pub(crate) fn new(user: credentials::User, _scopes: &'static [&'static str]) -> Self {
174            Self {
175                inner: Client::new(),
176                // https://github.com/golang/oauth2/blob/0f29369cfe4552d0e4bcddc57cc75f4d7e672a33/google/google.go#L24
177                token_uri: Uri::from_static("https://oauth2.googleapis.com/token"),
178                creds: user,
179            }
180        }
181
182        pub(crate) fn token(&self) -> impl Future<Output = Result<Token>> + Send + 'static {
183            let req = self.inner.request(&self.token_uri, &Payload {
184                client_id: &self.creds.client_id,
185                client_secret: &self.creds.client_secret,
186                grant_type: "refresh_token",
187                // The reflesh token is not included in the response from google's server,
188                // so it always uses the specified refresh token from the file.
189                refresh_token: &self.creds.refresh_token,
190            });
191            self.inner.send(req)
192        }
193    }
194}
195
196pub(super) mod service_account {
197    use std::time::SystemTime;
198
199    use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
200
201    use super::*;
202
203    // If client machine's time is in the future according
204    // to Google servers, an access token will not be issued.
205    fn issued_at() -> u64 {
206        SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs() - 10
207    }
208
209    // https://cloud.google.com/iot/docs/concepts/device-security#security_standards
210    fn header(typ: impl Into<String>, key_id: impl Into<String>) -> Header {
211        Header {
212            typ: Some(typ.into()),
213            alg: Algorithm::RS256,
214            kid: Some(key_id.into()),
215            ..Default::default()
216        }
217    }
218
219    #[derive(serde::Serialize)]
220    struct Claims<'a> {
221        iss: &'a str,
222        scope: &'a str,
223        aud: &'a str,
224        iat: u64,
225        exp: u64,
226    }
227
228    #[derive(serde::Serialize)]
229    struct Payload<'a> {
230        grant_type: &'a str,
231        assertion: &'a str,
232    }
233
234    pub struct ServiceAccount {
235        inner: Client,
236        header: Header,
237        private_key: EncodingKey,
238        token_uri: Uri,
239        token_uri_str: String,
240        scopes: String,
241        client_email: String,
242    }
243
244    impl ServiceAccount {
245        pub(crate) fn new(
246            sa: credentials::ServiceAccount,
247            scopes: &'static [&'static str],
248        ) -> Self {
249            Self {
250                inner: Client::new(),
251                header: header("JWT", sa.private_key_id),
252                private_key: EncodingKey::from_rsa_pem(sa.private_key.as_bytes()).unwrap(),
253                token_uri: Uri::from_maybe_shared(sa.token_uri.clone()).unwrap(),
254                token_uri_str: sa.token_uri,
255                scopes: scopes.join(" "),
256                client_email: sa.client_email,
257            }
258        }
259
260        pub(crate) fn token(&self) -> impl Future<Output = Result<Token>> + Send + 'static {
261            const EXPIRE: u64 = 60 * 60;
262
263            let iat = issued_at();
264            let claims = Claims {
265                iss: &self.client_email,
266                scope: &self.scopes,
267                aud: &self.token_uri_str,
268                iat,
269                exp: iat + EXPIRE,
270            };
271
272            let req = self.inner.request(&self.token_uri, &Payload {
273                grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
274                assertion: &encode(&self.header, &claims, &self.private_key).unwrap(),
275            });
276            self.inner.send(req)
277        }
278    }
279}
280
281pub(super) mod metadata {
282    use std::str::FromStr;
283
284    use hyper::{client::HttpConnector, http::uri::PathAndQuery, Body};
285
286    use super::*;
287
288    #[derive(serde::Serialize)]
289    struct Query<'a> {
290        scopes: &'a str,
291    }
292
293    pub struct Metadata {
294        inner: gcemeta::Client<HttpConnector, Body>,
295        path_and_query: PathAndQuery,
296    }
297
298    impl Metadata {
299        pub(crate) fn new(meta: credentials::Metadata, scopes: &'static [&'static str]) -> Self {
300            let query = match scopes.len() {
301                0 => String::new(),
302                _ => serde_urlencoded::to_string(&Query { scopes: &scopes.join(",") }).unwrap(),
303            };
304            let path_and_query = format!(
305                "/computeMetadata/v1/instance/service-accounts/{}/token?{}",
306                meta.account.unwrap_or("default"),
307                query
308            );
309            let path_and_query = PathAndQuery::from_str(&path_and_query).unwrap();
310            Self { inner: meta.client, path_and_query }
311        }
312
313        pub fn token(&self) -> impl Future<Output = Result<Token>> + Send + 'static {
314            // Already checked that this process is running on GCE.
315            let fut = self.inner.get_as(self.path_and_query.clone());
316            async { Ok(fut.await?) }
317        }
318    }
319}