use std::{str::FromStr, sync::Arc};
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::EncodingKey;
use serde::{
de::{self, Deserializer},
Deserialize, Serialize,
};
use serde_with::serde_as;
use super::{Token, TokenProvider};
pub const SCOPE_STORAGE_FULL_CONTROL: &str =
"https://www.googleapis.com/auth/devstorage.full_control";
pub struct OAuthTokenProvider {
scope: String,
service_account: ServiceAccount,
client: reqwest::Client,
}
impl OAuthTokenProvider {
pub fn new(
service_account: ServiceAccount,
scope: impl Into<String>,
) -> Result<Self, OAuthError> {
Self::new_with_client(service_account, scope, Default::default())
}
pub fn new_with_client(
service_account: ServiceAccount,
scope: impl Into<String>,
client: reqwest::Client,
) -> Result<Self, OAuthError> {
Ok(Self {
scope: scope.into(),
service_account,
client,
})
}
}
#[async_trait::async_trait]
impl TokenProvider for OAuthTokenProvider {
async fn get_token(&self) -> anyhow::Result<Arc<Token>> {
let header = jsonwebtoken::Header {
alg: jsonwebtoken::Algorithm::RS256,
..Default::default()
};
let now = Utc::now();
let expiry = now + Duration::hours(1);
let claims = Claims {
iss: &self.service_account.client_email,
scope: &self.scope,
aud: &self.service_account.token_uri,
iat: now,
exp: expiry,
};
let client_assertion =
jsonwebtoken::encode(&header, &claims, &self.service_account.private_key)?;
let res = self
.client
.post(&self.service_account.token_uri)
.form(&[
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", &client_assertion),
])
.send()
.await
.context("failed to request access token from Google")?;
let res_status = res.status();
let (token, expires_in) = match res.json::<OAuthResponse>().await? {
OAuthResponse::Token {
token: TokenKind::IdToken(..),
..
} => return Err(OAuthError::InvalidScope.into()),
OAuthResponse::Token {
token: TokenKind::AccessToken(token),
expires_in,
} => (token, expires_in),
OAuthResponse::Error {
error_description, ..
} => {
return Err(OAuthError::Other(crate::api::GoogleError {
status: res_status,
message: error_description,
})
.into())
}
};
Ok(Arc::new(Token {
token,
expiry: now + expires_in,
}))
}
}
#[serde_as]
#[derive(Serialize)]
struct Claims<'a> {
iss: &'a str,
aud: &'a str,
scope: &'a str,
#[serde_as(as = "serde_with::TimestampSeconds")]
exp: DateTime<Utc>,
#[serde_as(as = "serde_with::TimestampSeconds")]
iat: DateTime<Utc>,
}
#[serde_as]
#[derive(Deserialize)]
#[serde(untagged)]
enum OAuthResponse {
Token {
#[serde(flatten)]
token: TokenKind,
#[serde_as(as = "serde_with::DurationSeconds<i64>")]
expires_in: Duration,
},
Error {
error_description: String,
},
}
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
enum TokenKind {
IdToken(String),
AccessToken(String),
}
#[derive(Debug, thiserror::Error)]
pub enum OAuthError {
#[error("invalid RSA private key: {0}")]
InvalidSigningKey(#[from] jsonwebtoken::errors::Error),
#[error(transparent)]
Http(#[from] reqwest::Error),
#[error(transparent)]
Other(#[from] crate::api::GoogleError),
#[error("received an ID token instead of an access token. ensure that the scope is correct.")]
InvalidScope,
}
impl From<crate::api::Error> for OAuthError {
fn from(api_error: crate::api::Error) -> Self {
match api_error {
crate::api::Error::Http(e) => Self::Http(e),
crate::api::Error::Google(e) => Self::Other(e),
}
}
}
pub struct ServiceAccount {
client_email: String,
private_key: EncodingKey,
token_uri: String,
}
impl ServiceAccount {
pub fn read_from_file(path: impl AsRef<std::path::Path>) -> Result<Self, ServiceAccountError> {
let path = path.as_ref();
std::fs::read_to_string(path)
.map_err(|error| ServiceAccountError::Io {
file: path.to_path_buf(),
error,
})?
.parse()
}
pub fn read_from_canonical_env() -> Result<Self, ServiceAccountError> {
let service_account_path =
std::env::var_os("GOOGLE_APPLICATION_CREDENTIALS").unwrap_or_default();
Self::read_from_file(service_account_path)
}
}
impl FromStr for ServiceAccount {
type Err = ServiceAccountError;
fn from_str(sa_json: &str) -> Result<Self, Self::Err> {
let sa: DeserializableServiceAccount = serde_json::from_str(sa_json)?;
Ok(Self {
client_email: sa.client_email,
private_key: jsonwebtoken::EncodingKey::from_rsa_pem(sa.private_key.as_bytes())?,
token_uri: sa.token_uri,
})
}
}
impl<'de> Deserialize<'de> for ServiceAccount {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
String::deserialize(d)?.parse().map_err(de::Error::custom)
}
}
#[derive(Deserialize)]
struct DeserializableServiceAccount {
#[serde(rename = "type")]
_ty: ServiceAccountMarker,
client_email: String,
private_key: String,
token_uri: String,
}
const SERVICE_ACCOUNT_MARKER: &str = "service_account";
struct ServiceAccountMarker;
impl<'de> Deserialize<'de> for ServiceAccountMarker {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let ty: String = String::deserialize(d)?;
if ty == SERVICE_ACCOUNT_MARKER {
Ok(Self)
} else {
Err(de::Error::custom(&format!(
"provided JSON had unexpected `type` `{}`. expected `{}`.",
ty, SERVICE_ACCOUNT_MARKER
)))
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ServiceAccountError {
#[error("failed to read service account file `{file}`: {error}")]
Io {
file: std::path::PathBuf,
#[source]
error: std::io::Error,
},
#[error("cound not parse service account json: {0}")]
Parse(#[from] serde_json::Error),
#[error("invalid `private_key`: {0}")]
InvalidKey(#[from] jsonwebtoken::errors::Error),
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn provides_token() {
let sa = ServiceAccount::read_from_canonical_env().unwrap();
let provider = OAuthTokenProvider::new(sa, SCOPE_STORAGE_FULL_CONTROL).unwrap();
provider.get_token().await.unwrap();
}
#[tokio::test]
async fn fails_sanely() {
let mut sa = ServiceAccount::read_from_canonical_env().unwrap();
sa.client_email += "q";
let provider = OAuthTokenProvider::new(sa, SCOPE_STORAGE_FULL_CONTROL).unwrap();
let err = provider
.get_token()
.await
.unwrap_err()
.downcast::<OAuthError>()
.unwrap();
assert!(matches!(
err,
OAuthError::Other(crate::api::GoogleError { .. })
))
}
}