1use std::{
2 borrow::Cow,
3 future::Future,
4 marker::PhantomData,
5 pin::Pin,
6 sync::Arc,
7 task::{ready, Context, Poll},
8};
9
10use cookie::{Cookie, CookieJar};
11use http::{Request, Response};
12use pin_project_lite::pin_project;
13use tower::{Layer, Service};
14use tower_sesh_core::SessionStore;
15
16use crate::{
17 config::{CookieSecurity, PlainCookie, PrivateCookie, SameSite, SignedCookie},
18 session::{self, Session},
19 util::CookieJarExt,
20};
21
22#[derive(Debug)]
47pub struct SessionLayer<T, Store: SessionStore<T>, C: CookieSecurity = PrivateCookie> {
48 store: Arc<Store>,
49 config: Config,
50 cookie_controller: C,
51 _marker: PhantomData<fn() -> T>,
52}
53
54#[derive(Debug)]
58pub struct SessionManager<S, T, Store: SessionStore<T>, C: CookieSecurity> {
59 inner: S,
60 layer: SessionLayer<T, Store, C>,
61}
62
63#[derive(Clone, Debug)]
64pub(crate) struct Config {
65 pub(crate) cookie_name: Cow<'static, str>,
66 pub(crate) domain: Option<Cow<'static, str>>,
67 pub(crate) http_only: bool,
68 pub(crate) path: Cow<'static, str>,
69 pub(crate) same_site: SameSite,
70 pub(crate) secure: bool,
71 pub(crate) session_config: SessionConfig,
72}
73
74#[derive(Clone, Debug)]
75pub(crate) struct SessionConfig {
76 pub(crate) ignore_invalid_session: bool,
77}
78
79const DEFAULT_COOKIE_NAME: &str = "id";
81
82impl Default for Config {
83 fn default() -> Self {
87 Config {
88 cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME),
89 domain: None,
90 http_only: true,
91 path: Cow::Borrowed("/"),
92 same_site: SameSite::Strict,
93 secure: true,
94 session_config: SessionConfig::default(),
95 }
96 }
97}
98
99impl Default for SessionConfig {
100 fn default() -> Self {
101 SessionConfig {
102 ignore_invalid_session: true,
103 }
104 }
105}
106
107impl<T, Store: SessionStore<T>> SessionLayer<T, Store> {
108 #[track_caller]
112 pub fn new(store: Arc<Store>, key: &[u8]) -> SessionLayer<T, Store> {
113 let key = match cookie::Key::try_from(key) {
114 Ok(key) => key,
115 Err(_) => panic!("key must be 64 bytes in length"),
116 };
117 Self {
118 store,
119 config: Config::default(),
120 cookie_controller: PrivateCookie::new(key),
121 _marker: PhantomData,
122 }
123 }
124}
125
126impl<T, Store: SessionStore<T>, C: CookieSecurity> SessionLayer<T, Store, C> {
128 #[track_caller]
132 pub fn signed(self) -> SessionLayer<T, Store, SignedCookie> {
133 let key = self.cookie_controller.into_key();
134 SessionLayer {
135 store: self.store,
136 config: self.config,
137 cookie_controller: SignedCookie::new(key),
138 _marker: PhantomData,
139 }
140 }
141
142 #[track_caller]
146 pub fn private(self) -> SessionLayer<T, Store, PrivateCookie> {
147 let key = self.cookie_controller.into_key();
148 SessionLayer {
149 store: self.store,
150 config: self.config,
151 cookie_controller: PrivateCookie::new(key),
152 _marker: PhantomData,
153 }
154 }
155
156 pub fn cookie_name(mut self, name: impl Into<Cow<'static, str>>) -> Self {
168 self.config.cookie_name = name.into();
169 self
170 }
171
172 pub fn domain(mut self, domain: impl Into<Cow<'static, str>>) -> Self {
176 self.config.domain = Some(domain.into());
177 self
178 }
179
180 pub fn http_only(mut self, enable: bool) -> Self {
185 self.config.http_only = enable;
186 self
187 }
188
189 pub fn path(mut self, path: impl Into<Cow<'static, str>>) -> Self {
193 self.config.path = path.into();
194 self
195 }
196
197 pub fn same_site(mut self, same_site: SameSite) -> Self {
201 self.config.same_site = same_site;
202 self
203 }
204
205 pub fn secure(mut self, enable: bool) -> Self {
210 self.config.secure = enable;
211 self
212 }
213
214 pub fn ignore_invalid_session(mut self, enable: bool) -> Self {
230 self.config.session_config.ignore_invalid_session = enable;
231 self
232 }
233}
234
235impl<T, Store: SessionStore<T>> SessionLayer<T, Store, PlainCookie> {
236 pub fn plain(store: Arc<Store>) -> SessionLayer<T, Store, PlainCookie> {
238 SessionLayer {
239 store,
240 config: Config::default(),
241 cookie_controller: PlainCookie,
242 _marker: PhantomData,
243 }
244 }
245}
246
247impl<T, Store: SessionStore<T>, C: CookieSecurity> Clone for SessionLayer<T, Store, C> {
248 fn clone(&self) -> Self {
249 Self {
250 store: Arc::clone(&self.store),
251 config: self.config.clone(),
252 cookie_controller: self.cookie_controller.clone(),
253 _marker: PhantomData,
254 }
255 }
256}
257
258impl<S, T, Store: SessionStore<T>, C: CookieSecurity> Layer<S> for SessionLayer<T, Store, C> {
259 type Service = SessionManager<S, T, Store, C>;
260
261 fn layer(&self, inner: S) -> Self::Service {
262 SessionManager {
263 inner,
264 layer: self.clone(),
265 }
266 }
267}
268
269impl<S, T, Store: SessionStore<T>, C: CookieSecurity> Clone for SessionManager<S, T, Store, C>
270where
271 S: Clone,
272{
273 fn clone(&self) -> Self {
274 SessionManager {
275 inner: self.inner.clone(),
276 layer: self.layer.clone(),
277 }
278 }
279}
280
281impl<S, T, Store: SessionStore<T>, C: CookieSecurity> SessionManager<S, T, Store, C> {
282 fn session_cookie<'c>(&self, jar: &'c CookieJar) -> Option<Cookie<'c>> {
283 self.layer
284 .cookie_controller
285 .get(jar, &self.layer.config.cookie_name)
286 }
287}
288
289impl<ReqBody, ResBody, S, T, Store: SessionStore<T>, C: CookieSecurity> Service<Request<ReqBody>>
290 for SessionManager<S, T, Store, C>
291where
292 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
293 T: 'static + Send + Sync,
294{
295 type Response = S::Response;
296 type Error = S::Error;
297 type Future = ResponseFuture<S::Future, T, C>;
298
299 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
300 self.inner.poll_ready(cx)
301 }
302
303 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
304 let jar = CookieJar::from_headers(req.headers());
305 let cookie = self.session_cookie(&jar).map(Cookie::into_owned);
306 session::lazy::insert(
307 cookie,
308 &self.layer.store,
309 req.extensions_mut(),
310 self.layer.config.session_config.clone(),
311 );
312
313 let session: Option<Session<T>> =
317 session::lazy::take(req.extensions_mut()).expect("this panic should be removed");
318
319 todo!()
320 }
321}
322
323pin_project! {
324 pub struct ResponseFuture<F, T, C: CookieSecurity> {
326 state: State<T, C>,
327 #[pin]
328 future: F,
329 }
330}
331
332enum State<T, C> {
333 Session {
334 session: Session<T>,
335 cookie_controller: C,
336 },
337 Fallback,
338}
339
340impl<F, B, E, T, C: CookieSecurity> Future for ResponseFuture<F, T, C>
341where
342 F: Future<Output = Result<Response<B>, E>>,
343{
344 type Output = Result<Response<B>, E>;
345
346 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
347 let this = self.project();
348 let mut res = ready!(this.future.poll(cx)?);
349
350 if let State::Session {
351 session,
352 cookie_controller,
353 } = this.state
354 {
355 todo!("sync changes in session state to store and set the `Set-Cookie` header");
356 }
357
358 Poll::Ready(Ok(res))
359 }
360}