use std::time::Duration;
use super::cipher::TokenCipher;
use crate::error::Error;
use crate::oauth::{AuthClient, TokenResponse};
const DEFAULT_TRANSIENT_RETRY_AFTER: Duration = Duration::from_secs(2);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum RevokeCause {
CipherFailure,
PasRejected,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum TransientCause {
PasServerError,
Transport,
CipherEncryptFailed,
Unknown,
}
#[derive(Debug)]
#[must_use]
pub enum LivenessFailure {
Revoked { cause: RevokeCause },
Transient {
retry_after: Option<Duration>,
cause: TransientCause,
},
}
#[derive(Debug)]
#[must_use]
pub enum LivenessOutcome {
Fresh { rotated_ciphertext: Option<String> },
Failed(LivenessFailure),
}
pub fn classify_refresh_error(err: &Error) -> LivenessFailure {
match err {
Error::OAuth {
status: Some(code), ..
} if (500..600).contains(code) => LivenessFailure::Transient {
retry_after: Some(DEFAULT_TRANSIENT_RETRY_AFTER),
cause: TransientCause::PasServerError,
},
Error::OAuth {
status: Some(_), ..
} => LivenessFailure::Revoked {
cause: RevokeCause::PasRejected,
},
Error::OAuth { status: None, .. } => LivenessFailure::Transient {
retry_after: Some(DEFAULT_TRANSIENT_RETRY_AFTER),
cause: TransientCause::Unknown,
},
#[cfg(feature = "oauth")]
Error::Http(_) => LivenessFailure::Transient {
retry_after: Some(DEFAULT_TRANSIENT_RETRY_AFTER),
cause: TransientCause::Transport,
},
_ => LivenessFailure::Transient {
retry_after: Some(DEFAULT_TRANSIENT_RETRY_AFTER),
cause: TransientCause::Unknown,
},
}
}
pub async fn attempt_liveness_refresh(
cipher: &TokenCipher,
client: &AuthClient,
ciphertext: &str,
) -> LivenessOutcome {
let plaintext = match cipher.decrypt(ciphertext) {
Ok(p) => p,
Err(_) => {
return LivenessOutcome::Failed(LivenessFailure::Revoked {
cause: RevokeCause::CipherFailure,
});
}
};
match client.refresh_token(&plaintext).await {
Ok(TokenResponse { refresh_token, .. }) => match refresh_token.as_deref() {
Some(new_rt) => match cipher.encrypt(new_rt) {
Ok(new_ct) => LivenessOutcome::Fresh {
rotated_ciphertext: Some(new_ct),
},
Err(_) => LivenessOutcome::Failed(LivenessFailure::Transient {
retry_after: Some(DEFAULT_TRANSIENT_RETRY_AFTER),
cause: TransientCause::CipherEncryptFailed,
}),
},
None => LivenessOutcome::Fresh {
rotated_ciphertext: None,
},
},
Err(e) => LivenessOutcome::Failed(classify_refresh_error(&e)),
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn oauth_err(status: Option<u16>) -> Error {
Error::OAuth {
operation: "token refresh",
status,
detail: "test".into(),
}
}
#[test]
fn oauth_400_is_revoked_pas_rejected() {
assert!(matches!(
classify_refresh_error(&oauth_err(Some(400))),
LivenessFailure::Revoked {
cause: RevokeCause::PasRejected
}
));
}
#[test]
fn oauth_401_is_revoked_pas_rejected() {
assert!(matches!(
classify_refresh_error(&oauth_err(Some(401))),
LivenessFailure::Revoked {
cause: RevokeCause::PasRejected
}
));
}
#[test]
fn oauth_403_is_revoked_pas_rejected() {
assert!(matches!(
classify_refresh_error(&oauth_err(Some(403))),
LivenessFailure::Revoked {
cause: RevokeCause::PasRejected
}
));
}
#[test]
fn oauth_500_is_transient_pas_server_error() {
assert!(matches!(
classify_refresh_error(&oauth_err(Some(500))),
LivenessFailure::Transient {
cause: TransientCause::PasServerError,
..
}
));
}
#[test]
fn oauth_503_is_transient_pas_server_error() {
assert!(matches!(
classify_refresh_error(&oauth_err(Some(503))),
LivenessFailure::Transient {
cause: TransientCause::PasServerError,
..
}
));
}
#[test]
fn oauth_missing_status_is_transient_unknown() {
assert!(matches!(
classify_refresh_error(&oauth_err(None)),
LivenessFailure::Transient {
cause: TransientCause::Unknown,
..
}
));
}
#[test]
fn non_oauth_error_is_transient_unknown() {
use crate::error::TokenError;
assert!(matches!(
classify_refresh_error(&Error::Token(TokenError::Expired)),
LivenessFailure::Transient {
cause: TransientCause::Unknown,
..
}
));
}
#[tokio::test]
async fn decrypt_failure_short_circuits_to_revoked_cipher_failure() {
use crate::oauth::{AuthClient, OAuthConfig};
use base64::{Engine, engine::general_purpose::STANDARD};
let key_b64 = STANDARD.encode([0u8; 32]);
let cipher = TokenCipher::from_base64_key(&key_b64).unwrap();
let garbage_ct = STANDARD.encode([0u8; 64]);
let config = OAuthConfig::new("test-client", "https://example.invalid".parse().unwrap());
let client = AuthClient::try_new(config).unwrap();
let outcome = attempt_liveness_refresh(&cipher, &client, &garbage_ct).await;
assert!(matches!(
outcome,
LivenessOutcome::Failed(LivenessFailure::Revoked {
cause: RevokeCause::CipherFailure
})
));
}
}