author_axum/
session.rs

1use author_web::session::store::in_memory::InMemorySessionData;
2use author_web::session::store::SessionStore;
3use author_web::session::{SessionConfig, SessionError, SessionKey};
4use axum::extract::FromRequestParts;
5use axum::http::request::Parts;
6use axum::http::{Request, StatusCode};
7use axum::response::{IntoResponse, Response};
8use axum::{async_trait, RequestPartsExt};
9use axum_extra::extract::cookie::{Cookie, Key};
10use axum_extra::extract::PrivateCookieJar;
11use futures::future::BoxFuture;
12use std::convert::Infallible;
13use std::fmt::Display;
14use std::str::FromStr;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use thiserror::Error;
18use tower_layer::Layer;
19use tower_service::Service;
20use tower_util::ServiceExt;
21use tracing::{debug, error, trace};
22
23#[derive(Clone)]
24pub struct Session<T: Clone = Arc<InMemorySessionData>>(pub T);
25
26#[async_trait]
27impl<S, T> FromRequestParts<S> for Session<T>
28where
29    S: Send + Sync,
30    T: Clone + Send + Sync + 'static,
31{
32    type Rejection = (StatusCode, &'static str);
33
34    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
35        parts
36            .extensions
37            .get::<Session<T>>()
38            .cloned()
39            .ok_or((StatusCode::FORBIDDEN, "Forbidden"))
40    }
41}
42
43pub struct SessionManagerService<Inner, Store>
44where
45    Store: SessionStore,
46{
47    inner: Inner,
48    config: SessionConfig,
49    store: Arc<Store>,
50}
51
52impl<Inner, Store> SessionManagerService<Inner, Store>
53where
54    Store: SessionStore,
55{
56    pub fn new(inner: Inner, config: SessionConfig, store: Arc<Store>) -> Self {
57        SessionManagerService {
58            inner,
59            config: config.into(),
60            store,
61        }
62    }
63}
64
65// #[derive(Clone)] requires Store to be Clone, which shouldn't really be necessary because it's
66// in an Arc. The only way to get around this is to manually implement Clone.
67// See https://github.com/rust-lang/rust/issues/26925
68impl<Inner, Store> Clone for SessionManagerService<Inner, Store>
69where
70    Inner: Clone,
71    Store: SessionStore,
72{
73    fn clone(&self) -> Self {
74        Self {
75            inner: self.inner.clone(),
76            config: self.config.clone(),
77            store: self.store.clone(),
78        }
79    }
80}
81
82impl<Inner, S, K, B, ResBody, Store> Service<Request<B>> for SessionManagerService<Inner, Store>
83where
84    Inner: Service<Request<B>, Response = Response<ResBody>, Error = Infallible>
85        + Clone
86        + Send
87        + 'static,
88    Inner::Response: IntoResponse,
89    Inner::Future: Send,
90    B: Send + 'static,
91    K: SessionKey + Display + Send + Sync + 'static,
92    <K as FromStr>::Err: Send,
93    S: Clone + Send + Sync + 'static,
94    Store: SessionStore<Session = S, Key = K> + Send + Sync + 'static,
95{
96    type Response = (
97        Option<PrivateCookieJar>,
98        Result<Inner::Response, StatusCode>,
99    );
100    type Error = Infallible;
101    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
102
103    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
104        self.inner.poll_ready(cx)
105    }
106
107    fn call(&mut self, req: Request<B>) -> Self::Future {
108        let config = self.config.clone();
109        let store = self.store.clone();
110
111        let clone = self.inner.clone();
112        let inner = std::mem::replace(&mut self.inner, clone);
113
114        Box::pin(async move {
115            let (mut parts, body) = req.into_parts();
116
117            let mut cookie_jar = match parts
118                .extract_with_state::<PrivateCookieJar, Key>(&config.key)
119                .await
120            {
121                Err(e) => {
122                    error!("Failed to extract session cookie: {}", e);
123                    return Ok((None, Err(StatusCode::INTERNAL_SERVER_ERROR)));
124                }
125                Ok(j) => j,
126            };
127
128            let cookie = cookie_jar.get(&config.cookie_name);
129
130            // Check whether we have any existing session
131            let existing_session = match cookie {
132                Some(c) => {
133                    let session = match K::from_str(c.value()) {
134                        Err(_) => {
135                            error!("Error parsing key in session cookie: {}", c.value());
136                            None
137                        }
138                        Ok(session_key) => {
139                            debug!(
140                                "Existing session cookie found containing key {}",
141                                session_key
142                            );
143
144                            // TODO: Refresh the session cookie with a new key
145
146                            match store.load_session(&session_key).await {
147                                Err(e) => {
148                                    error!("Failed to load session: {}", e);
149                                    None
150                                }
151                                Ok(u) => match u {
152                                    None => {
153                                        error!("Session with key {} not found", session_key);
154                                        None
155                                    }
156                                    Some(s) => Some(s),
157                                },
158                            }
159                        }
160                    };
161
162                    session
163                }
164                None => {
165                    debug!("No existing session cookie found");
166                    None
167                }
168            };
169
170            // If there's no usable existing session for any reason, create a new one
171            let session = match existing_session {
172                Some(s) => s,
173                None => {
174                    debug!("No existing session found, creating new session");
175
176                    let (session_key, session) = match store.create_session().await {
177                        Err(e) => {
178                            error!("Failed to create session: {}", e);
179                            return Ok((None, Err(StatusCode::INTERNAL_SERVER_ERROR)));
180                        }
181                        Ok(s) => s,
182                    };
183
184                    trace!("Session created with key {}", session_key);
185
186                    let cookie =
187                        Cookie::build((config.cookie_name.to_string(), session_key.to_string()))
188                            .same_site(config.same_site)
189                            .secure(true)
190                            .http_only(true)
191                            .path("/")
192                            .build();
193
194                    cookie_jar = cookie_jar.add(cookie);
195
196                    session
197                }
198            };
199
200            trace!("Adding session to extensions");
201
202            parts.extensions.insert(Session(session));
203
204            trace!("Processing inner service");
205
206            let response = inner.oneshot(Request::from_parts(parts, body)).await?;
207
208            Ok((Some(cookie_jar), Ok(response)))
209        })
210    }
211}
212
213pub struct SessionManagerLayer<Store>
214where
215    Store: SessionStore,
216{
217    config: SessionConfig,
218    store: Arc<Store>,
219}
220
221// #[derive(Clone)] requires Store to be Clone, which shouldn't really be necessary because it's
222// in an Arc. The only way to get around this is to manually implement Clone.
223// See https://github.com/rust-lang/rust/issues/26925
224impl<Store> Clone for SessionManagerLayer<Store>
225where
226    Store: SessionStore,
227{
228    fn clone(&self) -> Self {
229        Self {
230            config: self.config.clone(),
231            store: self.store.clone(),
232        }
233    }
234}
235
236impl<Store> SessionManagerLayer<Store>
237where
238    Store: SessionStore,
239{
240    pub fn new(config: SessionConfig, store: Store) -> Self {
241        SessionManagerLayer {
242            config,
243            store: Arc::new(store),
244        }
245    }
246}
247
248impl<Inner, Store> Layer<Inner> for SessionManagerLayer<Store>
249where
250    Store: SessionStore,
251{
252    type Service = SessionManagerService<Inner, Store>;
253
254    fn layer(&self, inner: Inner) -> Self::Service {
255        SessionManagerService::new(inner, self.config.clone(), self.store.clone())
256    }
257}
258
259#[derive(Debug, Error)]
260pub enum AxumSessionError<E>
261where
262    E: IntoResponse,
263{
264    #[error("Error from inner service: {0}")]
265    InnerServiceError(E),
266    #[error("Unexpected session error: {0}")]
267    SessionError(#[from] SessionError),
268    #[error("Session store not found")]
269    SessionStoreNotFound,
270    #[error("Session config not found")]
271    SessionConfigNotFound,
272    #[error("UUID error: {0}")]
273    UuidError(#[from] uuid::Error),
274}
275
276impl<E> IntoResponse for AxumSessionError<E>
277where
278    E: IntoResponse,
279{
280    fn into_response(self) -> Response {
281        match self {
282            AxumSessionError::InnerServiceError(inner) => inner.into_response(),
283            AxumSessionError::SessionError(SessionError::SessionNotFound) => {
284                (StatusCode::FORBIDDEN, "Forbidden").into_response()
285            }
286            _ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error").into_response(),
287        }
288    }
289}