use hyper::{
body::{aggregate, Body},
client::HttpConnector,
header::{HeaderValue, CONTENT_TYPE, USER_AGENT},
Method, Request, StatusCode, Uri,
};
use hyper_rustls::HttpsConnector;
use std::time::{Duration, Instant};
use crate::{credentials, Credentials};
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("http client error: {0}")]
Http(#[from] hyper::Error),
#[error("gcemeta client error: {0}")]
Gcemeta(#[from] gcemeta::Error),
#[error("response status code error: {0}")]
StatusCode(StatusCode),
#[error("json deserialization error: {0:?}")]
InvalidJson(serde_json::Error),
#[error("invalid token: {0:?}")]
InvalidToken(Token),
#[error("invalid header value: {0:?}")]
InvalidHeaderValue(hyper::header::InvalidHeaderValue),
}
pub type Result<T> = std::result::Result<T, Error>;
struct Client {
inner: hyper::Client<HttpsConnector<HttpConnector>, Body>,
user_agent: HeaderValue,
content_type: HeaderValue,
}
impl Client {
fn new() -> Client {
#[allow(unused_variables)]
#[cfg(feature = "native-certs")]
let https = HttpsConnector::with_native_roots();
#[cfg(feature = "webpki-roots")]
let https = HttpsConnector::with_webpki_roots();
Client {
inner: hyper::Client::builder().build(https),
user_agent: HeaderValue::from_static(concat!(
"github.com/mechiru/",
env!("CARGO_PKG_NAME"),
" v",
env!("CARGO_PKG_VERSION")
)),
content_type: HeaderValue::from_static("application/x-www-form-urlencoded"),
}
}
async fn request<T, U>(&self, uri: &Uri, body: &T) -> Result<U>
where
T: serde::Serialize,
U: serde::de::DeserializeOwned,
{
use bytes::Buf as _;
let mut req = Request::builder().uri(uri).method(Method::POST);
let headers = req.headers_mut().unwrap();
headers.insert(USER_AGENT, self.user_agent.clone());
headers.insert(CONTENT_TYPE, self.content_type.clone());
let body = Body::from(serde_urlencoded::to_string(body).unwrap());
let req = req.body(body).unwrap();
let (parts, body) = self.inner.request(req).await?.into_parts();
match parts.status {
StatusCode::OK => {
let buf = aggregate(body).await?;
serde_json::from_reader(buf.reader()).map_err(Error::InvalidJson)
}
code => Err(Error::StatusCode(code)),
}
}
}
#[derive(Debug, serde::Deserialize)]
pub struct Token {
pub token_type: String,
pub access_token: String,
pub expires_in: u64,
}
impl Token {
pub fn into_pairs(self) -> Result<(HeaderValue, Instant)> {
if self.token_type.is_empty() || self.access_token.is_empty() || self.expires_in == 0 {
Err(Error::InvalidToken(self))
} else {
match HeaderValue::from_str(&format!("{} {}", self.token_type, self.access_token)) {
Ok(value) => Ok((value, Instant::now() + Duration::from_secs(self.expires_in))),
Err(err) => Err(Error::InvalidHeaderValue(err)),
}
}
}
}
pub enum TokenSource {
User(user::User),
ServiceAccount(service_account::ServiceAccount),
Metadata(metadata::Metadata),
}
impl TokenSource {
pub async fn token(&self) -> Result<Token> {
match self {
TokenSource::User(user) => user.token().await,
TokenSource::ServiceAccount(sa) => sa.token().await,
TokenSource::Metadata(meta) => meta.token().await,
}
}
}
impl From<Credentials> for TokenSource {
fn from(c: Credentials) -> Self {
use crate::{
credentials::Kind,
token::{service_account as sa, TokenSource::*},
};
match c.into_parts() {
(s, Kind::User(user)) => User(user::User::new(user, s)),
(s, Kind::ServiceAccount(sa)) => ServiceAccount(sa::ServiceAccount::new(sa, s)),
(s, Kind::Metadata(meta)) => Metadata(metadata::Metadata::new(meta, s)),
}
}
}
pub(super) mod user {
use super::*;
#[derive(serde::Serialize)]
struct Payload<'a> {
client_id: &'a str,
client_secret: &'a str,
grant_type: &'a str,
refresh_token: &'a str,
}
pub struct User {
inner: Client,
token_uri: Uri,
creds: credentials::User,
}
impl User {
pub(crate) fn new(user: credentials::User, _scopes: &'static [&'static str]) -> Self {
Self {
inner: Client::new(),
token_uri: Uri::from_static("https://oauth2.googleapis.com/token"),
creds: user,
}
}
pub(crate) async fn token(&self) -> Result<Token> {
self.inner
.request(
&self.token_uri,
&Payload {
client_id: &self.creds.client_id,
client_secret: &self.creds.client_secret,
grant_type: "refresh_token",
refresh_token: &self.creds.refresh_token,
},
)
.await
}
}
}
pub(super) mod service_account {
use super::*;
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use std::time::SystemTime;
fn issued_at() -> u64 {
SystemTime::UNIX_EPOCH.elapsed().unwrap().as_secs() - 10
}
fn header(typ: impl Into<String>, key_id: impl Into<String>) -> Header {
Header {
typ: Some(typ.into()),
alg: Algorithm::RS256,
kid: Some(key_id.into()),
..Default::default()
}
}
#[derive(serde::Serialize)]
struct Claims<'a> {
iss: &'a str,
scope: &'a str,
aud: &'a str,
iat: u64,
exp: u64,
}
#[derive(serde::Serialize)]
struct Payload<'a> {
grant_type: &'a str,
assertion: &'a str,
}
pub struct ServiceAccount {
inner: Client,
header: Header,
private_key: EncodingKey,
token_uri: Uri,
token_uri_str: String,
scopes: String,
client_email: String,
}
impl ServiceAccount {
pub(crate) fn new(
sa: credentials::ServiceAccount,
scopes: &'static [&'static str],
) -> Self {
Self {
inner: Client::new(),
header: header("JWT", sa.private_key_id),
private_key: EncodingKey::from_rsa_pem(sa.private_key.as_bytes()).unwrap(),
token_uri: Uri::from_maybe_shared(sa.token_uri.clone()).unwrap(),
token_uri_str: sa.token_uri,
scopes: scopes.join(" "),
client_email: sa.client_email,
}
}
pub(crate) async fn token(&self) -> Result<Token> {
const EXPIRE: u64 = 60 * 60;
let iat = issued_at();
let claims = Claims {
iss: &self.client_email,
scope: &self.scopes,
aud: &self.token_uri_str,
iat,
exp: iat + EXPIRE,
};
self.inner
.request(
&self.token_uri,
&Payload {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
assertion: &encode(&self.header, &claims, &self.private_key).unwrap(),
},
)
.await
}
}
}
pub(super) mod metadata {
use super::*;
use hyper::{client::HttpConnector, http::uri::PathAndQuery, Body};
use std::str::FromStr;
#[derive(serde::Serialize)]
struct Query<'a> {
scopes: &'a str,
}
pub struct Metadata {
inner: gcemeta::Client<HttpConnector, Body>,
path_and_query: PathAndQuery,
}
impl Metadata {
pub(crate) fn new(meta: credentials::Metadata, scopes: &'static [&'static str]) -> Self {
let query = match scopes.len() {
0 => String::new(),
_ => serde_urlencoded::to_string(&Query { scopes: &scopes.join(",") }).unwrap(),
};
let path_and_query = format!(
"/computeMetadata/v1/instance/service-accounts/{}/token?{}",
meta.account.unwrap_or("default"),
query
);
let path_and_query = PathAndQuery::from_str(&path_and_query).unwrap();
Self { inner: meta.client, path_and_query }
}
pub async fn token(&self) -> Result<Token> {
if !self.inner.on_gce().await? {
panic!("this process is not running on GCE")
}
Ok(self.inner.get_as(self.path_and_query.clone()).await?)
}
}
}