use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use axum::body::Body;
use axum::http::{StatusCode, Uri};
use axum::response::{IntoResponse, Response};
use axum_core::extract::FromRequestParts;
use chrono::{DateTime, Utc};
use http::request::Parts;
use serde_json::json;
use sha2::{Digest, Sha256};
use tower::{Layer, Service};
use crate::UserModel;
#[derive(Debug, Clone)]
pub struct LoginRequired {
pub login_url: Option<String>,
pub next_param: Option<String>,
}
impl LoginRequired {
pub const API: Self = Self {
login_url: None,
next_param: None,
};
pub fn html(login_url: impl Into<String>) -> Self {
Self {
login_url: Some(login_url.into()),
next_param: Some("next".to_string()),
}
}
pub fn no_next(mut self) -> Self {
self.next_param = None;
self
}
pub(crate) fn rejection_response(&self, uri: &Uri) -> Response {
match &self.login_url {
None => {
let body = json!({"error": "authentication required"}).to_string();
axum::http::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("content-type", "application/json")
.header("www-authenticate", "Bearer")
.body(Body::from(body))
.expect("building 401 response cannot fail")
.into_response()
}
Some(url) => {
let location = match &self.next_param {
Some(param) => {
let original = uri.to_string();
format!("{url}?{param}={}", urlencoded(original.as_str()))
}
None => url.clone(),
};
axum::http::Response::builder()
.status(StatusCode::FOUND)
.header("location", location)
.body(Body::empty())
.expect("building 302 response cannot fail")
.into_response()
}
}
}
}
fn urlencoded(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'?' => out.push_str("%3F"),
'&' => out.push_str("%26"),
'=' => out.push_str("%3D"),
'+' => out.push_str("%2B"),
'%' => out.push_str("%25"),
' ' => out.push_str("%20"),
c => out.push(c),
}
}
out
}
pub struct LoggedIn<U: UserModel>(pub U);
impl<U: UserModel> std::ops::Deref for LoggedIn<U> {
type Target = U;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<U: UserModel> std::ops::DerefMut for LoggedIn<U> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<U: UserModel + serde::Serialize> serde::Serialize for LoggedIn<U> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.serialize(serializer)
}
}
impl<U, S> FromRequestParts<S> for LoggedIn<U>
where
U: UserModel
+ for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
+ for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
+ umbral::orm::HydrateRelated
+ Unpin
+ Send,
<U as umbral::orm::Model>::PrimaryKey: std::str::FromStr,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let config = parts
.extensions
.get::<LoginRequired>()
.cloned()
.unwrap_or(LoginRequired::API);
let uri = parts.uri.clone();
match resolve_user::<U>(&parts.headers).await {
Some(user) => Ok(LoggedIn(user)),
None => Err(config.rejection_response(&uri)),
}
}
}
fn hash_token(raw: &str) -> String {
let mut h = Sha256::new();
h.update(raw.as_bytes());
format!("{:x}", h.finalize())
}
fn cookie_from_headers(headers: &http::HeaderMap) -> Option<String> {
let header = headers.get(http::header::COOKIE)?.to_str().ok()?;
for pair in header.split(';') {
let pair = pair.trim();
if let Some(value) = pair.strip_prefix("umbral_session=") {
return Some(value.to_string());
}
}
None
}
pub async fn resolve_user<U>(headers: &http::HeaderMap) -> Option<U>
where
U: UserModel
+ for<'r> sqlx::FromRow<'r, sqlx::sqlite::SqliteRow>
+ for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
+ umbral::orm::HydrateRelated
+ Unpin
+ Send,
<U as umbral::orm::Model>::PrimaryKey: std::str::FromStr,
{
let user_id = current_session_user_pk::<U>(headers).await?;
umbral::orm::Manager::<U>::default()
.filter(
umbral::orm::Predicate::<U>::col_eq("id", user_id)
& umbral::orm::Predicate::<U>::col_eq("is_active", true),
)
.first()
.await
.ok()
.flatten()
}
pub async fn current_session_user_pk<U>(
headers: &http::HeaderMap,
) -> Option<<U as umbral::orm::Model>::PrimaryKey>
where
U: UserModel,
<U as umbral::orm::Model>::PrimaryKey: std::str::FromStr,
{
let raw_token = cookie_from_headers(headers)?;
let stored_id = hash_token(&raw_token);
let row: Option<SessionRow> = umbral::orm::Manager::<SessionRow>::default()
.filter(umbral::orm::Predicate::<SessionRow>::col_eq("id", stored_id))
.first()
.await
.ok()
.flatten();
let row = row?;
if row.expires_at < Utc::now() {
return None;
}
row.user_id?.parse().ok()
}
pub(crate) async fn is_authenticated(headers: &http::HeaderMap) -> bool {
current_session_user_id(headers).await.is_some()
}
pub async fn current_session_user_id(headers: &http::HeaderMap) -> Option<i64> {
current_session_user_pk::<crate::AuthUser>(headers).await
}
#[doc(hidden)]
#[derive(Debug, Clone, sqlx::FromRow, serde::Serialize, serde::Deserialize, umbral::orm::Model)]
#[umbral(table = "session")]
pub struct SessionRow {
pub id: String,
pub user_id: Option<String>,
pub data: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
#[derive(Clone)]
pub struct LoginRequiredLayer {
config: LoginRequired,
}
impl LoginRequiredLayer {
pub fn new(config: LoginRequired) -> Self {
Self { config }
}
pub fn apply(self, router: axum::Router) -> axum::Router {
router.layer(self)
}
}
impl<S> Layer<S> for LoginRequiredLayer {
type Service = LoginRequiredService<S>;
fn layer(&self, inner: S) -> Self::Service {
LoginRequiredService {
inner,
config: self.config.clone(),
}
}
}
#[derive(Clone)]
pub struct LoginRequiredService<S> {
inner: S,
config: LoginRequired,
}
impl<S> Service<axum::extract::Request> for LoginRequiredService<S>
where
S: Service<axum::extract::Request, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: axum::extract::Request) -> Self::Future {
let config = self.config.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let uri = req.uri().clone();
if !is_authenticated(req.headers()).await {
return Ok(config.rejection_response(&uri));
}
req.extensions_mut().insert(config);
inner.call(req).await
})
}
}
pub fn login_required() -> LoginRequiredLayer {
LoginRequiredLayer::new(LoginRequired::API)
}
pub fn login_required_html(login_url: impl Into<String>) -> LoginRequiredLayer {
LoginRequiredLayer::new(LoginRequired::html(login_url))
}