use crate::{headers::*, DatabasePool, Session, SessionData, SessionError, SessionStore};
use axum::{response::Response, BoxError};
use bytes::Bytes;
use chrono::{Duration, Utc};
#[cfg(feature = "key-store")]
use fastbloom_rs::Deletable;
use futures::future::BoxFuture;
use http::Request;
use http_body::Body as HttpBody;
use std::{
convert::Infallible,
fmt::{self, Debug, Formatter},
sync::Arc,
task::{Context, Poll},
};
use tokio::task::JoinHandle;
use tower_service::Service;
#[derive(Clone)]
pub struct SessionService<S, T>
where
T: DatabasePool + Clone + Debug + Sync + Send + 'static,
{
pub(crate) session_store: SessionStore<T>,
pub(crate) handle: Arc<JoinHandle<Result<(), SessionError>>>,
pub(crate) inner: S,
}
pub(crate) fn trace_error<ResBody>(
err: SessionError,
msg: &str,
) -> Result<Response<ResBody>, Infallible>
where
ResBody: HttpBody<Data = Bytes> + Default + Send + 'static,
ResBody::Error: Into<BoxError>,
{
tracing::error!(err = %err, msg);
let mut res = Response::default();
*res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
Ok(res)
}
impl<S, T, ReqBody, ResBody> Service<Request<ReqBody>> for SessionService<S, T>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
Infallible: From<<S as Service<Request<ReqBody>>>::Error>,
ResBody: HttpBody<Data = Bytes> + Default + Send + 'static,
ResBody::Error: Into<BoxError>,
T: DatabasePool + Clone + Debug + Sync + Send + 'static,
{
type Response = Response<ResBody>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let store = self.session_store.clone();
let not_ready_inner = self.inner.clone();
let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
Box::pin(async move {
let ip_user_agent = get_ips_hash(&req, &store);
#[cfg(not(feature = "rest_mode"))]
let cookies = get_cookies(req.headers());
#[cfg(not(feature = "rest_mode"))]
let (session_id, storable) = get_headers_and_key(&store, cookies, &ip_user_agent).await;
#[cfg(feature = "rest_mode")]
let headers = get_headers(&store, req.headers());
#[cfg(feature = "rest_mode")]
let (session_id, storable) = get_headers_and_key(&store, headers, &ip_user_agent).await;
let (mut session, is_new) = match Session::new(store, session_id).await {
Ok(v) => v,
Err(err) => {
return trace_error(err, "failed to generate Session ID");
}
};
if is_new && !session.store.config.session_mode.is_manual() {
let sess = SessionData::new(session.id.clone(), storable, &session.store.config);
session.store.inner.insert(session.id.clone(), sess);
} else if (!is_new || !session.store.config.session_mode.is_manual())
&& !session.store.service_session_data(&session)
{
let mut fresh_session = session
.store
.load_session(session.id.clone())
.await
.ok()
.flatten()
.unwrap_or_else(|| {
tracing::debug!(
"Session {} did not exist in Database. So it was Recreated.",
session.id.clone()
);
SessionData::new(session.id.clone(), storable, &session.store.config)
});
fresh_session.autoremove = Utc::now() + session.store.config.memory.memory_lifespan;
fresh_session.store = storable;
fresh_session.update = true;
fresh_session.requests = 1;
session
.store
.inner
.insert(session.id.clone(), fresh_session);
};
req.extensions_mut().insert(session.clone());
let mut response = ready_inner.call(req).await?;
let (last_sweep, last_database_sweep) = {
let timers = session.store.timers.read().await;
(timers.last_expiry_sweep, timers.last_database_expiry_sweep)
};
let (renew, storable, destroy, loaded) =
if let Some(session_data) = session.store.inner.get(&session.id) {
(
session_data.renew,
session_data.store,
session_data.destroy,
true,
)
} else {
(false, false, false, false)
};
tracing::trace!(
renew = renew,
storable = storable,
destroy = destroy,
loaded = loaded,
"Session id: {}",
session.id
);
if !destroy && (!session.store.config.session_mode.is_manual() || loaded) && renew {
let session_id = match Session::generate_id(&session.store).await {
Ok(v) => v,
Err(err) => {
return trace_error(err, "failed to Generate Session ID");
}
};
if session.store.is_persistent() {
if let Err(err) = session
.store
.database_remove_session(session.id.clone())
.await
{
return trace_error(err, "failed to remove session from database");
};
}
#[cfg(feature = "key-store")]
if session.store.config.memory.use_bloom_filters {
let mut filter = session.store.filter.write().await;
filter.remove(session.id.as_bytes());
}
if let Some((_, mut session_data)) = session.store.inner.remove(&session.id) {
session_data.id = session_id.clone();
session_data.renew = false;
session.id = session_id.clone();
session.store.inner.insert(session.id.clone(), session_data);
}
}
if (!session.store.config.session_mode.is_opt_in() || storable)
&& session.store.is_persistent()
&& !destroy
{
let clone_session = if let Some(mut sess) =
session.store.inner.get_mut(&session.id.clone())
{
let now = Utc::now();
let time_since_last_db_update =
(now - sess.last_db_update).max(Duration::zero());
let db_update_threshold = session.store.config.database.db_update_interval;
let time_remaining_next_database_sweep =
(last_database_sweep - now).max(Duration::zero());
let time_remaining_next_memory_sweep = (last_sweep - now).max(Duration::zero());
let time_until_next_update =
(db_update_threshold - time_since_last_db_update).max(Duration::zero());
let should_update =
session.store.config.database.always_save || sess.update || !sess.expired();
if should_update {
if sess.longterm {
sess.expires = Utc::now() + session.store.config.max_lifespan;
} else {
sess.expires = Utc::now() + session.store.config.lifespan;
};
let should_update_db = session.store.config.database.always_save
|| sess.update
|| (!sess.expired()
&& time_since_last_db_update >= db_update_threshold)
|| (!sess.expired()
&& time_remaining_next_memory_sweep <= time_until_next_update)
|| (!sess.expired()
&& time_remaining_next_database_sweep <= time_until_next_update);
if should_update_db {
sess.last_db_update = Utc::now();
sess.update = false;
Some(sess.clone())
} else {
None
}
} else {
None
}
} else {
None
};
if let Some(sess) = clone_session {
if let Err(err) = session.store.store_session(&sess).await {
return trace_error(err, "failed to save session to database");
} else {
tracing::debug!("Session id {}: was saved to the database.", session.id);
}
}
}
session.remove_request();
if ((session.store.config.session_mode.is_opt_in() && !storable) || destroy)
&& !session.is_parallel()
{
#[cfg(feature = "key-store")]
if session.store.config.memory.use_bloom_filters {
let mut filter = session.store.filter.write().await;
filter.remove(session.id.as_bytes());
}
let _ = session.store.inner.remove(&session.id);
if session.store.is_persistent() {
if let Err(err) = session
.store
.database_remove_session(session.id.clone())
.await
{
return trace_error(err, "failed to remove session from database");
}
}
}
if session.store.config.memory.memory_lifespan.is_zero() && !session.is_parallel() {
#[cfg(feature = "key-store")]
if !session.store.is_persistent() && session.store.config.memory.use_bloom_filters {
let mut filter = session.store.filter.write().await;
filter.remove(session.id.as_bytes());
}
session.store.inner.remove(&session.id);
}
set_headers(
&session,
response.headers_mut(),
&ip_user_agent,
destroy,
storable,
)
.await;
Ok(response)
})
}
}
impl<S, T> Debug for SessionService<S, T>
where
S: Debug,
T: DatabasePool + Clone + Debug + Sync + Send + 'static,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("SessionService")
.field("session_store", &self.session_store)
.field("inner", &self.inner)
.finish()
}
}