use std::time::{Duration, Instant};
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct TokenResponse {
access_token: String,
token_type: String,
expires_at: Option<Instant>,
refresh_token: Option<String>,
scopes: Option<Vec<String>>,
id_token: Option<String>,
}
impl TokenResponse {
pub fn new(access_token: String, token_type: String) -> Self {
Self {
access_token,
token_type,
expires_at: None,
refresh_token: None,
scopes: None,
id_token: None,
}
}
pub fn with_expires_in(mut self, expires_in: Duration) -> Self {
self.expires_at = Some(Instant::now() + expires_in);
self
}
pub fn with_refresh_token(mut self, refresh_token: String) -> Self {
self.refresh_token = Some(refresh_token);
self
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = Some(scopes);
self
}
pub fn with_id_token(mut self, id_token: String) -> Self {
self.id_token = Some(id_token);
self
}
pub fn access_token(&self) -> &str {
&self.access_token
}
pub fn token_type(&self) -> &str {
&self.token_type
}
pub fn is_expired(&self) -> bool {
match self.expires_at {
Some(expires_at) => Instant::now() >= expires_at,
None => false, }
}
pub fn refresh_token(&self) -> Option<&str> {
self.refresh_token.as_deref()
}
pub fn id_token(&self) -> Option<&str> {
self.id_token.as_deref()
}
pub fn scopes(&self) -> Option<&[String]> {
self.scopes.as_deref()
}
pub fn expires_in(&self) -> Option<Duration> {
self.expires_at
.and_then(|exp| exp.checked_duration_since(Instant::now()))
}
pub fn authorization_header(&self) -> String {
format!("{} {}", self.token_type, self.access_token)
}
}
#[derive(Debug, Error)]
pub enum TokenError {
#[error("Authorization denied: {0}")]
AuthorizationDenied(String),
#[error("Invalid authorization code")]
InvalidCode,
#[error("Invalid CSRF state - possible CSRF attack")]
InvalidState,
#[error("Token exchange failed: {0}")]
ExchangeFailed(String),
#[error("Token refresh failed: {0}")]
RefreshFailed(String),
#[error("Network error: {0}")]
NetworkError(String),
#[error("Invalid response: {0}")]
InvalidResponse(String),
#[error("Token is expired")]
TokenExpired,
#[error("Missing required field: {0}")]
MissingField(String),
}
#[derive(Debug, Clone)]
pub struct PkceVerifier {
verifier: String,
challenge: String,
method: String,
}
impl PkceVerifier {
pub fn new(verifier: impl Into<String>) -> Self {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use sha2::{Digest, Sha256};
let verifier = verifier.into();
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
let challenge = URL_SAFE_NO_PAD.encode(hash);
Self {
verifier,
challenge,
method: "S256".to_string(),
}
}
pub fn generate() -> Self {
use rand::{rngs::OsRng, RngCore};
let mut verifier_bytes = [0u8; 32];
OsRng.fill_bytes(&mut verifier_bytes);
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
Self::new(URL_SAFE_NO_PAD.encode(verifier_bytes))
}
pub fn verifier(&self) -> &str {
&self.verifier
}
pub fn challenge(&self) -> &str {
&self.challenge
}
pub fn method(&self) -> &str {
&self.method
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrfState(String);
impl CsrfState {
pub fn generate() -> Self {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use rand::{rngs::OsRng, RngCore};
let mut bytes = [0u8; 16];
OsRng.fill_bytes(&mut bytes);
Self(URL_SAFE_NO_PAD.encode(bytes))
}
pub fn new(state: String) -> Self {
Self(state)
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn verify(&self, other: &str) -> bool {
self.0 == other
}
}
impl std::fmt::Display for CsrfState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_response() {
let token = TokenResponse::new("access123".to_string(), "Bearer".to_string())
.with_refresh_token("refresh456".to_string())
.with_expires_in(Duration::from_secs(3600));
assert_eq!(token.access_token(), "access123");
assert_eq!(token.token_type(), "Bearer");
assert_eq!(token.refresh_token(), Some("refresh456"));
assert!(!token.is_expired());
assert_eq!(token.authorization_header(), "Bearer access123");
}
#[test]
fn test_pkce_verifier() {
let pkce = PkceVerifier::generate();
assert!(!pkce.verifier().is_empty());
assert!(!pkce.challenge().is_empty());
assert_eq!(pkce.method(), "S256");
assert_ne!(pkce.verifier(), pkce.challenge());
let rebuilt = PkceVerifier::new(pkce.verifier().to_string());
assert_eq!(rebuilt.verifier(), pkce.verifier());
assert_eq!(rebuilt.challenge(), pkce.challenge());
}
#[test]
fn test_csrf_state() {
let state1 = CsrfState::generate();
let state2 = CsrfState::generate();
assert_ne!(state1, state2);
assert!(state1.verify(state1.as_str()));
assert!(!state1.verify(state2.as_str()));
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
fn access_token_strategy() -> impl Strategy<Value = String> {
prop::string::string_regex("[a-zA-Z0-9_.-]{20,100}").unwrap()
}
fn token_type_strategy() -> impl Strategy<Value = String> {
prop_oneof![
Just("Bearer".to_string()),
Just("bearer".to_string()),
Just("MAC".to_string()),
]
}
fn refresh_token_strategy() -> impl Strategy<Value = Option<String>> {
prop_oneof![
Just(None),
prop::string::string_regex("[a-zA-Z0-9_.-]{20,100}")
.unwrap()
.prop_map(Some),
]
}
fn expires_in_strategy() -> impl Strategy<Value = Option<Duration>> {
prop_oneof![
Just(None),
(300u64..86400).prop_map(|secs| Some(Duration::from_secs(secs))),
]
}
fn scopes_strategy() -> impl Strategy<Value = Option<Vec<String>>> {
prop_oneof![
Just(None),
prop::collection::vec("[a-z]{3,10}", 0..5).prop_map(Some),
]
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_token_response_has_access_token(
access_token in access_token_strategy(),
token_type in token_type_strategy(),
) {
let response = TokenResponse::new(access_token.clone(), token_type.clone());
prop_assert_eq!(response.access_token(), access_token.as_str());
prop_assert_eq!(response.token_type(), token_type.as_str());
}
#[test]
fn prop_token_expiration_tracking(
access_token in access_token_strategy(),
token_type in token_type_strategy(),
expires_in_secs in 1u64..3600,
) {
let expires_in = Duration::from_secs(expires_in_secs);
let response = TokenResponse::new(access_token, token_type)
.with_expires_in(expires_in);
prop_assert!(!response.is_expired());
let remaining = response.expires_in();
prop_assert!(remaining.is_some());
let remaining_secs = remaining.unwrap().as_secs();
prop_assert!(remaining_secs <= expires_in_secs);
prop_assert!(remaining_secs >= expires_in_secs.saturating_sub(2)); }
#[test]
fn prop_token_response_builder(
access_token in access_token_strategy(),
token_type in token_type_strategy(),
refresh_token in refresh_token_strategy(),
scopes in scopes_strategy(),
) {
let mut response = TokenResponse::new(access_token.clone(), token_type.clone());
if let Some(ref rt) = refresh_token {
response = response.with_refresh_token(rt.clone());
}
if let Some(ref sc) = scopes {
response = response.with_scopes(sc.clone());
}
prop_assert_eq!(response.access_token(), access_token.as_str());
prop_assert_eq!(response.refresh_token(), refresh_token.as_deref());
match (response.scopes(), scopes.as_ref()) {
(Some(got), Some(expected)) => prop_assert_eq!(got, expected.as_slice()),
(None, None) => {},
_ => prop_assert!(false, "Scope mismatch"),
}
}
#[test]
fn prop_authorization_header_format(
access_token in access_token_strategy(),
token_type in token_type_strategy(),
) {
let response = TokenResponse::new(access_token.clone(), token_type.clone());
let header = response.authorization_header();
let expected = format!("{} {}", token_type, access_token);
prop_assert_eq!(header.clone(), expected);
prop_assert!(header.starts_with(&token_type));
prop_assert!(header.ends_with(&access_token));
}
#[test]
fn prop_pkce_generates_unique_challenges(_seed in 0u32..100) {
let pkce1 = PkceVerifier::generate();
let pkce2 = PkceVerifier::generate();
prop_assert_ne!(pkce1.verifier(), pkce2.verifier());
prop_assert_ne!(pkce1.challenge(), pkce2.challenge());
prop_assert_eq!(pkce1.method(), "S256");
prop_assert_eq!(pkce2.method(), "S256");
}
#[test]
fn prop_pkce_verifier_challenge_different(_seed in 0u32..100) {
let pkce = PkceVerifier::generate();
prop_assert_ne!(pkce.verifier(), pkce.challenge());
prop_assert!(!pkce.verifier().is_empty());
prop_assert!(!pkce.challenge().is_empty());
prop_assert!(!pkce.verifier().contains('='));
prop_assert!(!pkce.challenge().contains('='));
}
#[test]
fn prop_pkce_challenge_deterministic(verifier_input in "[a-zA-Z0-9_-]{32,64}") {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(verifier_input.as_bytes());
let hash = hasher.finalize();
let expected_challenge = URL_SAFE_NO_PAD.encode(hash);
let mut hasher2 = Sha256::new();
hasher2.update(verifier_input.as_bytes());
let hash2 = hasher2.finalize();
let challenge2 = URL_SAFE_NO_PAD.encode(hash2);
prop_assert_eq!(expected_challenge, challenge2);
}
#[test]
fn prop_csrf_state_unique(_seed in 0u32..100) {
let state1 = CsrfState::generate();
let state2 = CsrfState::generate();
prop_assert_ne!(state1.clone(), state2.clone());
prop_assert_ne!(state1.as_str(), state2.as_str());
}
#[test]
fn prop_csrf_state_verification(
valid_state_str in "[a-zA-Z0-9_-]{10,50}",
invalid_state_str in "[a-zA-Z0-9_-]{10,50}",
) {
prop_assume!(valid_state_str != invalid_state_str);
let state = CsrfState::new(valid_state_str.clone());
prop_assert!(state.verify(&valid_state_str));
prop_assert!(!state.verify(&invalid_state_str));
}
#[test]
fn prop_csrf_state_roundtrip(state_str in "[a-zA-Z0-9_-]{10,50}") {
let state1 = CsrfState::new(state_str.clone());
let state2 = CsrfState::new(state1.as_str().to_string());
prop_assert_eq!(state1.clone(), state2.clone());
prop_assert_eq!(state1.as_str(), state2.as_str());
}
#[test]
fn prop_token_expiration_behavior(
access_token in access_token_strategy(),
has_expiration in proptest::bool::ANY,
) {
let mut response = TokenResponse::new(access_token, "Bearer".to_string());
if has_expiration {
response = response.with_expires_in(Duration::from_secs(3600));
prop_assert!(!response.is_expired());
prop_assert!(response.expires_in().is_some());
} else {
prop_assert!(!response.is_expired());
prop_assert!(response.expires_in().is_none());
}
}
}
}