1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
//! Provides the authentication-related extensions to the Humphrey app.
use crate::database::AuthDatabase;
use crate::AuthProvider;
use humphrey::http::{Request, Response, StatusCode};
use humphrey::App;
use std::sync::{Arc, MutexGuard};
/// Represents a state which contains an `AuthProvider`.
/// This must be implemented on the state in order to use authentication.
///
/// # Example
/// ```
/// type DatabaseWrapper = Arc<RwLock<MyDatabase>>;
///
/// struct MyState {
/// db: DatabaseWrapper,
/// auth_provider: Mutex<AuthProvider<DatabaseWrapper>>
/// }
///
/// impl AuthState<DatabaseWrapper> for MyState {
/// fn auth_provider(&self) -> MutexGuard<AuthProvider<DatabaseWrapper>> {
/// self.auth_provider.lock().unwrap()
/// }
/// }
/// ```
pub trait AuthState<D>
where
D: AuthDatabase,
{
/// Returns a `MutexGuard` to the `AuthProvider`.
fn auth_provider(&self) -> MutexGuard<AuthProvider<D>>;
}
/// Represents a function able to handle an authenticated request.
/// This is passed the request, the state, and the UID of the authenticated user.
///
/// # Example
/// ```
/// fn auth_req_handler(_: Request, state: Arc<MyState>, uid: String) -> Response {
/// Response::new(StatusCode::OK, uid)
/// }
/// ```
pub trait AuthRequestHandler<S>: Fn(Request, Arc<S>, String) -> Response + Send + Sync {}
impl<T, S> AuthRequestHandler<S> for T where T: Fn(Request, Arc<S>, String) -> Response + Send + Sync
{}
/// Represents a Humphrey application with authentication enabled.
/// This is implemented on Humphrey's `App` type provided that the state implements `AuthState`
/// and the database implements `AuthDatabase`.
pub trait AuthApp<S, D>
where
S: AuthState<D>,
D: AuthDatabase,
{
/// Adds an authenticated route and associated handler to the server.
/// Routes can include wildcards, such as `/blog/*`.
fn with_auth_route<T>(self, route: &str, handler: T) -> Self
where
T: AuthRequestHandler<S> + 'static;
}
impl<S, D> AuthApp<S, D> for App<S>
where
S: AuthState<D> + Send + Sync,
D: AuthDatabase,
{
fn with_auth_route<T>(self, route: &str, handler: T) -> Self
where
T: AuthRequestHandler<S> + 'static,
{
self.with_route(route, move |request: Request, state: Arc<S>| {
if let Some(cookie) = request.get_cookie("HumphreyToken") {
let uid = state.auth_provider().get_uid_by_token(cookie.value);
if let Ok(uid) = uid {
return (handler)(request, state, uid);
}
}
forbidden()
})
}
}
fn forbidden() -> Response {
Response::new(StatusCode::Unauthorized, "401 Unauthorized")
}