1use author_web::session::store::in_memory::InMemorySessionData;
2use author_web::session::store::SessionStore;
3use author_web::session::{SessionConfig, SessionError, SessionKey};
4use axum::extract::FromRequestParts;
5use axum::http::request::Parts;
6use axum::http::{Request, StatusCode};
7use axum::response::{IntoResponse, Response};
8use axum::{async_trait, RequestPartsExt};
9use axum_extra::extract::cookie::{Cookie, Key};
10use axum_extra::extract::PrivateCookieJar;
11use futures::future::BoxFuture;
12use std::convert::Infallible;
13use std::fmt::Display;
14use std::str::FromStr;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use thiserror::Error;
18use tower_layer::Layer;
19use tower_service::Service;
20use tower_util::ServiceExt;
21use tracing::{debug, error, trace};
22
23#[derive(Clone)]
24pub struct Session<T: Clone = Arc<InMemorySessionData>>(pub T);
25
26#[async_trait]
27impl<S, T> FromRequestParts<S> for Session<T>
28where
29 S: Send + Sync,
30 T: Clone + Send + Sync + 'static,
31{
32 type Rejection = (StatusCode, &'static str);
33
34 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
35 parts
36 .extensions
37 .get::<Session<T>>()
38 .cloned()
39 .ok_or((StatusCode::FORBIDDEN, "Forbidden"))
40 }
41}
42
43pub struct SessionManagerService<Inner, Store>
44where
45 Store: SessionStore,
46{
47 inner: Inner,
48 config: SessionConfig,
49 store: Arc<Store>,
50}
51
52impl<Inner, Store> SessionManagerService<Inner, Store>
53where
54 Store: SessionStore,
55{
56 pub fn new(inner: Inner, config: SessionConfig, store: Arc<Store>) -> Self {
57 SessionManagerService {
58 inner,
59 config: config.into(),
60 store,
61 }
62 }
63}
64
65impl<Inner, Store> Clone for SessionManagerService<Inner, Store>
69where
70 Inner: Clone,
71 Store: SessionStore,
72{
73 fn clone(&self) -> Self {
74 Self {
75 inner: self.inner.clone(),
76 config: self.config.clone(),
77 store: self.store.clone(),
78 }
79 }
80}
81
82impl<Inner, S, K, B, ResBody, Store> Service<Request<B>> for SessionManagerService<Inner, Store>
83where
84 Inner: Service<Request<B>, Response = Response<ResBody>, Error = Infallible>
85 + Clone
86 + Send
87 + 'static,
88 Inner::Response: IntoResponse,
89 Inner::Future: Send,
90 B: Send + 'static,
91 K: SessionKey + Display + Send + Sync + 'static,
92 <K as FromStr>::Err: Send,
93 S: Clone + Send + Sync + 'static,
94 Store: SessionStore<Session = S, Key = K> + Send + Sync + 'static,
95{
96 type Response = (
97 Option<PrivateCookieJar>,
98 Result<Inner::Response, StatusCode>,
99 );
100 type Error = Infallible;
101 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
102
103 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104 self.inner.poll_ready(cx)
105 }
106
107 fn call(&mut self, req: Request<B>) -> Self::Future {
108 let config = self.config.clone();
109 let store = self.store.clone();
110
111 let clone = self.inner.clone();
112 let inner = std::mem::replace(&mut self.inner, clone);
113
114 Box::pin(async move {
115 let (mut parts, body) = req.into_parts();
116
117 let mut cookie_jar = match parts
118 .extract_with_state::<PrivateCookieJar, Key>(&config.key)
119 .await
120 {
121 Err(e) => {
122 error!("Failed to extract session cookie: {}", e);
123 return Ok((None, Err(StatusCode::INTERNAL_SERVER_ERROR)));
124 }
125 Ok(j) => j,
126 };
127
128 let cookie = cookie_jar.get(&config.cookie_name);
129
130 let existing_session = match cookie {
132 Some(c) => {
133 let session = match K::from_str(c.value()) {
134 Err(_) => {
135 error!("Error parsing key in session cookie: {}", c.value());
136 None
137 }
138 Ok(session_key) => {
139 debug!(
140 "Existing session cookie found containing key {}",
141 session_key
142 );
143
144 match store.load_session(&session_key).await {
147 Err(e) => {
148 error!("Failed to load session: {}", e);
149 None
150 }
151 Ok(u) => match u {
152 None => {
153 error!("Session with key {} not found", session_key);
154 None
155 }
156 Some(s) => Some(s),
157 },
158 }
159 }
160 };
161
162 session
163 }
164 None => {
165 debug!("No existing session cookie found");
166 None
167 }
168 };
169
170 let session = match existing_session {
172 Some(s) => s,
173 None => {
174 debug!("No existing session found, creating new session");
175
176 let (session_key, session) = match store.create_session().await {
177 Err(e) => {
178 error!("Failed to create session: {}", e);
179 return Ok((None, Err(StatusCode::INTERNAL_SERVER_ERROR)));
180 }
181 Ok(s) => s,
182 };
183
184 trace!("Session created with key {}", session_key);
185
186 let cookie =
187 Cookie::build((config.cookie_name.to_string(), session_key.to_string()))
188 .same_site(config.same_site)
189 .secure(true)
190 .http_only(true)
191 .path("/")
192 .build();
193
194 cookie_jar = cookie_jar.add(cookie);
195
196 session
197 }
198 };
199
200 trace!("Adding session to extensions");
201
202 parts.extensions.insert(Session(session));
203
204 trace!("Processing inner service");
205
206 let response = inner.oneshot(Request::from_parts(parts, body)).await?;
207
208 Ok((Some(cookie_jar), Ok(response)))
209 })
210 }
211}
212
213pub struct SessionManagerLayer<Store>
214where
215 Store: SessionStore,
216{
217 config: SessionConfig,
218 store: Arc<Store>,
219}
220
221impl<Store> Clone for SessionManagerLayer<Store>
225where
226 Store: SessionStore,
227{
228 fn clone(&self) -> Self {
229 Self {
230 config: self.config.clone(),
231 store: self.store.clone(),
232 }
233 }
234}
235
236impl<Store> SessionManagerLayer<Store>
237where
238 Store: SessionStore,
239{
240 pub fn new(config: SessionConfig, store: Store) -> Self {
241 SessionManagerLayer {
242 config,
243 store: Arc::new(store),
244 }
245 }
246}
247
248impl<Inner, Store> Layer<Inner> for SessionManagerLayer<Store>
249where
250 Store: SessionStore,
251{
252 type Service = SessionManagerService<Inner, Store>;
253
254 fn layer(&self, inner: Inner) -> Self::Service {
255 SessionManagerService::new(inner, self.config.clone(), self.store.clone())
256 }
257}
258
259#[derive(Debug, Error)]
260pub enum AxumSessionError<E>
261where
262 E: IntoResponse,
263{
264 #[error("Error from inner service: {0}")]
265 InnerServiceError(E),
266 #[error("Unexpected session error: {0}")]
267 SessionError(#[from] SessionError),
268 #[error("Session store not found")]
269 SessionStoreNotFound,
270 #[error("Session config not found")]
271 SessionConfigNotFound,
272 #[error("UUID error: {0}")]
273 UuidError(#[from] uuid::Error),
274}
275
276impl<E> IntoResponse for AxumSessionError<E>
277where
278 E: IntoResponse,
279{
280 fn into_response(self) -> Response {
281 match self {
282 AxumSessionError::InnerServiceError(inner) => inner.into_response(),
283 AxumSessionError::SessionError(SessionError::SessionNotFound) => {
284 (StatusCode::FORBIDDEN, "Forbidden").into_response()
285 }
286 _ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error").into_response(),
287 }
288 }
289}