use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
pub expires_in: u64,
pub token_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token: Option<String>,
}
impl Default for TokenResponse {
fn default() -> Self {
Self {
access_token: String::new(),
refresh_token: None,
expires_in: 3600,
token_type: "Bearer".to_string(),
scope: None,
id_token: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
pub id: String,
pub email: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub picture: Option<String>,
pub provider: String,
#[serde(default)]
pub email_verified: bool,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub extra_claims: HashMap<String, serde_json::Value>,
}
impl UserInfo {
pub fn new(
id: impl Into<String>,
email: impl Into<String>,
provider: impl Into<String>,
) -> Self {
Self {
id: id.into(),
email: email.into(),
name: None,
picture: None,
provider: provider.into(),
email_verified: false,
extra_claims: HashMap::new(),
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_picture(mut self, picture: impl Into<String>) -> Self {
self.picture = Some(picture.into());
self
}
pub fn with_email_verified(mut self, verified: bool) -> Self {
self.email_verified = verified;
self
}
pub fn with_claim(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.extra_claims.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthState {
pub state: String,
pub provider: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_verifier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redirect_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nonce: Option<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub ttl_seconds: u64,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub extra_data: HashMap<String, String>,
}
impl OAuthState {
pub fn new(provider: impl Into<String>) -> Self {
Self {
state: Uuid::new_v4().to_string(),
provider: provider.into(),
code_verifier: None,
redirect_uri: None,
nonce: None,
created_at: chrono::Utc::now(),
ttl_seconds: 600, extra_data: HashMap::new(),
}
}
pub fn with_pkce(provider: impl Into<String>) -> Self {
let code_verifier = generate_pkce_verifier();
Self {
state: Uuid::new_v4().to_string(),
provider: provider.into(),
code_verifier: Some(code_verifier),
redirect_uri: None,
nonce: Some(Uuid::new_v4().to_string()),
created_at: chrono::Utc::now(),
ttl_seconds: 600,
extra_data: HashMap::new(),
}
}
pub fn with_redirect_uri(mut self, uri: impl Into<String>) -> Self {
self.redirect_uri = Some(uri.into());
self
}
pub fn with_ttl(mut self, ttl_seconds: u64) -> Self {
self.ttl_seconds = ttl_seconds;
self
}
pub fn with_data(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extra_data.insert(key.into(), value.into());
self
}
pub fn is_expired(&self) -> bool {
let now = chrono::Utc::now();
let expires_at = self.created_at + chrono::Duration::seconds(self.ttl_seconds as i64);
now > expires_at
}
pub fn code_challenge(&self) -> Option<String> {
self.code_verifier
.as_ref()
.map(|v| generate_pkce_challenge(v))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum PkceChallengeMethod {
Plain,
#[default]
S256,
}
impl std::fmt::Display for PkceChallengeMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Plain => write!(f, "plain"),
Self::S256 => write!(f, "S256"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthError {
pub error: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_uri: Option<String>,
}
impl std::fmt::Display for OAuthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.error)?;
if let Some(desc) = &self.error_description {
write!(f, ": {}", desc)?;
}
Ok(())
}
}
impl std::error::Error for OAuthError {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationRequest {
pub response_type: String,
pub client_id: String,
pub redirect_uri: String,
pub scope: String,
pub state: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_challenge: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_challenge_method: Option<PkceChallengeMethod>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nonce: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub login_hint: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenExchangeRequest {
pub grant_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redirect_uri: Option<String>,
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_verifier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CallbackParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub state: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogoutRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub access_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub post_logout_redirect_uri: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
fn generate_pkce_verifier() -> String {
use rand::Rng;
const VERIFIER_LENGTH: usize = 64;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
let mut rng = rand::rng();
(0..VERIFIER_LENGTH)
.map(|_| {
let idx = rng.random_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
fn generate_pkce_challenge(verifier: &str) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
URL_SAFE_NO_PAD.encode(hash)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_response_default() {
let response = TokenResponse::default();
assert!(response.access_token.is_empty());
assert_eq!(response.token_type, "Bearer");
assert_eq!(response.expires_in, 3600);
assert!(response.refresh_token.is_none());
}
#[test]
fn test_token_response_serialization() {
let response = TokenResponse {
access_token: "access123".to_string(),
refresh_token: Some("refresh456".to_string()),
expires_in: 7200,
token_type: "Bearer".to_string(),
scope: Some("openid email".to_string()),
id_token: None,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("access123"));
assert!(json.contains("refresh456"));
assert!(json.contains("7200"));
let parsed: TokenResponse = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.access_token, "access123");
assert_eq!(parsed.refresh_token, Some("refresh456".to_string()));
}
#[test]
fn test_user_info_builder() {
let user = UserInfo::new("123", "test@example.com", "google")
.with_name("Test User")
.with_picture("https://example.com/pic.jpg")
.with_email_verified(true)
.with_claim("locale", serde_json::json!("en-US"));
assert_eq!(user.id, "123");
assert_eq!(user.email, "test@example.com");
assert_eq!(user.provider, "google");
assert_eq!(user.name, Some("Test User".to_string()));
assert_eq!(
user.picture,
Some("https://example.com/pic.jpg".to_string())
);
assert!(user.email_verified);
assert!(user.extra_claims.contains_key("locale"));
}
#[test]
fn test_user_info_serialization() {
let user = UserInfo::new("456", "user@example.com", "microsoft");
let json = serde_json::to_string(&user).unwrap();
let parsed: UserInfo = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "456");
assert_eq!(parsed.email, "user@example.com");
assert_eq!(parsed.provider, "microsoft");
}
#[test]
fn test_oauth_state_creation() {
let state = OAuthState::new("google");
assert!(!state.state.is_empty());
assert_eq!(state.provider, "google");
assert!(state.code_verifier.is_none());
assert_eq!(state.ttl_seconds, 600);
assert!(!state.is_expired());
}
#[test]
fn test_oauth_state_with_pkce() {
let state = OAuthState::with_pkce("github");
assert!(state.code_verifier.is_some());
assert!(state.nonce.is_some());
assert!(state.code_challenge().is_some());
}
#[test]
fn test_oauth_state_builder() {
let state = OAuthState::new("okta")
.with_redirect_uri("https://app.example.com/callback")
.with_ttl(300)
.with_data("flow", "login");
assert_eq!(
state.redirect_uri,
Some("https://app.example.com/callback".to_string())
);
assert_eq!(state.ttl_seconds, 300);
assert_eq!(state.extra_data.get("flow"), Some(&"login".to_string()));
}
#[test]
fn test_oauth_state_expiration() {
let mut state = OAuthState::new("test");
state.created_at = chrono::Utc::now() - chrono::Duration::seconds(700);
assert!(state.is_expired());
}
#[test]
fn test_pkce_verifier_generation() {
let verifier1 = generate_pkce_verifier();
let verifier2 = generate_pkce_verifier();
assert_eq!(verifier1.len(), 64);
assert_eq!(verifier2.len(), 64);
assert_ne!(verifier1, verifier2);
}
#[test]
fn test_pkce_challenge_generation() {
let verifier = "test_verifier_string";
let challenge = generate_pkce_challenge(verifier);
assert!(!challenge.is_empty());
assert!(!challenge.contains('+'));
assert!(!challenge.contains('/'));
assert!(!challenge.contains('='));
}
#[test]
fn test_pkce_challenge_method_display() {
assert_eq!(PkceChallengeMethod::Plain.to_string(), "plain");
assert_eq!(PkceChallengeMethod::S256.to_string(), "S256");
}
#[test]
fn test_oauth_error_display() {
let error = OAuthError {
error: "invalid_grant".to_string(),
error_description: Some("The authorization code has expired".to_string()),
error_uri: None,
};
let display = format!("{}", error);
assert!(display.contains("invalid_grant"));
assert!(display.contains("authorization code has expired"));
}
#[test]
fn test_callback_params_serialization() {
let params = CallbackParams {
code: Some("auth_code_123".to_string()),
state: Some("state_456".to_string()),
error: None,
error_description: None,
};
let json = serde_json::to_string(¶ms).unwrap();
let parsed: CallbackParams = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.code, Some("auth_code_123".to_string()));
assert_eq!(parsed.state, Some("state_456".to_string()));
}
#[test]
fn test_callback_params_with_error() {
let params = CallbackParams {
code: None,
state: Some("state_789".to_string()),
error: Some("access_denied".to_string()),
error_description: Some("User denied access".to_string()),
};
assert!(params.code.is_none());
assert!(params.error.is_some());
}
#[test]
fn test_refresh_request() {
let request = RefreshRequest {
refresh_token: "refresh_token_abc".to_string(),
scope: Some("openid email".to_string()),
};
let json = serde_json::to_string(&request).unwrap();
let parsed: RefreshRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.refresh_token, "refresh_token_abc");
}
#[test]
fn test_logout_request() {
let request = LogoutRequest {
refresh_token: Some("refresh_token".to_string()),
access_token: Some("access_token".to_string()),
post_logout_redirect_uri: Some("https://app.example.com".to_string()),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("refresh_token"));
assert!(json.contains("access_token"));
}
}