use chrono::{DateTime, Duration, Utc};
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use rand::RngCore;
use rusqlite::{Connection, params};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::path::Path;
use crate::Result;
use crate::config::CloudflareKvConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: String,
pub data: HashMap<String, Value>,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub last_accessed: DateTime<Utc>,
}
pub const CSRF_TOKEN_KEY: &str = "_csrf_token";
impl Session {
pub fn new(max_age_seconds: i64) -> Self {
let now = Utc::now();
let mut data = HashMap::new();
data.insert(
CSRF_TOKEN_KEY.to_string(),
Value::String(generate_csrf_token()),
);
Self {
id: generate_session_id(),
data,
created_at: now,
expires_at: now + Duration::seconds(max_age_seconds),
last_accessed: now,
}
}
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
pub fn to_context(&self) -> Value {
let mut map = serde_json::Map::new();
map.insert("id".to_string(), Value::String(self.id.clone()));
map.insert(
"created_at".to_string(),
Value::String(self.created_at.to_rfc3339()),
);
map.insert(
"expires_at".to_string(),
Value::String(self.expires_at.to_rfc3339()),
);
for (key, value) in &self.data {
map.insert(key.clone(), value.clone());
}
Value::Object(map)
}
}
pub fn generate_session_id() -> String {
let mut bytes = [0u8; 64];
rand::thread_rng().fill_bytes(&mut bytes);
hex::encode(&bytes)
}
pub fn generate_csrf_token() -> String {
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
hex::encode(&bytes)
}
mod hex {
pub fn encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
}
pub enum SessionBackend {
Sqlite(SqliteSessionStore),
CloudflareKv(KvSessionStore),
}
impl SessionBackend {
pub async fn create(&self) -> Result<Session> {
match self {
Self::Sqlite(s) => s.create().await,
Self::CloudflareKv(s) => s.create().await,
}
}
pub async fn get(&self, id: &str) -> Result<Option<Session>> {
match self {
Self::Sqlite(s) => s.get(id).await,
Self::CloudflareKv(s) => s.get(id).await,
}
}
pub async fn get_or_create(&self, id: Option<&str>) -> Result<Session> {
match self {
Self::Sqlite(s) => s.get_or_create(id).await,
Self::CloudflareKv(s) => s.get_or_create(id).await,
}
}
pub async fn update(&self, id: &str, data: HashMap<String, Value>) -> Result<()> {
match self {
Self::Sqlite(s) => s.update(id, data).await,
Self::CloudflareKv(s) => s.update(id, data).await,
}
}
pub async fn touch(&self, id: &str) -> Result<()> {
match self {
Self::Sqlite(s) => s.touch(id).await,
Self::CloudflareKv(s) => s.touch(id).await,
}
}
pub async fn delete(&self, id: &str) -> Result<()> {
match self {
Self::Sqlite(s) => s.delete(id).await,
Self::CloudflareKv(s) => s.delete(id).await,
}
}
pub async fn cleanup_expired(&self) -> Result<u64> {
match self {
Self::Sqlite(s) => s.cleanup_expired().await,
Self::CloudflareKv(_) => Ok(0), }
}
pub async fn list_session_ids(&self) -> Result<Vec<String>> {
match self {
Self::Sqlite(s) => s.list_session_ids().await,
Self::CloudflareKv(_) => Ok(vec![]), }
}
pub async fn count(&self) -> Result<usize> {
match self {
Self::Sqlite(s) => s.count().await,
Self::CloudflareKv(_) => Ok(0), }
}
pub async fn apply_mutation(
&self,
id: &str,
mutation: &AtomicMutation,
) -> Result<HashMap<String, Value>> {
match self {
Self::Sqlite(s) => s.apply_atomic_mutation(id, mutation).await,
Self::CloudflareKv(s) => {
if let Some(mut session) = s.get(id).await? {
apply_mutation_in_memory(&mut session.data, mutation);
s.update(id, session.data.clone()).await?;
Ok(session.data)
} else {
Ok(HashMap::new())
}
}
}
}
}
impl Clone for SessionBackend {
fn clone(&self) -> Self {
match self {
Self::Sqlite(s) => Self::Sqlite(s.clone()),
Self::CloudflareKv(s) => Self::CloudflareKv(s.clone()),
}
}
}
#[derive(Clone)]
pub struct SqliteSessionStore {
pool: Pool<SqliteConnectionManager>,
max_age: i64,
}
#[derive(Debug)]
struct SessionCustomizer;
impl r2d2::CustomizeConnection<Connection, rusqlite::Error> for SessionCustomizer {
fn on_acquire(&self, conn: &mut Connection) -> std::result::Result<(), rusqlite::Error> {
conn.execute_batch("PRAGMA busy_timeout=5000; PRAGMA synchronous=NORMAL;")?;
Ok(())
}
}
impl SqliteSessionStore {
pub fn new(db_path: impl AsRef<Path>, max_age_seconds: i64) -> Result<Self> {
let manager = SqliteConnectionManager::file(db_path);
let pool = Pool::builder()
.max_size(4)
.connection_customizer(Box::new(SessionCustomizer))
.build(manager)
.map_err(|e| crate::Error::Session(format!("Session pool creation failed: {}", e)))?;
let conn = pool
.get()
.map_err(|e| crate::Error::Session(format!("Session pool get failed: {}", e)))?;
conn.execute_batch("PRAGMA journal_mode=WAL;")?;
conn.execute(
"CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
data TEXT NOT NULL DEFAULT '{}',
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
last_accessed INTEGER NOT NULL
)",
[],
)?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at)",
[],
)?;
let now = Utc::now().timestamp();
let cleaned = conn.execute("DELETE FROM sessions WHERE expires_at < ?1", params![now])?;
if cleaned > 0 {
tracing::info!("Cleaned up {} expired sessions", cleaned);
}
Ok(Self {
pool,
max_age: max_age_seconds,
})
}
pub fn in_memory(max_age_seconds: i64) -> Result<Self> {
let manager = SqliteConnectionManager::memory();
let pool = Pool::builder()
.max_size(1)
.build(manager)
.map_err(|e| crate::Error::Session(format!("Session pool creation failed: {}", e)))?;
let conn = pool
.get()
.map_err(|e| crate::Error::Session(format!("Session pool get failed: {}", e)))?;
conn.execute(
"CREATE TABLE sessions (
id TEXT PRIMARY KEY,
data TEXT NOT NULL DEFAULT '{}',
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
last_accessed INTEGER NOT NULL
)",
[],
)?;
Ok(Self {
pool,
max_age: max_age_seconds,
})
}
pub async fn create(&self) -> Result<Session> {
let pool = self.pool.clone();
let max_age = self.max_age;
tokio::task::spawn_blocking(move || {
let session = Session::new(max_age);
let conn = pool.get()
.map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
conn.execute(
"INSERT INTO sessions (id, data, created_at, expires_at, last_accessed) VALUES (?1, ?2, ?3, ?4, ?5)",
params![
session.id,
serde_json::to_string(&session.data)?,
session.created_at.timestamp(),
session.expires_at.timestamp(),
session.last_accessed.timestamp(),
],
)?;
Ok(session)
}).await.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
}
pub async fn get(&self, id: &str) -> Result<Option<Session>> {
let pool = self.pool.clone();
let id = id.to_string();
tokio::task::spawn_blocking(move || {
let conn = pool
.get()
.map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
let mut stmt = conn.prepare(
"SELECT id, data, created_at, expires_at, last_accessed FROM sessions WHERE id = ?1"
)?;
let session = match stmt.query_row(params![id], |row| {
let id: String = row.get(0)?;
let data_str: String = row.get(1)?;
let created_at: i64 = row.get(2)?;
let expires_at: i64 = row.get(3)?;
let last_accessed: i64 = row.get(4)?;
Ok(Session {
id,
data: serde_json::from_str(&data_str).unwrap_or_default(),
created_at: DateTime::from_timestamp(created_at, 0).unwrap_or_else(Utc::now),
expires_at: DateTime::from_timestamp(expires_at, 0).unwrap_or_else(Utc::now),
last_accessed: DateTime::from_timestamp(last_accessed, 0)
.unwrap_or_else(Utc::now),
})
}) {
Ok(s) => Some(s),
Err(rusqlite::Error::QueryReturnedNoRows) => None,
Err(e) => return Err(e.into()),
};
match session {
Some(s) if s.is_expired() => {
conn.execute("DELETE FROM sessions WHERE id = ?1", params![s.id])?;
Ok(None)
}
s => Ok(s),
}
})
.await
.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
}
pub async fn get_or_create(&self, id: Option<&str>) -> Result<Session> {
if let Some(session_id) = id {
if let Some(session) = self.get(session_id).await? {
self.touch(&session.id).await?;
return Ok(session);
}
}
self.create().await
}
pub async fn update(&self, id: &str, data: HashMap<String, Value>) -> Result<()> {
let pool = self.pool.clone();
let id = id.to_string();
tokio::task::spawn_blocking(move || {
let conn = pool
.get()
.map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
let now = Utc::now().timestamp();
conn.execute(
"UPDATE sessions SET data = ?1, last_accessed = ?2 WHERE id = ?3",
params![serde_json::to_string(&data)?, now, id],
)?;
Ok(())
})
.await
.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
}
pub async fn touch(&self, id: &str) -> Result<()> {
let pool = self.pool.clone();
let id = id.to_string();
tokio::task::spawn_blocking(move || {
let conn = pool
.get()
.map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
let now = Utc::now().timestamp();
conn.execute(
"UPDATE sessions SET last_accessed = ?1 WHERE id = ?2",
params![now, id],
)?;
Ok(())
})
.await
.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
}
pub async fn delete(&self, id: &str) -> Result<()> {
let pool = self.pool.clone();
let id = id.to_string();
tokio::task::spawn_blocking(move || {
let conn = pool
.get()
.map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
conn.execute("DELETE FROM sessions WHERE id = ?1", params![id])?;
Ok(())
})
.await
.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
}
pub async fn cleanup_expired(&self) -> Result<u64> {
let pool = self.pool.clone();
tokio::task::spawn_blocking(move || {
let conn = pool
.get()
.map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
let now = Utc::now().timestamp();
let deleted =
conn.execute("DELETE FROM sessions WHERE expires_at < ?1", params![now])?;
Ok(deleted as u64)
})
.await
.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
}
pub async fn list_session_ids(&self) -> Result<Vec<String>> {
let pool = self.pool.clone();
tokio::task::spawn_blocking(move || {
let conn = pool
.get()
.map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
let now = Utc::now().timestamp();
let mut stmt = conn.prepare(
"SELECT id FROM sessions WHERE expires_at > ?1 ORDER BY last_accessed DESC",
)?;
let ids: Vec<String> = stmt
.query_map(params![now], |row| row.get(0))?
.filter_map(|r| r.ok())
.collect();
Ok(ids)
})
.await
.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
}
pub async fn count(&self) -> Result<usize> {
let pool = self.pool.clone();
tokio::task::spawn_blocking(move || {
let conn = pool
.get()
.map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
let now = Utc::now().timestamp();
let count: i64 = conn.query_row(
"SELECT COUNT(*) FROM sessions WHERE expires_at > ?1",
params![now],
|row| row.get(0),
)?;
Ok(count as usize)
})
.await
.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
}
pub async fn apply_atomic_mutation(
&self,
id: &str,
mutation: &AtomicMutation,
) -> Result<HashMap<String, Value>> {
let pool = self.pool.clone();
let id = id.to_string();
let mutation = mutation.clone();
tokio::task::spawn_blocking(move || {
let conn = pool.get()
.map_err(|e| crate::Error::Session(format!("Pool error: {}", e)))?;
let now = Utc::now().timestamp();
match &mutation {
AtomicMutation::Increment { key, value } => {
let path = format!("$.{}", key);
conn.execute(
"UPDATE sessions SET data = json_set(data, ?1, COALESCE(json_extract(data, ?1), 0) + ?2), last_accessed = ?3 WHERE id = ?4",
params![path, value, now, id],
)?;
}
AtomicMutation::Set { key, value } => {
let path = format!("$.{}", key);
let json_str = serde_json::to_string(value).unwrap_or_default();
conn.execute(
"UPDATE sessions SET data = json_set(data, ?1, json(?2)), last_accessed = ?3 WHERE id = ?4",
params![path, json_str, now, id],
)?;
}
AtomicMutation::Push { key, value } => {
let path = format!("$.{}", key);
let json_val = serde_json::to_string(value).unwrap_or_default();
conn.execute(
"UPDATE sessions SET data = json_set(data, ?1, \
CASE WHEN json_extract(data, ?1) IS NULL THEN json_array(json(?2)) \
ELSE json_insert(json_extract(data, ?1), '$[#]', json(?2)) END \
), last_accessed = ?3 WHERE id = ?4",
params![path, json_val, now, id],
)?;
}
AtomicMutation::PushMax { key, max, value } => {
let path = format!("$.{}", key);
let current: String = conn.query_row(
"SELECT COALESCE(json_extract(data, ?1), '[]') FROM sessions WHERE id = ?2",
params![path, id],
|row| row.get(0),
).unwrap_or_else(|_| "[]".to_string());
let mut arr: Vec<Value> = serde_json::from_str(¤t).unwrap_or_default();
arr.push(value.clone());
while arr.len() > *max {
arr.remove(0);
}
let new_arr = serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string());
conn.execute(
"UPDATE sessions SET data = json_set(data, ?1, json(?2)), last_accessed = ?3 WHERE id = ?4",
params![path, new_arr, now, id],
)?;
}
AtomicMutation::Unshift { key, value } => {
let path = format!("$.{}", key);
let current: String = conn.query_row(
"SELECT COALESCE(json_extract(data, ?1), '[]') FROM sessions WHERE id = ?2",
params![path, id],
|row| row.get(0),
).unwrap_or_else(|_| "[]".to_string());
let mut arr: Vec<Value> = serde_json::from_str(¤t).unwrap_or_default();
arr.insert(0, value.clone());
let new_arr = serde_json::to_string(&arr).unwrap_or_else(|_| "[]".to_string());
conn.execute(
"UPDATE sessions SET data = json_set(data, ?1, json(?2)), last_accessed = ?3 WHERE id = ?4",
params![path, new_arr, now, id],
)?;
}
AtomicMutation::Clear { key } => {
let path = format!("$.{}", key);
conn.execute(
"UPDATE sessions SET data = json_set(data, ?1, json_array()), last_accessed = ?2 WHERE id = ?3",
params![path, now, id],
)?;
}
}
let data_str: String = conn.query_row(
"SELECT data FROM sessions WHERE id = ?1",
params![id],
|row| row.get(0),
)?;
let data: HashMap<String, Value> = serde_json::from_str(&data_str).unwrap_or_default();
Ok(data)
}).await.map_err(|e| crate::Error::Session(format!("Task join error: {}", e)))?
}
}
#[derive(Debug, Clone)]
pub enum AtomicMutation {
Increment { key: String, value: i64 },
Set { key: String, value: Value },
Push { key: String, value: Value },
PushMax {
key: String,
max: usize,
value: Value,
},
Unshift { key: String, value: Value },
Clear { key: String },
}
pub fn apply_mutation_in_memory(data: &mut HashMap<String, Value>, mutation: &AtomicMutation) {
match mutation {
AtomicMutation::Increment { key, value } => {
let current = data.get(key).and_then(|v| v.as_i64()).unwrap_or(0);
data.insert(key.clone(), serde_json::json!(current + value));
}
AtomicMutation::Set { key, value } => {
data.insert(key.clone(), value.clone());
}
AtomicMutation::Push { key, value } => {
let arr = data
.entry(key.clone())
.or_insert_with(|| serde_json::json!([]));
if let Some(arr) = arr.as_array_mut() {
arr.push(value.clone());
}
}
AtomicMutation::PushMax { key, max, value } => {
let arr = data
.entry(key.clone())
.or_insert_with(|| serde_json::json!([]));
if let Some(arr) = arr.as_array_mut() {
arr.push(value.clone());
while arr.len() > *max {
arr.remove(0);
}
}
}
AtomicMutation::Unshift { key, value } => {
let arr = data
.entry(key.clone())
.or_insert_with(|| serde_json::json!([]));
if let Some(arr) = arr.as_array_mut() {
arr.insert(0, value.clone());
}
}
AtomicMutation::Clear { key } => {
data.insert(key.clone(), serde_json::json!([]));
}
}
}
#[derive(Clone)]
pub struct KvSessionStore {
account_id: String,
namespace_id: String,
api_token: String,
max_age: i64,
}
impl KvSessionStore {
pub fn new(config: &CloudflareKvConfig, max_age_seconds: i64) -> Self {
Self {
account_id: config.account_id.clone(),
namespace_id: config.namespace_id.clone(),
api_token: config.api_token.clone(),
max_age: max_age_seconds,
}
}
fn base_url(&self) -> String {
format!(
"https://api.cloudflare.com/client/v4/accounts/{}/storage/kv/namespaces/{}",
self.account_id, self.namespace_id
)
}
fn key(&self, session_id: &str) -> String {
format!("session:{}", session_id)
}
fn client() -> &'static reqwest::Client {
use std::sync::OnceLock;
static CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
CLIENT.get_or_init(|| {
crate::http_client::build_http_client(Some(std::time::Duration::from_secs(10)))
.expect("failed to build Cloudflare KV HTTP client")
})
}
pub async fn create(&self) -> Result<Session> {
let session = Session::new(self.max_age);
self.put_session(&session).await?;
Ok(session)
}
pub async fn get(&self, id: &str) -> Result<Option<Session>> {
let url = format!("{}/values/{}", self.base_url(), self.key(id));
let response = Self::client()
.get(&url)
.bearer_auth(&self.api_token)
.send()
.await
.map_err(|e| crate::Error::Session(format!("KV read failed: {}", e)))?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
if !response.status().is_success() {
return Err(crate::Error::Session(format!(
"KV read error: HTTP {}",
response.status()
)));
}
let body = response
.text()
.await
.map_err(|e| crate::Error::Session(format!("KV read body failed: {}", e)))?;
match serde_json::from_str::<Session>(&body) {
Ok(session) if session.is_expired() => {
let _ = self.delete(&session.id).await;
Ok(None)
}
Ok(session) => Ok(Some(session)),
Err(e) => {
tracing::warn!("KV session deserialize failed: {}", e);
Ok(None)
}
}
}
pub async fn get_or_create(&self, id: Option<&str>) -> Result<Session> {
if let Some(session_id) = id {
if let Some(session) = self.get(session_id).await? {
self.touch(&session.id).await?;
return Ok(session);
}
}
self.create().await
}
pub async fn update(&self, id: &str, data: HashMap<String, Value>) -> Result<()> {
if let Some(mut session) = self.get(id).await? {
session.data = data;
session.last_accessed = Utc::now();
self.put_session(&session).await?;
}
Ok(())
}
pub async fn touch(&self, id: &str) -> Result<()> {
if let Some(mut session) = self.get(id).await? {
session.last_accessed = Utc::now();
self.put_session(&session).await?;
}
Ok(())
}
pub async fn delete(&self, id: &str) -> Result<()> {
let url = format!("{}/values/{}", self.base_url(), self.key(id));
Self::client()
.delete(&url)
.bearer_auth(&self.api_token)
.send()
.await
.map_err(|e| crate::Error::Session(format!("KV delete failed: {}", e)))?;
Ok(())
}
async fn put_session(&self, session: &Session) -> Result<()> {
let url = format!(
"{}/values/{}?expiration_ttl={}",
self.base_url(),
self.key(&session.id),
self.max_age
);
let body = serde_json::to_string(session)
.map_err(|e| crate::Error::Session(format!("KV serialize failed: {}", e)))?;
let response = Self::client()
.put(&url)
.bearer_auth(&self.api_token)
.header("Content-Type", "application/json")
.body(body)
.send()
.await
.map_err(|e| crate::Error::Session(format!("KV write failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(crate::Error::Session(format!(
"KV write error: HTTP {} — {}",
status, body
)));
}
Ok(())
}
}
pub fn parse_session_cookie(cookie_header: Option<&str>, cookie_name: &str) -> Option<String> {
cookie_header.and_then(|header| {
header
.split(';')
.map(|s| s.trim())
.find(|s| s.starts_with(&format!("{}=", cookie_name)))
.map(|s| s[cookie_name.len() + 1..].to_string())
})
}
pub fn build_session_cookie(
session_id: &str,
cookie_name: &str,
max_age: i64,
secure: bool,
) -> String {
let mut cookie = format!(
"{}={}; HttpOnly; SameSite=Strict; Path=/; Max-Age={}",
cookie_name, session_id, max_age
);
if secure {
cookie.push_str("; Secure");
}
cookie
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_session_id() {
let id = generate_session_id();
assert_eq!(id.len(), 128); assert!(id.chars().all(|c| c.is_ascii_hexdigit()));
}
#[tokio::test]
async fn test_session_store() {
let store = SqliteSessionStore::in_memory(3600).unwrap();
let session = store.create().await.unwrap();
assert_eq!(session.id.len(), 128);
let retrieved = store.get(&session.id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().id, session.id);
store.delete(&session.id).await.unwrap();
let deleted = store.get(&session.id).await.unwrap();
assert!(deleted.is_none());
}
#[test]
fn test_parse_session_cookie() {
let header = "w_session=abc123; other=value";
let result = parse_session_cookie(Some(header), "w_session");
assert_eq!(result, Some("abc123".to_string()));
let result = parse_session_cookie(Some(header), "missing");
assert_eq!(result, None);
}
#[test]
fn test_kv_key_format() {
let config = CloudflareKvConfig {
account_id: "acc123".to_string(),
namespace_id: "ns456".to_string(),
api_token: "token789".to_string(),
};
let store = KvSessionStore::new(&config, 3600);
assert_eq!(store.key("abc123"), "session:abc123");
}
#[test]
fn test_kv_base_url() {
let config = CloudflareKvConfig {
account_id: "acc123".to_string(),
namespace_id: "ns456".to_string(),
api_token: "token789".to_string(),
};
let store = KvSessionStore::new(&config, 3600);
assert_eq!(
store.base_url(),
"https://api.cloudflare.com/client/v4/accounts/acc123/storage/kv/namespaces/ns456"
);
}
#[test]
fn test_session_serialization_roundtrip() {
let session = Session::new(3600);
let json = serde_json::to_string(&session).unwrap();
let deserialized: Session = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.id, session.id);
assert_eq!(deserialized.data.len(), 1);
assert!(deserialized.data.contains_key(CSRF_TOKEN_KEY));
}
#[tokio::test]
async fn test_atomic_increment() {
let store = SqliteSessionStore::in_memory(3600).unwrap();
let session = store.create().await.unwrap();
let data = store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Increment {
key: "counter".to_string(),
value: 1,
},
)
.await
.unwrap();
assert_eq!(data.get("counter").and_then(|v| v.as_i64()), Some(1));
let data = store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Increment {
key: "counter".to_string(),
value: 5,
},
)
.await
.unwrap();
assert_eq!(data.get("counter").and_then(|v| v.as_i64()), Some(6));
let data = store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Increment {
key: "counter".to_string(),
value: -2,
},
)
.await
.unwrap();
assert_eq!(data.get("counter").and_then(|v| v.as_i64()), Some(4));
}
#[tokio::test]
async fn test_atomic_set() {
let store = SqliteSessionStore::in_memory(3600).unwrap();
let session = store.create().await.unwrap();
let data = store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Set {
key: "name".to_string(),
value: serde_json::json!("Alice"),
},
)
.await
.unwrap();
assert_eq!(data.get("name").and_then(|v| v.as_str()), Some("Alice"));
let data = store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Set {
key: "name".to_string(),
value: serde_json::json!("Bob"),
},
)
.await
.unwrap();
assert_eq!(data.get("name").and_then(|v| v.as_str()), Some("Bob"));
}
#[tokio::test]
async fn test_atomic_push() {
let store = SqliteSessionStore::in_memory(3600).unwrap();
let session = store.create().await.unwrap();
let data = store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Push {
key: "items".to_string(),
value: serde_json::json!("first"),
},
)
.await
.unwrap();
let items = data.get("items").and_then(|v| v.as_array()).unwrap();
assert_eq!(items.len(), 1);
assert_eq!(items[0].as_str(), Some("first"));
let data = store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Push {
key: "items".to_string(),
value: serde_json::json!("second"),
},
)
.await
.unwrap();
let items = data.get("items").and_then(|v| v.as_array()).unwrap();
assert_eq!(items.len(), 2);
assert_eq!(items[1].as_str(), Some("second"));
}
#[tokio::test]
async fn test_atomic_push_max() {
let store = SqliteSessionStore::in_memory(3600).unwrap();
let session = store.create().await.unwrap();
for i in 1..=3 {
store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::PushMax {
key: "log".to_string(),
max: 2,
value: serde_json::json!(i),
},
)
.await
.unwrap();
}
let data = store.get(&session.id).await.unwrap().unwrap();
let log = data.data.get("log").and_then(|v| v.as_array()).unwrap();
assert_eq!(log.len(), 2);
assert_eq!(log[0].as_i64(), Some(2));
assert_eq!(log[1].as_i64(), Some(3));
}
#[tokio::test]
async fn test_atomic_unshift() {
let store = SqliteSessionStore::in_memory(3600).unwrap();
let session = store.create().await.unwrap();
store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Unshift {
key: "stack".to_string(),
value: serde_json::json!("first"),
},
)
.await
.unwrap();
let data = store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Unshift {
key: "stack".to_string(),
value: serde_json::json!("second"),
},
)
.await
.unwrap();
let stack = data.get("stack").and_then(|v| v.as_array()).unwrap();
assert_eq!(stack.len(), 2);
assert_eq!(stack[0].as_str(), Some("second"));
assert_eq!(stack[1].as_str(), Some("first"));
}
#[tokio::test]
async fn test_atomic_clear() {
let store = SqliteSessionStore::in_memory(3600).unwrap();
let session = store.create().await.unwrap();
store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Push {
key: "items".to_string(),
value: serde_json::json!("a"),
},
)
.await
.unwrap();
store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Push {
key: "items".to_string(),
value: serde_json::json!("b"),
},
)
.await
.unwrap();
let data = store
.apply_atomic_mutation(
&session.id,
&AtomicMutation::Clear {
key: "items".to_string(),
},
)
.await
.unwrap();
let items = data.get("items").and_then(|v| v.as_array()).unwrap();
assert_eq!(items.len(), 0);
}
#[test]
fn test_apply_mutation_in_memory() {
let mut data = HashMap::new();
apply_mutation_in_memory(
&mut data,
&AtomicMutation::Increment {
key: "x".to_string(),
value: 3,
},
);
assert_eq!(data.get("x").and_then(|v| v.as_i64()), Some(3));
apply_mutation_in_memory(
&mut data,
&AtomicMutation::Set {
key: "name".to_string(),
value: serde_json::json!("test"),
},
);
assert_eq!(data.get("name").and_then(|v| v.as_str()), Some("test"));
apply_mutation_in_memory(
&mut data,
&AtomicMutation::Push {
key: "list".to_string(),
value: serde_json::json!(1),
},
);
apply_mutation_in_memory(
&mut data,
&AtomicMutation::Push {
key: "list".to_string(),
value: serde_json::json!(2),
},
);
let list = data.get("list").and_then(|v| v.as_array()).unwrap();
assert_eq!(list.len(), 2);
}
}