use async_trait::async_trait;
use serde::Deserialize;
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::auth::types::{AccessToken, AuthorizedUserCreds, CachedToken};
use crate::token::{TokenError, TokenProvider};
const TOKEN_EXPIRY_BUFFER_SECS: u64 = 60;
const TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
expires_in: u64,
#[allow(dead_code)]
token_type: String,
#[allow(dead_code)]
scope: Option<String>,
}
#[derive(Debug)]
pub struct AuthorizedUserCredential {
client_id: String,
client_secret: String,
refresh_token: String,
quota_project_id: Option<String>,
cache: CachedToken,
http_client: reqwest::Client,
}
impl AuthorizedUserCredential {
pub fn new(
client_id: impl Into<String>,
client_secret: impl Into<String>,
refresh_token: impl Into<String>,
) -> Self {
Self {
client_id: client_id.into(),
client_secret: client_secret.into(),
refresh_token: refresh_token.into(),
quota_project_id: None,
cache: CachedToken::new(),
http_client: reqwest::Client::new(),
}
}
pub fn from_json(json: &str) -> Result<Self, AuthorizedUserError> {
let creds: AuthorizedUserCreds =
serde_json::from_str(json).map_err(|e| AuthorizedUserError::InvalidJson {
message: e.to_string(),
})?;
Self::validate_creds(&creds)?;
Ok(Self {
client_id: creds.client_id,
client_secret: creds.client_secret,
refresh_token: creds.refresh_token,
quota_project_id: creds.quota_project_id,
cache: CachedToken::new(),
http_client: reqwest::Client::new(),
})
}
pub fn from_file(path: &Path) -> Result<Self, AuthorizedUserError> {
let json =
std::fs::read_to_string(path).map_err(|e| AuthorizedUserError::FileReadError {
path: path.to_path_buf(),
source: e,
})?;
Self::from_json(&json)
}
pub fn with_http_client(
client_id: impl Into<String>,
client_secret: impl Into<String>,
refresh_token: impl Into<String>,
http_client: reqwest::Client,
) -> Self {
Self {
client_id: client_id.into(),
client_secret: client_secret.into(),
refresh_token: refresh_token.into(),
quota_project_id: None,
cache: CachedToken::new(),
http_client,
}
}
pub fn client_id(&self) -> &str {
&self.client_id
}
fn validate_creds(creds: &AuthorizedUserCreds) -> Result<(), AuthorizedUserError> {
if creds.cred_type != "authorized_user" {
return Err(AuthorizedUserError::InvalidCredType {
expected: "authorized_user".to_string(),
actual: creds.cred_type.clone(),
});
}
if creds.client_id.is_empty() {
return Err(AuthorizedUserError::MissingField {
field: "client_id".to_string(),
});
}
if creds.client_secret.is_empty() {
return Err(AuthorizedUserError::MissingField {
field: "client_secret".to_string(),
});
}
if creds.refresh_token.is_empty() {
return Err(AuthorizedUserError::MissingField {
field: "refresh_token".to_string(),
});
}
Ok(())
}
async fn fetch_token(&self) -> Result<AccessToken, AuthorizedUserError> {
let body = format!(
"client_id={}&client_secret={}&refresh_token={}&grant_type=refresh_token",
urlencoding::encode(&self.client_id),
urlencoding::encode(&self.client_secret),
urlencoding::encode(&self.refresh_token),
);
let response = self
.http_client
.post(TOKEN_ENDPOINT)
.header("Content-Type", "application/x-www-form-urlencoded")
.body(body)
.send()
.await
.map_err(|e| AuthorizedUserError::TokenExchangeFailed {
message: format!("HTTP request failed: {}", e),
})?;
let status = response.status();
let response_text =
response
.text()
.await
.map_err(|e| AuthorizedUserError::TokenExchangeFailed {
message: format!("Failed to read response body: {}", e),
})?;
if !status.is_success() {
return Err(AuthorizedUserError::TokenExchangeFailed {
message: format!("Token endpoint returned {}: {}", status, response_text),
});
}
let token_response: TokenResponse = serde_json::from_str(&response_text).map_err(|e| {
AuthorizedUserError::TokenExchangeFailed {
message: format!("Failed to parse token response: {}", e),
}
})?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| AuthorizedUserError::TokenExchangeFailed {
message: format!("Failed to get current time: {}", e),
})?
.as_secs();
Ok(AccessToken::new(
token_response.access_token,
now + token_response.expires_in,
))
}
}
#[async_trait]
impl TokenProvider for AuthorizedUserCredential {
async fn get_token(&self, _scopes: &[&str]) -> Result<String, TokenError> {
if let Some(token) = self.cache.get(TOKEN_EXPIRY_BUFFER_SECS).await {
return Ok(token);
}
let token = self
.fetch_token()
.await
.map_err(|e| TokenError::RefreshFailed {
message: e.to_string(),
})?;
let token_string = token.token.clone();
self.cache.set(token).await;
Ok(token_string)
}
fn on_token_rejected(&self) {
self.cache.clear_sync();
}
fn quota_project_id(&self) -> Option<&str> {
self.quota_project_id.as_deref()
}
}
#[derive(Debug, thiserror::Error)]
pub enum AuthorizedUserError {
#[error("Failed to read credentials file at {path}: {source}")]
FileReadError {
path: std::path::PathBuf,
#[source]
source: std::io::Error,
},
#[error("Invalid JSON: {message}")]
InvalidJson {
message: String,
},
#[error("Invalid credential type: expected '{expected}', got '{actual}'")]
InvalidCredType {
expected: String,
actual: String,
},
#[error("Missing required field: {field}")]
MissingField {
field: String,
},
#[error("Token exchange failed: {message}")]
TokenExchangeFailed {
message: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn test_creds_json() -> String {
r#"{
"type": "authorized_user",
"client_id": "test-client-id.apps.googleusercontent.com",
"client_secret": "test-client-secret",
"refresh_token": "1//test-refresh-token"
}"#
.to_string()
}
#[test]
fn test_new() {
let cred = AuthorizedUserCredential::new("client-id", "client-secret", "refresh-token");
assert_eq!(cred.client_id(), "client-id");
}
#[test]
fn test_from_json_valid() {
let json = test_creds_json();
let cred = AuthorizedUserCredential::from_json(&json).unwrap();
assert_eq!(
cred.client_id(),
"test-client-id.apps.googleusercontent.com"
);
}
#[test]
fn test_from_json_with_quota_project() {
let json = r#"{
"type": "authorized_user",
"client_id": "test-client-id.apps.googleusercontent.com",
"client_secret": "test-client-secret",
"refresh_token": "1//test-refresh-token",
"quota_project_id": "my-quota-project"
}"#;
let cred = AuthorizedUserCredential::from_json(json).unwrap();
assert_eq!(
cred.client_id(),
"test-client-id.apps.googleusercontent.com"
);
}
#[test]
fn test_from_json_invalid_json() {
let result = AuthorizedUserCredential::from_json("not valid json");
assert!(matches!(
result,
Err(AuthorizedUserError::InvalidJson { .. })
));
}
#[test]
fn test_from_json_wrong_type() {
let json = r#"{
"type": "service_account",
"client_id": "123",
"client_secret": "secret",
"refresh_token": "token"
}"#;
let result = AuthorizedUserCredential::from_json(json);
assert!(matches!(
result,
Err(AuthorizedUserError::InvalidCredType { .. })
));
}
#[test]
fn test_from_json_missing_client_id() {
let json = r#"{
"type": "authorized_user",
"client_id": "",
"client_secret": "secret",
"refresh_token": "token"
}"#;
let result = AuthorizedUserCredential::from_json(json);
assert!(matches!(
result,
Err(AuthorizedUserError::MissingField { field }) if field == "client_id"
));
}
#[test]
fn test_from_json_missing_client_secret() {
let json = r#"{
"type": "authorized_user",
"client_id": "client-id",
"client_secret": "",
"refresh_token": "token"
}"#;
let result = AuthorizedUserCredential::from_json(json);
assert!(matches!(
result,
Err(AuthorizedUserError::MissingField { field }) if field == "client_secret"
));
}
#[test]
fn test_from_json_missing_refresh_token() {
let json = r#"{
"type": "authorized_user",
"client_id": "client-id",
"client_secret": "secret",
"refresh_token": ""
}"#;
let result = AuthorizedUserCredential::from_json(json);
assert!(matches!(
result,
Err(AuthorizedUserError::MissingField { field }) if field == "refresh_token"
));
}
#[test]
fn test_from_file_not_found() {
let result = AuthorizedUserCredential::from_file(Path::new("/nonexistent/file.json"));
assert!(matches!(
result,
Err(AuthorizedUserError::FileReadError { .. })
));
}
#[test]
fn test_error_display() {
let err = AuthorizedUserError::InvalidJson {
message: "test error".to_string(),
};
assert!(err.to_string().contains("Invalid JSON"));
let err = AuthorizedUserError::InvalidCredType {
expected: "authorized_user".to_string(),
actual: "other".to_string(),
};
assert!(err.to_string().contains("authorized_user"));
assert!(err.to_string().contains("other"));
let err = AuthorizedUserError::MissingField {
field: "client_id".to_string(),
};
assert!(err.to_string().contains("client_id"));
let err = AuthorizedUserError::TokenExchangeFailed {
message: "exchange error".to_string(),
};
assert!(err.to_string().contains("exchange error"));
}
#[tokio::test]
async fn test_token_caching() {
let cred = AuthorizedUserCredential::new("client-id", "client-secret", "refresh-token");
let token = AccessToken::new("cached-token", u64::MAX);
cred.cache.set(token).await;
let result = cred.get_token(&[]).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "cached-token");
}
#[tokio::test]
async fn test_expired_token_not_returned() {
let cred = AuthorizedUserCredential::new("client-id", "client-secret", "refresh-token");
let token = AccessToken::new("expired-token", 0);
cred.cache.set(token).await;
let result = cred.get_token(&[]).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_token_near_expiry_not_returned() {
let cred = AuthorizedUserCredential::new("client-id", "client-secret", "refresh-token");
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let token = AccessToken::new("near-expiry-token", now + 30);
cred.cache.set(token).await;
let result = cred.get_token(&[]).await;
assert!(result.is_err()); }
#[tokio::test]
async fn test_on_token_rejected() {
let cred = AuthorizedUserCredential::new("client-id", "client-secret", "refresh-token");
let token = AccessToken::new("valid-token", u64::MAX);
cred.cache.set(token).await;
let cached = cred.cache.get(0).await;
assert_eq!(cached, Some("valid-token".to_string()));
cred.on_token_rejected();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let cached = cred.cache.get(0).await;
assert!(cached.is_none());
}
#[test]
fn test_with_http_client() {
let custom_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.unwrap();
let cred = AuthorizedUserCredential::with_http_client(
"client-id",
"client-secret",
"refresh-token",
custom_client,
);
assert_eq!(cred.client_id(), "client-id");
}
#[test]
fn test_quota_project_id_from_json() {
let json = r#"{
"type": "authorized_user",
"client_id": "test-client-id.apps.googleusercontent.com",
"client_secret": "test-client-secret",
"refresh_token": "1//test-refresh-token",
"quota_project_id": "my-billing-project"
}"#;
let cred = AuthorizedUserCredential::from_json(json).unwrap();
assert_eq!(cred.quota_project_id(), Some("my-billing-project"));
}
#[test]
fn test_quota_project_id_none_when_missing() {
let json = r#"{
"type": "authorized_user",
"client_id": "test-client-id.apps.googleusercontent.com",
"client_secret": "test-client-secret",
"refresh_token": "1//test-refresh-token"
}"#;
let cred = AuthorizedUserCredential::from_json(json).unwrap();
assert!(cred.quota_project_id().is_none());
}
}