mod middleware;
pub use middleware::AuthMiddleware;
use crate::error::Error;
use crate::extract::{FromRequestParts, PathParams};
use crate::state::AppState;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub exp: u64,
pub iat: u64,
}
impl Claims {
pub fn new(sub: impl Into<String>, expires_in_secs: u64) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
sub: sub.into(),
exp: now + expires_in_secs,
iat: now,
}
}
pub fn is_expired(&self) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
self.exp < now
}
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct TokenResponse {
pub token: String,
pub expires_in: u64,
}
impl TokenResponse {
pub fn new(token: String, expires_in: u64) -> Self {
Self { token, expires_in }
}
}
#[derive(Debug, Clone)]
pub struct CurrentUser {
pub id: String,
pub claims: Claims,
}
impl FromRequestParts for CurrentUser {
async fn from_request_parts(
parts: &http::request::Parts,
_params: &PathParams,
_state: &Arc<AppState>,
) -> Result<Self, Error> {
parts
.extensions
.get::<CurrentUser>()
.cloned()
.ok_or_else(|| Error::unauthorized("authentication required"))
}
}
#[derive(Clone)]
pub struct AuthConfig {
secret: String,
expiration: u64,
}
impl AuthConfig {
pub fn new(secret: impl Into<String>, expiration: u64) -> Self {
Self {
secret: secret.into(),
expiration,
}
}
pub fn from_env() -> Result<Self, crate::config::ConfigError> {
let secret = crate::config::get_env("JWT_SECRET")?;
let expiration = crate::config::get_env_parsed_or("JWT_EXPIRATION", 3600);
Ok(Self { secret, expiration })
}
pub fn expiration(&self) -> u64 {
self.expiration
}
pub fn encode(&self, claims: &Claims) -> Result<String, Error> {
encode(
&Header::default(),
claims,
&EncodingKey::from_secret(self.secret.as_bytes()),
)
.map_err(|e| Error::internal(format!("failed to encode token: {}", e)))
}
pub fn decode(&self, token: &str) -> Result<Claims, Error> {
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(self.secret.as_bytes()),
&Validation::default(),
)
.map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
Error::unauthorized("token expired")
}
jsonwebtoken::errors::ErrorKind::InvalidToken => Error::unauthorized("invalid token"),
_ => Error::unauthorized(format!("token validation failed: {}", e)),
})?;
Ok(token_data.claims)
}
pub fn create_token(&self, user_id: impl Into<String>) -> Result<String, Error> {
let claims = Claims::new(user_id, self.expiration);
self.encode(&claims)
}
}
#[derive(Clone, Default)]
pub struct PublicRoutes {
routes: Vec<(String, String)>, }
impl PublicRoutes {
pub fn new() -> Self {
Self { routes: Vec::new() }
}
pub fn add(&mut self, method: &str, path: &str) {
self.routes.push((method.to_string(), path.to_string()));
}
pub fn is_public(&self, method: &str, path: &str) -> bool {
if path.starts_with("/__rapina") {
return true;
}
self.routes
.iter()
.any(|(m, p)| m == method && Self::matches_pattern(p, path))
}
fn matches_pattern(pattern: &str, path: &str) -> bool {
let pattern_parts: Vec<&str> = pattern.split('/').collect();
let path_parts: Vec<&str> = path.split('/').collect();
if pattern_parts.len() != path_parts.len() {
return false;
}
pattern_parts
.iter()
.zip(path_parts.iter())
.all(|(pattern_part, path_part)| {
pattern_part.starts_with(':') || pattern_part == path_part
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_claims_new() {
let claims = Claims::new("user123", 3600);
assert_eq!(claims.sub, "user123");
assert!(claims.exp > claims.iat);
assert_eq!(claims.exp - claims.iat, 3600);
}
#[test]
fn test_claims_not_expired() {
let claims = Claims::new("user123", 3600);
assert!(!claims.is_expired());
}
#[test]
fn test_claims_expired() {
let mut claims = Claims::new("user123", 0);
claims.exp = claims.iat - 1; assert!(claims.is_expired());
}
#[test]
fn test_auth_config_new() {
let config = AuthConfig::new("secret", 7200);
assert_eq!(config.expiration(), 7200);
}
#[test]
fn test_auth_config_encode_decode() {
let config = AuthConfig::new("test-secret", 3600);
let claims = Claims::new("user456", 3600);
let token = config.encode(&claims).unwrap();
assert!(!token.is_empty());
let decoded = config.decode(&token).unwrap();
assert_eq!(decoded.sub, "user456");
}
#[test]
fn test_auth_config_create_token() {
let config = AuthConfig::new("test-secret", 3600);
let token = config.create_token("user789").unwrap();
assert!(!token.is_empty());
let decoded = config.decode(&token).unwrap();
assert_eq!(decoded.sub, "user789");
}
#[test]
fn test_auth_config_invalid_token() {
let config = AuthConfig::new("test-secret", 3600);
let result = config.decode("invalid.token.here");
assert!(result.is_err());
}
#[test]
fn test_auth_config_wrong_secret() {
let config1 = AuthConfig::new("secret1", 3600);
let config2 = AuthConfig::new("secret2", 3600);
let token = config1.create_token("user").unwrap();
let result = config2.decode(&token);
assert!(result.is_err());
}
#[test]
fn test_public_routes_empty() {
let routes = PublicRoutes::new();
assert!(!routes.is_public("GET", "/protected"));
}
#[test]
fn test_public_routes_exact_match() {
let mut routes = PublicRoutes::new();
routes.add("GET", "/health");
routes.add("POST", "/login");
assert!(routes.is_public("GET", "/health"));
assert!(routes.is_public("POST", "/login"));
assert!(!routes.is_public("GET", "/login"));
assert!(!routes.is_public("POST", "/health"));
}
#[test]
fn test_public_routes_with_params() {
let mut routes = PublicRoutes::new();
routes.add("GET", "/users/:id/public");
assert!(routes.is_public("GET", "/users/123/public"));
assert!(routes.is_public("GET", "/users/abc/public"));
assert!(!routes.is_public("GET", "/users/123/private"));
}
#[test]
fn test_public_routes_introspection_always_public() {
let routes = PublicRoutes::new();
assert!(routes.is_public("GET", "/__rapina/routes"));
assert!(routes.is_public("GET", "/__rapina/openapi.json"));
}
}