1use std::collections::HashMap;
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use dashmap::DashMap;
11use serde_json::Value;
12
13pub struct StateKey<T> {
25 key: &'static str,
26 _phantom: PhantomData<fn() -> T>,
27}
28
29impl<T> StateKey<T> {
30 pub const fn new(key: &'static str) -> Self {
32 Self {
33 key,
34 _phantom: PhantomData,
35 }
36 }
37
38 pub const fn key(&self) -> &'static str {
40 self.key
41 }
42}
43
44#[derive(Debug, Clone)]
50pub struct State {
51 inner: Arc<DashMap<String, Value>>,
52 delta: Arc<DashMap<String, Value>>,
53 track_delta: bool,
54}
55
56impl Default for State {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl State {
63 pub fn new() -> Self {
65 Self {
66 inner: Arc::new(DashMap::new()),
67 delta: Arc::new(DashMap::new()),
68 track_delta: false,
69 }
70 }
71
72 pub fn with_delta_tracking(&self) -> State {
75 State {
76 inner: self.inner.clone(),
77 delta: Arc::new(DashMap::new()),
78 track_delta: true,
79 }
80 }
81
82 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
85 self.get_raw(key)
86 .and_then(|v| serde_json::from_value(v).ok())
87 }
88
89 pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
97 where
98 F: FnOnce(&Value) -> R,
99 {
100 if self.track_delta {
101 if let Some(ref_multi) = self.delta.get(key) {
102 return Some(f(ref_multi.value()));
103 }
104 }
105 if let Some(ref_multi) = self.inner.get(key) {
106 return Some(f(ref_multi.value()));
107 }
108 if !key.contains(':') {
109 let mut derived_key = String::with_capacity(8 + key.len());
110 use std::fmt::Write;
111 let _ = write!(derived_key, "derived:{}", key);
112 if self.track_delta {
113 if let Some(ref_multi) = self.delta.get(&derived_key) {
114 return Some(f(ref_multi.value()));
115 }
116 }
117 if let Some(ref_multi) = self.inner.get(&derived_key) {
118 return Some(f(ref_multi.value()));
119 }
120 }
121 None
122 }
123
124 pub fn get_raw(&self, key: &str) -> Option<Value> {
129 if self.track_delta {
130 if let Some(v) = self.delta.get(key) {
131 return Some(v.value().clone());
132 }
133 }
134 if let Some(v) = self.inner.get(key) {
135 return Some(v.value().clone());
136 }
137 if !key.contains(':') {
139 use std::fmt::Write;
140 let mut derived_key = String::with_capacity(8 + key.len());
141 let _ = write!(derived_key, "derived:{}", key);
142 if self.track_delta {
143 if let Some(v) = self.delta.get(&derived_key) {
144 return Some(v.value().clone());
145 }
146 }
147 return self.inner.get(&derived_key).map(|v| v.value().clone());
148 }
149 None
150 }
151
152 pub fn get_key<T: serde::de::DeserializeOwned>(&self, key: &StateKey<T>) -> Option<T> {
154 self.get(key.key())
155 }
156
157 pub fn set_key<T: serde::Serialize>(&self, key: &StateKey<T>, value: T) {
159 self.set(key.key(), value);
160 }
161
162 pub fn with_key<T, F, R>(&self, key: &StateKey<T>, f: F) -> Option<R>
164 where
165 F: FnOnce(&Value) -> R,
166 {
167 self.with(key.key(), f)
168 }
169
170 pub fn set(&self, key: impl Into<String>, value: impl serde::Serialize) {
173 let v = serde_json::to_value(value).expect("value must be serializable");
174 if self.track_delta {
175 self.delta.insert(key.into(), v);
176 } else {
177 self.inner.insert(key.into(), v);
178 }
179 }
180
181 pub fn set_committed(&self, key: impl Into<String>, value: impl serde::Serialize) {
183 let v = serde_json::to_value(value).expect("value must be serializable");
184 self.inner.insert(key.into(), v);
185 }
186
187 pub fn modify<T, F>(&self, key: &str, default: T, f: F) -> T
193 where
194 T: serde::Serialize + serde::de::DeserializeOwned,
195 F: FnOnce(T) -> T,
196 {
197 let current: T = self.get(key).unwrap_or(default);
199 let new_val = f(current);
200 self.set(key, &new_val);
201 new_val
202 }
203
204 pub fn contains(&self, key: &str) -> bool {
206 if self.track_delta && self.delta.contains_key(key) {
207 return true;
208 }
209 self.inner.contains_key(key)
210 }
211
212 pub fn remove(&self, key: &str) -> Option<Value> {
214 if self.track_delta {
215 let from_delta = self.delta.remove(key).map(|(_, v)| v);
217 let from_inner = self.inner.remove(key).map(|(_, v)| v);
218 from_delta.or(from_inner)
219 } else {
220 self.inner.remove(key).map(|(_, v)| v)
221 }
222 }
223
224 pub fn keys(&self) -> Vec<String> {
226 if !self.track_delta || self.delta.is_empty() {
227 return self.inner.iter().map(|r| r.key().clone()).collect();
228 }
229 let mut seen =
230 std::collections::HashSet::with_capacity(self.inner.len() + self.delta.len());
231 let mut keys = Vec::with_capacity(self.inner.len() + self.delta.len());
232 for entry in self.inner.iter() {
233 let key = entry.key().clone();
234 seen.insert(key.clone());
235 keys.push(key);
236 }
237 for entry in self.delta.iter() {
238 let key = entry.key().clone();
239 if seen.insert(key.clone()) {
240 keys.push(key);
241 }
242 }
243 keys
244 }
245
246 pub fn pick(&self, keys: &[&str]) -> State {
248 let new = State::new();
249 for key in keys {
250 if let Some(v) = self.get_raw(key) {
251 new.set(*key, v);
252 }
253 }
254 new
255 }
256
257 pub fn merge(&self, other: &State) {
259 for entry in other.inner.iter() {
260 self.inner
261 .insert(entry.key().clone(), entry.value().clone());
262 }
263 }
264
265 pub fn rename(&self, from: &str, to: &str) {
267 if let Some(v) = self.remove(from) {
268 if self.track_delta {
269 self.delta.insert(to.to_string(), v);
270 } else {
271 self.inner.insert(to.to_string(), v);
272 }
273 }
274 }
275
276 pub fn is_tracking_delta(&self) -> bool {
280 self.track_delta
281 }
282
283 pub fn has_delta(&self) -> bool {
285 self.track_delta && !self.delta.is_empty()
286 }
287
288 pub fn delta(&self) -> HashMap<String, Value> {
290 self.delta
291 .iter()
292 .map(|entry| (entry.key().clone(), entry.value().clone()))
293 .collect()
294 }
295
296 pub fn commit(&self) {
298 for entry in self.delta.iter() {
299 self.inner
300 .insert(entry.key().clone(), entry.value().clone());
301 }
302 self.delta.clear();
303 }
304
305 pub fn rollback(&self) {
307 self.delta.clear();
308 }
309
310 pub fn app(&self) -> PrefixedState<'_> {
314 PrefixedState {
315 state: self,
316 prefix: "app:",
317 }
318 }
319
320 pub fn user(&self) -> PrefixedState<'_> {
322 PrefixedState {
323 state: self,
324 prefix: "user:",
325 }
326 }
327
328 pub fn temp(&self) -> PrefixedState<'_> {
330 PrefixedState {
331 state: self,
332 prefix: "temp:",
333 }
334 }
335
336 pub fn session(&self) -> PrefixedState<'_> {
338 PrefixedState {
339 state: self,
340 prefix: "session:",
341 }
342 }
343
344 pub fn turn(&self) -> PrefixedState<'_> {
346 PrefixedState {
347 state: self,
348 prefix: "turn:",
349 }
350 }
351
352 pub fn bg(&self) -> PrefixedState<'_> {
354 PrefixedState {
355 state: self,
356 prefix: "bg:",
357 }
358 }
359
360 pub fn derived(&self) -> ReadOnlyPrefixedState<'_> {
362 ReadOnlyPrefixedState {
363 state: self,
364 prefix: "derived:",
365 }
366 }
367
368 pub fn snapshot_values(&self, keys: &[&str]) -> HashMap<String, Value> {
373 keys.iter()
374 .filter_map(|&k| self.get_raw(k).map(|v| (k.to_string(), v)))
375 .collect()
376 }
377
378 pub fn diff_values(
381 &self,
382 prev: &HashMap<String, Value>,
383 keys: &[&str],
384 ) -> Vec<(String, Value, Value)> {
385 keys.iter()
386 .filter_map(|&k| {
387 let old = prev.get(k);
388 let new = self.get_raw(k);
389 match (old, new) {
390 (Some(o), Some(n)) if o != &n => Some((k.to_string(), o.clone(), n)),
391 (None, Some(n)) => Some((k.to_string(), Value::Null, n)),
392 (Some(o), None) => Some((k.to_string(), o.clone(), Value::Null)),
393 _ => None,
394 }
395 })
396 .collect()
397 }
398
399 pub fn to_hashmap(&self) -> std::collections::HashMap<String, serde_json::Value> {
401 self.inner
402 .iter()
403 .map(|entry| (entry.key().clone(), entry.value().clone()))
404 .collect()
405 }
406
407 pub fn from_hashmap(&self, map: std::collections::HashMap<String, serde_json::Value>) {
409 for (key, value) in map {
410 self.inner.insert(key, value);
411 }
412 }
413
414 pub fn clear_prefix(&self, prefix: &str) {
416 let keys_to_remove: Vec<String> = self
417 .inner
418 .iter()
419 .filter(|entry| entry.key().starts_with(prefix))
420 .map(|entry| entry.key().clone())
421 .collect();
422 for key in keys_to_remove {
423 self.inner.remove(&key);
424 }
425 if self.track_delta {
426 let delta_keys: Vec<String> = self
427 .delta
428 .iter()
429 .filter(|entry| entry.key().starts_with(prefix))
430 .map(|entry| entry.key().clone())
431 .collect();
432 for key in delta_keys {
433 self.delta.remove(&key);
434 }
435 }
436 }
437}
438
439pub struct PrefixedState<'a> {
441 state: &'a State,
442 prefix: &'static str,
443}
444
445impl<'a> PrefixedState<'a> {
446 fn prefixed_key(&self, key: &str) -> String {
447 format!("{}{}", self.prefix, key)
448 }
449
450 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
452 self.state.get(&self.prefixed_key(key))
453 }
454
455 pub fn get_raw(&self, key: &str) -> Option<Value> {
457 self.state.get_raw(&self.prefixed_key(key))
458 }
459
460 pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
462 where
463 F: FnOnce(&Value) -> R,
464 {
465 self.state.with(&self.prefixed_key(key), f)
466 }
467
468 pub fn set(&self, key: impl AsRef<str>, value: impl serde::Serialize) {
470 self.state.set(self.prefixed_key(key.as_ref()), value);
471 }
472
473 pub fn contains(&self, key: &str) -> bool {
475 self.state.contains(&self.prefixed_key(key))
476 }
477
478 pub fn remove(&self, key: &str) -> Option<Value> {
480 self.state.remove(&self.prefixed_key(key))
481 }
482
483 pub fn keys(&self) -> Vec<String> {
485 self.state
486 .keys()
487 .into_iter()
488 .filter_map(|k| k.strip_prefix(self.prefix).map(|s| s.to_string()))
489 .collect()
490 }
491}
492
493pub struct ReadOnlyPrefixedState<'a> {
498 state: &'a State,
499 prefix: &'static str,
500}
501
502impl<'a> ReadOnlyPrefixedState<'a> {
503 fn prefixed_key(&self, key: &str) -> String {
504 format!("{}{}", self.prefix, key)
505 }
506
507 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
509 self.state.get(&self.prefixed_key(key))
510 }
511
512 pub fn get_raw(&self, key: &str) -> Option<Value> {
514 self.state.get_raw(&self.prefixed_key(key))
515 }
516
517 pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
519 where
520 F: FnOnce(&Value) -> R,
521 {
522 self.state.with(&self.prefixed_key(key), f)
523 }
524
525 pub fn contains(&self, key: &str) -> bool {
527 self.state.contains(&self.prefixed_key(key))
528 }
529
530 pub fn keys(&self) -> Vec<String> {
532 self.state
533 .keys()
534 .into_iter()
535 .filter_map(|k| k.strip_prefix(self.prefix).map(|s| s.to_string()))
536 .collect()
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 #[test]
545 fn set_and_get_string() {
546 let state = State::new();
547 state.set("name", "Alice");
548 assert_eq!(state.get::<String>("name"), Some("Alice".to_string()));
549 }
550
551 #[test]
552 fn set_and_get_json() {
553 let state = State::new();
554 state.set("data", serde_json::json!({"temp": 22}));
555 let v: Value = state.get("data").unwrap();
556 assert_eq!(v["temp"], 22);
557 }
558
559 #[test]
560 fn pick_subset() {
561 let state = State::new();
562 state.set("a", 1);
563 state.set("b", 2);
564 state.set("c", 3);
565 let picked = state.pick(&["a", "c"]);
566 assert!(picked.contains("a"));
567 assert!(!picked.contains("b"));
568 assert!(picked.contains("c"));
569 }
570
571 #[test]
572 fn merge_states() {
573 let s1 = State::new();
574 s1.set("a", 1);
575 let s2 = State::new();
576 s2.set("b", 2);
577 s1.merge(&s2);
578 assert!(s1.contains("a"));
579 assert!(s1.contains("b"));
580 }
581
582 #[test]
583 fn rename_key() {
584 let state = State::new();
585 state.set("old", "value");
586 state.rename("old", "new");
587 assert!(!state.contains("old"));
588 assert_eq!(state.get::<String>("new"), Some("value".to_string()));
589 }
590
591 #[test]
592 fn remove_returns_value() {
593 let state = State::new();
594 state.set("key", 42);
595 let removed = state.remove("key");
596 assert!(removed.is_some());
597 assert!(!state.contains("key"));
598 }
599
600 #[test]
601 fn get_missing_returns_none() {
602 let state = State::new();
603 assert_eq!(state.get::<String>("nope"), None);
604 }
605
606 #[test]
609 fn delta_tracking_writes_to_delta() {
610 let state = State::new();
611 state.set("committed", "yes");
612
613 let tracked = state.with_delta_tracking();
614 tracked.set("new_key", "new_value");
615
616 assert_eq!(
618 tracked.get::<String>("new_key"),
619 Some("new_value".to_string())
620 );
621 assert!(!state.contains("new_key"));
623 assert_eq!(tracked.get::<String>("committed"), Some("yes".to_string()));
625 }
626
627 #[test]
628 fn delta_has_delta_reports_correctly() {
629 let state = State::new();
630 let tracked = state.with_delta_tracking();
631 assert!(!tracked.has_delta());
632
633 tracked.set("key", "val");
634 assert!(tracked.has_delta());
635 }
636
637 #[test]
638 fn delta_commit_merges_to_inner() {
639 let state = State::new();
640 let tracked = state.with_delta_tracking();
641 tracked.set("key", "val");
642 assert!(!state.contains("key"));
643
644 tracked.commit();
645 assert_eq!(state.get::<String>("key"), Some("val".to_string()));
647 assert!(!tracked.has_delta());
648 }
649
650 #[test]
651 fn delta_rollback_discards_changes() {
652 let state = State::new();
653 let tracked = state.with_delta_tracking();
654 tracked.set("key", "val");
655 assert!(tracked.has_delta());
656
657 tracked.rollback();
658 assert!(!tracked.has_delta());
659 assert!(!state.contains("key"));
660 assert!(!tracked.contains("key"));
661 }
662
663 #[test]
664 fn delta_snapshot() {
665 let state = State::new();
666 let tracked = state.with_delta_tracking();
667 tracked.set("a", 1);
668 tracked.set("b", 2);
669
670 let snapshot = tracked.delta();
671 assert_eq!(snapshot.len(), 2);
672 assert!(snapshot.contains_key("a"));
673 assert!(snapshot.contains_key("b"));
674 }
675
676 #[test]
677 fn set_committed_bypasses_delta() {
678 let state = State::new();
679 let tracked = state.with_delta_tracking();
680 tracked.set_committed("direct", "value");
681
682 assert_eq!(state.get::<String>("direct"), Some("value".to_string()));
684 assert!(!tracked.has_delta());
686 assert_eq!(tracked.get::<String>("direct"), Some("value".to_string()));
688 }
689
690 #[test]
691 fn no_delta_tracking_preserves_existing_behavior() {
692 let state = State::new();
693 assert!(!state.is_tracking_delta());
694 state.set("key", "val");
695 assert_eq!(state.get::<String>("key"), Some("val".to_string()));
696 assert!(!state.has_delta());
697 }
698
699 #[test]
702 fn prefix_app_set_and_get() {
703 let state = State::new();
704 state.app().set("flag", true);
705
706 assert_eq!(state.app().get::<bool>("flag"), Some(true));
708 assert_eq!(state.get::<bool>("app:flag"), Some(true));
710 }
711
712 #[test]
713 fn prefix_user_set_and_get() {
714 let state = State::new();
715 state.user().set("name", "Alice");
716 assert_eq!(
717 state.user().get::<String>("name"),
718 Some("Alice".to_string())
719 );
720 assert_eq!(state.get::<String>("user:name"), Some("Alice".to_string()));
721 }
722
723 #[test]
724 fn prefix_temp_set_and_get() {
725 let state = State::new();
726 state.temp().set("scratch", 42);
727 assert_eq!(state.temp().get::<i32>("scratch"), Some(42));
728 }
729
730 #[test]
731 fn prefix_contains_and_remove() {
732 let state = State::new();
733 state.app().set("x", 1);
734 assert!(state.app().contains("x"));
735 state.app().remove("x");
736 assert!(!state.app().contains("x"));
737 }
738
739 #[test]
740 fn prefix_keys() {
741 let state = State::new();
742 state.app().set("a", 1);
743 state.app().set("b", 2);
744 state.user().set("c", 3);
745
746 let app_keys = state.app().keys();
747 assert_eq!(app_keys.len(), 2);
748 assert!(app_keys.contains(&"a".to_string()));
749 assert!(app_keys.contains(&"b".to_string()));
750
751 let user_keys = state.user().keys();
752 assert_eq!(user_keys.len(), 1);
753 assert!(user_keys.contains(&"c".to_string()));
754 }
755
756 #[test]
757 fn prefix_with_delta_tracking() {
758 let state = State::new();
759 let tracked = state.with_delta_tracking();
760 tracked.app().set("flag", true);
761
762 assert_eq!(tracked.app().get::<bool>("flag"), Some(true));
764 assert!(tracked.has_delta());
766 assert!(!state.contains("app:flag"));
767
768 tracked.commit();
769 assert_eq!(state.get::<bool>("app:flag"), Some(true));
770 }
771
772 #[test]
775 fn prefix_session_set_and_get() {
776 let state = State::new();
777 state.session().set("turn_count", 5);
778 assert_eq!(state.session().get::<i32>("turn_count"), Some(5));
779 assert_eq!(state.get::<i32>("session:turn_count"), Some(5));
780 }
781
782 #[test]
783 fn prefix_turn_set_and_get() {
784 let state = State::new();
785 state.turn().set("transcript", "hello");
786 assert_eq!(
787 state.turn().get::<String>("transcript"),
788 Some("hello".to_string())
789 );
790 assert_eq!(
791 state.get::<String>("turn:transcript"),
792 Some("hello".to_string())
793 );
794 }
795
796 #[test]
797 fn prefix_bg_set_and_get() {
798 let state = State::new();
799 state.bg().set("task_id", "abc-123");
800 assert_eq!(
801 state.bg().get::<String>("task_id"),
802 Some("abc-123".to_string())
803 );
804 assert_eq!(
805 state.get::<String>("bg:task_id"),
806 Some("abc-123".to_string())
807 );
808 }
809
810 #[test]
811 fn prefix_session_contains_and_remove() {
812 let state = State::new();
813 state.session().set("x", 1);
814 assert!(state.session().contains("x"));
815 state.session().remove("x");
816 assert!(!state.session().contains("x"));
817 }
818
819 #[test]
820 fn prefix_turn_keys() {
821 let state = State::new();
822 state.turn().set("a", 1);
823 state.turn().set("b", 2);
824 state.session().set("c", 3);
825
826 let turn_keys = state.turn().keys();
827 assert_eq!(turn_keys.len(), 2);
828 assert!(turn_keys.contains(&"a".to_string()));
829 assert!(turn_keys.contains(&"b".to_string()));
830 }
831
832 #[test]
835 fn derived_read_only_get() {
836 let state = State::new();
837 state.set("derived:sentiment", "positive");
839 assert_eq!(
840 state.derived().get::<String>("sentiment"),
841 Some("positive".to_string())
842 );
843 }
844
845 #[test]
846 fn derived_read_only_get_raw() {
847 let state = State::new();
848 state.set("derived:score", serde_json::json!(0.95));
849 let raw = state.derived().get_raw("score");
850 assert!(raw.is_some());
851 assert_eq!(raw.unwrap(), serde_json::json!(0.95));
852 }
853
854 #[test]
855 fn derived_read_only_contains() {
856 let state = State::new();
857 state.set("derived:exists", true);
858 assert!(state.derived().contains("exists"));
859 assert!(!state.derived().contains("missing"));
860 }
861
862 #[test]
863 fn derived_read_only_keys() {
864 let state = State::new();
865 state.set("derived:a", 1);
866 state.set("derived:b", 2);
867 state.set("app:c", 3);
868
869 let derived_keys = state.derived().keys();
870 assert_eq!(derived_keys.len(), 2);
871 assert!(derived_keys.contains(&"a".to_string()));
872 assert!(derived_keys.contains(&"b".to_string()));
873 }
874
875 #[test]
876 fn derived_missing_key_returns_none() {
877 let state = State::new();
878 assert_eq!(state.derived().get::<String>("nope"), None);
879 assert_eq!(state.derived().get_raw("nope"), None);
880 }
881
882 #[test]
885 fn snapshot_values_captures_existing_keys() {
886 let state = State::new();
887 state.set("a", 1);
888 state.set("b", "hello");
889 state.set("c", true);
890
891 let snap = state.snapshot_values(&["a", "b", "missing"]);
892 assert_eq!(snap.len(), 2);
893 assert_eq!(snap.get("a"), Some(&serde_json::json!(1)));
894 assert_eq!(snap.get("b"), Some(&serde_json::json!("hello")));
895 assert!(!snap.contains_key("missing"));
896 }
897
898 #[test]
899 fn snapshot_values_empty_keys() {
900 let state = State::new();
901 state.set("a", 1);
902 let snap = state.snapshot_values(&[]);
903 assert!(snap.is_empty());
904 }
905
906 #[test]
909 fn diff_values_detects_changed_value() {
910 let state = State::new();
911 state.set("x", 1);
912 let snap = state.snapshot_values(&["x"]);
913
914 state.set("x", 2);
915 let diffs = state.diff_values(&snap, &["x"]);
916 assert_eq!(diffs.len(), 1);
917 assert_eq!(diffs[0].0, "x");
918 assert_eq!(diffs[0].1, serde_json::json!(1));
919 assert_eq!(diffs[0].2, serde_json::json!(2));
920 }
921
922 #[test]
923 fn diff_values_detects_new_key() {
924 let state = State::new();
925 let snap = state.snapshot_values(&["y"]);
926
927 state.set("y", "new");
928 let diffs = state.diff_values(&snap, &["y"]);
929 assert_eq!(diffs.len(), 1);
930 assert_eq!(diffs[0].0, "y");
931 assert_eq!(diffs[0].1, Value::Null);
932 assert_eq!(diffs[0].2, serde_json::json!("new"));
933 }
934
935 #[test]
936 fn diff_values_detects_removed_key() {
937 let state = State::new();
938 state.set("z", 42);
939 let snap = state.snapshot_values(&["z"]);
940
941 state.remove("z");
942 let diffs = state.diff_values(&snap, &["z"]);
943 assert_eq!(diffs.len(), 1);
944 assert_eq!(diffs[0].0, "z");
945 assert_eq!(diffs[0].1, serde_json::json!(42));
946 assert_eq!(diffs[0].2, Value::Null);
947 }
948
949 #[test]
950 fn diff_values_no_change() {
951 let state = State::new();
952 state.set("stable", 10);
953 let snap = state.snapshot_values(&["stable"]);
954
955 let diffs = state.diff_values(&snap, &["stable"]);
957 assert!(diffs.is_empty());
958 }
959
960 #[test]
961 fn diff_values_multiple_keys_mixed_changes() {
962 let state = State::new();
963 state.set("a", 1);
964 state.set("b", 2);
965 let snap = state.snapshot_values(&["a", "b", "c"]);
966
967 state.set("a", 10); state.set("c", 3); let diffs = state.diff_values(&snap, &["a", "b", "c"]);
972 assert_eq!(diffs.len(), 2); let diff_keys: Vec<&str> = diffs.iter().map(|(k, _, _)| k.as_str()).collect();
974 assert!(diff_keys.contains(&"a"));
975 assert!(diff_keys.contains(&"c"));
976 }
977
978 #[test]
981 fn clear_prefix_removes_matching_keys() {
982 let state = State::new();
983 state.set("turn:a", 1);
984 state.set("turn:b", 2);
985 state.set("app:c", 3);
986 state.set("session:d", 4);
987
988 state.clear_prefix("turn:");
989 assert!(!state.contains("turn:a"));
990 assert!(!state.contains("turn:b"));
991 assert!(state.contains("app:c"));
992 assert!(state.contains("session:d"));
993 }
994
995 #[test]
996 fn clear_prefix_no_matching_keys_is_noop() {
997 let state = State::new();
998 state.set("app:x", 1);
999 state.clear_prefix("turn:");
1000 assert!(state.contains("app:x"));
1001 }
1002
1003 #[test]
1004 fn clear_prefix_also_clears_delta() {
1005 let state = State::new();
1006 state.set("turn:committed", 1);
1007 let tracked = state.with_delta_tracking();
1008 tracked.set("turn:delta_val", 2);
1009
1010 assert!(tracked.contains("turn:committed"));
1012 assert!(tracked.contains("turn:delta_val"));
1013
1014 tracked.clear_prefix("turn:");
1015 assert!(!tracked.contains("turn:committed"));
1016 assert!(!tracked.contains("turn:delta_val"));
1017 }
1018
1019 #[test]
1020 fn clear_prefix_via_turn_accessor() {
1021 let state = State::new();
1022 state.turn().set("x", 1);
1023 state.turn().set("y", 2);
1024 state.app().set("z", 3);
1025
1026 state.clear_prefix("turn:");
1027 assert!(state.turn().keys().is_empty());
1028 assert!(state.app().contains("z"));
1029 }
1030
1031 #[test]
1034 fn modify_increment_existing() {
1035 let state = State::new();
1036 state.set("count", 5u32);
1037 let result = state.modify("count", 0u32, |n| n + 1);
1038 assert_eq!(result, 6);
1039 assert_eq!(state.get::<u32>("count"), Some(6));
1040 }
1041
1042 #[test]
1043 fn modify_uses_default_when_missing() {
1044 let state = State::new();
1045 let result = state.modify("new_count", 0u32, |n| n + 1);
1046 assert_eq!(result, 1);
1047 assert_eq!(state.get::<u32>("new_count"), Some(1));
1048 }
1049
1050 #[test]
1051 fn modify_with_delta_tracking() {
1052 let state = State::new();
1053 state.set("x", 10u32);
1054 let tracked = state.with_delta_tracking();
1055 let result = tracked.modify("x", 0u32, |n| n * 2);
1056 assert_eq!(result, 20);
1057 assert_eq!(tracked.get::<u32>("x"), Some(20));
1059 assert_eq!(state.get::<u32>("x"), Some(10)); }
1061
1062 #[test]
1065 fn get_falls_back_to_derived_prefix() {
1066 let state = State::new();
1067 state.set("derived:risk", 0.85);
1068 assert_eq!(state.get::<f64>("risk"), Some(0.85));
1070 }
1071
1072 #[test]
1073 fn get_prefers_direct_key_over_derived() {
1074 let state = State::new();
1075 state.set("score", 1.0);
1076 state.set("derived:score", 0.5);
1077 assert_eq!(state.get::<f64>("score"), Some(1.0));
1079 }
1080
1081 #[test]
1082 fn get_derived_fallback_skipped_for_prefixed_keys() {
1083 let state = State::new();
1084 state.set("derived:risk", 0.85);
1085 assert_eq!(state.get::<f64>("app:risk"), None);
1087 }
1088
1089 #[test]
1090 fn get_derived_fallback_with_delta_tracking() {
1091 let state = State::new();
1092 let tracked = state.with_delta_tracking();
1093 tracked.set("derived:computed_val", 42);
1094 assert_eq!(tracked.get::<i32>("computed_val"), Some(42));
1095 }
1096
1097 #[test]
1100 fn with_reads_from_inner() {
1101 let state = State::new();
1102 state.set("name", "Alice");
1103 let len = state.with("name", |v| v.as_str().unwrap().len());
1104 assert_eq!(len, Some(5));
1105 }
1106
1107 #[test]
1108 fn with_reads_from_delta_first() {
1109 let state = State::new();
1110 state.set("x", 1);
1111 let tracked = state.with_delta_tracking();
1112 tracked.set("x", 99);
1113 let val = tracked.with("x", |v| v.as_i64().unwrap());
1114 assert_eq!(val, Some(99));
1115 }
1116
1117 #[test]
1118 fn with_falls_back_to_inner_when_not_in_delta() {
1119 let state = State::new();
1120 state.set("committed", "yes");
1121 let tracked = state.with_delta_tracking();
1122 let val = tracked.with("committed", |v| v.as_str().unwrap().to_string());
1123 assert_eq!(val, Some("yes".to_string()));
1124 }
1125
1126 #[test]
1127 fn with_falls_back_to_derived() {
1128 let state = State::new();
1129 state.set("derived:risk", 0.85);
1130 let val = state.with("risk", |v| v.as_f64().unwrap());
1131 assert_eq!(val, Some(0.85));
1132 }
1133
1134 #[test]
1135 fn with_derived_fallback_skipped_for_prefixed() {
1136 let state = State::new();
1137 state.set("derived:risk", 0.85);
1138 let val = state.with("app:risk", |v| v.as_f64().unwrap());
1139 assert_eq!(val, None);
1140 }
1141
1142 #[test]
1143 fn with_returns_none_for_missing() {
1144 let state = State::new();
1145 let val = state.with("missing", |v| v.clone());
1146 assert_eq!(val, None);
1147 }
1148
1149 #[test]
1150 fn with_on_prefixed_state() {
1151 let state = State::new();
1152 state.app().set("flag", true);
1153 let val = state.app().with("flag", |v| v.as_bool().unwrap());
1154 assert_eq!(val, Some(true));
1155 }
1156
1157 #[test]
1158 fn with_on_read_only_prefixed_state() {
1159 let state = State::new();
1160 state.set("derived:score", serde_json::json!(0.95));
1161 let val = state.derived().with("score", |v| v.as_f64().unwrap());
1162 assert_eq!(val, Some(0.95));
1163 }
1164
1165 const TURN_COUNT: StateKey<u32> = StateKey::new("session:turn_count");
1168 const NAME: StateKey<String> = StateKey::new("user:name");
1169
1170 #[test]
1171 fn state_key_get_and_set() {
1172 let state = State::new();
1173 state.set_key(&TURN_COUNT, 5);
1174 assert_eq!(state.get_key(&TURN_COUNT), Some(5));
1175 }
1176
1177 #[test]
1178 fn state_key_get_missing() {
1179 let state = State::new();
1180 assert_eq!(state.get_key(&TURN_COUNT), None);
1181 }
1182
1183 #[test]
1184 fn state_key_string_type() {
1185 let state = State::new();
1186 state.set_key(&NAME, "Alice".to_string());
1187 assert_eq!(state.get_key(&NAME), Some("Alice".to_string()));
1188 }
1189
1190 #[test]
1191 fn state_key_with() {
1192 let state = State::new();
1193 state.set_key(&TURN_COUNT, 42);
1194 let val = state.with_key(&TURN_COUNT, |v| v.as_u64().unwrap());
1195 assert_eq!(val, Some(42));
1196 }
1197
1198 #[test]
1199 fn state_key_interop_with_raw() {
1200 let state = State::new();
1201 state.set_key(&TURN_COUNT, 10);
1202 assert_eq!(state.get::<u32>("session:turn_count"), Some(10));
1204 }
1205}