Skip to main content

authx_axum/
middleware.rs

1use std::sync::Arc;
2
3use axum::response::Response;
4use tower::{Layer, Service};
5
6use authx_core::{crypto::sha256_hex, identity::Identity};
7use authx_storage::ports::{SessionRepository, UserRepository};
8
9const SESSION_HEADER: &str = "x-authx-token";
10const SESSION_COOKIE: &str = "authx_session";
11
12// ── Public Layer ─────────────────────────────────────────────────────────────
13
14/// Tower [`Layer`] that resolves session tokens into [`Identity`] extensions.
15///
16/// Add this to your router **after** all routes. Unauthenticated requests pass
17/// through; use [`RequireAuth`] on individual routes to enforce auth.
18///
19/// ```rust,ignore
20/// let app = Router::new()
21///     .route("/me", get(me))
22///     .layer(SessionLayer::new(store));
23/// ```
24#[derive(Clone)]
25pub struct SessionLayer<S> {
26    storage: Arc<S>,
27}
28
29impl<S> SessionLayer<S>
30where
31    S: SessionRepository + UserRepository + Clone + Send + Sync + 'static,
32{
33    pub fn new(storage: S) -> Self {
34        Self {
35            storage: Arc::new(storage),
36        }
37    }
38}
39
40impl<S, Svc> Layer<Svc> for SessionLayer<S>
41where
42    S: SessionRepository + UserRepository + Clone + Send + Sync + 'static,
43{
44    type Service = SessionService<S, Svc>;
45
46    fn layer(&self, inner: Svc) -> Self::Service {
47        SessionService {
48            storage: Arc::clone(&self.storage),
49            inner,
50        }
51    }
52}
53
54// ── Inner Service ─────────────────────────────────────────────────────────────
55
56#[derive(Clone)]
57pub struct SessionService<S, Svc> {
58    storage: Arc<S>,
59    inner: Svc,
60}
61
62impl<S, Svc, ReqBody> Service<axum::http::Request<ReqBody>> for SessionService<S, Svc>
63where
64    S: SessionRepository + UserRepository + Clone + Send + Sync + 'static,
65    Svc: Service<axum::http::Request<ReqBody>, Response = Response> + Clone + Send + 'static,
66    Svc::Future: Send + 'static,
67    ReqBody: Send + 'static,
68{
69    type Response = Response;
70    type Error = Svc::Error;
71    type Future =
72        std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Svc::Error>> + Send>>;
73
74    fn poll_ready(
75        &mut self,
76        cx: &mut std::task::Context<'_>,
77    ) -> std::task::Poll<Result<(), Self::Error>> {
78        self.inner.poll_ready(cx)
79    }
80
81    fn call(&mut self, mut req: axum::http::Request<ReqBody>) -> Self::Future {
82        let storage = Arc::clone(&self.storage);
83        let mut inner = self.inner.clone();
84
85        Box::pin(async move {
86            let token_hash = extract_token(&req).map(|t| sha256_hex(t.as_bytes()));
87
88            if let Some(hash) = token_hash
89                && let Some(identity) = resolve_identity(&*storage, &hash).await
90            {
91                req.extensions_mut().insert(identity);
92                tracing::debug!("identity resolved");
93            }
94
95            inner.call(req).await
96        })
97    }
98}
99
100// ── Helpers ───────────────────────────────────────────────────────────────────
101
102async fn resolve_identity<S>(storage: &S, token_hash: &str) -> Option<Identity>
103where
104    S: SessionRepository + UserRepository + Clone + Send + Sync + 'static,
105{
106    let session = storage.find_by_token_hash(token_hash).await.ok()??;
107    if session.expires_at < chrono::Utc::now() {
108        tracing::debug!(session_id = %session.id, "session expired");
109        return None;
110    }
111    let user = storage.find_by_id(session.user_id).await.ok()??;
112    Some(Identity::new(user, session))
113}
114
115fn extract_token<B>(request: &axum::http::Request<B>) -> Option<String> {
116    if let Some(bearer) = request
117        .headers()
118        .get(axum::http::header::AUTHORIZATION)
119        .and_then(|v| v.to_str().ok())
120        .and_then(|v| v.strip_prefix("Bearer "))
121    {
122        return Some(bearer.to_owned());
123    }
124
125    if let Some(token) = request
126        .headers()
127        .get(SESSION_HEADER)
128        .and_then(|v| v.to_str().ok())
129    {
130        return Some(token.to_owned());
131    }
132
133    let cookie_header = request
134        .headers()
135        .get(axum::http::header::COOKIE)
136        .and_then(|v| v.to_str().ok())?;
137
138    for part in cookie_header.split(';') {
139        let part = part.trim();
140        if let Some(value) = part.strip_prefix(&format!("{SESSION_COOKIE}=")) {
141            return Some(value.to_owned());
142        }
143    }
144
145    None
146}