use crate::errors::TokenError;
use rand::{Rng, distr::Alphanumeric, rng};
use sha2::{Digest, Sha256};
use webgates_codecs::Codec;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AuthToken {
value: String,
}
impl AuthToken {
pub fn new(value: impl Into<String>) -> Result<Self, TokenError> {
let value = value.into();
if value.trim().is_empty() {
return Err(TokenError::InvalidTokenMaterial);
}
Ok(Self { value })
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.value
}
#[must_use]
pub fn into_inner(self) -> String {
self.value
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RefreshTokenPlaintext {
value: String,
}
impl RefreshTokenPlaintext {
pub fn new(value: impl Into<String>) -> Result<Self, TokenError> {
let value = value.into();
if value.trim().is_empty() {
return Err(TokenError::InvalidTokenMaterial);
}
Ok(Self { value })
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.value
}
#[must_use]
pub fn into_inner(self) -> String {
self.value
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RefreshTokenHash {
value: String,
}
impl RefreshTokenHash {
pub fn new(value: impl Into<String>) -> Result<Self, TokenError> {
let value = value.into();
if value.trim().is_empty() {
return Err(TokenError::InvalidTokenMaterial);
}
Ok(Self { value })
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.value
}
#[must_use]
pub fn into_inner(self) -> String {
self.value
}
}
pub type RefreshTokenHashRef<'a> = &'a str;
pub const MIN_REFRESH_TOKEN_LENGTH: usize = 32;
pub const DEFAULT_REFRESH_TOKEN_LENGTH: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RefreshTokenLength(usize);
impl RefreshTokenLength {
pub fn new(value: usize) -> Result<Self, TokenError> {
if value < MIN_REFRESH_TOKEN_LENGTH {
return Err(TokenError::InvalidRefreshTokenLength);
}
Ok(Self(value))
}
#[must_use]
pub fn get(self) -> usize {
self.0
}
}
impl Default for RefreshTokenLength {
fn default() -> Self {
Self(DEFAULT_REFRESH_TOKEN_LENGTH)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OpaqueRefreshTokenGenerator {
length: RefreshTokenLength,
}
impl OpaqueRefreshTokenGenerator {
#[must_use]
pub fn new(length: RefreshTokenLength) -> Self {
Self { length }
}
#[must_use]
pub fn length(&self) -> RefreshTokenLength {
self.length
}
}
impl Default for OpaqueRefreshTokenGenerator {
fn default() -> Self {
Self::new(RefreshTokenLength::default())
}
}
impl RefreshTokenGenerator for OpaqueRefreshTokenGenerator {
type Error = TokenError;
fn generate_refresh_token(
&self,
) -> impl std::future::Future<Output = Result<RefreshTokenPlaintext, Self::Error>> + Send {
let length = self.length.get();
async move {
let value: String = rng()
.sample_iter(&Alphanumeric)
.take(length)
.map(char::from)
.collect();
RefreshTokenPlaintext::new(value)
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct Sha256RefreshTokenHasher;
impl RefreshTokenHasher for Sha256RefreshTokenHasher {
type Error = TokenError;
fn hash_refresh_token(
&self,
refresh_token: &RefreshTokenPlaintext,
) -> impl std::future::Future<Output = Result<RefreshTokenHash, Self::Error>> + Send {
let bytes = refresh_token.as_str().as_bytes().to_owned();
async move {
let digest = Sha256::digest(&bytes);
let mut encoded = String::with_capacity(digest.len() * 2);
for byte in digest {
use std::fmt::Write as _;
write!(&mut encoded, "{byte:02x}").map_err(|_| TokenError::HashFailed)?;
}
RefreshTokenHash::new(encoded).map_err(|_| TokenError::HashFailed)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IssuedSessionTokens {
pub token_pair: IssuedTokenPair,
pub refresh_token_hash: RefreshTokenHash,
}
impl IssuedSessionTokens {
#[must_use]
pub fn new(token_pair: IssuedTokenPair, refresh_token_hash: RefreshTokenHash) -> Self {
Self {
token_pair,
refresh_token_hash,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IssuedTokenPair {
pub auth_token: AuthToken,
pub refresh_token: RefreshTokenPlaintext,
}
impl IssuedTokenPair {
#[must_use]
pub fn new(auth_token: AuthToken, refresh_token: RefreshTokenPlaintext) -> Self {
Self {
auth_token,
refresh_token,
}
}
}
pub trait AuthTokenIssuer<Subject>: Send + Sync {
type Error;
fn issue_auth_token(
&self,
subject: &Subject,
) -> impl std::future::Future<Output = Result<AuthToken, Self::Error>> + Send;
}
pub trait RefreshTokenGenerator: Send + Sync {
type Error;
fn generate_refresh_token(
&self,
) -> impl std::future::Future<Output = Result<RefreshTokenPlaintext, Self::Error>> + Send;
}
pub trait RefreshTokenHasher: Send + Sync {
type Error;
fn hash_refresh_token(
&self,
refresh_token: &RefreshTokenPlaintext,
) -> impl std::future::Future<Output = Result<RefreshTokenHash, Self::Error>> + Send;
}
#[derive(Clone)]
pub struct CodecAuthTokenIssuer<C, F> {
codec: C,
claims_factory: F,
}
impl<C, F> CodecAuthTokenIssuer<C, F> {
#[must_use]
pub fn new(codec: C, claims_factory: F) -> Self {
Self {
codec,
claims_factory,
}
}
}
impl<Subject, Claims, C, F> AuthTokenIssuer<Subject> for CodecAuthTokenIssuer<C, F>
where
C: Codec<Payload = Claims> + Send + Sync,
F: Fn(&Subject) -> Claims + Send + Sync,
Subject: Send + Sync,
Claims: Send + Sync,
{
type Error = TokenError;
fn issue_auth_token(
&self,
subject: &Subject,
) -> impl std::future::Future<Output = Result<AuthToken, Self::Error>> + Send {
let claims = (self.claims_factory)(subject);
let result = self
.codec
.encode(&claims)
.map_err(|_| TokenError::AuthIssuanceFailed)
.and_then(|encoded| {
String::from_utf8(encoded).map_err(|_| TokenError::AuthIssuanceFailed)
})
.and_then(|token| AuthToken::new(token).map_err(|_| TokenError::AuthIssuanceFailed));
std::future::ready(result)
}
}
#[derive(Debug, Clone)]
pub struct TokenPairIssuer<A, G, H> {
auth_token_issuer: A,
refresh_token_generator: G,
refresh_token_hasher: H,
}
impl<A, G, H> TokenPairIssuer<A, G, H> {
#[must_use]
pub fn new(auth_token_issuer: A, refresh_token_generator: G, refresh_token_hasher: H) -> Self {
Self {
auth_token_issuer,
refresh_token_generator,
refresh_token_hasher,
}
}
}
impl<A, G, H> TokenPairIssuer<A, G, H> {
pub async fn hash_refresh_token(
&self,
refresh_token: &RefreshTokenPlaintext,
) -> Result<RefreshTokenHash, TokenError>
where
H: RefreshTokenHasher,
{
self.refresh_token_hasher
.hash_refresh_token(refresh_token)
.await
.map_err(|_| TokenError::HashFailed)
}
pub async fn issue_for_subject<Subject>(
&self,
subject: &Subject,
) -> Result<IssuedSessionTokens, TokenError>
where
A: AuthTokenIssuer<Subject>,
G: RefreshTokenGenerator,
H: RefreshTokenHasher,
{
let auth_token = self
.auth_token_issuer
.issue_auth_token(subject)
.await
.map_err(|_| TokenError::AuthIssuanceFailed)?;
let refresh_token = self
.refresh_token_generator
.generate_refresh_token()
.await
.map_err(|_| TokenError::GenerationFailed)?;
let refresh_token_hash = self.hash_refresh_token(&refresh_token).await?;
Ok(IssuedSessionTokens::new(
IssuedTokenPair::new(auth_token, refresh_token),
refresh_token_hash,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use webgates_codecs::jwt::{JsonWebToken, JwtClaims, RegisteredClaims};
use webgates_codecs::{Codec, jsonwebtoken};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct TestClaims {
session_id: String,
}
#[derive(Debug, Clone, Copy)]
struct StaticAuthTokenIssuer;
impl AuthTokenIssuer<String> for StaticAuthTokenIssuer {
type Error = TokenError;
fn issue_auth_token(
&self,
subject: &String,
) -> impl std::future::Future<Output = Result<AuthToken, Self::Error>> + Send {
std::future::ready(AuthToken::new(format!("auth-{subject}")))
}
}
#[derive(Debug, Clone, Copy)]
struct StaticRefreshTokenGenerator;
impl RefreshTokenGenerator for StaticRefreshTokenGenerator {
type Error = TokenError;
fn generate_refresh_token(
&self,
) -> impl std::future::Future<Output = Result<RefreshTokenPlaintext, Self::Error>> + Send
{
std::future::ready(RefreshTokenPlaintext::new("fixed-refresh-token"))
}
}
#[test]
fn auth_token_rejects_empty_value() {
let error = match AuthToken::new(" ") {
Ok(token) => panic!("expected invalid auth token, got {}", token.as_str()),
Err(error) => error,
};
assert_eq!(error, TokenError::InvalidTokenMaterial);
}
#[test]
fn refresh_token_plaintext_rejects_empty_value() {
let error = match RefreshTokenPlaintext::new("") {
Ok(token) => panic!("expected invalid refresh token, got {}", token.as_str()),
Err(error) => error,
};
assert_eq!(error, TokenError::InvalidTokenMaterial);
}
#[test]
fn refresh_token_hash_rejects_empty_value() {
let error = match RefreshTokenHash::new(" ") {
Ok(hash) => panic!("expected invalid refresh token hash, got {}", hash.as_str()),
Err(error) => error,
};
assert_eq!(error, TokenError::InvalidTokenMaterial);
}
#[test]
fn refresh_token_length_rejects_short_values() {
let error = match RefreshTokenLength::new(MIN_REFRESH_TOKEN_LENGTH - 1) {
Ok(length) => panic!("expected invalid length, got {}", length.get()),
Err(error) => error,
};
assert_eq!(error, TokenError::InvalidRefreshTokenLength);
}
#[test]
fn issued_token_pair_keeps_both_tokens() {
let auth_token = match AuthToken::new("auth-token") {
Ok(token) => token,
Err(error) => panic!("expected valid auth token: {error}"),
};
let refresh_token = match RefreshTokenPlaintext::new("refresh-token") {
Ok(token) => token,
Err(error) => panic!("expected valid refresh token: {error}"),
};
let pair = IssuedTokenPair::new(auth_token.clone(), refresh_token.clone());
assert_eq!(pair.auth_token, auth_token);
assert_eq!(pair.refresh_token, refresh_token);
}
#[tokio::test]
async fn opaque_refresh_token_generator_uses_requested_length() {
let length = match RefreshTokenLength::new(48) {
Ok(length) => length,
Err(error) => panic!("expected valid refresh token length: {error}"),
};
let generator = OpaqueRefreshTokenGenerator::new(length);
let token = match generator.generate_refresh_token().await {
Ok(token) => token,
Err(error) => panic!("expected generated refresh token: {error}"),
};
assert_eq!(token.as_str().len(), 48);
}
#[tokio::test]
async fn sha256_refresh_token_hasher_is_deterministic() {
let token = match RefreshTokenPlaintext::new("repeatable-refresh-token") {
Ok(token) => token,
Err(error) => panic!("expected valid refresh token: {error}"),
};
let hasher = Sha256RefreshTokenHasher;
let first_hash = match hasher.hash_refresh_token(&token).await {
Ok(hash) => hash,
Err(error) => panic!("expected successful hash: {error}"),
};
let second_hash = match hasher.hash_refresh_token(&token).await {
Ok(hash) => hash,
Err(error) => panic!("expected successful hash: {error}"),
};
assert_eq!(first_hash, second_hash);
}
#[tokio::test]
async fn codec_auth_token_issuer_encodes_claims_with_codec() {
let _ = jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER.install_default();
let codec = JsonWebToken::<JwtClaims<TestClaims>>::default();
let issuer = CodecAuthTokenIssuer::new(codec.clone(), |subject: &String| {
JwtClaims::new(
TestClaims {
session_id: subject.clone(),
},
RegisteredClaims::new("tests", 4_102_444_800),
)
});
let subject = String::from("session-123");
let token = match issuer.issue_auth_token(&subject).await {
Ok(token) => token,
Err(error) => panic!("expected successful auth-token issuance: {error}"),
};
let decoded = match codec.decode(token.as_str().as_bytes()) {
Ok(claims) => claims,
Err(error) => panic!("expected decodable auth token: {error}"),
};
assert_eq!(decoded.custom_claims.session_id, subject);
}
#[tokio::test]
async fn token_pair_issuer_returns_tokens_and_hash() {
let issuer = TokenPairIssuer::new(
StaticAuthTokenIssuer,
StaticRefreshTokenGenerator,
Sha256RefreshTokenHasher,
);
let subject = String::from("subject-123");
let issued = match issuer.issue_for_subject(&subject).await {
Ok(issued) => issued,
Err(error) => panic!("expected successful token-pair issuance: {error}"),
};
let expected_hash = match Sha256RefreshTokenHasher
.hash_refresh_token(&issued.token_pair.refresh_token)
.await
{
Ok(hash) => hash,
Err(error) => panic!("expected successful hash calculation: {error}"),
};
assert_eq!(issued.token_pair.auth_token.as_str(), "auth-subject-123");
assert_eq!(issued.refresh_token_hash, expected_hash);
}
}