use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::time::Instant;
use http::HeaderValue;
use parking_lot::Mutex;
use scc::HashMap as SccHashMap;
use serde::Serialize;
use serde::de::DeserializeOwned;
use tako_rs_core::middleware::IntoMiddleware;
use tako_rs_core::middleware::Next;
use tako_rs_core::types::Request;
use tako_rs_core::types::Response;
#[derive(Clone, Copy)]
pub struct SessionTtl {
pub idle_secs: u64,
pub absolute_secs: Option<u64>,
}
impl Default for SessionTtl {
fn default() -> Self {
Self {
idle_secs: 3_600,
absolute_secs: Some(86_400),
}
}
}
#[derive(Clone, Copy, Debug)]
pub enum SameSite {
Strict,
Lax,
None,
}
impl SameSite {
fn as_str(self) -> &'static str {
match self {
SameSite::Strict => "Strict",
SameSite::Lax => "Lax",
SameSite::None => "None",
}
}
}
#[derive(Clone)]
struct SessionEntry {
data: serde_json::Map<String, serde_json::Value>,
created_at: Instant,
last_seen_at: Instant,
}
#[derive(Clone)]
struct Store(Arc<SccHashMap<String, SessionEntry>>);
impl Store {
fn new() -> Self {
Self(Arc::new(SccHashMap::new()))
}
fn get(&self, id: &str) -> Option<SessionEntry> {
self.0.get_sync(id).map(|e| e.clone())
}
fn upsert(&self, id: String, entry: SessionEntry) {
let _ = self.0.upsert_sync(id, entry);
}
fn remove(&self, id: &str) {
let _ = self.0.remove_sync(id);
}
fn revoke_all(&self) {
self.0.clear_sync();
}
fn revoke_predicate(&self, mut keep: impl FnMut(&str, &SessionEntry) -> bool) {
self.0.retain_sync(|k, v| keep(k, v));
}
fn retain_expired(&self, ttl: SessionTtl) {
let now = Instant::now();
let idle = Duration::from_secs(ttl.idle_secs);
let absolute = ttl.absolute_secs.map(Duration::from_secs);
self.0.retain_sync(|_, v| {
if now.duration_since(v.last_seen_at) > idle {
return false;
}
if let Some(abs) = absolute
&& now.duration_since(v.created_at) > abs
{
return false;
}
true
});
}
}
#[derive(Clone)]
pub struct SessionStoreHandle {
store: Store,
}
impl SessionStoreHandle {
pub fn revoke_all(&self) {
self.store.revoke_all();
}
pub fn revoke_where<F>(&self, mut pred: F)
where
F: FnMut(&str, &serde_json::Map<String, serde_json::Value>) -> bool,
{
self.store.revoke_predicate(|k, v| !pred(k, &v.data));
}
}
#[derive(Clone)]
pub struct Session {
data: Arc<Mutex<serde_json::Map<String, serde_json::Value>>>,
dirty: Arc<AtomicBool>,
rotation_counter: Arc<AtomicU64>,
destroyed: Arc<AtomicBool>,
}
impl Session {
fn new(data: serde_json::Map<String, serde_json::Value>) -> Self {
Self {
data: Arc::new(Mutex::new(data)),
dirty: Arc::new(AtomicBool::new(false)),
rotation_counter: Arc::new(AtomicU64::new(0)),
destroyed: Arc::new(AtomicBool::new(false)),
}
}
pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self
.data
.lock()
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn set<T: Serialize>(&self, key: &str, value: T) {
if let Ok(v) = serde_json::to_value(value) {
self.data.lock().insert(key.to_string(), v);
self.dirty.store(true, Ordering::Relaxed);
}
}
pub fn remove(&self, key: &str) {
if self.data.lock().remove(key).is_some() {
self.dirty.store(true, Ordering::Relaxed);
}
}
pub fn clear(&self) {
let mut guard = self.data.lock();
if !guard.is_empty() {
guard.clear();
self.dirty.store(true, Ordering::Relaxed);
}
}
pub fn destroy(&self) {
self.data.lock().clear();
self.destroyed.store(true, Ordering::Release);
self.dirty.store(true, Ordering::Relaxed);
}
fn is_destroyed(&self) -> bool {
self.destroyed.load(Ordering::Acquire)
}
pub fn rotate(&self) {
self.rotation_counter.fetch_add(1, Ordering::AcqRel);
self.dirty.store(true, Ordering::Relaxed);
}
fn is_dirty(&self) -> bool {
self.dirty.load(Ordering::Relaxed)
}
pub fn rotation_requested(&self) -> bool {
self.rotation_counter.load(Ordering::Acquire) > 0
}
fn snapshot(&self) -> serde_json::Map<String, serde_json::Value> {
self.data.lock().clone()
}
}
pub struct SessionMiddleware {
cookie_name: String,
ttl: SessionTtl,
path: String,
domain: Option<String>,
secure: bool,
http_only: bool,
same_site: SameSite,
store: Store,
}
impl Default for SessionMiddleware {
fn default() -> Self {
Self::new()
}
}
impl SessionMiddleware {
pub fn new() -> Self {
Self {
cookie_name: "tako_session".to_string(),
ttl: SessionTtl::default(),
path: "/".to_string(),
domain: None,
secure: false,
http_only: true,
same_site: SameSite::Lax,
store: Store::new(),
}
}
pub fn cookie_name(mut self, name: &str) -> Self {
self.cookie_name = name.to_string();
self
}
pub fn ttl_secs(mut self, secs: u64) -> Self {
self.ttl.idle_secs = secs;
self
}
pub fn ttl(mut self, ttl: SessionTtl) -> Self {
self.ttl = ttl;
self
}
pub fn path(mut self, path: &str) -> Self {
self.path = path.to_string();
self
}
pub fn domain(mut self, domain: &str) -> Self {
self.domain = Some(domain.to_string());
self
}
pub fn secure(mut self, secure: bool) -> Self {
self.secure = secure;
self
}
pub fn http_only(mut self, on: bool) -> Self {
self.http_only = on;
self
}
pub fn same_site(mut self, ss: SameSite) -> Self {
self.same_site = ss;
self
}
pub fn handle(&self) -> SessionStoreHandle {
SessionStoreHandle {
store: self.store.clone(),
}
}
}
fn generate_session_id() -> String {
uuid::Uuid::new_v4().simple().to_string()
}
fn extract_cookie_value<'a>(req: &'a Request, cookie_name: &str) -> Option<&'a str> {
req
.headers()
.get(http::header::COOKIE)
.and_then(|v| v.to_str().ok())
.and_then(|cookies| {
cookies.split(';').find_map(|pair| {
let pair = pair.trim();
let (name, value) = pair.split_once('=')?;
if name.trim() == cookie_name {
Some(value.trim())
} else {
None
}
})
})
}
#[allow(clippy::too_many_arguments)]
fn build_cookie(
cookie_name: &str,
sid: &str,
path: &str,
domain: Option<&str>,
ttl_secs: u64,
secure: bool,
http_only: bool,
same_site: SameSite,
) -> String {
let mut s = format!("{cookie_name}={sid}; Path={path}");
if let Some(d) = domain {
s.push_str("; Domain=");
s.push_str(d);
}
s.push_str(&format!("; Max-Age={ttl_secs}"));
if http_only {
s.push_str("; HttpOnly");
}
if secure {
s.push_str("; Secure");
}
s.push_str("; SameSite=");
s.push_str(same_site.as_str());
s
}
fn build_expired_cookie(
cookie_name: &str,
path: &str,
domain: Option<&str>,
secure: bool,
http_only: bool,
same_site: SameSite,
) -> String {
let mut s = format!("{cookie_name}=; Path={path}");
if let Some(d) = domain {
s.push_str("; Domain=");
s.push_str(d);
}
s.push_str("; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT");
if http_only {
s.push_str("; HttpOnly");
}
if secure {
s.push_str("; Secure");
}
s.push_str("; SameSite=");
s.push_str(same_site.as_str());
s
}
impl IntoMiddleware for SessionMiddleware {
fn into_middleware(
self,
) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
+ Clone
+ Send
+ Sync
+ 'static {
let store = self.store.clone();
let cookie_name = Arc::new(self.cookie_name);
let ttl = self.ttl;
let path = Arc::new(self.path);
let domain = self.domain.map(Arc::new);
let secure = self.secure;
let http_only = self.http_only;
let same_site = self.same_site;
{
let store = store.clone();
let interval = Duration::from_secs(ttl.idle_secs.clamp(60, 3_600));
#[cfg(not(feature = "compio"))]
tokio::spawn(async move {
let mut tick = tokio::time::interval(interval);
loop {
tick.tick().await;
store.retain_expired(ttl);
}
});
#[cfg(feature = "compio")]
compio::runtime::spawn(async move {
loop {
compio::time::sleep(interval).await;
store.retain_expired(ttl);
}
})
.detach();
}
move |mut req: Request, next: Next| {
let store = store.clone();
let cookie_name = cookie_name.clone();
let path = path.clone();
let domain = domain.clone();
Box::pin(async move {
let now = Instant::now();
let idle = Duration::from_secs(ttl.idle_secs);
let absolute = ttl.absolute_secs.map(Duration::from_secs);
let inbound_id = extract_cookie_value(&req, &cookie_name).map(str::to_string);
let (sid, data, created_at, was_existing) = match inbound_id {
Some(ref id) => match store.get(id) {
Some(entry)
if now.duration_since(entry.last_seen_at) <= idle
&& absolute.is_none_or(|abs| now.duration_since(entry.created_at) <= abs) =>
{
(id.clone(), entry.data, entry.created_at, true)
}
_ => {
if let Some(id) = inbound_id.as_ref() {
store.remove(id);
}
(generate_session_id(), serde_json::Map::new(), now, false)
}
},
None => (generate_session_id(), serde_json::Map::new(), now, false),
};
let session = Session::new(data);
req.extensions_mut().insert(session.clone());
let resp_outcome = next.run(req).await;
let mut resp = resp_outcome;
let dirty = session.is_dirty();
let rotated = session.rotation_requested();
let destroyed = session.is_destroyed();
if destroyed {
if was_existing {
store.remove(&sid);
}
let expired = build_expired_cookie(
&cookie_name,
&path,
domain.as_deref().map(String::as_str),
secure,
http_only,
same_site,
);
if let Ok(v) = HeaderValue::from_str(&expired) {
resp.headers_mut().append(http::header::SET_COOKIE, v);
}
let _ = dirty;
return resp;
}
let effective_sid = if rotated {
if was_existing {
store.remove(&sid);
}
generate_session_id()
} else {
sid
};
let updated_entry = SessionEntry {
data: session.snapshot(),
created_at,
last_seen_at: now,
};
store.upsert(effective_sid.clone(), updated_entry);
let max_age = match absolute {
Some(abs) => {
let elapsed = now.duration_since(created_at);
let absolute_remaining = abs.saturating_sub(elapsed);
absolute_remaining.as_secs().min(idle.as_secs())
}
None => idle.as_secs(),
};
let cookie_value = build_cookie(
&cookie_name,
&effective_sid,
&path,
domain.as_deref().map(String::as_str),
max_age,
secure,
http_only,
same_site,
);
if let Ok(v) = HeaderValue::from_str(&cookie_value) {
resp.headers_mut().append(http::header::SET_COOKIE, v);
}
let _ = dirty;
resp
})
}
}
}