use crate::{AuthUser, auth_user};
use async_trait::async_trait;
use axum_core::extract::FromRequestParts;
use http::StatusCode;
use http::request::Parts;
use umbral::auth::{Authentication, Identity};
use umbral::web::HeaderMap;
use umbral_sessions::SessionError;
pub async fn current_user(headers: &HeaderMap) -> Result<Option<AuthUser>, SessionError> {
let Some(user_id_str) = umbral_sessions::current_user_id_str(headers).await? else {
return Ok(None);
};
let Ok(user_id) = user_id_str.parse::<i64>() else {
return Ok(None);
};
let user: Option<AuthUser> = AuthUser::objects()
.filter(auth_user::ID.eq(user_id) & auth_user::IS_ACTIVE.eq(true))
.first()
.await?;
Ok(user)
}
pub async fn login(
response_headers: &mut HeaderMap,
user: &AuthUser,
) -> Result<String, SessionError> {
login_with_request(&HeaderMap::new(), response_headers, user).await
}
pub async fn login_with_request(
request_headers: &HeaderMap,
response_headers: &mut HeaderMap,
user: &AuthUser,
) -> Result<String, SessionError> {
let token = umbral_sessions::login_user_id(
request_headers,
response_headers,
Some(user.id.to_string()),
)
.await?;
let mut patch = serde_json::Map::new();
patch.insert(
"last_login".to_string(),
serde_json::to_value(chrono::Utc::now()).unwrap_or(serde_json::Value::Null),
);
if let Err(e) = AuthUser::objects()
.filter(auth_user::ID.eq(user.id))
.update_values(patch)
.await
{
tracing::warn!(
error = ?e,
user_id = user.id,
"umbral-auth::login: failed to update last_login (session still active)",
);
}
Ok(token)
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SessionAuthentication;
impl SessionAuthentication {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl Authentication for SessionAuthentication {
async fn authenticate(&self, headers: &HeaderMap) -> Option<Identity> {
let user = current_user(headers).await.ok().flatten()?;
Some(
Identity::user(crate::UserModel::id_string(&user))
.with_staff(user.is_staff)
.with_superuser(user.is_superuser)
.with_extra("auth", serde_json::json!("session")),
)
}
fn security_scheme(&self) -> Option<(String, serde_json::Value)> {
Some((
"SessionAuth".to_string(),
serde_json::json!({
"type": "apiKey",
"in": "cookie",
"name": "umbral_session",
"description": "umbral session cookie. Set by `POST /api/auth/login`; cleared by `/logout`."
}),
))
}
}
#[derive(Debug, Clone)]
pub struct User(pub AuthUser);
#[derive(Debug, Clone)]
pub struct OptionalUser(pub Option<AuthUser>);
impl<S> FromRequestParts<S> for User
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
match current_user(&parts.headers).await.ok().flatten() {
Some(u) => Ok(User(u)),
None => Err((StatusCode::UNAUTHORIZED, "authentication required")),
}
}
}
impl<S> FromRequestParts<S> for OptionalUser
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Ok(OptionalUser(
current_user(&parts.headers).await.ok().flatten(),
))
}
}
pub async fn user_context_layer(
req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
let headers = req.headers().clone();
let lazy = umbral::templates::LazyUser::new(move || {
let headers = headers.clone();
async move {
match current_user(&headers).await {
Ok(Some(u)) => serialize_authenticated_with_relations(&u).await,
_ => anonymous_user_value(),
}
}
});
umbral::templates::with_current_user_lazy(lazy, next.run(req)).await
}
const USER_RELATION_DEPTH: usize = 2;
async fn serialize_authenticated_with_relations(user: &AuthUser) -> umbral::templates::Value {
let mut json = match serde_json::to_value(user) {
Ok(serde_json::Value::Object(map)) => map,
_ => serde_json::Map::new(),
};
json.insert(
"is_authenticated".to_string(),
serde_json::Value::Bool(true),
);
let registered = umbral::migrate::registered_models();
if let Some(meta) = registered.iter().find(|m| m.table == "auth_user") {
let mut visited: std::collections::HashSet<(String, String)> =
std::collections::HashSet::new();
let seed_pk = serde_json::Value::Number(user.id.into());
visited.insert(("auth_user".to_string(), pk_json_key(&seed_pk)));
expand_relations(
meta,
®istered,
&mut json,
USER_RELATION_DEPTH,
&mut visited,
)
.await;
}
umbral::templates::Value::from_serialize(serde_json::Value::Object(json))
}
fn expand_relations<'a>(
meta: &'a umbral::migrate::ModelMeta,
registered: &'a [umbral::migrate::ModelMeta],
row: &'a mut serde_json::Map<String, serde_json::Value>,
depth: usize,
visited: &'a mut std::collections::HashSet<(String, String)>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'a>> {
Box::pin(async move {
if depth == 0 {
return;
}
let forward_fks: Vec<(String, String, serde_json::Value)> = meta
.fields
.iter()
.filter_map(|col| {
let target_table = col.fk_target.as_deref()?;
let fk_val = row.get(&col.name)?.clone();
if fk_val.is_null() {
return None;
}
Some((col.name.clone(), target_table.to_string(), fk_val))
})
.collect();
for (col_name, target_table, fk_val) in forward_fks {
let visit_key = (target_table.clone(), pk_json_key(&fk_val));
if visited.contains(&visit_key) {
continue;
}
let Some(target_meta) = registered.iter().find(|m| m.table == target_table) else {
continue;
};
let Some(target_pk) = target_meta.pk_column() else {
continue;
};
let fetched = umbral::orm::DynQuerySet::for_meta(target_meta)
.filter_eq_string(&target_pk.name, &json_value_to_pk_string(&fk_val))
.first_as_json()
.await;
let Ok(Some(mut target_row)) = fetched else {
continue;
};
visited.insert(visit_key);
expand_relations(target_meta, registered, &mut target_row, depth - 1, visited).await;
row.insert(col_name, serde_json::Value::Object(target_row));
}
let Some(parent_pk_col) = meta.pk_column() else {
return;
};
let parent_pk = match row.get(&parent_pk_col.name).cloned() {
Some(v) if !v.is_null() => v,
_ => return,
};
let candidates: Vec<(&umbral::migrate::ModelMeta, String)> = registered
.iter()
.filter_map(|child| {
let mut matches = child
.fields
.iter()
.filter(|c| c.fk_target.as_deref() == Some(&meta.table) && c.unique);
let first = matches.next()?;
if matches.next().is_some() {
return None;
}
Some((child, first.name.clone()))
})
.collect();
for (child_meta, fk_col_name) in candidates {
if row.contains_key(&child_meta.table) {
continue;
}
let fetched = umbral::orm::DynQuerySet::for_meta(child_meta)
.filter_eq_string(&fk_col_name, &json_value_to_pk_string(&parent_pk))
.first_as_json()
.await;
let Ok(Some(mut child_row)) = fetched else {
continue;
};
let Some(child_pk_col) = child_meta.pk_column() else {
continue;
};
let Some(child_pk) = child_row.get(&child_pk_col.name).cloned() else {
continue;
};
if child_pk.is_null() {
continue;
}
let visit_key = (child_meta.table.clone(), pk_json_key(&child_pk));
if visited.contains(&visit_key) {
continue;
}
visited.insert(visit_key);
expand_relations(child_meta, registered, &mut child_row, depth - 1, visited).await;
row.insert(
child_meta.table.clone(),
serde_json::Value::Object(child_row),
);
}
})
}
fn pk_json_key(v: &serde_json::Value) -> String {
match v {
serde_json::Value::Number(n) => format!("n:{n}"),
serde_json::Value::String(s) => format!("s:{s}"),
other => format!("o:{other}"),
}
}
fn json_value_to_pk_string(v: &serde_json::Value) -> String {
match v {
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
}
}
fn anonymous_user_value() -> umbral::templates::Value {
let mut json = serde_json::Map::new();
json.insert(
"is_authenticated".to_string(),
serde_json::Value::Bool(false),
);
umbral::templates::Value::from_serialize(serde_json::Value::Object(json))
}