1#![forbid(unsafe_code)]
7
8use std::collections::HashMap;
9
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13#[derive(Debug, Default, Clone, Serialize, Deserialize)]
19pub struct StateDelta {
20 pub changes: HashMap<String, Option<Value>>,
23}
24
25impl StateDelta {
26 pub fn is_empty(&self) -> bool {
28 self.changes.is_empty()
29 }
30
31 pub fn len(&self) -> usize {
33 self.changes.len()
34 }
35}
36
37#[derive(Debug, Default, Clone, Serialize, Deserialize)]
45pub struct SessionState {
46 data: HashMap<String, Value>,
47 #[serde(skip)]
48 delta: StateDelta,
49}
50
51impl SessionState {
52 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn with_data(data: HashMap<String, Value>) -> Self {
61 Self {
62 data,
63 delta: StateDelta::default(),
64 }
65 }
66
67 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
70 self.data
71 .get(key)
72 .and_then(|v| serde_json::from_value(v.clone()).ok())
73 }
74
75 pub fn get_raw(&self, key: &str) -> Option<&Value> {
77 self.data.get(key)
78 }
79
80 pub fn set<T: Serialize>(&mut self, key: &str, value: T) -> Result<(), serde_json::Error> {
84 let val = serde_json::to_value(value)?;
85 self.data.insert(key.to_string(), val.clone());
86 self.delta.changes.insert(key.to_string(), Some(val));
87 Ok(())
88 }
89
90 pub fn remove(&mut self, key: &str) {
92 if self.data.remove(key).is_some() {
93 self.delta.changes.insert(key.to_string(), None);
94 }
95 }
96
97 pub fn contains(&self, key: &str) -> bool {
99 self.data.contains_key(key)
100 }
101
102 pub fn keys(&self) -> impl Iterator<Item = &str> {
104 self.data.keys().map(String::as_str)
105 }
106
107 pub fn len(&self) -> usize {
109 self.data.len()
110 }
111
112 pub fn is_empty(&self) -> bool {
114 self.data.is_empty()
115 }
116
117 pub fn clear(&mut self) {
119 for key in self.data.keys() {
120 self.delta.changes.insert(key.clone(), None);
121 }
122 self.data.clear();
123 }
124
125 pub const fn delta(&self) -> &StateDelta {
127 &self.delta
128 }
129
130 pub fn flush_delta(&mut self) -> StateDelta {
132 std::mem::take(&mut self.delta)
133 }
134
135 pub fn snapshot(&self) -> Value {
137 serde_json::to_value(&self.data).expect("HashMap<String, Value> is always serializable")
138 }
139
140 pub fn restore_from_snapshot(snapshot: Value) -> Result<Self, serde_json::Error> {
143 let data: HashMap<String, Value> = serde_json::from_value(snapshot)?;
144 Ok(Self {
145 data,
146 delta: StateDelta::default(),
147 })
148 }
149}
150
151const _: () = {
154 const fn assert_send_sync<T: Send + Sync>() {}
155 assert_send_sync::<SessionState>();
156 assert_send_sync::<StateDelta>();
157 assert_send_sync::<std::sync::Arc<std::sync::RwLock<SessionState>>>();
158};
159
160#[cfg(test)]
163mod tests {
164 use super::*;
165 use serde_json::json;
166
167 #[test]
170 fn delta_default_is_empty() {
171 let d = StateDelta::default();
172 assert!(d.is_empty());
173 assert_eq!(d.len(), 0);
174 }
175
176 #[test]
177 fn delta_serde_roundtrip() {
178 let mut d = StateDelta::default();
179 d.changes.insert("a".into(), Some(json!(1)));
180 d.changes.insert("b".into(), None);
181 let json = serde_json::to_string(&d).unwrap();
182 let d2: StateDelta = serde_json::from_str(&json).unwrap();
183 assert_eq!(d2.len(), 2);
184 assert_eq!(d2.changes["a"], Some(json!(1)));
185 assert_eq!(d2.changes["b"], None);
186 }
187
188 #[test]
191 fn set_and_get_typed() {
192 let mut s = SessionState::new();
193 s.set("count", 42_i64).unwrap();
194 assert_eq!(s.get::<i64>("count"), Some(42));
195 }
196
197 #[test]
198 fn get_raw_returns_value_ref() {
199 let mut s = SessionState::new();
200 s.set("key", "hello").unwrap();
201 assert_eq!(s.get_raw("key"), Some(&json!("hello")));
202 }
203
204 #[test]
205 fn get_missing_returns_none() {
206 let s = SessionState::new();
207 assert_eq!(s.get::<String>("nope"), None);
208 }
209
210 #[test]
211 fn get_wrong_type_returns_none() {
212 let mut s = SessionState::new();
213 s.set("key", "hello").unwrap();
214 assert_eq!(s.get::<i64>("key"), None);
216 assert_eq!(s.get::<String>("key"), Some("hello".to_string()));
218 }
219
220 #[test]
221 fn remove_existing_key() {
222 let mut s = SessionState::new();
223 s.set("x", 1).unwrap();
224 s.remove("x");
225 assert!(!s.contains("x"));
226 assert!(s.is_empty());
227 }
228
229 #[test]
230 fn remove_absent_key_is_noop() {
231 let mut s = SessionState::new();
232 s.remove("nope");
233 assert!(s.delta().is_empty());
234 }
235
236 #[test]
237 fn contains_keys_len_is_empty() {
238 let mut s = SessionState::new();
239 assert!(s.is_empty());
240 s.set("a", 1).unwrap();
241 s.set("b", 2).unwrap();
242 assert!(s.contains("a"));
243 assert!(!s.contains("c"));
244 assert_eq!(s.len(), 2);
245 assert!(!s.is_empty());
246 let keys: Vec<&str> = s.keys().collect();
247 assert!(keys.contains(&"a"));
248 assert!(keys.contains(&"b"));
249 }
250
251 #[test]
252 fn clear_records_all_removals() {
253 let mut s = SessionState::new();
254 s.set("a", 1).unwrap();
255 s.set("b", 2).unwrap();
256 s.flush_delta(); s.clear();
258 assert!(s.is_empty());
259 assert_eq!(s.delta().len(), 2);
260 assert_eq!(s.delta().changes["a"], None);
261 assert_eq!(s.delta().changes["b"], None);
262 }
263
264 #[test]
267 fn delta_set_set_last_wins() {
268 let mut s = SessionState::new();
269 s.set("k", 1).unwrap();
270 s.set("k", 2).unwrap();
271 assert_eq!(s.delta().changes["k"], Some(json!(2)));
272 assert_eq!(s.delta().len(), 1);
273 }
274
275 #[test]
276 fn delta_set_remove_is_none() {
277 let mut s = SessionState::new();
278 s.set("k", 1).unwrap();
279 s.remove("k");
280 assert_eq!(s.delta().changes["k"], None);
281 }
282
283 #[test]
284 fn delta_remove_set_is_some() {
285 let mut s = SessionState::with_data(std::iter::once(("k".to_string(), json!(1))).collect());
286 s.remove("k");
287 s.set("k", 99).unwrap();
288 assert_eq!(s.delta().changes["k"], Some(json!(99)));
289 }
290
291 #[test]
294 fn flush_delta_returns_and_resets() {
295 let mut s = SessionState::new();
296 s.set("a", 1).unwrap();
297 let d = s.flush_delta();
298 assert_eq!(d.len(), 1);
299 assert!(s.delta().is_empty());
300 }
301
302 #[test]
303 fn flush_empty_delta_returns_empty() {
304 let mut s = SessionState::new();
305 let d = s.flush_delta();
306 assert!(d.is_empty());
307 }
308
309 #[test]
312 fn with_data_pre_seeds_without_delta() {
313 let data: HashMap<String, Value> = std::iter::once(("x".into(), json!(42))).collect();
314 let s = SessionState::with_data(data);
315 assert_eq!(s.get::<i64>("x"), Some(42));
316 assert!(s.delta().is_empty());
317 }
318
319 #[test]
322 fn snapshot_restore_roundtrip() {
323 let mut s = SessionState::new();
324 s.set("name", "alice").unwrap();
325 s.set("age", 30).unwrap();
326 let snap = s.snapshot();
327 let s2 = SessionState::restore_from_snapshot(snap).unwrap();
328 assert_eq!(s2.get::<String>("name"), Some("alice".to_string()));
329 assert_eq!(s2.get::<i64>("age"), Some(30));
330 assert!(s2.delta().is_empty());
331 }
332
333 #[test]
336 fn serde_roundtrip_skips_delta() {
337 let mut s = SessionState::new();
338 s.set("k", "v").unwrap();
339 assert!(!s.delta().is_empty());
341 let json = serde_json::to_string(&s).unwrap();
342 let s2: SessionState = serde_json::from_str(&json).unwrap();
343 assert_eq!(s2.get::<String>("k"), Some("v".to_string()));
344 assert!(s2.delta().is_empty());
346 }
347
348 #[test]
351 fn set_returns_error_on_serialization_failure() {
352 use serde::ser::{self, Serializer};
353
354 struct Unserializable;
356
357 impl Serialize for Unserializable {
358 fn serialize<S: Serializer>(&self, _s: S) -> Result<S::Ok, S::Error> {
359 Err(ser::Error::custom("intentional serialization failure"))
360 }
361 }
362
363 let mut s = SessionState::new();
364 let result = s.set("bad", Unserializable);
365 assert!(result.is_err());
366 assert!(!s.contains("bad"));
368 assert!(s.delta().is_empty());
369 }
370
371 #[test]
374 fn nested_json_roundtrip() {
375 let mut s = SessionState::new();
376 let nested = json!({
377 "user": {"name": "bob", "scores": [1, 2, 3]},
378 "active": true
379 });
380 s.set("profile", nested.clone()).unwrap();
381 let snap = s.snapshot();
382 let s2 = SessionState::restore_from_snapshot(snap).unwrap();
383 assert_eq!(s2.get_raw("profile"), Some(&nested));
384 }
385
386 #[test]
387 fn restore_from_corrupt_snapshot_returns_error() {
388 let err = SessionState::restore_from_snapshot(json!(["not", "an", "object"])).unwrap_err();
389 assert!(err.to_string().contains("map"));
390 }
391}