axum_gate/codecs/jwt/
validation_service.rs

1use super::JwtValidationResult;
2use super::{Codec, JwtClaims};
3use crate::accounts::Account;
4use crate::authz::AccessHierarchy;
5
6use std::sync::Arc;
7use tracing::{debug, warn};
8
9/// Service responsible for JWT token validation.
10///
11/// This service handles all aspects of JWT token validation including:
12/// - Token decoding using the provided codec
13/// - Issuer validation
14/// - Token expiration (handled by the underlying jsonwebtoken library)
15#[derive(Debug, Clone)]
16pub struct JwtValidationService<C> {
17    codec: Arc<C>,
18    expected_issuer: String,
19}
20
21impl<C> JwtValidationService<C> {
22    /// Creates a new JWT validation service.
23    ///
24    /// # Parameters
25    /// - `codec`: The codec used for decoding JWT tokens
26    /// - `expected_issuer`: The issuer that tokens must have to be considered valid
27    pub fn new(codec: Arc<C>, expected_issuer: &str) -> Self {
28        Self {
29            codec,
30            expected_issuer: expected_issuer.to_owned(),
31        }
32    }
33}
34
35impl<C, R, G> JwtValidationService<C>
36where
37    C: Codec<Payload = JwtClaims<Account<R, G>>>,
38    R: AccessHierarchy + Eq,
39    G: Eq + Clone,
40{
41    /// Validates a JWT token from its raw string representation.
42    ///
43    /// This method performs the following validations:
44    /// 1. Attempts to decode the token using the configured codec
45    /// 2. Validates the issuer matches the expected issuer
46    /// 3. Token expiration is automatically handled by the jsonwebtoken library
47    ///
48    /// # Parameters
49    /// - `token_value`: The raw JWT token string
50    ///
51    /// # Returns
52    /// - `JwtValidationResult::Valid` if the token is valid and authorized
53    /// - `JwtValidationResult::InvalidToken` if the token cannot be decoded
54    /// - `JwtValidationResult::InvalidIssuer` if the issuer doesn't match
55    pub fn validate_token(&self, token_value: &str) -> JwtValidationResult<Account<R, G>> {
56        // Attempt to decode the JWT token
57        let jwt = match self.codec.decode(token_value.as_bytes()) {
58            Ok(jwt) => jwt,
59            Err(e) => {
60                debug!("Could not decode JWT token: {e}");
61                return JwtValidationResult::InvalidToken;
62            }
63        };
64
65        debug!(
66            "JWT token decoded successfully for account: {}",
67            jwt.custom_claims.account_id
68        );
69
70        // Validate the issuer
71        if !jwt.has_issuer(&self.expected_issuer) {
72            warn!(
73                "JWT issuer validation failed. Expected: '{}', Actual: {:?}, Account: {}",
74                self.expected_issuer, jwt.registered_claims.issuer, jwt.custom_claims.account_id
75            );
76            return JwtValidationResult::InvalidIssuer {
77                expected: self.expected_issuer.clone(),
78                actual: jwt.registered_claims.issuer,
79            };
80        }
81
82        JwtValidationResult::Valid(jwt)
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::groups::Group;
90    use crate::permissions::Permissions;
91    use crate::roles::Role;
92    use std::sync::Arc;
93
94    // Mock codec for testing
95    #[derive(Clone)]
96    struct MockCodec {
97        should_fail_decode: bool,
98        mock_issuer: String,
99    }
100
101    impl MockCodec {
102        fn new() -> Self {
103            Self {
104                should_fail_decode: false,
105                mock_issuer: "test-issuer".to_string(),
106            }
107        }
108
109        fn with_decode_failure() -> Self {
110            Self {
111                should_fail_decode: true,
112                mock_issuer: "".to_string(),
113            }
114        }
115
116        fn with_different_issuer() -> Self {
117            Self {
118                should_fail_decode: false,
119                mock_issuer: "different-issuer".to_string(),
120            }
121        }
122    }
123
124    impl Codec for MockCodec {
125        type Payload = JwtClaims<Account<Role, Group>>;
126
127        fn decode(&self, _data: &[u8]) -> crate::errors::Result<Self::Payload> {
128            if self.should_fail_decode {
129                return Err(crate::errors::Error::Jwt(
130                    crate::codecs::JwtError::processing(
131                        crate::codecs::JwtOperation::Decode,
132                        "Mock decode failure",
133                    ),
134                ));
135            }
136
137            use crate::codecs::jwt::RegisteredClaims;
138
139            use uuid::Uuid;
140
141            let account = Account {
142                account_id: Uuid::new_v4(),
143                user_id: "test_user".to_string(),
144                roles: vec![Role::User],
145                groups: vec![Group::new("engineering")],
146                permissions: Permissions::new(),
147            };
148
149            let registered_claims = RegisteredClaims {
150                issuer: self.mock_issuer.clone(),
151                subject: Some("test".to_string()),
152                audience: None,
153                expiration_time: 9999999999, // Far future
154                not_before_time: None,
155                issued_at_time: 1000000000, // Past time
156                jwt_id: None,
157            };
158
159            Ok(JwtClaims {
160                custom_claims: account,
161                registered_claims,
162            })
163        }
164
165        fn encode(&self, _payload: &Self::Payload) -> crate::errors::Result<Vec<u8>> {
166            unimplemented!()
167        }
168    }
169
170    #[test]
171    fn validation_service_valid_token() {
172        let codec = Arc::new(MockCodec::new());
173        let service = JwtValidationService::new(codec, "test-issuer");
174
175        let result = service.validate_token("valid-token");
176
177        match result {
178            JwtValidationResult::Valid(jwt) => {
179                assert_eq!(jwt.custom_claims.user_id, "test_user");
180                assert_eq!(jwt.registered_claims.issuer, "test-issuer".to_string());
181            }
182            _ => panic!("Expected valid token result"),
183        }
184    }
185
186    #[test]
187    fn validation_service_invalid_token() {
188        let codec = Arc::new(MockCodec::with_decode_failure());
189        let service = JwtValidationService::new(codec, "test-issuer");
190
191        let result = service.validate_token("invalid-token");
192
193        assert!(matches!(result, JwtValidationResult::InvalidToken));
194    }
195
196    #[test]
197    fn validation_service_invalid_issuer() {
198        let codec = Arc::new(MockCodec::with_different_issuer());
199        let service = JwtValidationService::new(codec, "expected-issuer");
200
201        let result = service.validate_token("valid-token");
202
203        match result {
204            JwtValidationResult::InvalidIssuer { expected, actual } => {
205                assert_eq!(expected, "expected-issuer");
206                assert_eq!(actual, "different-issuer".to_string());
207            }
208            _ => panic!("Expected invalid issuer result"),
209        }
210    }
211}