use super::{
middleware::{User, UserProvider as UserProvider, auth_layer, protect, protected_error_layer},
state::StateProvider,
};
use crate::app::state::StateProvider as AppStateProvider;
use axum::{
Router,
middleware::from_fn_with_state,
routing::MethodRouter,
};
use std::sync::Arc;
use tower_sessions::{
MemoryStore as SessionMemoryStore,
SessionManagerLayer,
cookie::Key as SessionKey,
};
pub trait RouterExt<S>
where
S: Clone + Send + Sync + 'static,
{
#[must_use]
fn add_authentication<SP, U, UP>(self, state: &Arc<SP>) -> Self
where
SP: StateProvider,
U: User,
UP: UserProvider<User = U>,
;
#[must_use]
fn add_protected_error_catcher<SP, U>(self, state: &Arc<SP>) -> Self
where
SP: AppStateProvider,
U: User,
;
#[must_use]
fn protected_routes<SP, U>(self, routes: Vec<(&str, MethodRouter<S>)>, state: &Arc<SP>) -> Self
where
SP: StateProvider,
U: User,
;
}
#[expect(clippy::similar_names, reason = "Not too similar")]
impl<S> RouterExt<S> for Router<S>
where
S: Clone + Send + Sync + 'static,
{
fn add_authentication<SP, U, UP>(self, state: &Arc<SP>) -> Self
where
SP: StateProvider,
U: User,
UP: UserProvider<User = U>,
{
let session_key = SessionKey::generate();
let session_store = SessionMemoryStore::default();
self
.layer(from_fn_with_state(Arc::clone(state), auth_layer::<_, U, UP>))
.layer(SessionManagerLayer::new(session_store).with_secure(false).with_signed(session_key))
}
fn add_protected_error_catcher<SP, U>(self, state: &Arc<SP>) -> Self
where
SP: AppStateProvider,
U: User,
{
self
.layer(from_fn_with_state(Arc::clone(state), protected_error_layer::<_, U>))
}
fn protected_routes<SP, U>(self, routes: Vec<(&str, MethodRouter<S>)>, state: &Arc<SP>) -> Self
where
SP: StateProvider,
U: User,
{
let mut router = self;
for (path, method_router) in routes {
router = router.route(path, method_router);
}
router
.route_layer(from_fn_with_state(Arc::clone(state), protect::<_, U>))
}
}