#![forbid(unsafe_code)]
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct StateDelta {
pub changes: HashMap<String, Option<Value>>,
}
impl StateDelta {
pub fn is_empty(&self) -> bool {
self.changes.is_empty()
}
pub fn len(&self) -> usize {
self.changes.len()
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct SessionState {
data: HashMap<String, Value>,
#[serde(skip)]
delta: StateDelta,
}
impl SessionState {
pub fn new() -> Self {
Self::default()
}
pub fn with_data(data: HashMap<String, Value>) -> Self {
Self {
data,
delta: StateDelta::default(),
}
}
pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
self.data
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn get_raw(&self, key: &str) -> Option<&Value> {
self.data.get(key)
}
pub fn set<T: Serialize>(&mut self, key: &str, value: T) -> Result<(), serde_json::Error> {
let val = serde_json::to_value(value)?;
self.data.insert(key.to_string(), val.clone());
self.delta.changes.insert(key.to_string(), Some(val));
Ok(())
}
pub fn remove(&mut self, key: &str) {
if self.data.remove(key).is_some() {
self.delta.changes.insert(key.to_string(), None);
}
}
pub fn contains(&self, key: &str) -> bool {
self.data.contains_key(key)
}
pub fn keys(&self) -> impl Iterator<Item = &str> {
self.data.keys().map(String::as_str)
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn clear(&mut self) {
for key in self.data.keys() {
self.delta.changes.insert(key.clone(), None);
}
self.data.clear();
}
pub const fn delta(&self) -> &StateDelta {
&self.delta
}
pub fn flush_delta(&mut self) -> StateDelta {
std::mem::take(&mut self.delta)
}
pub fn snapshot(&self) -> Value {
serde_json::to_value(&self.data).expect("HashMap<String, Value> is always serializable")
}
pub fn restore_from_snapshot(snapshot: Value) -> Result<Self, serde_json::Error> {
let data: HashMap<String, Value> = serde_json::from_value(snapshot)?;
Ok(Self {
data,
delta: StateDelta::default(),
})
}
}
const _: () = {
const fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SessionState>();
assert_send_sync::<StateDelta>();
assert_send_sync::<std::sync::Arc<std::sync::RwLock<SessionState>>>();
};
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn delta_default_is_empty() {
let d = StateDelta::default();
assert!(d.is_empty());
assert_eq!(d.len(), 0);
}
#[test]
fn delta_serde_roundtrip() {
let mut d = StateDelta::default();
d.changes.insert("a".into(), Some(json!(1)));
d.changes.insert("b".into(), None);
let json = serde_json::to_string(&d).unwrap();
let d2: StateDelta = serde_json::from_str(&json).unwrap();
assert_eq!(d2.len(), 2);
assert_eq!(d2.changes["a"], Some(json!(1)));
assert_eq!(d2.changes["b"], None);
}
#[test]
fn set_and_get_typed() {
let mut s = SessionState::new();
s.set("count", 42_i64).unwrap();
assert_eq!(s.get::<i64>("count"), Some(42));
}
#[test]
fn get_raw_returns_value_ref() {
let mut s = SessionState::new();
s.set("key", "hello").unwrap();
assert_eq!(s.get_raw("key"), Some(&json!("hello")));
}
#[test]
fn get_missing_returns_none() {
let s = SessionState::new();
assert_eq!(s.get::<String>("nope"), None);
}
#[test]
fn get_wrong_type_returns_none() {
let mut s = SessionState::new();
s.set("key", "hello").unwrap();
assert_eq!(s.get::<i64>("key"), None);
assert_eq!(s.get::<String>("key"), Some("hello".to_string()));
}
#[test]
fn remove_existing_key() {
let mut s = SessionState::new();
s.set("x", 1).unwrap();
s.remove("x");
assert!(!s.contains("x"));
assert!(s.is_empty());
}
#[test]
fn remove_absent_key_is_noop() {
let mut s = SessionState::new();
s.remove("nope");
assert!(s.delta().is_empty());
}
#[test]
fn contains_keys_len_is_empty() {
let mut s = SessionState::new();
assert!(s.is_empty());
s.set("a", 1).unwrap();
s.set("b", 2).unwrap();
assert!(s.contains("a"));
assert!(!s.contains("c"));
assert_eq!(s.len(), 2);
assert!(!s.is_empty());
let keys: Vec<&str> = s.keys().collect();
assert!(keys.contains(&"a"));
assert!(keys.contains(&"b"));
}
#[test]
fn clear_records_all_removals() {
let mut s = SessionState::new();
s.set("a", 1).unwrap();
s.set("b", 2).unwrap();
s.flush_delta(); s.clear();
assert!(s.is_empty());
assert_eq!(s.delta().len(), 2);
assert_eq!(s.delta().changes["a"], None);
assert_eq!(s.delta().changes["b"], None);
}
#[test]
fn delta_set_set_last_wins() {
let mut s = SessionState::new();
s.set("k", 1).unwrap();
s.set("k", 2).unwrap();
assert_eq!(s.delta().changes["k"], Some(json!(2)));
assert_eq!(s.delta().len(), 1);
}
#[test]
fn delta_set_remove_is_none() {
let mut s = SessionState::new();
s.set("k", 1).unwrap();
s.remove("k");
assert_eq!(s.delta().changes["k"], None);
}
#[test]
fn delta_remove_set_is_some() {
let mut s = SessionState::with_data(std::iter::once(("k".to_string(), json!(1))).collect());
s.remove("k");
s.set("k", 99).unwrap();
assert_eq!(s.delta().changes["k"], Some(json!(99)));
}
#[test]
fn flush_delta_returns_and_resets() {
let mut s = SessionState::new();
s.set("a", 1).unwrap();
let d = s.flush_delta();
assert_eq!(d.len(), 1);
assert!(s.delta().is_empty());
}
#[test]
fn flush_empty_delta_returns_empty() {
let mut s = SessionState::new();
let d = s.flush_delta();
assert!(d.is_empty());
}
#[test]
fn with_data_pre_seeds_without_delta() {
let data: HashMap<String, Value> = std::iter::once(("x".into(), json!(42))).collect();
let s = SessionState::with_data(data);
assert_eq!(s.get::<i64>("x"), Some(42));
assert!(s.delta().is_empty());
}
#[test]
fn snapshot_restore_roundtrip() {
let mut s = SessionState::new();
s.set("name", "alice").unwrap();
s.set("age", 30).unwrap();
let snap = s.snapshot();
let s2 = SessionState::restore_from_snapshot(snap).unwrap();
assert_eq!(s2.get::<String>("name"), Some("alice".to_string()));
assert_eq!(s2.get::<i64>("age"), Some(30));
assert!(s2.delta().is_empty());
}
#[test]
fn serde_roundtrip_skips_delta() {
let mut s = SessionState::new();
s.set("k", "v").unwrap();
assert!(!s.delta().is_empty());
let json = serde_json::to_string(&s).unwrap();
let s2: SessionState = serde_json::from_str(&json).unwrap();
assert_eq!(s2.get::<String>("k"), Some("v".to_string()));
assert!(s2.delta().is_empty());
}
#[test]
fn set_returns_error_on_serialization_failure() {
use serde::ser::{self, Serializer};
struct Unserializable;
impl Serialize for Unserializable {
fn serialize<S: Serializer>(&self, _s: S) -> Result<S::Ok, S::Error> {
Err(ser::Error::custom("intentional serialization failure"))
}
}
let mut s = SessionState::new();
let result = s.set("bad", Unserializable);
assert!(result.is_err());
assert!(!s.contains("bad"));
assert!(s.delta().is_empty());
}
#[test]
fn nested_json_roundtrip() {
let mut s = SessionState::new();
let nested = json!({
"user": {"name": "bob", "scores": [1, 2, 3]},
"active": true
});
s.set("profile", nested.clone()).unwrap();
let snap = s.snapshot();
let s2 = SessionState::restore_from_snapshot(snap).unwrap();
assert_eq!(s2.get_raw("profile"), Some(&nested));
}
#[test]
fn restore_from_corrupt_snapshot_returns_error() {
let err = SessionState::restore_from_snapshot(json!(["not", "an", "object"])).unwrap_err();
assert!(err.to_string().contains("map"));
}
}