use std::fmt;
use serde::{Deserialize, Serialize};
use reqwest::{Url, blocking::Client};
use crate::{Error, ErrorKind, LocalServer, ClientSecrets};
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AccessToken {
pub access_token: String,
pub expires_in: i128,
pub refresh_token: String,
pub scope: String,
pub token_type: String,
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for AccessToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AccessToken")
.field("access_token", &format_args!("[hidden for security]"))
.field("expires_in", &self.expires_in)
.field("refresh_token", &format_args!("[hidden for security]"))
.field("scope", &self.scope)
.field("token_type", &self.token_type)
.finish()
}
}
impl AccessToken {
const TOKEN_INFO_URI: &'static str = "https://oauth2.googleapis.com/tokeninfo";
fn update_with( &mut self, refresh_token: &RefreshToken ) {
self.access_token = refresh_token.access_token.clone();
self.expires_in = refresh_token.expires_in;
self.scope = refresh_token.scope.clone();
self.token_type = refresh_token.token_type.clone();
}
pub fn is_valid( &self ) -> bool {
let parameters = [ ("access_token", &self.access_token) ];
let request_url = match Url::parse_with_params(Self::TOKEN_INFO_URI, ¶meters) {
Ok(url) => url,
#[cfg(not(tarpaulin_include))]
Err(_) => return false,
};
let response = match reqwest::blocking::get(request_url) {
Ok(response) => response,
#[cfg(not(tarpaulin_include))]
Err(_) => return false,
};
response.status() == 200
}
pub fn has_scopes<T: AsRef<str>> ( &self, scopes: &[T] ) -> bool {
let token_scopes: Vec<&str> = self.scope.split(' ').collect();
scopes.iter().all( |s| token_scopes.contains(&s.as_ref()) )
}
#[cfg(not(tarpaulin_include))]
pub fn request<T: AsRef<str>> ( client_secrets: &ClientSecrets, scopes: &[T], ) -> crate::Result<Self> {
let (authorization_code, code_verifier) = client_secrets.get_authorization_code(scopes, true)?;
let redirect_uri = &LocalServer::default().uri;
let parameters = [
( "client_id", &client_secrets.client_id ),
( "client_secret", &client_secrets.client_secret ),
( "code", &authorization_code.to_string() ),
( "code_verifier", &code_verifier.to_string() ),
( "grant_type", &String::from("authorization_code") ),
( "redirect_uri", &redirect_uri ),
];
let request = Client::new()
.post(&client_secrets.token_uri)
.form(¶meters);
let response = request.send()?;
if response.status() != 200 {
return Err( response.into() );
}
let access_token: AccessToken = serde_json::from_str( &response.text()? )?;
if !access_token.has_scopes(scopes) {
return Err( Error::new(
ErrorKind::MismatchedScopes,
"created access token does not contain the requested scopes",
) )
}
Ok(access_token)
}
pub fn refresh( &mut self, client_secrets: &ClientSecrets ) -> crate::Result<()> {
let body: [(&str, &str); 4] = [
( "client_id", &client_secrets.client_id ),
( "client_secret", &client_secrets.client_secret ),
( "grant_type", "refresh_token" ),
( "refresh_token", &self.refresh_token ),
];
let request = Client::new().post(&client_secrets.token_uri).form(&body);
let response = request.send()?;
if response.status() != 200 {
return Err( response.into() );
}
let refresh_token: RefreshToken = serde_json::from_str( &response.text()? )?;
self.update_with(&refresh_token);
Ok(())
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct RefreshToken {
pub access_token: String,
pub expires_in: i128,
pub scope: String,
pub token_type: String,
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for RefreshToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RefreshToken")
.field("access_token", &format_args!("[hidden for security]"))
.field("expires_in", &self.expires_in)
.field("scope", &self.scope)
.field("token_type", &self.token_type)
.finish()
}
}
#[cfg(test)]
mod tests {
use crate::ErrorKind;
use super::{AccessToken, RefreshToken};
use crate::utils::test::{VALID_CREDENTIALS, INVALID_CREDENTIALS};
fn get_test_access_token() -> AccessToken {
AccessToken {
access_token: String::from("test_access_token"),
expires_in: 1984,
refresh_token: String::from("test_refresh_token"),
scope: String::from("test_scope_one test_scope_two"),
token_type: String::from("Bearer"),
}
}
fn get_test_refresh_token() -> RefreshToken {
RefreshToken {
access_token: String::from("updated_test_access_token"),
expires_in: 9999,
scope: String::from("test_scope_one test_scope_two"),
token_type: String::from("Bearer"),
}
}
#[test]
fn update_with_test() {
let mut access_token = get_test_access_token();
let refresh_token_value = access_token.refresh_token.clone();
let refresh_token = get_test_refresh_token();
access_token.update_with(&refresh_token);
assert_eq!(access_token.access_token, refresh_token.access_token);
assert_eq!(access_token.expires_in, refresh_token.expires_in);
assert_eq!(access_token.refresh_token, refresh_token_value);
assert_eq!(access_token.scope, refresh_token.scope);
assert_eq!(access_token.token_type, refresh_token.token_type);
}
#[test]
fn is_valid_test() {
let invalid_access_token = get_test_access_token();
assert!( !invalid_access_token.is_valid() );
assert!( VALID_CREDENTIALS.access_token.is_valid() );
}
#[test]
fn has_scopes_test() {
let access_token = get_test_access_token();
let scopes = ["test_scope_one", "test_scope_two"];
assert!( access_token.has_scopes(&scopes) );
let scopes = ["test_scope_one"];
assert!( access_token.has_scopes(&scopes) );
let scopes= ["test_scope_two"];
assert!( access_token.has_scopes(&scopes) );
}
#[test]
fn has_scopes_mismatch_test() {
let access_token = get_test_access_token();
let scopes = ["invalid_test_scope"];
assert!( !access_token.has_scopes(&scopes) );
let scopes = ["test_scope_one", "test_scope_one", "test_scope_three"];
assert!( !access_token.has_scopes(&scopes) );
let scopes = ["test_scope_one", "invalid_test_scope_two"];
assert!( !access_token.has_scopes(&scopes) );
}
#[test]
fn has_scopes_empty_test() {
let access_token = get_test_access_token();
let scopes: [&str; 0] = [];
assert!( access_token.has_scopes(&scopes) );
}
#[test]
#[ignore = "requires user input (CI/CD)"]
fn request_test() {
let scopes = ["https://www.googleapis.com/auth/drive.metadata.readonly"];
let access_token = AccessToken::request(&VALID_CREDENTIALS.client_secrets, &scopes);
assert!( access_token.is_ok() )
}
#[test]
#[ignore = "requires user input (CI/CD)"]
fn request_invalid_secrets_uri_test() {
let mut client_secrets = VALID_CREDENTIALS.client_secrets.clone();
client_secrets.token_uri = String::from("invalid-token-uri");
let scopes: [&str; 1] = ["https://www.googleapis.com/auth/drive.metadata.readonly"];
let result = AccessToken::request(&client_secrets, &scopes);
assert!( result.is_err() );
assert_eq!( result.unwrap_err().kind, ErrorKind::Request );
}
#[test]
#[ignore = "requires user input (CI/CD)"]
fn request_mismatched_scopes_test() {
let client_secrets = VALID_CREDENTIALS.client_secrets.clone();
let scopes = [
"https://www.googleapis.com/auth/drive.metadata.readonly",
"https://www.googleapis.com/auth/drive.file",
];
let access_token = AccessToken::request(&client_secrets, &scopes);
assert!( access_token.is_err() );
assert_eq!( access_token.unwrap_err().kind, ErrorKind::MismatchedScopes );
}
#[test]
fn refresh_test() {
let mut invalid_credentials = INVALID_CREDENTIALS.clone();
let result = invalid_credentials.access_token.refresh(&invalid_credentials.client_secrets);
assert!( result.is_ok() );
assert!( invalid_credentials.access_token.is_valid() );
}
#[test]
fn refresh_invalid_refresh_token_test() {
let mut credentials = VALID_CREDENTIALS.clone();
credentials.access_token.refresh_token = String::from("invalid-token");
let result = credentials.access_token.refresh(&credentials.client_secrets);
assert!( result.is_err() );
assert_eq!( result.unwrap_err().kind, ErrorKind::Response );
}
#[test]
fn refresh_invalid_secrets_test() {
let mut credentials = VALID_CREDENTIALS.clone();
credentials.client_secrets.client_id = String::from("invalid-client-id");
credentials.client_secrets.client_secret = String::from("invalid-client-secret");
let result = credentials.access_token.refresh(&credentials.client_secrets);
assert!( result.is_err() );
assert_eq!( result.unwrap_err().kind, ErrorKind::Response );
}
#[test]
fn refresh_invalid_secrets_uri_test() {
let mut credentials = VALID_CREDENTIALS.clone();
credentials.client_secrets.token_uri = String::from("invalid-token-uri");
let result = credentials.access_token.refresh(&credentials.client_secrets);
assert!( result.is_err() );
assert_eq!( result.unwrap_err().kind, ErrorKind::Request );
}
}