use crate::Codec;
use crate::errors::{JwtError, JwtOperation};
use crate::jwt::JwtClaims;
use crate::jwt::validation_result::JwtValidationResult;
use std::sync::Arc;
use tracing::{debug, warn};
use webgates_core::accounts::Account;
use webgates_core::authz::access_hierarchy::AccessHierarchy;
#[derive(Debug, Clone)]
pub struct JwtValidationService<C> {
codec: Arc<C>,
expected_issuer: String,
}
pub trait JwtClaimsVerifier<T>: Clone {
fn verify_token(&self, token_value: &str) -> std::result::Result<T, JwtError>;
}
impl<C> JwtValidationService<C> {
pub fn new(codec: Arc<C>, expected_issuer: &str) -> Self {
Self {
codec,
expected_issuer: expected_issuer.to_owned(),
}
}
}
impl<C, R, G> JwtValidationService<C>
where
C: Codec<Payload = JwtClaims<Account<R, G>>>,
R: AccessHierarchy + Eq,
G: Eq + Clone,
{
pub fn validate_token(&self, token_value: &str) -> JwtValidationResult<Account<R, G>> {
let jwt = match self.codec.decode(token_value.as_bytes()) {
Ok(jwt) => jwt,
Err(error) => {
debug!(error = %error, "JWT token decoding failed");
return JwtValidationResult::InvalidToken;
}
};
debug!(
account_id = %jwt.custom_claims.account_id,
issuer = %jwt.registered_claims.issuer,
"JWT token decoded successfully"
);
if !jwt.has_issuer(&self.expected_issuer) {
warn!(
expected_issuer = %self.expected_issuer,
actual_issuer = %jwt.registered_claims.issuer,
account_id = %jwt.custom_claims.account_id,
"JWT issuer validation failed"
);
return JwtValidationResult::InvalidIssuer {
expected: self.expected_issuer.clone(),
actual: jwt.registered_claims.issuer,
};
}
JwtValidationResult::Valid(jwt)
}
}
impl<C, R, G> JwtClaimsVerifier<JwtClaims<Account<R, G>>> for JwtValidationService<C>
where
C: Codec<Payload = JwtClaims<Account<R, G>>> + Clone,
R: AccessHierarchy + Eq,
G: Eq + Clone,
{
fn verify_token(
&self,
token_value: &str,
) -> std::result::Result<JwtClaims<Account<R, G>>, JwtError> {
match self.validate_token(token_value) {
JwtValidationResult::Valid(jwt) => Ok(jwt),
JwtValidationResult::InvalidToken => Err(JwtError::processing(
JwtOperation::Validate,
"token verification failed",
)),
JwtValidationResult::InvalidIssuer { expected, actual } => Err(JwtError::processing(
JwtOperation::Validate,
format!("token issuer mismatch: expected `{expected}`, got `{actual}`"),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::{JwtError, JwtOperation};
use crate::jwt::RegisteredClaims;
use std::sync::Arc;
use uuid::Uuid;
use webgates_core::groups::Group;
use webgates_core::permissions::Permissions;
use webgates_core::roles::Role;
#[derive(Clone)]
struct MockCodec {
should_fail_decode: bool,
mock_issuer: String,
}
impl MockCodec {
fn new() -> Self {
Self {
should_fail_decode: false,
mock_issuer: "test-issuer".to_string(),
}
}
fn with_decode_failure() -> Self {
Self {
should_fail_decode: true,
mock_issuer: String::new(),
}
}
fn with_different_issuer() -> Self {
Self {
should_fail_decode: false,
mock_issuer: "different-issuer".to_string(),
}
}
}
impl Codec for MockCodec {
type Payload = JwtClaims<Account<Role, Group>>;
fn decode(&self, _encoded_value: &[u8]) -> crate::Result<Self::Payload> {
if self.should_fail_decode {
return Err(crate::Error::Jwt(JwtError::processing(
JwtOperation::Decode,
"Mock decode failure",
)));
}
let account = Account {
account_id: Uuid::now_v7(),
user_id: "test_user".to_string(),
roles: vec![Role::User],
groups: vec![Group::new("engineering")],
permissions: Permissions::new(),
};
let registered_claims = RegisteredClaims {
issuer: self.mock_issuer.clone(),
subject: Some("test".to_string()),
audience: None,
expiration_time: 9_999_999_999,
not_before_time: None,
issued_at_time: 1_000_000_000,
jwt_id: None,
session_id: None,
};
Ok(JwtClaims {
custom_claims: account,
registered_claims,
})
}
fn encode(&self, _payload: &Self::Payload) -> crate::Result<Vec<u8>> {
Ok(Vec::new())
}
}
#[test]
fn validation_service_returns_valid_result_for_matching_issuer() {
let codec = Arc::new(MockCodec::new());
let service = JwtValidationService::new(codec, "test-issuer");
let result = service.validate_token("valid-token");
match result {
JwtValidationResult::Valid(jwt) => {
assert_eq!(jwt.custom_claims.user_id, "test_user");
assert_eq!(jwt.registered_claims.issuer, "test-issuer");
}
other => panic!("Expected valid token result, got {other:?}"),
}
}
#[test]
fn validation_service_returns_invalid_token_when_decoding_fails() {
let codec = Arc::new(MockCodec::with_decode_failure());
let service = JwtValidationService::new(codec, "test-issuer");
let result = service.validate_token("invalid-token");
assert!(matches!(result, JwtValidationResult::InvalidToken));
}
#[test]
fn validation_service_returns_invalid_issuer_when_issuer_differs() {
let codec = Arc::new(MockCodec::with_different_issuer());
let service = JwtValidationService::new(codec, "expected-issuer");
let result = service.validate_token("valid-token");
match result {
JwtValidationResult::InvalidIssuer { expected, actual } => {
assert_eq!(expected, "expected-issuer");
assert_eq!(actual, "different-issuer");
}
other => panic!("Expected invalid issuer result, got {other:?}"),
}
}
}