Skip to main content

oxide_framework_core/auth/
layer.rs

1//! Tower layer: decode JWT from Bearer and/or session cookie, attach [`super::AuthClaims`].
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use axum::body::Body;
9use axum::http::header::{AUTHORIZATION, COOKIE};
10use axum::http::{Request, StatusCode};
11use axum::Json;
12use axum::response::{IntoResponse, Response};
13use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
14use tower::{Layer, Service};
15use tracing::debug;
16
17use super::claims::AuthClaims;
18use super::config::AuthConfig;
19
20/// Tower `Layer` that validates JWTs and inserts [`AuthClaims`] into request extensions.
21#[derive(Clone)]
22pub struct AuthLayer {
23    config: Arc<AuthConfig>,
24    key: DecodingKey,
25    validation: Validation,
26}
27
28impl AuthLayer {
29    pub fn new(config: AuthConfig) -> Self {
30        let config = Arc::new(config);
31        let key = DecodingKey::from_secret(&config.secret);
32        let mut validation = Validation::new(Algorithm::HS256);
33        if let Some(ref iss) = config.issuer {
34            validation.set_issuer(&[iss.as_str()]);
35        }
36        if let Some(ref aud) = config.audience {
37            validation.set_audience(&[aud.as_str()]);
38        }
39        Self {
40            config,
41            key,
42            validation,
43        }
44    }
45}
46
47impl<S> Layer<S> for AuthLayer {
48    type Service = AuthService<S>;
49
50    fn layer(&self, inner: S) -> Self::Service {
51        AuthService {
52            inner,
53            config: self.config.clone(),
54            key: self.key.clone(),
55            validation: self.validation.clone(),
56        }
57    }
58}
59
60/// Inner service that decodes JWT and forwards the request.
61#[derive(Clone)]
62pub struct AuthService<S> {
63    inner: S,
64    config: Arc<AuthConfig>,
65    key: DecodingKey,
66    validation: Validation,
67}
68
69impl<S> Service<Request<Body>> for AuthService<S>
70where
71    S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
72    S::Error: Send + 'static,
73    S::Future: Send + 'static,
74{
75    type Response = Response;
76    type Error = S::Error;
77    type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send>>;
78
79    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80        self.inner.poll_ready(cx)
81    }
82
83    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
84        let config = self.config.clone();
85        let key = self.key.clone();
86        let validation = self.validation.clone();
87        let mut inner = self.inner.clone();
88
89        Box::pin(async move {
90            let token_result = resolve_token(&config, req.headers());
91
92            match token_result {
93                TokenResolution::None => inner.call(req).await,
94                TokenResolution::Some(token) => match decode::<AuthClaims>(&token, &key, &validation) {
95                    Ok(data) => {
96                        req.extensions_mut().insert(data.claims);
97                        inner.call(req).await
98                    }
99                    Err(e) => {
100                        debug!(error = %e, "jwt validation failed");
101                        Ok(auth_error_response(
102                            StatusCode::UNAUTHORIZED,
103                            "invalid or expired token",
104                        ))
105                    }
106                },
107                TokenResolution::Malformed => Ok(auth_error_response(
108                    StatusCode::UNAUTHORIZED,
109                    "malformed authorization",
110                )),
111            }
112        })
113    }
114}
115
116enum TokenResolution {
117    /// No Bearer / no session cookie — anonymous.
118    None,
119    /// Raw JWT string to validate.
120    Some(String),
121    /// Authorization header present but not usable.
122    Malformed,
123}
124
125fn resolve_token(config: &AuthConfig, headers: &axum::http::HeaderMap) -> TokenResolution {
126    if config.bearer_token {
127        if let Some(auth) = headers.get(AUTHORIZATION) {
128            match auth.to_str() {
129                Ok(s) => {
130                    let s = s.trim();
131                    if let Some(rest) = s.strip_prefix("Bearer ").or_else(|| s.strip_prefix("bearer ")) {
132                        let t = rest.trim();
133                        if t.is_empty() {
134                            return TokenResolution::Malformed;
135                        }
136                        return TokenResolution::Some(t.to_string());
137                    }
138                    // Authorization present but not Bearer — reject
139                    return TokenResolution::Malformed;
140                }
141                Err(_) => return TokenResolution::Malformed,
142            }
143        }
144    }
145
146    if let Some(name) = config.session_cookie_name.as_deref() {
147        if let Some(cookie_hdr) = headers.get(COOKIE).and_then(|v| v.to_str().ok()) {
148            if let Some(tok) = cookie_token(cookie_hdr, name) {
149                return TokenResolution::Some(tok);
150            }
151        }
152    }
153
154    TokenResolution::None
155}
156
157fn cookie_token(header: &str, name: &str) -> Option<String> {
158    for c in cookie::Cookie::split_parse(header).flatten() {
159        if c.name() == name {
160            let v = c.value();
161            if !v.is_empty() {
162                return Some(v.to_string());
163            }
164        }
165    }
166    None
167}
168
169fn auth_error_response(status: StatusCode, message: &str) -> Response {
170    let body = serde_json::json!({
171        "status": status.as_u16(),
172        "error": message,
173    });
174    (status, Json(body)).into_response()
175}
176