use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Utc};
use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
use rand::TryRngCore;
use rand::rngs::OsRng;
use serde::Serialize;
use sha2::{Digest, Sha256};
use crate::applications::Application;
#[cfg(test)]
use crate::applications::CreateApplicationParams;
use crate::authorization::hash_authorization_code;
use crate::db::Db;
use crate::error::AuthError;
use crate::handle::OnUserActive;
use crate::signing_keys::SigningKey;
#[cfg(test)]
use crate::types::ClientType;
use crate::types::{ApplicationId, AuthorizationCodeId, RefreshTokenId, TokenHash, UserId};
#[derive(Debug)]
pub enum TokenError {
InvalidRequest(String),
InvalidClient(String),
InvalidGrant(String),
UnsupportedGrantType,
ServerError(String),
}
#[derive(Debug, Serialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: &'static str,
pub expires_in: i64,
pub refresh_token: String,
pub id_token: String,
}
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct RefreshToken {
pub id: RefreshTokenId,
pub application_id: ApplicationId,
pub user_id: UserId,
pub token_hash: TokenHash,
pub scopes: String,
pub authorization_code_id: Option<AuthorizationCodeId>,
pub expires_at: DateTime<Utc>,
pub revoked_at: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Serialize)]
struct AccessTokenJwtClaims {
sub: String,
iss: String,
aud: String,
exp: i64,
iat: i64,
scope: String,
email: String,
email_verified: bool,
#[serde(skip_serializing_if = "Option::is_none")]
username: Option<String>,
roles: Vec<String>,
permissions: Vec<String>,
}
#[derive(Debug, Serialize)]
struct IdTokenJwtClaims {
sub: String,
iss: String,
aud: String,
exp: i64,
iat: i64,
#[serde(skip_serializing_if = "Option::is_none")]
nonce: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
at_hash: Option<String>,
auth_time: i64,
}
pub fn verify_pkce_s256(code_verifier: &str, code_challenge: &str) -> bool {
let digest = Sha256::digest(code_verifier.as_bytes());
let computed = Base64UrlUnpadded::encode_string(&digest);
computed == code_challenge
}
pub fn compute_at_hash(access_token_jwt: &str) -> String {
let digest = Sha256::digest(access_token_jwt.as_bytes());
Base64UrlUnpadded::encode_string(&digest[..16])
}
#[allow(clippy::too_many_arguments)]
pub fn mint_access_token(
sub: UserId,
issuer: &str,
audience: &str,
scope: &str,
kid: &str,
private_key_pem: &str,
ttl_secs: i64,
email: &str,
email_verified: bool,
username: Option<&str>,
roles: &[String],
permissions: &[String],
) -> Result<String, AuthError> {
let now = Utc::now().timestamp();
let claims = AccessTokenJwtClaims {
sub: sub.to_string(),
iss: issuer.to_owned(),
aud: audience.to_owned(),
exp: now + ttl_secs,
iat: now,
scope: scope.to_owned(),
email: email.to_owned(),
email_verified,
username: username.map(|u| u.to_owned()),
roles: roles.to_vec(),
permissions: permissions.to_vec(),
};
let mut header = Header::new(Algorithm::RS256);
header.kid = Some(kid.to_owned());
header.typ = Some("at+jwt".into());
let key = EncodingKey::from_rsa_pem(private_key_pem.as_bytes())
.map_err(|e| AuthError::SigningKey(e.to_string()))?;
encode(&header, &claims, &key).map_err(|e| AuthError::SigningKey(e.to_string()))
}
#[allow(clippy::too_many_arguments)]
pub fn mint_id_token(
sub: UserId,
issuer: &str,
audience: &str,
nonce: Option<&str>,
at_hash: &str,
auth_time: i64,
kid: &str,
private_key_pem: &str,
ttl_secs: i64,
) -> Result<String, AuthError> {
let now = Utc::now().timestamp();
let claims = IdTokenJwtClaims {
sub: sub.to_string(),
iss: issuer.to_owned(),
aud: audience.to_owned(),
exp: now + ttl_secs,
iat: now,
nonce: nonce.map(|s| s.to_owned()),
at_hash: Some(at_hash.to_owned()),
auth_time,
};
let mut header = Header::new(Algorithm::RS256);
header.kid = Some(kid.to_owned());
header.typ = Some("JWT".into());
let key = EncodingKey::from_rsa_pem(private_key_pem.as_bytes())
.map_err(|e| AuthError::SigningKey(e.to_string()))?;
encode(&header, &claims, &key).map_err(|e| AuthError::SigningKey(e.to_string()))
}
pub fn generate_refresh_token() -> String {
let mut bytes = [0u8; 32];
OsRng
.try_fill_bytes(&mut bytes)
.expect("OS RNG unavailable");
Base64UrlUnpadded::encode_string(&bytes)
}
pub fn hash_refresh_token(raw: &str) -> TokenHash {
let digest = Sha256::digest(raw.as_bytes());
TokenHash::new_unchecked(format!("{digest:x}"))
}
impl Db {
pub async fn create_refresh_token(
&self,
application_id: ApplicationId,
user_id: UserId,
token_hash: &TokenHash,
scopes: &[String],
authorization_code_id: Option<AuthorizationCodeId>,
) -> Result<RefreshToken, AuthError> {
let id = RefreshTokenId::new();
let scopes_json = serde_json::to_string(scopes).expect("Vec<String> serializes to JSON");
let now = Utc::now();
let now_str = now.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
let expires_at = now + chrono::Duration::days(30);
let expires_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
sqlx::query(
"INSERT INTO allowthem_refresh_tokens \
(id, application_id, user_id, token_hash, scopes, \
authorization_code_id, expires_at, created_at) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
.bind(id)
.bind(application_id)
.bind(user_id)
.bind(token_hash)
.bind(&scopes_json)
.bind(authorization_code_id)
.bind(&expires_str)
.bind(&now_str)
.execute(self.pool())
.await?;
sqlx::query_as::<_, RefreshToken>(
"SELECT id, application_id, user_id, token_hash, scopes, \
authorization_code_id, expires_at, revoked_at, created_at \
FROM allowthem_refresh_tokens WHERE id = ?",
)
.bind(id)
.fetch_one(self.pool())
.await
.map_err(AuthError::Database)
}
pub async fn revoke_refresh_tokens_by_auth_code(
&self,
authorization_code_id: AuthorizationCodeId,
) -> Result<u64, AuthError> {
let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
let result = sqlx::query(
"UPDATE allowthem_refresh_tokens \
SET revoked_at = ? \
WHERE authorization_code_id = ? AND revoked_at IS NULL",
)
.bind(&now)
.bind(authorization_code_id)
.execute(self.pool())
.await?;
Ok(result.rows_affected())
}
pub async fn get_refresh_token_by_hash(
&self,
token_hash: &TokenHash,
) -> Result<Option<RefreshToken>, AuthError> {
sqlx::query_as::<_, RefreshToken>(
"SELECT id, application_id, user_id, token_hash, scopes, \
authorization_code_id, expires_at, revoked_at, created_at \
FROM allowthem_refresh_tokens WHERE token_hash = ?",
)
.bind(token_hash)
.fetch_optional(self.pool())
.await
.map_err(AuthError::Database)
}
pub async fn revoke_refresh_token(&self, id: RefreshTokenId) -> Result<(), AuthError> {
let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
sqlx::query("UPDATE allowthem_refresh_tokens SET revoked_at = ? WHERE id = ?")
.bind(&now)
.bind(id)
.execute(self.pool())
.await?;
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
pub async fn exchange_authorization_code(
db: &Db,
code: &str,
redirect_uri: &str,
code_verifier: &str,
application: &Application,
issuer: &str,
signing_key: &SigningKey,
private_key_pem: &str,
on_user_active: Option<&OnUserActive>,
) -> Result<TokenResponse, TokenError> {
let code_hash = hash_authorization_code(code);
let auth_code = db
.get_authorization_code_by_hash(&code_hash)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?
.ok_or_else(|| TokenError::InvalidGrant("invalid authorization code".into()))?;
if auth_code.used_at.is_some() {
let _ = db.revoke_refresh_tokens_by_auth_code(auth_code.id).await;
return Err(TokenError::InvalidGrant(
"authorization code already used".into(),
));
}
db.mark_authorization_code_used(auth_code.id)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
if auth_code.expires_at < Utc::now() {
return Err(TokenError::InvalidGrant(
"authorization code expired".into(),
));
}
if auth_code.application_id != application.id {
return Err(TokenError::InvalidGrant(
"code was issued to a different client".into(),
));
}
if auth_code.redirect_uri != redirect_uri {
return Err(TokenError::InvalidGrant("redirect_uri mismatch".into()));
}
if !verify_pkce_s256(code_verifier, &auth_code.code_challenge) {
return Err(TokenError::InvalidGrant("PKCE verification failed".into()));
}
let scopes: Vec<String> = serde_json::from_str(&auth_code.scopes)
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let scopes_str = scopes.join(" ");
let user = db
.get_user(auth_code.user_id)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let user_roles = db
.get_user_roles(&auth_code.user_id)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let user_perms = db
.get_user_permissions(&auth_code.user_id)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let role_names: Vec<String> = user_roles
.iter()
.map(|r| r.name.as_str().to_owned())
.collect();
let perm_names: Vec<String> = user_perms
.iter()
.map(|p| p.name.as_str().to_owned())
.collect();
let kid = signing_key.id.to_string();
let access_token = mint_access_token(
auth_code.user_id,
issuer,
application.client_id.as_str(),
&scopes_str,
&kid,
private_key_pem,
3600,
user.email.as_str(),
user.email_verified,
user.username.as_ref().map(|u| u.as_str()),
&role_names,
&perm_names,
)
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let at_hash = compute_at_hash(&access_token);
let auth_time = auth_code.created_at.timestamp();
let id_token = mint_id_token(
auth_code.user_id,
issuer,
application.client_id.as_str(),
auth_code.nonce.as_deref(),
&at_hash,
auth_time,
&kid,
private_key_pem,
3600,
)
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let raw_refresh = generate_refresh_token();
let refresh_hash = hash_refresh_token(&raw_refresh);
db.create_refresh_token(
application.id,
auth_code.user_id,
&refresh_hash,
&scopes,
Some(auth_code.id),
)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
if let Some(cb) = on_user_active {
let now = Utc::now();
let user_id = auth_code.user_id;
let cb = cb.clone();
if let Err(_payload) =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || cb(user_id, now)))
{
tracing::error!(user_id = %user_id, "on_user_active callback panicked");
}
}
Ok(TokenResponse {
access_token,
token_type: "Bearer",
expires_in: 3600,
refresh_token: raw_refresh,
id_token,
})
}
pub async fn exchange_refresh_token(
db: &Db,
raw_token: &str,
requested_scopes: Option<&str>,
application: &Application,
issuer: &str,
signing_key: &SigningKey,
private_key_pem: &str,
) -> Result<TokenResponse, TokenError> {
let hash = hash_refresh_token(raw_token);
let stored = db
.get_refresh_token_by_hash(&hash)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?
.ok_or_else(|| TokenError::InvalidGrant("invalid refresh token".into()))?;
if stored.revoked_at.is_some() {
return Err(TokenError::InvalidGrant(
"refresh token has been revoked".into(),
));
}
if stored.expires_at < Utc::now() {
return Err(TokenError::InvalidGrant("refresh token has expired".into()));
}
if stored.application_id != application.id {
return Err(TokenError::InvalidGrant(
"refresh token was issued to a different client".into(),
));
}
let original_scopes: Vec<String> =
serde_json::from_str(&stored.scopes).map_err(|e| TokenError::ServerError(e.to_string()))?;
let effective_scopes = match requested_scopes {
Some(s) if !s.is_empty() => {
let requested: Vec<&str> = s.split_whitespace().collect();
for scope in &requested {
if !original_scopes.iter().any(|orig| orig == scope) {
return Err(TokenError::InvalidGrant(
"requested scope exceeds original grant".into(),
));
}
}
requested.iter().map(|s| s.to_string()).collect::<Vec<_>>()
}
_ => original_scopes.clone(),
};
let scopes_str = effective_scopes.join(" ");
db.revoke_refresh_token(stored.id)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let user = db
.get_user(stored.user_id)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let user_roles = db
.get_user_roles(&stored.user_id)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let user_perms = db
.get_user_permissions(&stored.user_id)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let role_names: Vec<String> = user_roles
.iter()
.map(|r| r.name.as_str().to_owned())
.collect();
let perm_names: Vec<String> = user_perms
.iter()
.map(|p| p.name.as_str().to_owned())
.collect();
let kid = signing_key.id.to_string();
let access_token = mint_access_token(
stored.user_id,
issuer,
application.client_id.as_str(),
&scopes_str,
&kid,
private_key_pem,
3600,
user.email.as_str(),
user.email_verified,
user.username.as_ref().map(|u| u.as_str()),
&role_names,
&perm_names,
)
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let at_hash = compute_at_hash(&access_token);
let auth_time = stored.created_at.timestamp();
let id_token = mint_id_token(
stored.user_id,
issuer,
application.client_id.as_str(),
None,
&at_hash,
auth_time,
&kid,
private_key_pem,
3600,
)
.map_err(|e| TokenError::ServerError(e.to_string()))?;
let new_raw = generate_refresh_token();
let new_hash = hash_refresh_token(&new_raw);
db.create_refresh_token(
application.id,
stored.user_id,
&new_hash,
&effective_scopes,
stored.authorization_code_id,
)
.await
.map_err(|e| TokenError::ServerError(e.to_string()))?;
Ok(TokenResponse {
access_token,
token_type: "Bearer",
expires_in: 3600,
refresh_token: new_raw,
id_token,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::signing_keys::decrypt_private_key;
use crate::types::Email;
use jsonwebtoken::Algorithm;
use sqlx::SqlitePool;
use sqlx::sqlite::SqliteConnectOptions;
use std::str::FromStr;
const ENC_KEY: [u8; 32] = [0x42; 32];
const ISSUER: &str = "https://auth.example.com";
async fn test_db() -> Db {
let opts = SqliteConnectOptions::from_str("sqlite::memory:")
.unwrap()
.pragma("foreign_keys", "ON");
let pool = SqlitePool::connect_with(opts).await.unwrap();
Db::new(pool).await.unwrap()
}
#[test]
fn verify_pkce_s256_valid() {
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
let digest = Sha256::digest(verifier.as_bytes());
let challenge = Base64UrlUnpadded::encode_string(&digest);
assert!(verify_pkce_s256(verifier, &challenge));
}
#[test]
fn verify_pkce_s256_wrong_verifier() {
let verifier = "correct_verifier";
let digest = Sha256::digest(verifier.as_bytes());
let challenge = Base64UrlUnpadded::encode_string(&digest);
assert!(!verify_pkce_s256("wrong_verifier", &challenge));
}
#[test]
fn verify_pkce_s256_empty_verifier() {
let verifier = "actual_verifier";
let digest = Sha256::digest(verifier.as_bytes());
let challenge = Base64UrlUnpadded::encode_string(&digest);
assert!(!verify_pkce_s256("", &challenge));
}
#[test]
fn compute_at_hash_deterministic() {
let input = "eyJhbGciOiJSUzI1NiJ9.test.sig";
let h1 = compute_at_hash(input);
let h2 = compute_at_hash(input);
assert_eq!(h1, h2);
}
#[test]
fn compute_at_hash_known_value() {
let hash = compute_at_hash("test");
let digest = Sha256::digest(b"test");
let expected = Base64UrlUnpadded::encode_string(&digest[..16]);
assert_eq!(hash, expected);
}
#[test]
fn refresh_token_is_43_chars() {
let token = generate_refresh_token();
assert_eq!(token.len(), 43, "32 bytes base64url = 43 chars");
}
#[test]
fn refresh_token_is_url_safe() {
let token = generate_refresh_token();
assert!(
token
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
"token must be URL-safe: got {token}"
);
}
#[test]
fn hash_refresh_token_deterministic() {
let token = generate_refresh_token();
let h1 = hash_refresh_token(&token);
let h2 = hash_refresh_token(&token);
assert_eq!(format!("{h1:?}"), format!("{h2:?}"));
}
#[tokio::test]
async fn mint_access_token_roundtrip() {
let db = test_db().await;
let key = db.create_signing_key(&ENC_KEY).await.unwrap();
db.activate_signing_key(key.id).await.unwrap();
let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
let kid = key.id.to_string();
let user_id = UserId::new();
let token = mint_access_token(
user_id,
ISSUER,
"ath_test_client",
"openid profile",
&kid,
&pem,
3600,
"test@example.com",
true,
Some("testuser"),
&["admin".to_string()],
&["posts:write".to_string()],
)
.unwrap();
let claims = db.validate_access_token(&token, ISSUER).await.unwrap();
assert_eq!(claims.sub, user_id);
assert_eq!(claims.email, "test@example.com");
assert!(claims.email_verified);
assert_eq!(claims.username.as_deref(), Some("testuser"));
assert_eq!(claims.roles, vec!["admin"]);
assert_eq!(claims.permissions, vec!["posts:write"]);
assert_eq!(claims.scope, "openid profile");
assert_eq!(claims.iss, ISSUER);
}
#[tokio::test]
async fn mint_id_token_contains_nonce() {
let db = test_db().await;
let key = db.create_signing_key(&ENC_KEY).await.unwrap();
let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
let kid = key.id.to_string();
let user_id = UserId::new();
let token = mint_id_token(
user_id,
ISSUER,
"ath_test_client",
Some("test_nonce_123"),
"test_at_hash",
1234567890,
&kid,
&pem,
3600,
)
.unwrap();
let parts: Vec<&str> = token.splitn(3, '.').collect();
let payload = base64ct::Base64UrlUnpadded::decode_vec(parts[1]).unwrap();
let claims: serde_json::Value = serde_json::from_slice(&payload).unwrap();
assert_eq!(claims["nonce"], "test_nonce_123");
assert_eq!(claims["at_hash"], "test_at_hash");
assert_eq!(claims["auth_time"], 1234567890);
}
#[tokio::test]
async fn mint_id_token_omits_nonce_when_none() {
let db = test_db().await;
let key = db.create_signing_key(&ENC_KEY).await.unwrap();
let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
let kid = key.id.to_string();
let user_id = UserId::new();
let token = mint_id_token(
user_id,
ISSUER,
"ath_test_client",
None,
"hash",
0,
&kid,
&pem,
3600,
)
.unwrap();
let parts: Vec<&str> = token.splitn(3, '.').collect();
let payload = base64ct::Base64UrlUnpadded::decode_vec(parts[1]).unwrap();
let claims: serde_json::Value = serde_json::from_slice(&payload).unwrap();
assert!(claims.get("nonce").is_none());
}
async fn setup_exchange(db: &Db) -> (Application, SigningKey, String, String, String, String) {
let email = Email::new("exchange@example.com".into()).unwrap();
let user = db
.create_user(email, "password123", None, None)
.await
.unwrap();
let (app, _secret) = db
.create_application(CreateApplicationParams {
name: "ExchangeApp".to_string(),
client_type: ClientType::Confidential,
redirect_uris: vec!["https://example.com/callback".to_string()],
is_trusted: false,
created_by: Some(user.id),
logo_url: None,
primary_color: None,
accent_hex: None,
accent_ink: None,
forced_mode: None,
font_css_url: None,
font_family: None,
splash_text: None,
splash_image_url: None,
splash_primitive: None,
splash_url: None,
shader_cell_scale: None,
})
.await
.unwrap();
let key = db.create_signing_key(&ENC_KEY).await.unwrap();
db.activate_signing_key(key.id).await.unwrap();
let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
let code_verifier = "test_verifier_string_with_enough_entropy_1234567890";
let digest = Sha256::digest(code_verifier.as_bytes());
let code_challenge = Base64UrlUnpadded::encode_string(&digest);
let raw_code = crate::authorization::generate_authorization_code();
let code_hash = hash_authorization_code(&raw_code);
db.create_authorization_code(
app.id,
user.id,
&code_hash,
"https://example.com/callback",
&["openid".to_string(), "profile".to_string()],
&code_challenge,
"S256",
Some("test_nonce"),
)
.await
.unwrap();
(
app,
key,
pem,
raw_code,
code_verifier.to_string(),
"https://example.com/callback".to_string(),
)
}
#[tokio::test]
async fn exchange_valid_code() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let resp = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
assert_eq!(resp.token_type, "Bearer");
assert_eq!(resp.expires_in, 3600);
assert!(!resp.access_token.is_empty());
assert!(!resp.refresh_token.is_empty());
assert!(!resp.id_token.is_empty());
let claims = db
.validate_access_token(&resp.access_token, ISSUER)
.await
.unwrap();
assert_eq!(claims.scope, "openid profile");
}
#[tokio::test]
async fn exchange_used_code_triggers_revocation() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let _resp = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let err = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap_err();
assert!(matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("already used")));
}
#[tokio::test]
async fn exchange_wrong_redirect_uri() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, _) = setup_exchange(&db).await;
let err = exchange_authorization_code(
&db,
&raw_code,
"https://evil.example.com/callback",
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap_err();
assert!(matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("redirect_uri")));
}
#[tokio::test]
async fn exchange_bad_pkce() {
let db = test_db().await;
let (app, key, pem, raw_code, _, redirect_uri) = setup_exchange(&db).await;
let err = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
"wrong_verifier",
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap_err();
assert!(matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("PKCE")));
}
#[tokio::test]
async fn exchange_invalid_code() {
let db = test_db().await;
let (app, key, pem, _, verifier, redirect_uri) = setup_exchange(&db).await;
let err = exchange_authorization_code(
&db,
"nonexistent_code",
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap_err();
assert!(matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("invalid")));
}
#[tokio::test]
async fn exchange_expired_code() {
let db = test_db().await;
let email = Email::new("expired@example.com".into()).unwrap();
let user = db
.create_user(email, "password123", None, None)
.await
.unwrap();
let (app, _) = db
.create_application(CreateApplicationParams {
name: "ExpiredApp".to_string(),
client_type: ClientType::Confidential,
redirect_uris: vec!["https://example.com/callback".to_string()],
is_trusted: false,
created_by: Some(user.id),
logo_url: None,
primary_color: None,
accent_hex: None,
accent_ink: None,
forced_mode: None,
font_css_url: None,
font_family: None,
splash_text: None,
splash_image_url: None,
splash_primitive: None,
splash_url: None,
shader_cell_scale: None,
})
.await
.unwrap();
let key = db.create_signing_key(&ENC_KEY).await.unwrap();
db.activate_signing_key(key.id).await.unwrap();
let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
let code_verifier = "test_verifier_expired";
let digest = Sha256::digest(code_verifier.as_bytes());
let code_challenge = Base64UrlUnpadded::encode_string(&digest);
let raw_code = crate::authorization::generate_authorization_code();
let code_hash = hash_authorization_code(&raw_code);
db.create_authorization_code(
app.id,
user.id,
&code_hash,
"https://example.com/callback",
&["openid".to_string()],
&code_challenge,
"S256",
None,
)
.await
.unwrap();
sqlx::query(
"UPDATE allowthem_authorization_codes SET expires_at = '2020-01-01T00:00:00.000Z' WHERE code_hash = ?",
)
.bind(&code_hash)
.execute(db.pool())
.await
.unwrap();
let err = exchange_authorization_code(
&db,
&raw_code,
"https://example.com/callback",
code_verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap_err();
assert!(matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("expired")));
}
#[tokio::test]
async fn exchange_wrong_client() {
let db = test_db().await;
let (_, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let email_b = Email::new("other@example.com".into()).unwrap();
let user_b = db
.create_user(email_b, "password123", None, None)
.await
.unwrap();
let (app_b, _) = db
.create_application(CreateApplicationParams {
name: "OtherApp".to_string(),
client_type: ClientType::Confidential,
redirect_uris: vec!["https://other.example.com/callback".to_string()],
is_trusted: false,
created_by: Some(user_b.id),
logo_url: None,
primary_color: None,
accent_hex: None,
accent_ink: None,
forced_mode: None,
font_css_url: None,
font_family: None,
splash_text: None,
splash_image_url: None,
splash_primitive: None,
splash_url: None,
shader_cell_scale: None,
})
.await
.unwrap();
let err = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app_b,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap_err();
assert!(
matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("different client"))
);
}
#[tokio::test]
async fn access_token_has_correct_typ_header() {
let db = test_db().await;
let key = db.create_signing_key(&ENC_KEY).await.unwrap();
let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
let token = mint_access_token(
UserId::new(),
ISSUER,
"client",
"openid",
&key.id.to_string(),
&pem,
3600,
"t@example.com",
true,
None,
&[],
&[],
)
.unwrap();
let header = jsonwebtoken::decode_header(&token).unwrap();
assert_eq!(header.typ.as_deref(), Some("at+jwt"));
assert_eq!(header.alg, Algorithm::RS256);
assert!(header.kid.is_some());
}
#[tokio::test]
async fn id_token_has_correct_typ_header() {
let db = test_db().await;
let key = db.create_signing_key(&ENC_KEY).await.unwrap();
let pem = decrypt_private_key(&key, &ENC_KEY).unwrap();
let token = mint_id_token(
UserId::new(),
ISSUER,
"client",
None,
"hash",
0,
&key.id.to_string(),
&pem,
3600,
)
.unwrap();
let header = jsonwebtoken::decode_header(&token).unwrap();
assert_eq!(header.typ.as_deref(), Some("JWT"));
assert_eq!(header.alg, Algorithm::RS256);
}
#[tokio::test]
async fn exchange_id_token_at_hash_matches_access_token() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let resp = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let parts: Vec<&str> = resp.id_token.splitn(3, '.').collect();
let payload = base64ct::Base64UrlUnpadded::decode_vec(parts[1]).unwrap();
let claims: serde_json::Value = serde_json::from_slice(&payload).unwrap();
let expected = compute_at_hash(&resp.access_token);
assert_eq!(claims["at_hash"].as_str().unwrap(), expected);
}
#[tokio::test]
async fn exchange_creates_refresh_token_in_db() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let resp = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let refresh_hash = hash_refresh_token(&resp.refresh_token);
let count: (i64,) =
sqlx::query_as("SELECT COUNT(*) FROM allowthem_refresh_tokens WHERE token_hash = ?")
.bind(&refresh_hash)
.fetch_one(db.pool())
.await
.unwrap();
assert_eq!(count.0, 1, "refresh token should be stored in DB");
}
#[tokio::test]
async fn get_refresh_token_by_hash_returns_token() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let resp = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let hash = hash_refresh_token(&resp.refresh_token);
let stored = db.get_refresh_token_by_hash(&hash).await.unwrap().unwrap();
assert_eq!(stored.application_id, app.id);
assert_eq!(stored.revoked_at, None);
}
#[tokio::test]
async fn get_refresh_token_by_hash_returns_none_for_unknown() {
let db = test_db().await;
let unknown = hash_refresh_token("nonexistent_raw_token");
let result = db.get_refresh_token_by_hash(&unknown).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn revoke_refresh_token_sets_revoked_at() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let resp = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let hash = hash_refresh_token(&resp.refresh_token);
let stored = db.get_refresh_token_by_hash(&hash).await.unwrap().unwrap();
assert!(stored.revoked_at.is_none());
db.revoke_refresh_token(stored.id).await.unwrap();
let after = db.get_refresh_token_by_hash(&hash).await.unwrap().unwrap();
assert!(after.revoked_at.is_some());
}
#[tokio::test]
async fn revoke_refresh_token_idempotent() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let resp = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let hash = hash_refresh_token(&resp.refresh_token);
let stored = db.get_refresh_token_by_hash(&hash).await.unwrap().unwrap();
db.revoke_refresh_token(stored.id).await.unwrap();
db.revoke_refresh_token(stored.id).await.unwrap();
}
#[tokio::test]
async fn exchange_refresh_token_valid() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let resp =
exchange_refresh_token(&db, &initial.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap();
assert!(!resp.access_token.is_empty());
assert!(!resp.refresh_token.is_empty());
assert_ne!(resp.refresh_token, initial.refresh_token);
assert_eq!(resp.token_type, "Bearer");
assert_eq!(resp.expires_in, 3600);
let claims = db
.validate_access_token(&resp.access_token, ISSUER)
.await
.unwrap();
assert_eq!(claims.scope, "openid profile");
}
#[tokio::test]
async fn exchange_refresh_token_revokes_old_token() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let old_hash = hash_refresh_token(&initial.refresh_token);
exchange_refresh_token(&db, &initial.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap();
let old_stored = db
.get_refresh_token_by_hash(&old_hash)
.await
.unwrap()
.unwrap();
assert!(
old_stored.revoked_at.is_some(),
"old token should be revoked"
);
}
#[tokio::test]
async fn exchange_refresh_token_revoked_token_fails() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
exchange_refresh_token(&db, &initial.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap();
let err =
exchange_refresh_token(&db, &initial.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap_err();
assert!(matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("revoked")));
}
#[tokio::test]
async fn exchange_refresh_token_expired_fails() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let hash = hash_refresh_token(&initial.refresh_token);
sqlx::query(
"UPDATE allowthem_refresh_tokens SET expires_at = '2020-01-01T00:00:00.000Z' WHERE token_hash = ?",
)
.bind(&hash)
.execute(db.pool())
.await
.unwrap();
let err =
exchange_refresh_token(&db, &initial.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap_err();
assert!(matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("expired")));
}
#[tokio::test]
async fn exchange_refresh_token_wrong_client_fails() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let email_b = Email::new("other_refresh@example.com".into()).unwrap();
let user_b = db
.create_user(email_b, "password123", None, None)
.await
.unwrap();
let (app_b, _) = db
.create_application(CreateApplicationParams {
name: "OtherRefreshApp".to_string(),
client_type: ClientType::Confidential,
redirect_uris: vec!["https://other.example.com/callback".to_string()],
is_trusted: false,
created_by: Some(user_b.id),
logo_url: None,
primary_color: None,
accent_hex: None,
accent_ink: None,
forced_mode: None,
font_css_url: None,
font_family: None,
splash_text: None,
splash_image_url: None,
splash_primitive: None,
splash_url: None,
shader_cell_scale: None,
})
.await
.unwrap();
let err = exchange_refresh_token(
&db,
&initial.refresh_token,
None,
&app_b,
ISSUER,
&key,
&pem,
)
.await
.unwrap_err();
assert!(
matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("different client"))
);
}
#[tokio::test]
async fn exchange_refresh_token_scope_subset_succeeds() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let resp = exchange_refresh_token(
&db,
&initial.refresh_token,
Some("openid"),
&app,
ISSUER,
&key,
&pem,
)
.await
.unwrap();
let claims = db
.validate_access_token(&resp.access_token, ISSUER)
.await
.unwrap();
assert_eq!(claims.scope, "openid");
}
#[tokio::test]
async fn exchange_refresh_token_scope_escalation_fails() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let err = exchange_refresh_token(
&db,
&initial.refresh_token,
Some("openid admin"),
&app,
ISSUER,
&key,
&pem,
)
.await
.unwrap_err();
assert!(matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("exceeds")));
}
#[tokio::test]
async fn exchange_refresh_token_no_scope_uses_original() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let resp =
exchange_refresh_token(&db, &initial.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap();
let claims = db
.validate_access_token(&resp.access_token, ISSUER)
.await
.unwrap();
assert_eq!(claims.scope, "openid profile");
}
#[tokio::test]
async fn exchange_refresh_token_invalid_hash_fails() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let _ = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let err = exchange_refresh_token(
&db,
"totally_invalid_garbage_token",
None,
&app,
ISSUER,
&key,
&pem,
)
.await
.unwrap_err();
assert!(matches!(err, TokenError::InvalidGrant(ref msg) if msg.contains("invalid")));
}
#[tokio::test]
async fn exchange_refresh_token_propagates_authorization_code_id() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let first_hash = hash_refresh_token(&initial.refresh_token);
let first_stored = db
.get_refresh_token_by_hash(&first_hash)
.await
.unwrap()
.unwrap();
let original_auth_code_id = first_stored.authorization_code_id;
let rotated =
exchange_refresh_token(&db, &initial.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap();
let second_hash = hash_refresh_token(&rotated.refresh_token);
let second_stored = db
.get_refresh_token_by_hash(&second_hash)
.await
.unwrap()
.unwrap();
assert_eq!(
second_stored.authorization_code_id, original_auth_code_id,
"authorization_code_id must propagate through rotation"
);
}
#[tokio::test]
async fn exchange_refresh_token_chained_rotation_succeeds() {
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
None,
)
.await
.unwrap();
let second =
exchange_refresh_token(&db, &initial.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap();
assert_ne!(second.refresh_token, initial.refresh_token);
let third =
exchange_refresh_token(&db, &second.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap();
assert_ne!(third.refresh_token, second.refresh_token);
let first_hash = hash_refresh_token(&initial.refresh_token);
let first_stored = db
.get_refresh_token_by_hash(&first_hash)
.await
.unwrap()
.unwrap();
assert!(
first_stored.revoked_at.is_some(),
"first token must be revoked"
);
let second_hash = hash_refresh_token(&second.refresh_token);
let second_stored = db
.get_refresh_token_by_hash(&second_hash)
.await
.unwrap()
.unwrap();
assert!(
second_stored.revoked_at.is_some(),
"second token must be revoked"
);
}
#[tokio::test]
async fn on_user_active_fires_on_exchange_authorization_code() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
let counter = Arc::new(AtomicU64::new(0));
let c = counter.clone();
let cb: crate::handle::OnUserActive = Arc::new(move |_uid, _ts| {
c.fetch_add(1, Ordering::Relaxed);
});
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
Some(&cb),
)
.await
.unwrap();
assert_eq!(
counter.load(Ordering::Relaxed),
1,
"callback must fire exactly once after successful code exchange"
);
}
#[tokio::test]
async fn on_user_active_no_fire_on_exchange_authorization_code_failure() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
let counter = Arc::new(AtomicU64::new(0));
let c = counter.clone();
let cb: crate::handle::OnUserActive = Arc::new(move |_uid, _ts| {
c.fetch_add(1, Ordering::Relaxed);
});
let db = test_db().await;
let (app, key, pem, _raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let result = exchange_authorization_code(
&db,
"invalid_code_xyz",
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
Some(&cb),
)
.await;
assert!(result.is_err(), "exchange with invalid code must fail");
assert_eq!(
counter.load(Ordering::Relaxed),
0,
"callback must not fire when exchange fails"
);
}
#[tokio::test]
async fn on_user_active_no_fire_on_exchange_refresh_token() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
let counter = Arc::new(AtomicU64::new(0));
let c = counter.clone();
let cb: crate::handle::OnUserActive = Arc::new(move |_uid, _ts| {
c.fetch_add(1, Ordering::Relaxed);
});
let db = test_db().await;
let (app, key, pem, raw_code, verifier, redirect_uri) = setup_exchange(&db).await;
let initial = exchange_authorization_code(
&db,
&raw_code,
&redirect_uri,
&verifier,
&app,
ISSUER,
&key,
&pem,
Some(&cb),
)
.await
.unwrap();
assert_eq!(
counter.load(Ordering::Relaxed),
1,
"sanity: auth code fires"
);
exchange_refresh_token(&db, &initial.refresh_token, None, &app, ISSUER, &key, &pem)
.await
.unwrap();
assert_eq!(
counter.load(Ordering::Relaxed),
1,
"refresh_token grant must not fire callback"
);
}
}