1use std::sync::Arc;
2
3use axum::response::Response;
4use tower::{Layer, Service};
5
6use authx_core::{crypto::sha256_hex, identity::Identity};
7use authx_storage::ports::{SessionRepository, UserRepository};
8
9const SESSION_HEADER: &str = "x-authx-token";
10const SESSION_COOKIE: &str = "authx_session";
11
12#[derive(Clone)]
25pub struct SessionLayer<S> {
26 storage: Arc<S>,
27}
28
29impl<S> SessionLayer<S>
30where
31 S: SessionRepository + UserRepository + Clone + Send + Sync + 'static,
32{
33 pub fn new(storage: S) -> Self {
34 Self {
35 storage: Arc::new(storage),
36 }
37 }
38}
39
40impl<S, Svc> Layer<Svc> for SessionLayer<S>
41where
42 S: SessionRepository + UserRepository + Clone + Send + Sync + 'static,
43{
44 type Service = SessionService<S, Svc>;
45
46 fn layer(&self, inner: Svc) -> Self::Service {
47 SessionService {
48 storage: Arc::clone(&self.storage),
49 inner,
50 }
51 }
52}
53
54#[derive(Clone)]
57pub struct SessionService<S, Svc> {
58 storage: Arc<S>,
59 inner: Svc,
60}
61
62impl<S, Svc, ReqBody> Service<axum::http::Request<ReqBody>> for SessionService<S, Svc>
63where
64 S: SessionRepository + UserRepository + Clone + Send + Sync + 'static,
65 Svc: Service<axum::http::Request<ReqBody>, Response = Response> + Clone + Send + 'static,
66 Svc::Future: Send + 'static,
67 ReqBody: Send + 'static,
68{
69 type Response = Response;
70 type Error = Svc::Error;
71 type Future =
72 std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Svc::Error>> + Send>>;
73
74 fn poll_ready(
75 &mut self,
76 cx: &mut std::task::Context<'_>,
77 ) -> std::task::Poll<Result<(), Self::Error>> {
78 self.inner.poll_ready(cx)
79 }
80
81 fn call(&mut self, mut req: axum::http::Request<ReqBody>) -> Self::Future {
82 let storage = Arc::clone(&self.storage);
83 let mut inner = self.inner.clone();
84
85 Box::pin(async move {
86 let token_hash = extract_token(&req).map(|t| sha256_hex(t.as_bytes()));
87
88 if let Some(hash) = token_hash
89 && let Some(identity) = resolve_identity(&*storage, &hash).await
90 {
91 req.extensions_mut().insert(identity);
92 tracing::debug!("identity resolved");
93 }
94
95 inner.call(req).await
96 })
97 }
98}
99
100async fn resolve_identity<S>(storage: &S, token_hash: &str) -> Option<Identity>
103where
104 S: SessionRepository + UserRepository + Clone + Send + Sync + 'static,
105{
106 let session = storage.find_by_token_hash(token_hash).await.ok()??;
107 if session.expires_at < chrono::Utc::now() {
108 tracing::debug!(session_id = %session.id, "session expired");
109 return None;
110 }
111 let user = storage.find_by_id(session.user_id).await.ok()??;
112 Some(Identity::new(user, session))
113}
114
115fn extract_token<B>(request: &axum::http::Request<B>) -> Option<String> {
116 if let Some(bearer) = request
117 .headers()
118 .get(axum::http::header::AUTHORIZATION)
119 .and_then(|v| v.to_str().ok())
120 .and_then(|v| v.strip_prefix("Bearer "))
121 {
122 return Some(bearer.to_owned());
123 }
124
125 if let Some(token) = request
126 .headers()
127 .get(SESSION_HEADER)
128 .and_then(|v| v.to_str().ok())
129 {
130 return Some(token.to_owned());
131 }
132
133 let cookie_header = request
134 .headers()
135 .get(axum::http::header::COOKIE)
136 .and_then(|v| v.to_str().ok())?;
137
138 for part in cookie_header.split(';') {
139 let part = part.trim();
140 if let Some(value) = part.strip_prefix(&format!("{SESSION_COOKIE}=")) {
141 return Some(value.to_owned());
142 }
143 }
144
145 None
146}