rest/middleware/
auth.rs

1use crate::middleware::{Middleware, middleware};
2use http::StatusCode;
3use hyper::Body;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
5use serde::Deserialize;
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10// Reserved JWT claim keys kept for reference; currently not used.
11#[allow(dead_code)]
12const JWT_AUDIENCE: &str = "aud";
13#[allow(dead_code)]
14const JWT_EXPIRE: &str = "exp";
15#[allow(dead_code)]
16const JWT_ID: &str = "jti";
17#[allow(dead_code)]
18const JWT_ISSUED_AT: &str = "iat";
19#[allow(dead_code)]
20const JWT_ISSUER: &str = "iss";
21#[allow(dead_code)]
22const JWT_NOT_BEFORE: &str = "nbf";
23#[allow(dead_code)]
24const JWT_SUBJECT: &str = "sub";
25
26/// Unauthorized callback signature.
27pub type UnauthorizedCallback =
28    Arc<dyn Fn(&mut http::Response<Body>, &http::Request<Body>, &anyhow::Error) + Send + Sync>;
29
30/// Options for authorize middleware.
31#[derive(Clone, Default)]
32pub struct AuthorizeOptions {
33    pub prev_secret: Option<String>,
34    pub callback: Option<UnauthorizedCallback>,
35}
36
37/// Enable previous secret for token transition.
38pub fn with_prev_secret(secret: impl Into<String>) -> impl Fn(&mut AuthorizeOptions) {
39    let s = secret.into();
40    move |opts: &mut AuthorizeOptions| {
41        opts.prev_secret = Some(s.clone());
42    }
43}
44
45/// Set unauthorized callback.
46pub fn with_unauthorized_callback(
47    callback: UnauthorizedCallback,
48) -> impl Fn(&mut AuthorizeOptions) {
49    move |opts: &mut AuthorizeOptions| {
50        opts.callback = Some(callback.clone());
51    }
52}
53
54/// Authorize middleware: validates JWT Bearer token with secret/prev_secret.
55pub fn authorize(
56    secret: impl Into<String>,
57    opts: impl IntoIterator<Item = impl Fn(&mut AuthorizeOptions)>,
58) -> Middleware {
59    let mut options = AuthorizeOptions::default();
60    for opt in opts {
61        opt(&mut options);
62    }
63    let secret = secret.into();
64    let prev_secret = options.prev_secret.clone();
65    let callback = options.callback.clone();
66
67    middleware(move |mut req: http::Request<Body>, next| {
68        let secret = secret.clone();
69        let prev_secret = prev_secret.clone();
70        let callback = callback.clone();
71        async move {
72            match validate_token(&mut req, &secret, prev_secret.as_deref()) {
73                Ok(Some(claims)) => {
74                    req.extensions_mut().insert(claims);
75                    next.call(req).await
76                }
77                Ok(None) => {
78                    unauthorized_response(anyhow::anyhow!("missing bearer token"), callback, &req)
79                }
80                Err(e) => unauthorized_response(e, callback, &req),
81            }
82        }
83    })
84}
85
86/// Extracted custom claims stored in request extensions.
87#[derive(Debug, Clone, PartialEq)]
88pub struct AuthClaims {
89    pub claims: HashMap<String, Value>,
90}
91
92impl AuthClaims {
93    pub fn get(&self, key: &str) -> Option<&Value> {
94        self.claims.get(key)
95    }
96}
97
98#[derive(Debug, Deserialize)]
99struct RawClaims(HashMap<String, Value>);
100
101fn validate_token(
102    req: &mut http::Request<Body>,
103    secret: &str,
104    prev_secret: Option<&str>,
105) -> anyhow::Result<Option<AuthClaims>> {
106    let token = extract_bearer(req)?;
107    let mut validation = Validation::new(Algorithm::HS256);
108    validation.validate_aud = false;
109    validation.validate_exp = true;
110    let decode_with = |sec: &str| {
111        decode::<RawClaims>(
112            &token,
113            &DecodingKey::from_secret(sec.as_bytes()),
114            &validation,
115        )
116    };
117
118    let decoded = decode_with(secret).or_else(|e| {
119        if let Some(prev) = prev_secret {
120            decode_with(prev).map_err(|_| e)
121        } else {
122            Err(e)
123        }
124    })?;
125
126    let claims = decoded.claims.0;
127    Ok(Some(AuthClaims { claims }))
128}
129
130fn extract_bearer(req: &http::Request<Body>) -> anyhow::Result<String> {
131    let header = req
132        .headers()
133        .get(http::header::AUTHORIZATION)
134        .ok_or_else(|| anyhow::anyhow!("missing Authorization header"))?
135        .to_str()
136        .map_err(|_| anyhow::anyhow!("invalid Authorization header"))?;
137    let parts: Vec<&str> = header.split_whitespace().collect();
138    if parts.len() != 2 || parts[0] != "Bearer" {
139        anyhow::bail!("invalid bearer token");
140    }
141    Ok(parts[1].to_string())
142}
143
144fn unauthorized_response(
145    err: anyhow::Error,
146    callback: Option<UnauthorizedCallback>,
147    req: &http::Request<Body>,
148) -> http::Response<Body> {
149    let mut resp = http::Response::builder()
150        .status(StatusCode::UNAUTHORIZED)
151        .body(Body::from(err.to_string()))
152        .unwrap();
153    if let Some(cb) = callback {
154        cb(&mut resp, req, &err);
155    }
156    resp
157}