#[cfg(test)]
mod tests;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use crate::cookie::{CookieJar, SetCookie};
use crate::request::Request;
static SESSION_COUNTER: AtomicU64 = AtomicU64::new(0);
fn generate_id() -> String {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
let count = SESSION_COUNTER.fetch_add(1, Ordering::Relaxed);
let mut x = nanos ^ count.wrapping_mul(0x9e3779b97f4a7c15);
x ^= x >> 30;
x = x.wrapping_mul(0xbf58476d1ce4e5b9);
x ^= x >> 27;
x = x.wrapping_mul(0x94d049bb133111eb);
x ^= x >> 31;
let mut y = count ^ nanos.wrapping_mul(0x517cc1b727220a95);
y ^= y >> 30;
y = y.wrapping_mul(0xbf58476d1ce4e5b9);
y ^= y >> 27;
y = y.wrapping_mul(0x94d049bb133111eb);
y ^= y >> 31;
format!("{:016x}{:016x}", x, y)
}
pub struct Session {
pub id: String,
pub(crate) data: HashMap<String, String>,
}
impl Session {
pub fn get(&self, key: &str) -> Option<&str> {
self.data.get(key).map(String::as_str)
}
pub fn set(&mut self, key: &str, value: impl Into<String>) {
self.data.insert(key.to_string(), value.into());
}
pub fn remove(&mut self, key: &str) {
self.data.remove(key);
}
pub fn contains(&self, key: &str) -> bool {
self.data.contains_key(key)
}
}
struct Entry {
data: HashMap<String, String>,
expires_at: Instant,
}
struct Inner {
sessions: HashMap<String, Entry>,
}
pub struct SessionStore {
inner: Arc<Mutex<Inner>>,
ttl: Duration,
}
impl Clone for SessionStore {
fn clone(&self) -> Self {
SessionStore { inner: Arc::clone(&self.inner), ttl: self.ttl }
}
}
impl SessionStore {
pub fn new(ttl_secs: u64) -> Self {
SessionStore {
inner: Arc::new(Mutex::new(Inner { sessions: HashMap::new() })),
ttl: Duration::from_secs(ttl_secs),
}
}
pub fn create(&self) -> Session {
self.create_with_id(generate_id())
}
pub fn create_with_id(&self, id: String) -> Session {
let entry = Entry {
data: HashMap::new(),
expires_at: Instant::now() + self.ttl,
};
self.inner.lock().unwrap().sessions.insert(id.clone(), entry);
Session { id, data: HashMap::new() }
}
pub fn load(&self, id: &str) -> Option<Session> {
let inner = self.inner.lock().unwrap();
let entry = inner.sessions.get(id)?;
if Instant::now() > entry.expires_at {
return None;
}
Some(Session { id: id.to_string(), data: entry.data.clone() })
}
pub fn save(&self, session: &Session) {
let mut inner = self.inner.lock().unwrap();
if let Some(entry) = inner.sessions.get_mut(&session.id) {
entry.data = session.data.clone();
}
}
pub fn destroy(&self, id: &str) {
self.inner.lock().unwrap().sessions.remove(id);
}
pub fn purge_expired(&self) {
let now = Instant::now();
self.inner.lock().unwrap().sessions.retain(|_, e| e.expires_at > now);
}
pub fn len(&self) -> usize {
self.inner.lock().unwrap().sessions.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub fn session_id_from_request(request: &Request, cookie_name: &str) -> Option<String> {
let header = request.get_header("Cookie".to_string())?;
let jar = CookieJar::parse(&header.value);
jar.get(cookie_name).map(|c| c.value.clone())
}
pub fn session_cookie(session_id: &str, cookie_name: &str, ttl_secs: u64) -> String {
SetCookie::new(cookie_name, session_id)
.path("/")
.http_only()
.same_site("Lax")
.max_age(ttl_secs as i64)
.build()
}
pub fn destroy_cookie(cookie_name: &str) -> String {
SetCookie::new(cookie_name, "").path("/").max_age(0).build()
}
#[cfg(any(feature = "model-sqlite", feature = "model-postgres", feature = "model-mysql"))]
pub struct DbSessionStore {
pool: Arc<crate::model::DbPool>,
ttl: Duration,
}
#[cfg(any(feature = "model-sqlite", feature = "model-postgres", feature = "model-mysql"))]
impl Clone for DbSessionStore {
fn clone(&self) -> Self {
DbSessionStore { pool: Arc::clone(&self.pool), ttl: self.ttl }
}
}
#[cfg(any(feature = "model-sqlite", feature = "model-postgres", feature = "model-mysql"))]
impl DbSessionStore {
pub async fn new(pool: crate::model::DbPool, ttl_secs: u64) -> Result<Self, crate::model::DbError> {
let store = DbSessionStore {
pool: Arc::new(pool),
ttl: Duration::from_secs(ttl_secs),
};
store.ensure_table().await?;
Ok(store)
}
async fn ensure_table(&self) -> Result<(), crate::model::DbError> {
self.pool.execute(
"CREATE TABLE IF NOT EXISTS rws_sessions \
(id TEXT PRIMARY KEY, data TEXT NOT NULL DEFAULT '', expires_at INTEGER NOT NULL)",
&[],
).await?;
Ok(())
}
fn now_epoch() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
fn serialize(data: &HashMap<String, String>) -> String {
crate::url::URL::build_query(data.clone())
}
fn deserialize(s: &str) -> HashMap<String, String> {
crate::url::URL::parse_query(s)
}
pub async fn create(&self) -> Result<Session, crate::model::DbError> {
self.create_with_id(generate_id()).await
}
pub async fn create_with_id(&self, id: String) -> Result<Session, crate::model::DbError> {
let expires_at = Self::now_epoch() + self.ttl.as_secs() as i64;
self.pool.execute(
"INSERT INTO rws_sessions (id, data, expires_at) VALUES (?, ?, ?)",
&[
crate::model::Value::Text(id.clone()),
crate::model::Value::Text(String::new()),
crate::model::Value::Int(expires_at),
],
).await?;
Ok(Session { id, data: HashMap::new() })
}
pub async fn load(&self, id: &str) -> Result<Option<Session>, crate::model::DbError> {
let now = Self::now_epoch();
let rows = self.pool.query_rows(
"SELECT data FROM rws_sessions WHERE id = ? AND expires_at > ?",
&[
crate::model::Value::Text(id.to_string()),
crate::model::Value::Int(now),
],
).await?;
if rows.is_empty() {
return Ok(None);
}
let raw: String = rows[0].get("data")?;
Ok(Some(Session { id: id.to_string(), data: Self::deserialize(&raw) }))
}
pub async fn save(&self, session: &Session) -> Result<(), crate::model::DbError> {
self.pool.execute(
"UPDATE rws_sessions SET data = ? WHERE id = ?",
&[
crate::model::Value::Text(Self::serialize(&session.data)),
crate::model::Value::Text(session.id.clone()),
],
).await?;
Ok(())
}
pub async fn destroy(&self, id: &str) -> Result<(), crate::model::DbError> {
self.pool.execute(
"DELETE FROM rws_sessions WHERE id = ?",
&[crate::model::Value::Text(id.to_string())],
).await?;
Ok(())
}
pub async fn purge_expired(&self) -> Result<(), crate::model::DbError> {
let now = Self::now_epoch();
self.pool.execute(
"DELETE FROM rws_sessions WHERE expires_at <= ?",
&[crate::model::Value::Int(now)],
).await?;
Ok(())
}
pub async fn len(&self) -> Result<usize, crate::model::DbError> {
let rows = self.pool.query_rows("SELECT COUNT(*) AS n FROM rws_sessions", &[]).await?;
if rows.is_empty() {
return Ok(0);
}
let n: i64 = rows[0].get("n")?;
Ok(n as usize)
}
pub async fn is_empty(&self) -> Result<bool, crate::model::DbError> {
Ok(self.len().await? == 0)
}
}
pub struct RespConn {
addr: String,
password: Option<String>,
stream: Mutex<Option<std::net::TcpStream>>,
}
impl RespConn {
fn new(addr: impl Into<String>, password: Option<String>) -> Self {
RespConn { addr: addr.into(), password, stream: Mutex::new(None) }
}
fn connect(&self) -> std::io::Result<std::net::TcpStream> {
let stream = std::net::TcpStream::connect(&self.addr)?;
stream.set_read_timeout(Some(Duration::from_secs(5)))?;
stream.set_write_timeout(Some(Duration::from_secs(5)))?;
Ok(stream)
}
fn cmd(&self, args: &[&[u8]]) -> std::io::Result<RespReply> {
use std::io::Write;
let mut guard = self.stream.lock().unwrap();
if guard.is_none() {
let mut s = self.connect()?;
if let Some(ref pw) = self.password {
let auth_frame = resp_array(&[b"AUTH", pw.as_bytes()]);
s.write_all(&auth_frame)?;
read_reply(&mut s)?; }
*guard = Some(s);
}
let frame = resp_array(args);
let stream = guard.as_mut().unwrap();
if stream.write_all(&frame).is_err() {
*guard = None;
drop(guard);
return self.cmd(args);
}
read_reply(stream)
}
}
enum RespReply {
Ok,
Int(i64),
Bulk(Option<Vec<u8>>),
Error(String),
}
fn resp_array(args: &[&[u8]]) -> Vec<u8> {
let mut out = format!("*{}\r\n", args.len()).into_bytes();
for arg in args {
out.extend_from_slice(format!("${}\r\n", arg.len()).as_bytes());
out.extend_from_slice(arg);
out.extend_from_slice(b"\r\n");
}
out
}
fn read_reply(stream: &mut std::net::TcpStream) -> std::io::Result<RespReply> {
use std::io::{BufRead, BufReader, Read};
let mut reader = BufReader::new(stream);
let mut line = String::new();
reader.read_line(&mut line)?;
let line = line.trim_end_matches("\r\n");
match line.chars().next() {
Some('+') => Ok(RespReply::Ok),
Some(':') => {
let n = line[1..].parse::<i64>().unwrap_or(0);
Ok(RespReply::Int(n))
}
Some('-') => Ok(RespReply::Error(line[1..].to_string())),
Some('$') => {
let len = line[1..].parse::<i64>().unwrap_or(-1);
if len < 0 {
return Ok(RespReply::Bulk(None));
}
let mut buf = vec![0u8; len as usize + 2]; reader.read_exact(&mut buf)?;
buf.truncate(len as usize);
Ok(RespReply::Bulk(Some(buf)))
}
_ => Ok(RespReply::Ok), }
}
pub struct RedisSessionStore {
conn: Arc<RespConn>,
ttl: u64,
}
impl Clone for RedisSessionStore {
fn clone(&self) -> Self {
RedisSessionStore { conn: Arc::clone(&self.conn), ttl: self.ttl }
}
}
impl RedisSessionStore {
pub fn new(addr: impl Into<String>, password: Option<String>, ttl_secs: u64) -> Self {
RedisSessionStore {
conn: Arc::new(RespConn::new(addr, password)),
ttl: ttl_secs,
}
}
pub fn from_env() -> Self {
let host = std::env::var("RWS_REDIS_HOST").unwrap_or_else(|_| "127.0.0.1".into());
let port = std::env::var("RWS_REDIS_PORT").unwrap_or_else(|_| "6379".into());
let addr = format!("{}:{}", host, port);
let password = std::env::var("RWS_REDIS_PASSWORD").ok();
let ttl = std::env::var("RWS_REDIS_TTL_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(3600u64);
Self::new(addr, password, ttl)
}
fn key(id: &str) -> Vec<u8> {
format!("rws:sess:{}", id).into_bytes()
}
fn serialize(data: &HashMap<String, String>) -> Vec<u8> {
crate::url::URL::build_query(data.clone()).into_bytes()
}
fn deserialize(bytes: Vec<u8>) -> HashMap<String, String> {
let s = String::from_utf8(bytes).unwrap_or_default();
crate::url::URL::parse_query(&s)
}
pub fn create(&self) -> std::io::Result<Session> {
self.create_with_id(generate_id())
}
pub fn create_with_id(&self, id: String) -> std::io::Result<Session> {
let ttl_str = self.ttl.to_string();
self.conn.cmd(&[
b"SET",
&Self::key(&id),
b"",
b"EX",
ttl_str.as_bytes(),
])?;
Ok(Session { id, data: HashMap::new() })
}
pub fn load(&self, id: &str) -> std::io::Result<Option<Session>> {
match self.conn.cmd(&[b"GET", &Self::key(id)])? {
RespReply::Bulk(Some(bytes)) => {
Ok(Some(Session { id: id.to_string(), data: Self::deserialize(bytes) }))
}
_ => Ok(None),
}
}
pub fn save(&self, session: &Session) -> std::io::Result<()> {
let ttl_str = self.ttl.to_string();
let data = Self::serialize(&session.data);
self.conn.cmd(&[
b"SET",
&Self::key(&session.id),
&data,
b"EX",
ttl_str.as_bytes(),
])?;
Ok(())
}
pub fn destroy(&self, id: &str) -> std::io::Result<()> {
self.conn.cmd(&[b"DEL", &Self::key(id)])?;
Ok(())
}
pub fn purge_expired(&self) {}
pub fn len(&self) -> std::io::Result<usize> {
match self.conn.cmd(&[b"DBSIZE"])? {
RespReply::Int(n) => Ok(n as usize),
_ => Ok(0),
}
}
pub fn is_empty(&self) -> std::io::Result<bool> {
Ok(self.len()? == 0)
}
}