umbral_auth/login_required.rs
1//! A login-required gate for umbral handlers.
2//!
3//! Gates a view behind authentication. umbral ships the idea in two
4//! composable shapes:
5//!
6//! - [`LoggedIn<U>`] — a per-handler axum extractor. Drop it in a handler
7//! signature and the handler only runs when a valid session exists.
8//! - [`LoginRequiredLayer`] — a per-Router tower middleware layer. Every
9//! route in the wrapped subtree is gated; unauthenticated requests never
10//! reach the inner handler.
11//!
12//! Both shapes share [`LoginRequired`] for the redirect vs. 401 fork.
13//!
14//! ## Design decisions
15//!
16//! - `LoggedIn<U: UserModel>` is **fully generic** over the user model
17//! (option a from the spec). The cookie/session reading is ~25 lines of
18//! direct logic (read cookie, hash it, query session table, hydrate U).
19//! Keeping it generic means a custom user model (`TenantUser` etc.) can
20//! use `LoggedIn<TenantUser>` without any wrapper or code duplication.
21//!
22//! - The `LoginRequired` config is read from `request.extensions()` when
23//! set by `LoginRequiredLayer`, or falls back to `LoginRequired::API`
24//! (401 JSON) if the extractor is used directly without the layer.
25//!
26//! - `LoginRequiredLayer` implements `tower::Layer<S>` directly so it
27//! works with `Router::layer(login_required())` and
28//! `Router::layer(login_required_html("/login"))` without extra
29//! wrapping.
30//!
31//! - The layer gate does NOT load the full user struct — it checks only
32//! the session table (`user_id IS NOT NULL AND expires_at > now`). The
33//! `LoggedIn<U>` extractor does the full hydration. This avoids the `U`
34//! bound at the layer level, so `login_required()` works with any user
35//! model without a type parameter on the layer.
36//!
37//! ## Deferred
38//!
39//! - `permission_required(perm)` and `staff_member_required` are deferred
40//! pending gap 33 (groups + content-type model). They can be added as
41//! thin wrappers once permission objects exist.
42
43use std::future::Future;
44use std::pin::Pin;
45use std::task::{Context, Poll};
46
47use axum::body::Body;
48use axum::http::{StatusCode, Uri};
49use axum::response::{IntoResponse, Response};
50use axum_core::extract::FromRequestParts;
51use chrono::{DateTime, Utc};
52use http::request::Parts;
53use serde_json::json;
54use sha2::{Digest, Sha256};
55use tower::{Layer, Service};
56
57use crate::UserModel;
58
59// =========================================================================
60// LoginRequired — shared config struct
61// =========================================================================
62
63/// Configuration shared by both the extractor and the middleware.
64///
65/// Controls whether an unauthenticated request gets a JSON 401 (REST/API
66/// behaviour) or a 302 redirect to a login page (server-rendered HTML
67/// behaviour).
68#[derive(Debug, Clone)]
69pub struct LoginRequired {
70 /// `None` = return 401 JSON. `Some("/login")` = 302 to
71 /// `login_url?next=<uri>`.
72 pub login_url: Option<String>,
73 /// The query-string parameter name to append with the original URI.
74 /// `Some("next")` appends `?next=<uri>`; `None` redirects without it.
75 /// Only used when `login_url` is `Some`.
76 pub next_param: Option<String>,
77}
78
79impl LoginRequired {
80 /// API/REST shape: return a JSON 401 with a `WWW-Authenticate: Bearer`
81 /// header.
82 pub const API: Self = Self {
83 login_url: None,
84 next_param: None,
85 };
86
87 /// HTML shape: redirect to `login_url?next=<original-uri>`. The `next`
88 /// parameter is named `"next"` by default.
89 pub fn html(login_url: impl Into<String>) -> Self {
90 Self {
91 login_url: Some(login_url.into()),
92 next_param: Some("next".to_string()),
93 }
94 }
95
96 /// Drop the `next` parameter from the redirect.
97 pub fn no_next(mut self) -> Self {
98 self.next_param = None;
99 self
100 }
101
102 /// Build the rejection response.
103 pub(crate) fn rejection_response(&self, uri: &Uri) -> Response {
104 match &self.login_url {
105 None => {
106 let body = json!({"error": "authentication required"}).to_string();
107 axum::http::Response::builder()
108 .status(StatusCode::UNAUTHORIZED)
109 .header("content-type", "application/json")
110 .header("www-authenticate", "Bearer")
111 .body(Body::from(body))
112 .expect("building 401 response cannot fail")
113 .into_response()
114 }
115 Some(url) => {
116 let location = match &self.next_param {
117 Some(param) => {
118 let original = uri.to_string();
119 format!("{url}?{param}={}", urlencoded(original.as_str()))
120 }
121 None => url.clone(),
122 };
123 axum::http::Response::builder()
124 .status(StatusCode::FOUND)
125 .header("location", location)
126 .body(Body::empty())
127 .expect("building 302 response cannot fail")
128 .into_response()
129 }
130 }
131 }
132}
133
134/// Percent-encode a URI for safe embedding in a query-string value.
135fn urlencoded(s: &str) -> String {
136 let mut out = String::with_capacity(s.len());
137 for c in s.chars() {
138 match c {
139 '?' => out.push_str("%3F"),
140 '&' => out.push_str("%26"),
141 '=' => out.push_str("%3D"),
142 '+' => out.push_str("%2B"),
143 '%' => out.push_str("%25"),
144 ' ' => out.push_str("%20"),
145 c => out.push(c),
146 }
147 }
148 out
149}
150
151// =========================================================================
152// LoggedIn<U> extractor
153// =========================================================================
154
155/// Per-handler axum extractor that resolves the session cookie into a user
156/// of type `U`.
157///
158/// ```rust,ignore
159/// use umbral_auth::{AuthUser, login_required::LoggedIn};
160///
161/// async fn dashboard(LoggedIn(user): LoggedIn<AuthUser>) -> String {
162/// format!("Hello, {}!", user.username())
163/// }
164/// ```
165///
166/// If no valid session exists the extractor returns the configured rejection
167/// response. The config is read from `request.extensions()` (set by
168/// [`LoginRequiredLayer`]) or falls back to [`LoginRequired::API`].
169pub struct LoggedIn<U: UserModel>(pub U);
170
171// `LoggedIn` is a tuple-newtype around `U`. Drop in `Deref` /
172// `DerefMut` (so `user.username()` works directly without the
173// `.0`) and `Serialize` (so it slots into template contexts via
174// `context!(user)` without `user.0`). Closes BUG-18 from
175// bugs/tests/testBugs.md — the original ergonomic gap that
176// pushed test code to write `let username = user.0.username();`
177// for what should be the obvious shape.
178impl<U: UserModel> std::ops::Deref for LoggedIn<U> {
179 type Target = U;
180 fn deref(&self) -> &Self::Target {
181 &self.0
182 }
183}
184
185impl<U: UserModel> std::ops::DerefMut for LoggedIn<U> {
186 fn deref_mut(&mut self) -> &mut Self::Target {
187 &mut self.0
188 }
189}
190
191impl<U: UserModel + serde::Serialize> serde::Serialize for LoggedIn<U> {
192 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193 where
194 S: serde::Serializer,
195 {
196 // Forward verbatim so `LoggedIn<AuthUser>` round-trips
197 // exactly the same shape `AuthUser` would on its own.
198 self.0.serialize(serializer)
199 }
200}
201
202impl<U, S> FromRequestParts<S> for LoggedIn<U>
203where
204 U: UserModel
205 + for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
206 + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
207 + umbral::orm::HydrateRelated
208 + Unpin
209 + Send,
210 // The session-row parse step is the bit that needs FromStr —
211 // an `i64`, `Uuid`, `String`, or hand-rolled PK type all
212 // implement it for free; a future PK shape with no string
213 // representation would have to override `id_string` AND
214 // provide a `FromStr` mirror to keep this extractor happy.
215 <U as umbral::orm::Model>::PrimaryKey: std::str::FromStr,
216 S: Send + Sync,
217{
218 type Rejection = Response;
219
220 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
221 let config = parts
222 .extensions
223 .get::<LoginRequired>()
224 .cloned()
225 .unwrap_or(LoginRequired::API);
226
227 let uri = parts.uri.clone();
228
229 match resolve_user::<U>(&parts.headers).await {
230 Some(user) => Ok(LoggedIn(user)),
231 None => Err(config.rejection_response(&uri)),
232 }
233 }
234}
235
236// =========================================================================
237// Session resolution helpers
238// =========================================================================
239
240/// SHA-256 hash the raw session token. Mirrors `umbral-sessions`'s
241/// `hash_token`. umbral-auth must not depend on umbral-sessions (the dep
242/// arrow runs the other way), so we re-implement the trivial hash step.
243fn hash_token(raw: &str) -> String {
244 let mut h = Sha256::new();
245 h.update(raw.as_bytes());
246 format!("{:x}", h.finalize())
247}
248
249/// Extract the `umbral_session` cookie from the request headers.
250fn cookie_from_headers(headers: &http::HeaderMap) -> Option<String> {
251 let header = headers.get(http::header::COOKIE)?.to_str().ok()?;
252 for pair in header.split(';') {
253 let pair = pair.trim();
254 if let Some(value) = pair.strip_prefix("umbral_session=") {
255 return Some(value.to_string());
256 }
257 }
258 None
259}
260
261/// Load a user of type `U` from the session cookie in the given
262/// headers. The generic shape powers both [`LoggedIn`] and the
263/// public [`crate::current_user_as`] helper — apps using a custom
264/// `UserModel` reach for the latter from their own handlers when
265/// the AuthUser-flavoured [`crate::current_user`] doesn't fit.
266///
267/// **Polymorphic over `U::PrimaryKey`** — the session row stores
268/// the user PK as text (gap #59); we parse it back to the typed PK
269/// via `FromStr` before feeding it to the ORM, so a `UuidUser`
270/// stays UUID-shaped on the WHERE clause and an `AuthUser` stays
271/// `i64`-shaped. There is no `parse::<i64>()` hardcoded anywhere
272/// in the framework's session-read path; the typed PK threads
273/// through verbatim.
274///
275/// Conventions assumed: `U` has an `id` column populated by the
276/// model's PK type, and an `is_active` boolean column the filter
277/// excludes deactivated rows on. Custom user models that rename
278/// either column write their own resolver against
279/// [`umbral_sessions::current_user_id_str`] instead.
280pub async fn resolve_user<U>(headers: &http::HeaderMap) -> Option<U>
281where
282 U: UserModel
283 + for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
284 + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
285 + umbral::orm::HydrateRelated
286 + Unpin
287 + Send,
288 <U as umbral::orm::Model>::PrimaryKey: std::str::FromStr,
289{
290 let user_id = current_session_user_pk::<U>(headers).await?;
291 umbral::orm::Manager::<U>::default()
292 .filter(
293 umbral::orm::Predicate::<U>::col_eq("id", user_id)
294 & umbral::orm::Predicate::<U>::col_eq("is_active", true),
295 )
296 .first()
297 .await
298 .ok()
299 .flatten()
300}
301
302/// Read the request's session cookie and resolve it to the
303/// authenticated user's TYPED primary key. The generic version of
304/// [`current_session_user_id`]; this is what [`resolve_user`] and
305/// the future `permission_required_as<U>` build on.
306///
307/// Parses the text `session.user_id` (gap #59) via
308/// `<U::PrimaryKey as FromStr>::from_str`. A non-parseable value
309/// (the row was written by a different `UserModel` impl) resolves
310/// to `None` — same as missing cookie or expired session, so the
311/// caller's "anonymous" branch fires.
312pub async fn current_session_user_pk<U>(
313 headers: &http::HeaderMap,
314) -> Option<<U as umbral::orm::Model>::PrimaryKey>
315where
316 U: UserModel,
317 <U as umbral::orm::Model>::PrimaryKey: std::str::FromStr,
318{
319 let raw_token = cookie_from_headers(headers)?;
320 let stored_id = hash_token(&raw_token);
321 let row: Option<SessionRow> = umbral::orm::Manager::<SessionRow>::default()
322 .filter(umbral::orm::Predicate::<SessionRow>::col_eq("id", stored_id))
323 .first()
324 .await
325 .ok()
326 .flatten();
327 let row = row?;
328 if row.expires_at < Utc::now() {
329 return None;
330 }
331 row.user_id?.parse().ok()
332}
333
334/// Check whether headers carry a valid authenticated session.
335/// Returns `true` iff a valid, non-expired, non-anonymous session is present.
336pub(crate) async fn is_authenticated(headers: &http::HeaderMap) -> bool {
337 current_session_user_id(headers).await.is_some()
338}
339
340/// Resolve the `umbral_session` cookie in `headers` to the
341/// authenticated user's `i64` PK — the AuthUser-specific shorthand
342/// for [`current_session_user_pk::<AuthUser>`]. Returns `None` for
343/// missing cookie, expired session, anonymous session, a
344/// non-parseable `user_id` (session written by a non-AuthUser
345/// model), or any sqlx error.
346///
347/// This is the primitive `permission_required` (in `umbral-permissions`)
348/// builds on. Callers using a custom user model reach for
349/// [`current_session_user_pk`] (the typed generic) or
350/// [`umbral_sessions::current_user_id_str`] (the raw string)
351/// instead — both stay polymorphic over the active user model's PK.
352pub async fn current_session_user_id(headers: &http::HeaderMap) -> Option<i64> {
353 current_session_user_pk::<crate::AuthUser>(headers).await
354}
355
356/// Private mirror of `umbral_sessions::Session`. Lives here because
357/// `umbral-auth` does not depend on `umbral-sessions` (the dep arrow runs
358/// the other way), but we still need ORM access to the `session` table.
359/// Multiple `Model` impls can target the same table — sea-query treats
360/// the schema as data, not a type-level singleton.
361#[doc(hidden)]
362#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize, serde::Deserialize, umbral::orm::Model)]
363#[umbral(table = "session")]
364pub struct SessionRow {
365 pub id: String,
366 /// Polymorphic user-PK column (gap #59). Stored as the user's PK
367 /// `Display` form — i64 for AuthUser, UUID for custom user models,
368 /// etc. Parse with `<U::PrimaryKey as FromStr>::from_str` on the
369 /// way out.
370 pub user_id: Option<String>,
371 pub data: String,
372 pub created_at: DateTime<Utc>,
373 pub expires_at: DateTime<Utc>,
374}
375
376// =========================================================================
377// LoginRequiredLayer — tower::Layer impl
378// =========================================================================
379
380/// Per-router middleware layer that gates every route in the wrapped subtree.
381///
382/// ```rust,ignore
383/// use umbral_auth::login_required::{login_required, login_required_html};
384///
385/// // REST subtree — 401 JSON on unauthenticated.
386/// let api_router = Router::new()
387/// .route("/api/me", get(me_handler))
388/// .layer(login_required());
389///
390/// // HTML subtree — 302 to /login?next=<uri>.
391/// let app_router = Router::new()
392/// .route("/dashboard", get(dashboard_handler))
393/// .layer(login_required_html("/login"));
394/// ```
395///
396/// The layer also inserts the [`LoginRequired`] config into request
397/// extensions so nested [`LoggedIn<U>`] extractors pick it up without
398/// re-declaration.
399#[derive(Clone)]
400pub struct LoginRequiredLayer {
401 config: LoginRequired,
402}
403
404impl LoginRequiredLayer {
405 /// Build a layer with an explicit config.
406 pub fn new(config: LoginRequired) -> Self {
407 Self { config }
408 }
409
410 /// Apply this layer to a Router, returning the gated router.
411 ///
412 /// ```rust,ignore
413 /// let gated = LoginRequiredLayer::new(LoginRequired::html("/login"))
414 /// .apply(my_router);
415 /// ```
416 pub fn apply(self, router: axum::Router) -> axum::Router {
417 router.layer(self)
418 }
419}
420
421impl<S> Layer<S> for LoginRequiredLayer {
422 type Service = LoginRequiredService<S>;
423
424 fn layer(&self, inner: S) -> Self::Service {
425 LoginRequiredService {
426 inner,
427 config: self.config.clone(),
428 }
429 }
430}
431
432/// The tower `Service` produced by [`LoginRequiredLayer`].
433#[derive(Clone)]
434pub struct LoginRequiredService<S> {
435 inner: S,
436 config: LoginRequired,
437}
438
439impl<S> Service<axum::extract::Request> for LoginRequiredService<S>
440where
441 S: Service<axum::extract::Request, Response = Response> + Clone + Send + 'static,
442 S::Future: Send + 'static,
443{
444 type Response = Response;
445 type Error = S::Error;
446 type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send + 'static>>;
447
448 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
449 self.inner.poll_ready(cx)
450 }
451
452 fn call(&mut self, mut req: axum::extract::Request) -> Self::Future {
453 let config = self.config.clone();
454 // Clone inner for the async block — `self.inner` is consumed
455 // by `call()` semantically and must be driven after `poll_ready`.
456 let mut inner = self.inner.clone();
457
458 Box::pin(async move {
459 let uri = req.uri().clone();
460
461 if !is_authenticated(req.headers()).await {
462 return Ok(config.rejection_response(&uri));
463 }
464
465 // Insert config so LoggedIn<U> extractors can find it.
466 req.extensions_mut().insert(config);
467
468 inner.call(req).await
469 })
470 }
471}
472
473// =========================================================================
474// Convenience constructors
475// =========================================================================
476
477/// Returns a [`LoginRequiredLayer`] configured for REST/API use (401 JSON).
478///
479/// ```rust,ignore
480/// Router::new()
481/// .route("/api/me", get(me_handler))
482/// .layer(login_required())
483/// ```
484pub fn login_required() -> LoginRequiredLayer {
485 LoginRequiredLayer::new(LoginRequired::API)
486}
487
488/// Returns a [`LoginRequiredLayer`] configured for HTML use (302 redirect).
489///
490/// ```rust,ignore
491/// Router::new()
492/// .route("/dashboard", get(dashboard_handler))
493/// .layer(login_required_html("/login"))
494/// ```
495pub fn login_required_html(login_url: impl Into<String>) -> LoginRequiredLayer {
496 LoginRequiredLayer::new(LoginRequired::html(login_url))
497}