use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum TokenType {
#[default]
Bearer,
Mac,
}
impl std::fmt::Display for TokenType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TokenType::Bearer => write!(f, "Bearer"),
TokenType::Mac => write!(f, "MAC"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: TokenType,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
impl TokenResponse {
pub fn new(access_token: impl Into<String>) -> Self {
Self {
access_token: access_token.into(),
token_type: TokenType::Bearer,
expires_in: None,
refresh_token: None,
scope: None,
}
}
pub fn with_expires_in(mut self, seconds: u64) -> Self {
self.expires_in = Some(seconds);
self
}
pub fn with_refresh_token(mut self, refresh_token: impl Into<String>) -> Self {
self.refresh_token = Some(refresh_token.into());
self
}
pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
self.scope = Some(scope.into());
self
}
pub fn with_token_type(mut self, token_type: TokenType) -> Self {
self.token_type = token_type;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessToken {
pub token: String,
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
pub scopes: Vec<String>,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
#[serde(default)]
pub revoked: bool,
}
impl AccessToken {
pub fn new(
token: impl Into<String>,
client_id: impl Into<String>,
scopes: Vec<String>,
expires_in: Duration,
) -> Self {
let now = Utc::now();
Self {
token: token.into(),
client_id: client_id.into(),
user_id: None,
scopes,
created_at: now,
expires_at: now + expires_in,
revoked: false,
}
}
pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
self.user_id = Some(user_id.into());
self
}
pub fn is_valid(&self) -> bool {
!self.revoked && self.expires_at > Utc::now()
}
pub fn is_expired(&self) -> bool {
self.expires_at <= Utc::now()
}
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.iter().any(|s| s == scope)
}
pub fn has_all_scopes(&self, scopes: &[&str]) -> bool {
scopes.iter().all(|s| self.has_scope(s))
}
pub fn revoke(&mut self) {
self.revoked = true;
}
pub fn remaining_lifetime(&self) -> Option<i64> {
if self.is_valid() {
Some((self.expires_at - Utc::now()).num_seconds())
} else {
None
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthRefreshToken {
pub token: String,
pub client_id: String,
pub user_id: String,
pub scopes: Vec<String>,
pub created_at: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<DateTime<Utc>>,
#[serde(default)]
pub revoked: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub access_token_id: Option<String>,
}
impl OAuthRefreshToken {
pub fn new(
token: impl Into<String>,
client_id: impl Into<String>,
user_id: impl Into<String>,
scopes: Vec<String>,
) -> Self {
Self {
token: token.into(),
client_id: client_id.into(),
user_id: user_id.into(),
scopes,
created_at: Utc::now(),
expires_at: None,
revoked: false,
access_token_id: None,
}
}
pub fn with_expires_at(mut self, expires_at: DateTime<Utc>) -> Self {
self.expires_at = Some(expires_at);
self
}
pub fn with_expires_in(mut self, duration: Duration) -> Self {
self.expires_at = Some(Utc::now() + duration);
self
}
pub fn with_access_token_id(mut self, access_token_id: impl Into<String>) -> Self {
self.access_token_id = Some(access_token_id.into());
self
}
pub fn is_valid(&self) -> bool {
if self.revoked {
return false;
}
if let Some(expires_at) = self.expires_at {
expires_at > Utc::now()
} else {
true }
}
pub fn revoke(&mut self) {
self.revoked = true;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntrospectionResponse {
pub active: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_type: Option<TokenType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exp: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iat: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub aud: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iss: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
}
impl IntrospectionResponse {
pub fn inactive() -> Self {
Self {
active: false,
token_type: None,
scope: None,
client_id: None,
username: None,
exp: None,
iat: None,
nbf: None,
sub: None,
aud: None,
iss: None,
jti: None,
}
}
pub fn from_access_token(token: &AccessToken) -> Self {
Self {
active: token.is_valid(),
token_type: Some(TokenType::Bearer),
scope: if token.scopes.is_empty() {
None
} else {
Some(token.scopes.join(" "))
},
client_id: Some(token.client_id.clone()),
username: None,
exp: Some(token.expires_at.timestamp()),
iat: Some(token.created_at.timestamp()),
nbf: Some(token.created_at.timestamp()),
sub: token.user_id.clone(),
aud: None,
iss: None,
jti: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthError {
pub error: OAuthErrorCode,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_uri: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OAuthErrorCode {
InvalidRequest,
InvalidClient,
InvalidGrant,
UnauthorizedClient,
UnsupportedGrantType,
InvalidScope,
ServerError,
TemporarilyUnavailable,
AccessDenied,
UnsupportedResponseType,
}
impl OAuthError {
pub fn new(error: OAuthErrorCode) -> Self {
Self {
error,
error_description: None,
error_uri: None,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.error_description = Some(description.into());
self
}
pub fn with_uri(mut self, uri: impl Into<String>) -> Self {
self.error_uri = Some(uri.into());
self
}
pub fn invalid_request(description: impl Into<String>) -> Self {
Self::new(OAuthErrorCode::InvalidRequest).with_description(description)
}
pub fn invalid_client(description: impl Into<String>) -> Self {
Self::new(OAuthErrorCode::InvalidClient).with_description(description)
}
pub fn invalid_grant(description: impl Into<String>) -> Self {
Self::new(OAuthErrorCode::InvalidGrant).with_description(description)
}
pub fn invalid_scope(description: impl Into<String>) -> Self {
Self::new(OAuthErrorCode::InvalidScope).with_description(description)
}
pub fn unsupported_grant_type() -> Self {
Self::new(OAuthErrorCode::UnsupportedGrantType)
.with_description("The authorization grant type is not supported")
}
}
impl std::fmt::Display for OAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.error)?;
if let Some(desc) = &self.error_description {
write!(f, ": {}", desc)?;
}
Ok(())
}
}
impl std::error::Error for OAuthError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_response_builder() {
let response = TokenResponse::new("access_token_123")
.with_expires_in(3600)
.with_refresh_token("refresh_token_456")
.with_scope("read write");
assert_eq!(response.access_token, "access_token_123");
assert_eq!(response.token_type, TokenType::Bearer);
assert_eq!(response.expires_in, Some(3600));
assert_eq!(
response.refresh_token,
Some("refresh_token_456".to_string())
);
assert_eq!(response.scope, Some("read write".to_string()));
}
#[test]
fn test_access_token_validity() {
let token = AccessToken::new(
"test_token",
"client_123",
vec!["read".to_string()],
Duration::hours(1),
);
assert!(token.is_valid());
assert!(!token.is_expired());
assert!(token.has_scope("read"));
assert!(!token.has_scope("write"));
}
#[test]
fn test_access_token_expired() {
let token = AccessToken::new(
"test_token",
"client_123",
vec![],
Duration::seconds(-10), );
assert!(!token.is_valid());
assert!(token.is_expired());
}
#[test]
fn test_access_token_revoked() {
let mut token = AccessToken::new("test_token", "client_123", vec![], Duration::hours(1));
assert!(token.is_valid());
token.revoke();
assert!(!token.is_valid());
}
#[test]
fn test_introspection_inactive() {
let response = IntrospectionResponse::inactive();
assert!(!response.active);
}
#[test]
fn test_introspection_from_token() {
let token = AccessToken::new(
"test_token",
"client_123",
vec!["read".to_string(), "write".to_string()],
Duration::hours(1),
)
.with_user_id("user_456");
let response = IntrospectionResponse::from_access_token(&token);
assert!(response.active);
assert_eq!(response.client_id, Some("client_123".to_string()));
assert_eq!(response.sub, Some("user_456".to_string()));
assert_eq!(response.scope, Some("read write".to_string()));
}
#[test]
fn test_oauth_error() {
let error = OAuthError::invalid_request("Missing required parameter: client_id");
assert_eq!(error.error, OAuthErrorCode::InvalidRequest);
assert!(error.error_description.is_some());
}
#[test]
fn test_refresh_token() {
let token = OAuthRefreshToken::new(
"refresh_123",
"client_123",
"user_456",
vec!["read".to_string()],
)
.with_expires_in(Duration::days(30));
assert!(token.is_valid());
assert!(token.expires_at.is_some());
}
#[test]
fn test_token_type_display() {
assert_eq!(TokenType::Bearer.to_string(), "Bearer");
assert_eq!(TokenType::Mac.to_string(), "MAC");
}
#[test]
fn test_token_response_serialization() {
let response = TokenResponse::new("test_token")
.with_expires_in(3600)
.with_scope("read");
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("access_token"));
assert!(json.contains("token_type"));
let deserialized: TokenResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.access_token, "test_token");
}
}