Skip to main content

modo_auth/
extractor.rs

1use crate::cache::ResolvedUser;
2use crate::provider::UserProviderService;
3use modo::app::AppState;
4use modo::axum::extract::FromRequestParts;
5use modo::axum::http::request::Parts;
6use modo::{Error, HttpError};
7use modo_session::SessionManager;
8use std::ops::Deref;
9use std::sync::Arc;
10
11/// Resolve the authenticated user, checking the extension cache first.
12///
13/// Returns `Ok(None)` when no session or user exists.
14/// Returns `Err` for infrastructure failures (missing middleware/service, DB errors).
15async fn resolve_user<U: Clone + Send + Sync + 'static>(
16    parts: &mut Parts,
17    state: &AppState,
18) -> Result<Option<U>, Error> {
19    // Fast path: user already resolved by UserContextLayer or a prior extractor
20    if let Some(cached) = parts.extensions.get::<ResolvedUser<U>>() {
21        tracing::debug!(cache_hit = true, "auth user resolved from extension cache");
22        return Ok(Some((*cached.0).clone()));
23    }
24
25    let session = SessionManager::from_request_parts(parts, state)
26        .await
27        .map_err(|_| Error::internal("Auth requires session middleware"))?;
28
29    let user_id = match session.user_id().await {
30        Some(id) => id,
31        None => {
32            tracing::debug!("no session user_id, skipping auth resolution");
33            return Ok(None);
34        }
35    };
36
37    let provider = state
38        .services
39        .get::<UserProviderService<U>>()
40        .ok_or_else(|| {
41            Error::internal(format!(
42                "UserProviderService<{}> not registered",
43                std::any::type_name::<U>()
44            ))
45        })?;
46
47    let user = provider.find_by_id(&user_id).await?;
48
49    if let Some(ref u) = user {
50        tracing::debug!(user_id = %user_id, cache_hit = false, "auth user resolved from provider");
51        parts.extensions.insert(ResolvedUser(Arc::new(u.clone())));
52    } else {
53        tracing::warn!(user_id = %user_id, "session references non-existent user");
54    }
55
56    Ok(user)
57}
58
59/// Extractor that requires an authenticated user.
60///
61/// Resolves the user from the session via [`UserProviderService<U>`].
62/// Results are cached in request extensions so subsequent extractors in the
63/// same request do not trigger a second DB lookup.
64///
65/// Returns `401 Unauthorized` if no session exists or the user is not found.
66/// Returns `500 Internal Server Error` if session middleware or
67/// [`UserProviderService<U>`] is not registered, or if the provider returns an error.
68#[derive(Clone)]
69pub struct Auth<U: Clone + Send + Sync + 'static>(
70    /// The resolved user.
71    pub U,
72);
73
74impl<U: Clone + Send + Sync + 'static> Deref for Auth<U> {
75    type Target = U;
76
77    fn deref(&self) -> &Self::Target {
78        &self.0
79    }
80}
81
82impl<U: Clone + Send + Sync + 'static> FromRequestParts<AppState> for Auth<U> {
83    type Rejection = Error;
84
85    async fn from_request_parts(
86        parts: &mut Parts,
87        state: &AppState,
88    ) -> Result<Self, Self::Rejection> {
89        let user = resolve_user::<U>(parts, state)
90            .await?
91            .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
92        Ok(Auth(user))
93    }
94}
95
96/// Extractor that optionally loads the authenticated user.
97///
98/// Passes the request through regardless of authentication outcome:
99/// returns `OptionalAuth(Some(user))` when an authenticated user is found,
100/// or `OptionalAuth(None)` if there is no active session or the session's
101/// user ID is not found by the provider.
102///
103/// **Caveat:** this extractor still returns `500 Internal Server Error` when
104/// infrastructure is misconfigured (session middleware or
105/// [`UserProviderService<U>`] not registered) or when the provider returns a
106/// hard error (e.g. database connection failure). Only *authentication
107/// absence* is treated as `None`; infrastructure failures are propagated.
108#[derive(Clone)]
109pub struct OptionalAuth<U: Clone + Send + Sync + 'static>(
110    /// `Some(user)` when authenticated, `None` otherwise.
111    pub Option<U>,
112);
113
114impl<U: Clone + Send + Sync + 'static> Deref for OptionalAuth<U> {
115    type Target = Option<U>;
116
117    fn deref(&self) -> &Self::Target {
118        &self.0
119    }
120}
121
122impl<U: Clone + Send + Sync + 'static> FromRequestParts<AppState> for OptionalAuth<U> {
123    type Rejection = Error;
124
125    async fn from_request_parts(
126        parts: &mut Parts,
127        state: &AppState,
128    ) -> Result<Self, Self::Rejection> {
129        let user = resolve_user::<U>(parts, state).await?;
130        Ok(OptionalAuth(user))
131    }
132}