Skip to main content

modo_auth/
extractor.rs

1use crate::cache::ResolvedUser;
2use crate::provider::UserProviderService;
3use modo::app::AppState;
4use modo::axum::extract::FromRequestParts;
5use modo::axum::http::request::Parts;
6use modo::{Error, HttpError};
7use modo_session::SessionManager;
8use std::ops::Deref;
9use std::sync::Arc;
10
11/// Resolve the authenticated user, checking the extension cache first.
12///
13/// Returns `Ok(None)` when no session or user exists.
14/// Returns `Err` for infrastructure failures (missing middleware/service, DB errors).
15async fn resolve_user<U: Clone + Send + Sync + 'static>(
16    parts: &mut Parts,
17    state: &AppState,
18) -> Result<Option<U>, Error> {
19    // Fast path: user already resolved by UserContextLayer or a prior extractor
20    if let Some(cached) = parts.extensions.get::<ResolvedUser<U>>() {
21        return Ok(Some((*cached.0).clone()));
22    }
23
24    let session = SessionManager::from_request_parts(parts, state)
25        .await
26        .map_err(|_| Error::internal("Auth requires session middleware"))?;
27
28    let user_id = match session.user_id().await {
29        Some(id) => id,
30        None => return Ok(None),
31    };
32
33    let provider = state
34        .services
35        .get::<UserProviderService<U>>()
36        .ok_or_else(|| {
37            Error::internal(format!(
38                "UserProviderService<{}> not registered",
39                std::any::type_name::<U>()
40            ))
41        })?;
42
43    let user = provider.find_by_id(&user_id).await?;
44
45    if let Some(ref u) = user {
46        parts.extensions.insert(ResolvedUser(Arc::new(u.clone())));
47    }
48
49    Ok(user)
50}
51
52/// Extractor that requires an authenticated user.
53///
54/// Resolves the user from the session via [`UserProviderService<U>`].
55/// Results are cached in request extensions so subsequent extractors in the
56/// same request do not trigger a second DB lookup.
57///
58/// Returns `401 Unauthorized` if no session exists or the user is not found.
59/// Returns `500 Internal Server Error` if session middleware or
60/// [`UserProviderService<U>`] is not registered, or if the provider returns an error.
61#[derive(Clone)]
62pub struct Auth<U: Clone + Send + Sync + 'static>(
63    /// The resolved user.
64    pub U,
65);
66
67impl<U: Clone + Send + Sync + 'static> Deref for Auth<U> {
68    type Target = U;
69
70    fn deref(&self) -> &Self::Target {
71        &self.0
72    }
73}
74
75impl<U: Clone + Send + Sync + 'static> FromRequestParts<AppState> for Auth<U> {
76    type Rejection = Error;
77
78    async fn from_request_parts(
79        parts: &mut Parts,
80        state: &AppState,
81    ) -> Result<Self, Self::Rejection> {
82        let user = resolve_user::<U>(parts, state)
83            .await?
84            .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
85        Ok(Auth(user))
86    }
87}
88
89/// Extractor that optionally loads the authenticated user.
90///
91/// Never rejects — returns `OptionalAuth(None)` if there is no active session
92/// or the session's user ID is not found by the provider.
93///
94/// Returns `500 Internal Server Error` if session middleware or
95/// [`UserProviderService<U>`] is not registered, or if the provider returns an
96/// infrastructure error.
97#[derive(Clone)]
98pub struct OptionalAuth<U: Clone + Send + Sync + 'static>(
99    /// `Some(user)` when authenticated, `None` otherwise.
100    pub Option<U>,
101);
102
103impl<U: Clone + Send + Sync + 'static> Deref for OptionalAuth<U> {
104    type Target = Option<U>;
105
106    fn deref(&self) -> &Self::Target {
107        &self.0
108    }
109}
110
111impl<U: Clone + Send + Sync + 'static> FromRequestParts<AppState> for OptionalAuth<U> {
112    type Rejection = Error;
113
114    async fn from_request_parts(
115        parts: &mut Parts,
116        state: &AppState,
117    ) -> Result<Self, Self::Rejection> {
118        let user = resolve_user::<U>(parts, state).await?;
119        Ok(OptionalAuth(user))
120    }
121}