1use std::{
2 future::Future,
3 pin::Pin,
4 time::{Duration, Instant},
5};
6
7use hyper::{
8 body::{aggregate, Body},
9 client::HttpConnector,
10 header::{HeaderValue, CONTENT_TYPE, USER_AGENT},
11 Method, Request, StatusCode, Uri,
12};
13use hyper_rustls::HttpsConnector;
14
15use crate::{credentials, token, Credentials};
16
17#[derive(thiserror::Error, Debug)]
21pub enum Error {
22 #[error("http client error: {0}")]
24 Http(#[from] hyper::Error),
25 #[error("gcemeta client error: {0}")]
26 Gcemeta(#[from] gcemeta::Error),
27 #[error("response status code error: {0}")]
29 StatusCode(StatusCode),
30 #[error("json deserialization error: {0:?}")]
31 InvalidJson(serde_json::Error),
32 #[error("invalid token: {0:?}")]
33 InvalidToken(Token),
34 #[error("invalid header value: {0:?}")]
35 InvalidHeaderValue(hyper::header::InvalidHeaderValue),
36}
37
38pub type Result<T> = std::result::Result<T, Error>;
40
41struct Client {
44 inner: hyper::Client<HttpsConnector<HttpConnector>, Body>,
45 user_agent: HeaderValue,
46 content_type: HeaderValue,
47}
48
49impl Client {
50 fn new() -> Client {
51 #[allow(unused_variables)]
52 #[cfg(feature = "native-certs")]
53 let https = HttpsConnector::with_native_roots();
54 #[cfg(feature = "webpki-roots")]
55 let https = HttpsConnector::with_webpki_roots();
56
57 Client {
58 inner: hyper::Client::builder().build(https),
59 user_agent: HeaderValue::from_static(concat!(
60 "github.com/mechiru/",
61 env!("CARGO_PKG_NAME"),
62 " v",
63 env!("CARGO_PKG_VERSION")
64 )),
65 content_type: HeaderValue::from_static("application/x-www-form-urlencoded"),
66 }
67 }
68
69 fn request<T>(&self, uri: &Uri, body: &T) -> Request<Body>
70 where
71 T: serde::Serialize,
72 {
73 let mut req = Request::builder().uri(uri).method(Method::POST);
74 let headers = req.headers_mut().unwrap();
75 headers.insert(USER_AGENT, self.user_agent.clone());
76 headers.insert(CONTENT_TYPE, self.content_type.clone());
77 let body = Body::from(serde_urlencoded::to_string(body).unwrap());
78 req.body(body).unwrap()
79 }
80
81 fn send<T>(&self, req: Request<Body>) -> impl Future<Output = Result<T>> + Send + 'static
82 where
83 T: serde::de::DeserializeOwned,
84 {
85 let fut = self.inner.request(req);
86 async {
87 use bytes::Buf as _;
88
89 let (parts, body) = fut.await?.into_parts();
90 match parts.status {
91 StatusCode::OK => {
92 let buf = aggregate(body).await?;
93 serde_json::from_reader(buf.reader()).map_err(Error::InvalidJson)
94 }
95 code => Err(Error::StatusCode(code)),
96 }
97 }
98 }
99}
100
101#[derive(Debug, serde::Deserialize)]
104pub struct Token {
105 pub token_type: String,
106 pub access_token: String,
107 pub expires_in: u64,
108}
109
110impl Token {
111 pub fn into_pairs(self) -> Result<(HeaderValue, Instant)> {
112 if self.token_type.is_empty() || self.access_token.is_empty() || self.expires_in == 0 {
113 Err(Error::InvalidToken(self))
114 } else {
115 match HeaderValue::from_str(&format!("{} {}", self.token_type, self.access_token)) {
116 Ok(value) => Ok((value, Instant::now() + Duration::from_secs(self.expires_in))),
117 Err(err) => Err(Error::InvalidHeaderValue(err)),
118 }
119 }
120 }
121}
122
123pub enum TokenSource {
126 User(user::User),
127 ServiceAccount(service_account::ServiceAccount),
128 Metadata(metadata::Metadata),
129}
130
131impl TokenSource {
132 pub fn token(&self) -> Pin<Box<dyn Future<Output = token::Result<Token>> + Send + 'static>> {
133 match self {
134 TokenSource::User(user) => Box::pin(user.token()),
135 TokenSource::ServiceAccount(sa) => Box::pin(sa.token()),
136 TokenSource::Metadata(meta) => Box::pin(meta.token()),
137 }
138 }
139}
140
141impl From<Credentials> for TokenSource {
142 fn from(c: Credentials) -> Self {
143 use crate::{
144 credentials::Kind,
145 token::{service_account as sa, TokenSource::*},
146 };
147 match c.into_parts() {
148 (s, Kind::User(user)) => User(user::User::new(user, s)),
149 (s, Kind::ServiceAccount(sa)) => ServiceAccount(sa::ServiceAccount::new(sa, s)),
150 (s, Kind::Metadata(meta)) => Metadata(metadata::Metadata::new(meta, s)),
151 }
152 }
153}
154
155pub(super) mod user {
156 use super::*;
157
158 #[derive(serde::Serialize)]
159 struct Payload<'a> {
160 client_id: &'a str,
161 client_secret: &'a str,
162 grant_type: &'a str,
163 refresh_token: &'a str,
164 }
165
166 pub struct User {
167 inner: Client,
168 token_uri: Uri,
169 creds: credentials::User,
170 }
171
172 impl User {
173 pub(crate) fn new(user: credentials::User, _scopes: &'static [&'static str]) -> Self {
174 Self {
175 inner: Client::new(),
176 token_uri: Uri::from_static("https://oauth2.googleapis.com/token"),
178 creds: user,
179 }
180 }
181
182 pub(crate) fn token(&self) -> impl Future<Output = Result<Token>> + Send + 'static {
183 let req = self.inner.request(&self.token_uri, &Payload {
184 client_id: &self.creds.client_id,
185 client_secret: &self.creds.client_secret,
186 grant_type: "refresh_token",
187 refresh_token: &self.creds.refresh_token,
190 });
191 self.inner.send(req)
192 }
193 }
194}
195
196pub(super) mod service_account {
197 use std::time::SystemTime;
198
199 use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
200
201 use super::*;
202
203 fn issued_at() -> u64 {
206 SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs() - 10
207 }
208
209 fn header(typ: impl Into<String>, key_id: impl Into<String>) -> Header {
211 Header {
212 typ: Some(typ.into()),
213 alg: Algorithm::RS256,
214 kid: Some(key_id.into()),
215 ..Default::default()
216 }
217 }
218
219 #[derive(serde::Serialize)]
220 struct Claims<'a> {
221 iss: &'a str,
222 scope: &'a str,
223 aud: &'a str,
224 iat: u64,
225 exp: u64,
226 }
227
228 #[derive(serde::Serialize)]
229 struct Payload<'a> {
230 grant_type: &'a str,
231 assertion: &'a str,
232 }
233
234 pub struct ServiceAccount {
235 inner: Client,
236 header: Header,
237 private_key: EncodingKey,
238 token_uri: Uri,
239 token_uri_str: String,
240 scopes: String,
241 client_email: String,
242 }
243
244 impl ServiceAccount {
245 pub(crate) fn new(
246 sa: credentials::ServiceAccount,
247 scopes: &'static [&'static str],
248 ) -> Self {
249 Self {
250 inner: Client::new(),
251 header: header("JWT", sa.private_key_id),
252 private_key: EncodingKey::from_rsa_pem(sa.private_key.as_bytes()).unwrap(),
253 token_uri: Uri::from_maybe_shared(sa.token_uri.clone()).unwrap(),
254 token_uri_str: sa.token_uri,
255 scopes: scopes.join(" "),
256 client_email: sa.client_email,
257 }
258 }
259
260 pub(crate) fn token(&self) -> impl Future<Output = Result<Token>> + Send + 'static {
261 const EXPIRE: u64 = 60 * 60;
262
263 let iat = issued_at();
264 let claims = Claims {
265 iss: &self.client_email,
266 scope: &self.scopes,
267 aud: &self.token_uri_str,
268 iat,
269 exp: iat + EXPIRE,
270 };
271
272 let req = self.inner.request(&self.token_uri, &Payload {
273 grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
274 assertion: &encode(&self.header, &claims, &self.private_key).unwrap(),
275 });
276 self.inner.send(req)
277 }
278 }
279}
280
281pub(super) mod metadata {
282 use std::str::FromStr;
283
284 use hyper::{client::HttpConnector, http::uri::PathAndQuery, Body};
285
286 use super::*;
287
288 #[derive(serde::Serialize)]
289 struct Query<'a> {
290 scopes: &'a str,
291 }
292
293 pub struct Metadata {
294 inner: gcemeta::Client<HttpConnector, Body>,
295 path_and_query: PathAndQuery,
296 }
297
298 impl Metadata {
299 pub(crate) fn new(meta: credentials::Metadata, scopes: &'static [&'static str]) -> Self {
300 let query = match scopes.len() {
301 0 => String::new(),
302 _ => serde_urlencoded::to_string(&Query { scopes: &scopes.join(",") }).unwrap(),
303 };
304 let path_and_query = format!(
305 "/computeMetadata/v1/instance/service-accounts/{}/token?{}",
306 meta.account.unwrap_or("default"),
307 query
308 );
309 let path_and_query = PathAndQuery::from_str(&path_and_query).unwrap();
310 Self { inner: meta.client, path_and_query }
311 }
312
313 pub fn token(&self) -> impl Future<Output = Result<Token>> + Send + 'static {
314 let fut = self.inner.get_as(self.path_and_query.clone());
316 async { Ok(fut.await?) }
317 }
318 }
319}