use crate::filter::Predicate;
use axum::{extract::Request, http::StatusCode, response::Response};
use headers::{Cookie, HeaderMapExt};
use parking_lot::RwLock;
use std::{cmp::PartialEq, collections::HashMap, sync::Arc};
#[derive(Clone)]
pub struct Session<T> {
pub current: T,
pub all: Arc<SessionStore<T>>,
}
#[derive(Debug)]
pub struct SessionStore<T> {
key: String,
inner: RwLock<HashMap<String, T>>,
}
impl<T: PartialEq> SessionStore<T> {
pub fn new(key: impl Into<String>) -> Self {
SessionStore {
key: key.into(),
inner: RwLock::new(HashMap::new()),
}
}
pub fn key(&self) -> &str {
&self.key
}
pub fn insert(&self, k: impl Into<String>, v: T) {
self.inner.write().insert(k.into(), v);
}
pub fn remove(&self, v: T) {
self.inner.write().retain(|_, x| *x != v);
}
}
#[derive(Clone, Debug)]
pub struct AddSession<T>(Arc<SessionStore<T>>);
impl<T> AddSession<T> {
pub fn new(store: Arc<SessionStore<T>>) -> Self {
Self(store)
}
}
impl<T> Predicate<Request> for AddSession<T>
where
T: Send + Sync + 'static,
{
type Request = Request;
type Response = Response;
fn check(&self, mut request: Request) -> Result<Self::Request, Self::Response> {
request.extensions_mut().insert(self.0.clone());
Ok(request)
}
}
#[derive(Clone, Debug)]
pub struct RequireSession<T>(Arc<SessionStore<T>>);
impl<T> RequireSession<T> {
pub fn new(store: Arc<SessionStore<T>>) -> Self {
Self(store)
}
}
impl<T> Predicate<Request> for RequireSession<T>
where
T: Clone + Send + Sync + 'static,
{
type Request = Request;
type Response = Response;
fn check(&self, mut request: Request) -> Result<Self::Request, Self::Response> {
if let Some(cookie) = request.headers().typed_get::<Cookie>() {
let sessions = self.0.inner.read();
for (k, v) in cookie.iter() {
if k == self.0.key {
if let Some(u) = sessions.get(v) {
request.extensions_mut().insert(Session {
current: u.clone(),
all: self.0.clone(),
});
return Ok(request);
}
}
}
}
Err({
let mut response = Response::default();
*response.status_mut() = StatusCode::UNAUTHORIZED;
response
})
}
}