at-jet 0.7.2

High-performance HTTP + Protobuf API framework for mobile services
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
//! JWT Authentication Middleware
//!
//! Provides JWT token validation and middleware for protecting routes.
//!
//! Uses HMAC (HS512) signature verification with base64-encoded secrets.
//!
//! # Example
//!
//! ```ignore
//! use at_jet::middleware::jwt_auth::{JwtConfig, JwtAuthLayer};
//!
//! let config = JwtConfig {
//!     secret: "base64-encoded-secret".to_string(),
//!     algorithm: jsonwebtoken::Algorithm::HS512,
//! };
//!
//! let server = JetServer::new()
//!     .route("/api/protected", get(handler))
//!     .layer(JwtAuthLayer::new(&config));
//! ```

use {axum::{body::Body,
            http::{Request,
                   StatusCode},
            response::{IntoResponse,
                       Response}},
     jsonwebtoken::{Algorithm,
                    DecodingKey,
                    Validation,
                    decode},
     serde::{Deserialize,
             Serialize},
     std::{future::Future,
           pin::Pin,
           sync::Arc,
           task::{Context,
                  Poll}},
     thiserror::Error,
     tower::{Layer,
             Service}};

// --- Config ---

/// JWT configuration
#[derive(Debug, Clone)]
pub struct JwtConfig {
  /// Base64-encoded HMAC secret
  pub secret:    String,
  /// JWT algorithm (default: HS512)
  pub algorithm: Algorithm,
}

impl Default for JwtConfig {
  fn default() -> Self {
    Self {
      secret:    String::new(),
      algorithm: Algorithm::HS512,
    }
  }
}

// --- Claims ---

/// JWT claims structure
///
/// The subject (`sub`) field contains the authenticated identity.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
  /// Subject — the authenticated identity
  pub sub: String,
  /// Issued at timestamp
  #[serde(default)]
  pub iat: Option<i64>,
  /// Expiration timestamp
  #[serde(default)]
  pub exp: Option<i64>,
}

// --- Error ---

/// JWT authentication errors
#[derive(Debug, Error)]
pub enum JwtAuthError {
  #[error("Token expired")]
  TokenExpired,
  #[error("Invalid token signature")]
  InvalidSignature,
  #[error("Invalid token format: {0}")]
  InvalidFormat(String),
  #[error("JWT not configured")]
  NotConfigured,
}

impl IntoResponse for JwtAuthError {
  fn into_response(self) -> Response {
    let body = match &self {
      | JwtAuthError::TokenExpired => "Token expired",
      | JwtAuthError::InvalidSignature => "Invalid token signature",
      | JwtAuthError::InvalidFormat(_) => "Invalid token format",
      | JwtAuthError::NotConfigured => "Authentication not configured",
    };
    (StatusCode::UNAUTHORIZED, body).into_response()
  }
}

// --- Verification Mode ---

#[derive(Clone)]
enum VerificationMode {
  Hmac {
    decoding_key: DecodingKey,
    algorithm:    Algorithm,
  },
  None,
}

// --- Validator ---

/// JWT token validator
///
/// Validates JWT tokens using HMAC signature verification.
/// The secret is base64-encoded in config and decoded automatically.
#[derive(Clone)]
pub struct JwtValidator {
  mode: VerificationMode,
}

impl JwtValidator {
  /// Create a new JWT validator from config
  ///
  /// If the secret is empty or invalid base64, the validator is not configured
  /// and all validation attempts will return `JwtAuthError::NotConfigured`.
  pub fn from_config(config: &JwtConfig) -> Self {
    if !config.secret.is_empty() {
      match DecodingKey::from_base64_secret(&config.secret) {
        | Ok(decoding_key) => {
          return Self {
            mode: VerificationMode::Hmac {
              decoding_key,
              algorithm: config.algorithm,
            },
          };
        }
        | Err(e) => {
          tracing::error!("Failed to decode base64 JWT secret: {}", e);
        }
      }
    }

    Self {
      mode: VerificationMode::None,
    }
  }

  /// Check if JWT validation is configured
  pub fn is_configured(&self) -> bool {
    !matches!(self.mode, VerificationMode::None)
  }

  /// Validate token and extract subject
  ///
  /// # Arguments
  /// * `token` - JWT token string (with or without "Bearer " prefix)
  ///
  /// # Returns
  /// * `Ok(subject)` - The validated subject from the token's `sub` claim
  /// * `Err(JwtAuthError)` - Validation error
  pub fn validate_and_extract_subject(&self, token: &str) -> Result<String, JwtAuthError> {
    if !self.is_configured() {
      return Err(JwtAuthError::NotConfigured);
    }

    let token = token.strip_prefix("Bearer ").unwrap_or(token);

    let token_data = match &self.mode {
      | VerificationMode::Hmac {
        decoding_key,
        algorithm,
      } => {
        let mut validation = Validation::new(*algorithm);
        validation.validate_exp = true;

        decode::<Claims>(token, decoding_key, &validation)
      }
      | VerificationMode::None => return Err(JwtAuthError::NotConfigured),
    }
    .map_err(|e| match e.kind() {
      | jsonwebtoken::errors::ErrorKind::ExpiredSignature => JwtAuthError::TokenExpired,
      | jsonwebtoken::errors::ErrorKind::InvalidSignature => JwtAuthError::InvalidSignature,
      | _ => JwtAuthError::InvalidFormat(e.to_string()),
    })?;

    Ok(token_data.claims.sub)
  }
}

// --- Middleware Layer ---

/// JWT authentication layer for tower middleware
///
/// Extracts Bearer token from Authorization header, validates via `JwtValidator`,
/// and inserts the subject string into request extensions on success.
/// On failure, returns 401 Unauthorized.
///
/// # Example
///
/// ```ignore
/// use at_jet::middleware::jwt_auth::{JwtConfig, JwtAuthLayer};
///
/// let config = JwtConfig {
///     secret: "base64-secret".to_string(),
///     ..Default::default()
/// };
///
/// let server = JetServer::new()
///     .route("/protected", get(handler))
///     .layer(JwtAuthLayer::new(&config));
/// ```
#[derive(Clone)]
pub struct JwtAuthLayer {
  validator:         Arc<JwtValidator>,
  subject_validator: Option<fn(&str) -> bool>,
}

impl JwtAuthLayer {
  /// Create a new JWT auth layer from config
  pub fn new(config: &JwtConfig) -> Self {
    Self {
      validator:         Arc::new(JwtValidator::from_config(config)),
      subject_validator: None,
    }
  }

  /// Add an optional subject validator function
  ///
  /// This allows application-specific validation of the subject claim
  /// (e.g., format validation, allowlist checking).
  pub fn with_subject_validator(mut self, f: fn(&str) -> bool) -> Self {
    self.subject_validator = Some(f);
    self
  }
}

impl<S> Layer<S> for JwtAuthLayer {
  type Service = JwtAuthMiddleware<S>;

  fn layer(&self, inner: S) -> Self::Service {
    JwtAuthMiddleware {
      inner,
      validator: self.validator.clone(),
      subject_validator: self.subject_validator,
    }
  }
}

/// JWT authentication middleware service
#[derive(Clone)]
pub struct JwtAuthMiddleware<S> {
  inner:             S,
  validator:         Arc<JwtValidator>,
  subject_validator: Option<fn(&str) -> bool>,
}

impl<S> Service<Request<Body>> for JwtAuthMiddleware<S>
where
  S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
  S::Future: Send,
{
  type Error = S::Error;
  type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
  type Response = S::Response;

  fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
    self.inner.poll_ready(cx)
  }

  fn call(&mut self, mut req: Request<Body>) -> Self::Future {
    let mut inner = self.inner.clone();
    let validator = self.validator.clone();
    let subject_validator = self.subject_validator;

    Box::pin(async move {
      // Extract Authorization header
      let auth_header = req
        .headers()
        .get(axum::http::header::AUTHORIZATION)
        .and_then(|h| h.to_str().ok())
        .map(|s| s.to_string());

      let token = match auth_header {
        | Some(ref header) if header.starts_with("Bearer ") => &header[7 ..],
        | _ => {
          return Ok((StatusCode::UNAUTHORIZED, "Missing or invalid Authorization header").into_response());
        }
      };

      // Validate JWT
      let subject = match validator.validate_and_extract_subject(token) {
        | Ok(sub) => sub,
        | Err(e) => return Ok(e.into_response()),
      };

      // Optional subject validation
      if let Some(validate_fn) = subject_validator {
        if !validate_fn(&subject) {
          return Ok((StatusCode::UNAUTHORIZED, "Invalid subject in token").into_response());
        }
      }

      // Insert subject into request extensions
      req.extensions_mut().insert(subject);

      inner.call(req).await
    })
  }
}

#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
  use {super::*,
       jsonwebtoken::{EncodingKey,
                      Header,
                      encode}};

  fn create_test_secret() -> String {
    // "test-secret-key-for-jwt" base64-encoded
    data_encoding::BASE64.encode(b"test-secret-key-for-jwt")
  }

  fn create_test_token(sub: &str, secret: &[u8], exp: Option<i64>) -> String {
    let claims = Claims {
      sub: sub.to_string(),
      iat: Some(chrono_now()),
      exp,
    };
    let header = Header::new(Algorithm::HS512);
    encode(&header, &claims, &EncodingKey::from_secret(secret)).unwrap()
  }

  fn chrono_now() -> i64 {
    std::time::SystemTime::now()
      .duration_since(std::time::UNIX_EPOCH)
      .unwrap()
      .as_secs() as i64
  }

  #[test]
  fn test_from_config_empty_secret() {
    let config = JwtConfig::default();
    let validator = JwtValidator::from_config(&config);
    assert!(!validator.is_configured());
  }

  #[test]
  fn test_from_config_valid_secret() {
    let config = JwtConfig {
      secret:    create_test_secret(),
      algorithm: Algorithm::HS512,
    };
    let validator = JwtValidator::from_config(&config);
    assert!(validator.is_configured());
  }

  #[test]
  fn test_from_config_invalid_base64() {
    let config = JwtConfig {
      secret:    "not-valid-base64!!!".to_string(),
      algorithm: Algorithm::HS512,
    };
    let validator = JwtValidator::from_config(&config);
    assert!(!validator.is_configured());
  }

  #[test]
  fn test_not_configured_returns_error() {
    let validator = JwtValidator::from_config(&JwtConfig::default());
    let result = validator.validate_and_extract_subject("some-token");
    assert!(matches!(result, Err(JwtAuthError::NotConfigured)));
  }

  #[test]
  fn test_valid_token() {
    let secret = b"test-secret-key-for-jwt";
    let config = JwtConfig {
      secret:    data_encoding::BASE64.encode(secret),
      algorithm: Algorithm::HS512,
    };
    let validator = JwtValidator::from_config(&config);

    let exp = chrono_now() + 3600; // 1 hour from now
    let token = create_test_token("user123", secret, Some(exp));

    let result = validator.validate_and_extract_subject(&token);
    assert!(result.is_ok());
    assert_eq!(result.unwrap(), "user123");
  }

  #[test]
  fn test_valid_token_with_bearer_prefix() {
    let secret = b"test-secret-key-for-jwt";
    let config = JwtConfig {
      secret:    data_encoding::BASE64.encode(secret),
      algorithm: Algorithm::HS512,
    };
    let validator = JwtValidator::from_config(&config);

    let exp = chrono_now() + 3600;
    let token = create_test_token("user123", secret, Some(exp));
    let bearer_token = format!("Bearer {}", token);

    let result = validator.validate_and_extract_subject(&bearer_token);
    assert!(result.is_ok());
    assert_eq!(result.unwrap(), "user123");
  }

  #[test]
  fn test_expired_token() {
    let secret = b"test-secret-key-for-jwt";
    let config = JwtConfig {
      secret:    data_encoding::BASE64.encode(secret),
      algorithm: Algorithm::HS512,
    };
    let validator = JwtValidator::from_config(&config);

    let exp = chrono_now() - 3600; // 1 hour ago
    let token = create_test_token("user123", secret, Some(exp));

    let result = validator.validate_and_extract_subject(&token);
    assert!(matches!(result, Err(JwtAuthError::TokenExpired)));
  }

  #[test]
  fn test_invalid_signature() {
    let secret = b"test-secret-key-for-jwt";
    let wrong_secret = b"wrong-secret-key-for-jwt";
    let config = JwtConfig {
      secret:    data_encoding::BASE64.encode(secret),
      algorithm: Algorithm::HS512,
    };
    let validator = JwtValidator::from_config(&config);

    let exp = chrono_now() + 3600;
    let token = create_test_token("user123", wrong_secret, Some(exp));

    let result = validator.validate_and_extract_subject(&token);
    assert!(matches!(result, Err(JwtAuthError::InvalidSignature)));
  }

  #[test]
  fn test_invalid_token_format() {
    let secret = b"test-secret-key-for-jwt";
    let config = JwtConfig {
      secret:    data_encoding::BASE64.encode(secret),
      algorithm: Algorithm::HS512,
    };
    let validator = JwtValidator::from_config(&config);

    let result = validator.validate_and_extract_subject("not-a-jwt-token");
    assert!(matches!(result, Err(JwtAuthError::InvalidFormat(_))));
  }

  #[test]
  fn test_subject_validator_hook() {
    let layer = JwtAuthLayer::new(&JwtConfig::default())
      .with_subject_validator(|s| s.len() == 8 && s.chars().all(|c| c.is_ascii_alphanumeric()));

    assert!(layer.subject_validator.is_some());
    let validate = layer.subject_validator.unwrap();
    assert!(validate("abc12345"));
    assert!(!validate("short"));
    assert!(!validate("has spaces"));
  }
}