Skip to main content

rs_adk/
state.rs

1//! Typed key-value state container for agents.
2//!
3//! Supports optional delta tracking for transactional state management
4//! and prefix-scoped accessors for namespace isolation.
5
6use std::collections::HashMap;
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use dashmap::DashMap;
11use serde_json::Value;
12
13/// A compile-time typed state key that eliminates typo bugs and type mismatches.
14///
15/// Create as a const and use with `State::get_key()` / `State::set_key()`:
16///
17/// ```rust,ignore
18/// const TURN_COUNT: StateKey<u32> = StateKey::new("session:turn_count");
19/// const SENTIMENT: StateKey<String> = StateKey::new("derived:sentiment");
20///
21/// state.set_key(&TURN_COUNT, 5);
22/// let count: Option<u32> = state.get_key(&TURN_COUNT);
23/// ```
24pub struct StateKey<T> {
25    key: &'static str,
26    _phantom: PhantomData<fn() -> T>,
27}
28
29impl<T> StateKey<T> {
30    /// Create a new typed state key.
31    pub const fn new(key: &'static str) -> Self {
32        Self {
33            key,
34            _phantom: PhantomData,
35        }
36    }
37
38    /// The string key.
39    pub const fn key(&self) -> &'static str {
40        self.key
41    }
42}
43
44/// A concurrent, type-safe state container that agents read from and write to.
45///
46/// By default, `set()` writes directly to the inner store. When delta tracking
47/// is enabled via `with_delta_tracking()`, writes go to a separate delta map
48/// that can be committed or rolled back.
49#[derive(Debug, Clone)]
50pub struct State {
51    inner: Arc<DashMap<String, Value>>,
52    delta: Arc<DashMap<String, Value>>,
53    track_delta: bool,
54}
55
56impl Default for State {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl State {
63    /// Create a new empty state container.
64    pub fn new() -> Self {
65        Self {
66            inner: Arc::new(DashMap::new()),
67            delta: Arc::new(DashMap::new()),
68            track_delta: false,
69        }
70    }
71
72    /// Create a new State with delta tracking enabled.
73    /// Writes go to the delta map; reads check delta first, then inner.
74    pub fn with_delta_tracking(&self) -> State {
75        State {
76            inner: self.inner.clone(),
77            delta: Arc::new(DashMap::new()),
78            track_delta: true,
79        }
80    }
81
82    /// Get a value by key, attempting to deserialize to the requested type.
83    /// When delta tracking is enabled, checks delta first, then inner.
84    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
85        self.get_raw(key)
86            .and_then(|v| serde_json::from_value(v).ok())
87    }
88
89    /// Borrow a value by key without cloning, applying `f` to the reference.
90    ///
91    /// This is the zero-copy alternative to `get_raw()`. The closure receives
92    /// a `&Value` directly from the DashMap ref-guard, avoiding the
93    /// `Value::clone()` + `serde_json::from_value()` overhead of `get()`.
94    ///
95    /// Lookup order: delta (if tracking) → inner → derived fallback.
96    pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
97    where
98        F: FnOnce(&Value) -> R,
99    {
100        if self.track_delta {
101            if let Some(ref_multi) = self.delta.get(key) {
102                return Some(f(ref_multi.value()));
103            }
104        }
105        if let Some(ref_multi) = self.inner.get(key) {
106            return Some(f(ref_multi.value()));
107        }
108        if !key.contains(':') {
109            let mut derived_key = String::with_capacity(8 + key.len());
110            use std::fmt::Write;
111            let _ = write!(derived_key, "derived:{}", key);
112            if self.track_delta {
113                if let Some(ref_multi) = self.delta.get(&derived_key) {
114                    return Some(f(ref_multi.value()));
115                }
116            }
117            if let Some(ref_multi) = self.inner.get(&derived_key) {
118                return Some(f(ref_multi.value()));
119            }
120        }
121        None
122    }
123
124    /// Get a raw JSON value by key.
125    /// When delta tracking is enabled, checks delta first, then inner.
126    /// If the key is not found and doesn't contain a prefix, also checks `derived:{key}`
127    /// as a transparent fallback for computed variables.
128    pub fn get_raw(&self, key: &str) -> Option<Value> {
129        if self.track_delta {
130            if let Some(v) = self.delta.get(key) {
131                return Some(v.value().clone());
132            }
133        }
134        if let Some(v) = self.inner.get(key) {
135            return Some(v.value().clone());
136        }
137        // Transparent derived fallback: if key has no prefix, check derived:{key}
138        if !key.contains(':') {
139            use std::fmt::Write;
140            let mut derived_key = String::with_capacity(8 + key.len());
141            let _ = write!(derived_key, "derived:{}", key);
142            if self.track_delta {
143                if let Some(v) = self.delta.get(&derived_key) {
144                    return Some(v.value().clone());
145                }
146            }
147            return self.inner.get(&derived_key).map(|v| v.value().clone());
148        }
149        None
150    }
151
152    /// Get a typed value using a `StateKey<T>`.
153    pub fn get_key<T: serde::de::DeserializeOwned>(&self, key: &StateKey<T>) -> Option<T> {
154        self.get(key.key())
155    }
156
157    /// Set a typed value using a `StateKey<T>`.
158    pub fn set_key<T: serde::Serialize>(&self, key: &StateKey<T>, value: T) {
159        self.set(key.key(), value);
160    }
161
162    /// Zero-copy borrow using a `StateKey<T>`.
163    pub fn with_key<T, F, R>(&self, key: &StateKey<T>, f: F) -> Option<R>
164    where
165        F: FnOnce(&Value) -> R,
166    {
167        self.with(key.key(), f)
168    }
169
170    /// Set a value by key.
171    /// When delta tracking is enabled, writes to delta instead of inner.
172    pub fn set(&self, key: impl Into<String>, value: impl serde::Serialize) {
173        let v = serde_json::to_value(value).expect("value must be serializable");
174        if self.track_delta {
175            self.delta.insert(key.into(), v);
176        } else {
177            self.inner.insert(key.into(), v);
178        }
179    }
180
181    /// Set a value directly in the committed store, bypassing delta tracking.
182    pub fn set_committed(&self, key: impl Into<String>, value: impl serde::Serialize) {
183        let v = serde_json::to_value(value).expect("value must be serializable");
184        self.inner.insert(key.into(), v);
185    }
186
187    /// Atomically read-modify-write a value.
188    ///
189    /// If the key doesn't exist, `default` is used as the initial value.
190    /// The function `f` receives the current value and returns the new value.
191    /// Returns the new value after modification.
192    pub fn modify<T, F>(&self, key: &str, default: T, f: F) -> T
193    where
194        T: serde::Serialize + serde::de::DeserializeOwned,
195        F: FnOnce(T) -> T,
196    {
197        // Read current value from whichever store has it
198        let current: T = self.get(key).unwrap_or(default);
199        let new_val = f(current);
200        self.set(key, &new_val);
201        new_val
202    }
203
204    /// Check if a key exists (in delta or inner).
205    pub fn contains(&self, key: &str) -> bool {
206        if self.track_delta && self.delta.contains_key(key) {
207            return true;
208        }
209        self.inner.contains_key(key)
210    }
211
212    /// Remove a key.
213    pub fn remove(&self, key: &str) -> Option<Value> {
214        if self.track_delta {
215            // Remove from delta if present, but also check inner
216            let from_delta = self.delta.remove(key).map(|(_, v)| v);
217            let from_inner = self.inner.remove(key).map(|(_, v)| v);
218            from_delta.or(from_inner)
219        } else {
220            self.inner.remove(key).map(|(_, v)| v)
221        }
222    }
223
224    /// Get all keys (from both inner and delta when tracking).
225    pub fn keys(&self) -> Vec<String> {
226        if !self.track_delta || self.delta.is_empty() {
227            return self.inner.iter().map(|r| r.key().clone()).collect();
228        }
229        let mut seen =
230            std::collections::HashSet::with_capacity(self.inner.len() + self.delta.len());
231        let mut keys = Vec::with_capacity(self.inner.len() + self.delta.len());
232        for entry in self.inner.iter() {
233            let key = entry.key().clone();
234            seen.insert(key.clone());
235            keys.push(key);
236        }
237        for entry in self.delta.iter() {
238            let key = entry.key().clone();
239            if seen.insert(key.clone()) {
240                keys.push(key);
241            }
242        }
243        keys
244    }
245
246    /// Create a new State containing only the specified keys.
247    pub fn pick(&self, keys: &[&str]) -> State {
248        let new = State::new();
249        for key in keys {
250            if let Some(v) = self.get_raw(key) {
251                new.set(*key, v);
252            }
253        }
254        new
255    }
256
257    /// Merge another state into this one (other's values overwrite on conflict).
258    pub fn merge(&self, other: &State) {
259        for entry in other.inner.iter() {
260            self.inner
261                .insert(entry.key().clone(), entry.value().clone());
262        }
263    }
264
265    /// Rename a key.
266    pub fn rename(&self, from: &str, to: &str) {
267        if let Some(v) = self.remove(from) {
268            if self.track_delta {
269                self.delta.insert(to.to_string(), v);
270            } else {
271                self.inner.insert(to.to_string(), v);
272            }
273        }
274    }
275
276    // ── Delta methods ──────────────────────────────────────────────────────
277
278    /// Whether delta tracking is enabled.
279    pub fn is_tracking_delta(&self) -> bool {
280        self.track_delta
281    }
282
283    /// Whether there are uncommitted delta changes.
284    pub fn has_delta(&self) -> bool {
285        self.track_delta && !self.delta.is_empty()
286    }
287
288    /// Get a snapshot of the current delta.
289    pub fn delta(&self) -> HashMap<String, Value> {
290        self.delta
291            .iter()
292            .map(|entry| (entry.key().clone(), entry.value().clone()))
293            .collect()
294    }
295
296    /// Commit delta changes into the inner store, then clear the delta.
297    pub fn commit(&self) {
298        for entry in self.delta.iter() {
299            self.inner
300                .insert(entry.key().clone(), entry.value().clone());
301        }
302        self.delta.clear();
303    }
304
305    /// Discard all uncommitted delta changes.
306    pub fn rollback(&self) {
307        self.delta.clear();
308    }
309
310    // ── Prefix accessors ───────────────────────────────────────────────────
311
312    /// Access state with the `app:` prefix scope.
313    pub fn app(&self) -> PrefixedState<'_> {
314        PrefixedState {
315            state: self,
316            prefix: "app:",
317        }
318    }
319
320    /// Access state with the `user:` prefix scope.
321    pub fn user(&self) -> PrefixedState<'_> {
322        PrefixedState {
323            state: self,
324            prefix: "user:",
325        }
326    }
327
328    /// Access state with the `temp:` prefix scope.
329    pub fn temp(&self) -> PrefixedState<'_> {
330        PrefixedState {
331            state: self,
332            prefix: "temp:",
333        }
334    }
335
336    /// Access state with the `session:` prefix scope (auto-tracked signals).
337    pub fn session(&self) -> PrefixedState<'_> {
338        PrefixedState {
339            state: self,
340            prefix: "session:",
341        }
342    }
343
344    /// Access state with the `turn:` prefix scope (reset each turn).
345    pub fn turn(&self) -> PrefixedState<'_> {
346        PrefixedState {
347            state: self,
348            prefix: "turn:",
349        }
350    }
351
352    /// Access state with the `bg:` prefix scope (background tasks).
353    pub fn bg(&self) -> PrefixedState<'_> {
354        PrefixedState {
355            state: self,
356            prefix: "bg:",
357        }
358    }
359
360    /// Access read-only state with the `derived:` prefix scope (computed vars only).
361    pub fn derived(&self) -> ReadOnlyPrefixedState<'_> {
362        ReadOnlyPrefixedState {
363            state: self,
364            prefix: "derived:",
365        }
366    }
367
368    // ── Utility methods ───────────────────────────────────────────────────
369
370    /// Snapshot the values of specific keys. Returns HashMap of key -> current value.
371    /// Used by watchers to capture state before mutations.
372    pub fn snapshot_values(&self, keys: &[&str]) -> HashMap<String, Value> {
373        keys.iter()
374            .filter_map(|&k| self.get_raw(k).map(|v| (k.to_string(), v)))
375            .collect()
376    }
377
378    /// Diff current state against a previous snapshot.
379    /// Returns Vec of (key, old_value, new_value) for keys that changed.
380    pub fn diff_values(
381        &self,
382        prev: &HashMap<String, Value>,
383        keys: &[&str],
384    ) -> Vec<(String, Value, Value)> {
385        keys.iter()
386            .filter_map(|&k| {
387                let old = prev.get(k);
388                let new = self.get_raw(k);
389                match (old, new) {
390                    (Some(o), Some(n)) if o != &n => Some((k.to_string(), o.clone(), n)),
391                    (None, Some(n)) => Some((k.to_string(), Value::Null, n)),
392                    (Some(o), None) => Some((k.to_string(), o.clone(), Value::Null)),
393                    _ => None,
394                }
395            })
396            .collect()
397    }
398
399    /// Export all state as a HashMap (for persistence/serialization).
400    pub fn to_hashmap(&self) -> std::collections::HashMap<String, serde_json::Value> {
401        self.inner
402            .iter()
403            .map(|entry| (entry.key().clone(), entry.value().clone()))
404            .collect()
405    }
406
407    /// Restore state from a HashMap (for persistence/deserialization).
408    pub fn from_hashmap(&self, map: std::collections::HashMap<String, serde_json::Value>) {
409        for (key, value) in map {
410            self.inner.insert(key, value);
411        }
412    }
413
414    /// Remove all keys with the given prefix.
415    pub fn clear_prefix(&self, prefix: &str) {
416        let keys_to_remove: Vec<String> = self
417            .inner
418            .iter()
419            .filter(|entry| entry.key().starts_with(prefix))
420            .map(|entry| entry.key().clone())
421            .collect();
422        for key in keys_to_remove {
423            self.inner.remove(&key);
424        }
425        if self.track_delta {
426            let delta_keys: Vec<String> = self
427                .delta
428                .iter()
429                .filter(|entry| entry.key().starts_with(prefix))
430                .map(|entry| entry.key().clone())
431                .collect();
432            for key in delta_keys {
433                self.delta.remove(&key);
434            }
435        }
436    }
437}
438
439/// A borrowed view of state that automatically prepends a prefix to all keys.
440pub struct PrefixedState<'a> {
441    state: &'a State,
442    prefix: &'static str,
443}
444
445impl<'a> PrefixedState<'a> {
446    fn prefixed_key(&self, key: &str) -> String {
447        format!("{}{}", self.prefix, key)
448    }
449
450    /// Get a value by key (with prefix applied).
451    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
452        self.state.get(&self.prefixed_key(key))
453    }
454
455    /// Get a raw JSON value by key (with prefix applied).
456    pub fn get_raw(&self, key: &str) -> Option<Value> {
457        self.state.get_raw(&self.prefixed_key(key))
458    }
459
460    /// Zero-copy borrow a value by key (with prefix applied).
461    pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
462    where
463        F: FnOnce(&Value) -> R,
464    {
465        self.state.with(&self.prefixed_key(key), f)
466    }
467
468    /// Set a value by key (with prefix applied).
469    pub fn set(&self, key: impl AsRef<str>, value: impl serde::Serialize) {
470        self.state.set(self.prefixed_key(key.as_ref()), value);
471    }
472
473    /// Check if a key exists (with prefix applied).
474    pub fn contains(&self, key: &str) -> bool {
475        self.state.contains(&self.prefixed_key(key))
476    }
477
478    /// Remove a key (with prefix applied).
479    pub fn remove(&self, key: &str) -> Option<Value> {
480        self.state.remove(&self.prefixed_key(key))
481    }
482
483    /// Get all keys within this prefix scope (prefix stripped from results).
484    pub fn keys(&self) -> Vec<String> {
485        self.state
486            .keys()
487            .into_iter()
488            .filter_map(|k| k.strip_prefix(self.prefix).map(|s| s.to_string()))
489            .collect()
490    }
491}
492
493/// A borrowed, read-only view of state that automatically prepends a prefix to all keys.
494///
495/// Unlike `PrefixedState`, this does not expose `set()` or `remove()` methods,
496/// making it suitable for computed/derived state that user code should not mutate.
497pub struct ReadOnlyPrefixedState<'a> {
498    state: &'a State,
499    prefix: &'static str,
500}
501
502impl<'a> ReadOnlyPrefixedState<'a> {
503    fn prefixed_key(&self, key: &str) -> String {
504        format!("{}{}", self.prefix, key)
505    }
506
507    /// Get a value by key (with prefix applied).
508    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
509        self.state.get(&self.prefixed_key(key))
510    }
511
512    /// Get a raw JSON value by key (with prefix applied).
513    pub fn get_raw(&self, key: &str) -> Option<Value> {
514        self.state.get_raw(&self.prefixed_key(key))
515    }
516
517    /// Zero-copy borrow a value by key (with prefix applied).
518    pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
519    where
520        F: FnOnce(&Value) -> R,
521    {
522        self.state.with(&self.prefixed_key(key), f)
523    }
524
525    /// Check if a key exists (with prefix applied).
526    pub fn contains(&self, key: &str) -> bool {
527        self.state.contains(&self.prefixed_key(key))
528    }
529
530    /// Get all keys within this prefix scope (prefix stripped from results).
531    pub fn keys(&self) -> Vec<String> {
532        self.state
533            .keys()
534            .into_iter()
535            .filter_map(|k| k.strip_prefix(self.prefix).map(|s| s.to_string()))
536            .collect()
537    }
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543
544    #[test]
545    fn set_and_get_string() {
546        let state = State::new();
547        state.set("name", "Alice");
548        assert_eq!(state.get::<String>("name"), Some("Alice".to_string()));
549    }
550
551    #[test]
552    fn set_and_get_json() {
553        let state = State::new();
554        state.set("data", serde_json::json!({"temp": 22}));
555        let v: Value = state.get("data").unwrap();
556        assert_eq!(v["temp"], 22);
557    }
558
559    #[test]
560    fn pick_subset() {
561        let state = State::new();
562        state.set("a", 1);
563        state.set("b", 2);
564        state.set("c", 3);
565        let picked = state.pick(&["a", "c"]);
566        assert!(picked.contains("a"));
567        assert!(!picked.contains("b"));
568        assert!(picked.contains("c"));
569    }
570
571    #[test]
572    fn merge_states() {
573        let s1 = State::new();
574        s1.set("a", 1);
575        let s2 = State::new();
576        s2.set("b", 2);
577        s1.merge(&s2);
578        assert!(s1.contains("a"));
579        assert!(s1.contains("b"));
580    }
581
582    #[test]
583    fn rename_key() {
584        let state = State::new();
585        state.set("old", "value");
586        state.rename("old", "new");
587        assert!(!state.contains("old"));
588        assert_eq!(state.get::<String>("new"), Some("value".to_string()));
589    }
590
591    #[test]
592    fn remove_returns_value() {
593        let state = State::new();
594        state.set("key", 42);
595        let removed = state.remove("key");
596        assert!(removed.is_some());
597        assert!(!state.contains("key"));
598    }
599
600    #[test]
601    fn get_missing_returns_none() {
602        let state = State::new();
603        assert_eq!(state.get::<String>("nope"), None);
604    }
605
606    // ── Delta tracking tests ──────────────────────────────────────────────
607
608    #[test]
609    fn delta_tracking_writes_to_delta() {
610        let state = State::new();
611        state.set("committed", "yes");
612
613        let tracked = state.with_delta_tracking();
614        tracked.set("new_key", "new_value");
615
616        // New key visible through tracked state
617        assert_eq!(
618            tracked.get::<String>("new_key"),
619            Some("new_value".to_string())
620        );
621        // But NOT visible in original (non-delta) state's inner
622        assert!(!state.contains("new_key"));
623        // Committed key still visible through tracked state
624        assert_eq!(tracked.get::<String>("committed"), Some("yes".to_string()));
625    }
626
627    #[test]
628    fn delta_has_delta_reports_correctly() {
629        let state = State::new();
630        let tracked = state.with_delta_tracking();
631        assert!(!tracked.has_delta());
632
633        tracked.set("key", "val");
634        assert!(tracked.has_delta());
635    }
636
637    #[test]
638    fn delta_commit_merges_to_inner() {
639        let state = State::new();
640        let tracked = state.with_delta_tracking();
641        tracked.set("key", "val");
642        assert!(!state.contains("key"));
643
644        tracked.commit();
645        // Now visible in original state
646        assert_eq!(state.get::<String>("key"), Some("val".to_string()));
647        assert!(!tracked.has_delta());
648    }
649
650    #[test]
651    fn delta_rollback_discards_changes() {
652        let state = State::new();
653        let tracked = state.with_delta_tracking();
654        tracked.set("key", "val");
655        assert!(tracked.has_delta());
656
657        tracked.rollback();
658        assert!(!tracked.has_delta());
659        assert!(!state.contains("key"));
660        assert!(!tracked.contains("key"));
661    }
662
663    #[test]
664    fn delta_snapshot() {
665        let state = State::new();
666        let tracked = state.with_delta_tracking();
667        tracked.set("a", 1);
668        tracked.set("b", 2);
669
670        let snapshot = tracked.delta();
671        assert_eq!(snapshot.len(), 2);
672        assert!(snapshot.contains_key("a"));
673        assert!(snapshot.contains_key("b"));
674    }
675
676    #[test]
677    fn set_committed_bypasses_delta() {
678        let state = State::new();
679        let tracked = state.with_delta_tracking();
680        tracked.set_committed("direct", "value");
681
682        // Visible immediately in inner
683        assert_eq!(state.get::<String>("direct"), Some("value".to_string()));
684        // Not in delta
685        assert!(!tracked.has_delta());
686        // Still visible through tracked (reads inner too)
687        assert_eq!(tracked.get::<String>("direct"), Some("value".to_string()));
688    }
689
690    #[test]
691    fn no_delta_tracking_preserves_existing_behavior() {
692        let state = State::new();
693        assert!(!state.is_tracking_delta());
694        state.set("key", "val");
695        assert_eq!(state.get::<String>("key"), Some("val".to_string()));
696        assert!(!state.has_delta());
697    }
698
699    // ── Prefix tests ──────────────────────────────────────────────────────
700
701    #[test]
702    fn prefix_app_set_and_get() {
703        let state = State::new();
704        state.app().set("flag", true);
705
706        // Accessible via prefix accessor
707        assert_eq!(state.app().get::<bool>("flag"), Some(true));
708        // Also accessible via raw key
709        assert_eq!(state.get::<bool>("app:flag"), Some(true));
710    }
711
712    #[test]
713    fn prefix_user_set_and_get() {
714        let state = State::new();
715        state.user().set("name", "Alice");
716        assert_eq!(
717            state.user().get::<String>("name"),
718            Some("Alice".to_string())
719        );
720        assert_eq!(state.get::<String>("user:name"), Some("Alice".to_string()));
721    }
722
723    #[test]
724    fn prefix_temp_set_and_get() {
725        let state = State::new();
726        state.temp().set("scratch", 42);
727        assert_eq!(state.temp().get::<i32>("scratch"), Some(42));
728    }
729
730    #[test]
731    fn prefix_contains_and_remove() {
732        let state = State::new();
733        state.app().set("x", 1);
734        assert!(state.app().contains("x"));
735        state.app().remove("x");
736        assert!(!state.app().contains("x"));
737    }
738
739    #[test]
740    fn prefix_keys() {
741        let state = State::new();
742        state.app().set("a", 1);
743        state.app().set("b", 2);
744        state.user().set("c", 3);
745
746        let app_keys = state.app().keys();
747        assert_eq!(app_keys.len(), 2);
748        assert!(app_keys.contains(&"a".to_string()));
749        assert!(app_keys.contains(&"b".to_string()));
750
751        let user_keys = state.user().keys();
752        assert_eq!(user_keys.len(), 1);
753        assert!(user_keys.contains(&"c".to_string()));
754    }
755
756    #[test]
757    fn prefix_with_delta_tracking() {
758        let state = State::new();
759        let tracked = state.with_delta_tracking();
760        tracked.app().set("flag", true);
761
762        // Visible in tracked state via prefix
763        assert_eq!(tracked.app().get::<bool>("flag"), Some(true));
764        // In delta, not committed
765        assert!(tracked.has_delta());
766        assert!(!state.contains("app:flag"));
767
768        tracked.commit();
769        assert_eq!(state.get::<bool>("app:flag"), Some(true));
770    }
771
772    // ── New prefix accessor tests ────────────────────────────────────────
773
774    #[test]
775    fn prefix_session_set_and_get() {
776        let state = State::new();
777        state.session().set("turn_count", 5);
778        assert_eq!(state.session().get::<i32>("turn_count"), Some(5));
779        assert_eq!(state.get::<i32>("session:turn_count"), Some(5));
780    }
781
782    #[test]
783    fn prefix_turn_set_and_get() {
784        let state = State::new();
785        state.turn().set("transcript", "hello");
786        assert_eq!(
787            state.turn().get::<String>("transcript"),
788            Some("hello".to_string())
789        );
790        assert_eq!(
791            state.get::<String>("turn:transcript"),
792            Some("hello".to_string())
793        );
794    }
795
796    #[test]
797    fn prefix_bg_set_and_get() {
798        let state = State::new();
799        state.bg().set("task_id", "abc-123");
800        assert_eq!(
801            state.bg().get::<String>("task_id"),
802            Some("abc-123".to_string())
803        );
804        assert_eq!(
805            state.get::<String>("bg:task_id"),
806            Some("abc-123".to_string())
807        );
808    }
809
810    #[test]
811    fn prefix_session_contains_and_remove() {
812        let state = State::new();
813        state.session().set("x", 1);
814        assert!(state.session().contains("x"));
815        state.session().remove("x");
816        assert!(!state.session().contains("x"));
817    }
818
819    #[test]
820    fn prefix_turn_keys() {
821        let state = State::new();
822        state.turn().set("a", 1);
823        state.turn().set("b", 2);
824        state.session().set("c", 3);
825
826        let turn_keys = state.turn().keys();
827        assert_eq!(turn_keys.len(), 2);
828        assert!(turn_keys.contains(&"a".to_string()));
829        assert!(turn_keys.contains(&"b".to_string()));
830    }
831
832    // ── ReadOnlyPrefixedState (derived) tests ────────────────────────────
833
834    #[test]
835    fn derived_read_only_get() {
836        let state = State::new();
837        // Write via raw key (simulating ComputedRegistry)
838        state.set("derived:sentiment", "positive");
839        assert_eq!(
840            state.derived().get::<String>("sentiment"),
841            Some("positive".to_string())
842        );
843    }
844
845    #[test]
846    fn derived_read_only_get_raw() {
847        let state = State::new();
848        state.set("derived:score", serde_json::json!(0.95));
849        let raw = state.derived().get_raw("score");
850        assert!(raw.is_some());
851        assert_eq!(raw.unwrap(), serde_json::json!(0.95));
852    }
853
854    #[test]
855    fn derived_read_only_contains() {
856        let state = State::new();
857        state.set("derived:exists", true);
858        assert!(state.derived().contains("exists"));
859        assert!(!state.derived().contains("missing"));
860    }
861
862    #[test]
863    fn derived_read_only_keys() {
864        let state = State::new();
865        state.set("derived:a", 1);
866        state.set("derived:b", 2);
867        state.set("app:c", 3);
868
869        let derived_keys = state.derived().keys();
870        assert_eq!(derived_keys.len(), 2);
871        assert!(derived_keys.contains(&"a".to_string()));
872        assert!(derived_keys.contains(&"b".to_string()));
873    }
874
875    #[test]
876    fn derived_missing_key_returns_none() {
877        let state = State::new();
878        assert_eq!(state.derived().get::<String>("nope"), None);
879        assert_eq!(state.derived().get_raw("nope"), None);
880    }
881
882    // ── snapshot_values tests ────────────────────────────────────────────
883
884    #[test]
885    fn snapshot_values_captures_existing_keys() {
886        let state = State::new();
887        state.set("a", 1);
888        state.set("b", "hello");
889        state.set("c", true);
890
891        let snap = state.snapshot_values(&["a", "b", "missing"]);
892        assert_eq!(snap.len(), 2);
893        assert_eq!(snap.get("a"), Some(&serde_json::json!(1)));
894        assert_eq!(snap.get("b"), Some(&serde_json::json!("hello")));
895        assert!(!snap.contains_key("missing"));
896    }
897
898    #[test]
899    fn snapshot_values_empty_keys() {
900        let state = State::new();
901        state.set("a", 1);
902        let snap = state.snapshot_values(&[]);
903        assert!(snap.is_empty());
904    }
905
906    // ── diff_values tests ────────────────────────────────────────────────
907
908    #[test]
909    fn diff_values_detects_changed_value() {
910        let state = State::new();
911        state.set("x", 1);
912        let snap = state.snapshot_values(&["x"]);
913
914        state.set("x", 2);
915        let diffs = state.diff_values(&snap, &["x"]);
916        assert_eq!(diffs.len(), 1);
917        assert_eq!(diffs[0].0, "x");
918        assert_eq!(diffs[0].1, serde_json::json!(1));
919        assert_eq!(diffs[0].2, serde_json::json!(2));
920    }
921
922    #[test]
923    fn diff_values_detects_new_key() {
924        let state = State::new();
925        let snap = state.snapshot_values(&["y"]);
926
927        state.set("y", "new");
928        let diffs = state.diff_values(&snap, &["y"]);
929        assert_eq!(diffs.len(), 1);
930        assert_eq!(diffs[0].0, "y");
931        assert_eq!(diffs[0].1, Value::Null);
932        assert_eq!(diffs[0].2, serde_json::json!("new"));
933    }
934
935    #[test]
936    fn diff_values_detects_removed_key() {
937        let state = State::new();
938        state.set("z", 42);
939        let snap = state.snapshot_values(&["z"]);
940
941        state.remove("z");
942        let diffs = state.diff_values(&snap, &["z"]);
943        assert_eq!(diffs.len(), 1);
944        assert_eq!(diffs[0].0, "z");
945        assert_eq!(diffs[0].1, serde_json::json!(42));
946        assert_eq!(diffs[0].2, Value::Null);
947    }
948
949    #[test]
950    fn diff_values_no_change() {
951        let state = State::new();
952        state.set("stable", 10);
953        let snap = state.snapshot_values(&["stable"]);
954
955        // No mutation
956        let diffs = state.diff_values(&snap, &["stable"]);
957        assert!(diffs.is_empty());
958    }
959
960    #[test]
961    fn diff_values_multiple_keys_mixed_changes() {
962        let state = State::new();
963        state.set("a", 1);
964        state.set("b", 2);
965        let snap = state.snapshot_values(&["a", "b", "c"]);
966
967        state.set("a", 10); // changed
968                            // b unchanged
969        state.set("c", 3); // new
970
971        let diffs = state.diff_values(&snap, &["a", "b", "c"]);
972        assert_eq!(diffs.len(), 2); // a changed, c new; b unchanged
973        let diff_keys: Vec<&str> = diffs.iter().map(|(k, _, _)| k.as_str()).collect();
974        assert!(diff_keys.contains(&"a"));
975        assert!(diff_keys.contains(&"c"));
976    }
977
978    // ── clear_prefix tests ───────────────────────────────────────────────
979
980    #[test]
981    fn clear_prefix_removes_matching_keys() {
982        let state = State::new();
983        state.set("turn:a", 1);
984        state.set("turn:b", 2);
985        state.set("app:c", 3);
986        state.set("session:d", 4);
987
988        state.clear_prefix("turn:");
989        assert!(!state.contains("turn:a"));
990        assert!(!state.contains("turn:b"));
991        assert!(state.contains("app:c"));
992        assert!(state.contains("session:d"));
993    }
994
995    #[test]
996    fn clear_prefix_no_matching_keys_is_noop() {
997        let state = State::new();
998        state.set("app:x", 1);
999        state.clear_prefix("turn:");
1000        assert!(state.contains("app:x"));
1001    }
1002
1003    #[test]
1004    fn clear_prefix_also_clears_delta() {
1005        let state = State::new();
1006        state.set("turn:committed", 1);
1007        let tracked = state.with_delta_tracking();
1008        tracked.set("turn:delta_val", 2);
1009
1010        // Both committed and delta have turn: keys
1011        assert!(tracked.contains("turn:committed"));
1012        assert!(tracked.contains("turn:delta_val"));
1013
1014        tracked.clear_prefix("turn:");
1015        assert!(!tracked.contains("turn:committed"));
1016        assert!(!tracked.contains("turn:delta_val"));
1017    }
1018
1019    #[test]
1020    fn clear_prefix_via_turn_accessor() {
1021        let state = State::new();
1022        state.turn().set("x", 1);
1023        state.turn().set("y", 2);
1024        state.app().set("z", 3);
1025
1026        state.clear_prefix("turn:");
1027        assert!(state.turn().keys().is_empty());
1028        assert!(state.app().contains("z"));
1029    }
1030
1031    // ── modify() tests ──────────────────────────────────────────────────
1032
1033    #[test]
1034    fn modify_increment_existing() {
1035        let state = State::new();
1036        state.set("count", 5u32);
1037        let result = state.modify("count", 0u32, |n| n + 1);
1038        assert_eq!(result, 6);
1039        assert_eq!(state.get::<u32>("count"), Some(6));
1040    }
1041
1042    #[test]
1043    fn modify_uses_default_when_missing() {
1044        let state = State::new();
1045        let result = state.modify("new_count", 0u32, |n| n + 1);
1046        assert_eq!(result, 1);
1047        assert_eq!(state.get::<u32>("new_count"), Some(1));
1048    }
1049
1050    #[test]
1051    fn modify_with_delta_tracking() {
1052        let state = State::new();
1053        state.set("x", 10u32);
1054        let tracked = state.with_delta_tracking();
1055        let result = tracked.modify("x", 0u32, |n| n * 2);
1056        assert_eq!(result, 20);
1057        // Written to delta, not committed
1058        assert_eq!(tracked.get::<u32>("x"), Some(20));
1059        assert_eq!(state.get::<u32>("x"), Some(10)); // original unchanged
1060    }
1061
1062    // ── derived fallback tests ──────────────────────────────────────────
1063
1064    #[test]
1065    fn get_falls_back_to_derived_prefix() {
1066        let state = State::new();
1067        state.set("derived:risk", 0.85);
1068        // Access without prefix — should find derived:risk
1069        assert_eq!(state.get::<f64>("risk"), Some(0.85));
1070    }
1071
1072    #[test]
1073    fn get_prefers_direct_key_over_derived() {
1074        let state = State::new();
1075        state.set("score", 1.0);
1076        state.set("derived:score", 0.5);
1077        // Direct key should win
1078        assert_eq!(state.get::<f64>("score"), Some(1.0));
1079    }
1080
1081    #[test]
1082    fn get_derived_fallback_skipped_for_prefixed_keys() {
1083        let state = State::new();
1084        state.set("derived:risk", 0.85);
1085        // Prefixed key should NOT trigger fallback
1086        assert_eq!(state.get::<f64>("app:risk"), None);
1087    }
1088
1089    #[test]
1090    fn get_derived_fallback_with_delta_tracking() {
1091        let state = State::new();
1092        let tracked = state.with_delta_tracking();
1093        tracked.set("derived:computed_val", 42);
1094        assert_eq!(tracked.get::<i32>("computed_val"), Some(42));
1095    }
1096
1097    // ── with() zero-copy borrow tests ──────────────────────────────────
1098
1099    #[test]
1100    fn with_reads_from_inner() {
1101        let state = State::new();
1102        state.set("name", "Alice");
1103        let len = state.with("name", |v| v.as_str().unwrap().len());
1104        assert_eq!(len, Some(5));
1105    }
1106
1107    #[test]
1108    fn with_reads_from_delta_first() {
1109        let state = State::new();
1110        state.set("x", 1);
1111        let tracked = state.with_delta_tracking();
1112        tracked.set("x", 99);
1113        let val = tracked.with("x", |v| v.as_i64().unwrap());
1114        assert_eq!(val, Some(99));
1115    }
1116
1117    #[test]
1118    fn with_falls_back_to_inner_when_not_in_delta() {
1119        let state = State::new();
1120        state.set("committed", "yes");
1121        let tracked = state.with_delta_tracking();
1122        let val = tracked.with("committed", |v| v.as_str().unwrap().to_string());
1123        assert_eq!(val, Some("yes".to_string()));
1124    }
1125
1126    #[test]
1127    fn with_falls_back_to_derived() {
1128        let state = State::new();
1129        state.set("derived:risk", 0.85);
1130        let val = state.with("risk", |v| v.as_f64().unwrap());
1131        assert_eq!(val, Some(0.85));
1132    }
1133
1134    #[test]
1135    fn with_derived_fallback_skipped_for_prefixed() {
1136        let state = State::new();
1137        state.set("derived:risk", 0.85);
1138        let val = state.with("app:risk", |v| v.as_f64().unwrap());
1139        assert_eq!(val, None);
1140    }
1141
1142    #[test]
1143    fn with_returns_none_for_missing() {
1144        let state = State::new();
1145        let val = state.with("missing", |v| v.clone());
1146        assert_eq!(val, None);
1147    }
1148
1149    #[test]
1150    fn with_on_prefixed_state() {
1151        let state = State::new();
1152        state.app().set("flag", true);
1153        let val = state.app().with("flag", |v| v.as_bool().unwrap());
1154        assert_eq!(val, Some(true));
1155    }
1156
1157    #[test]
1158    fn with_on_read_only_prefixed_state() {
1159        let state = State::new();
1160        state.set("derived:score", serde_json::json!(0.95));
1161        let val = state.derived().with("score", |v| v.as_f64().unwrap());
1162        assert_eq!(val, Some(0.95));
1163    }
1164
1165    // ── StateKey typed key tests ───────────────────────────────────────
1166
1167    const TURN_COUNT: StateKey<u32> = StateKey::new("session:turn_count");
1168    const NAME: StateKey<String> = StateKey::new("user:name");
1169
1170    #[test]
1171    fn state_key_get_and_set() {
1172        let state = State::new();
1173        state.set_key(&TURN_COUNT, 5);
1174        assert_eq!(state.get_key(&TURN_COUNT), Some(5));
1175    }
1176
1177    #[test]
1178    fn state_key_get_missing() {
1179        let state = State::new();
1180        assert_eq!(state.get_key(&TURN_COUNT), None);
1181    }
1182
1183    #[test]
1184    fn state_key_string_type() {
1185        let state = State::new();
1186        state.set_key(&NAME, "Alice".to_string());
1187        assert_eq!(state.get_key(&NAME), Some("Alice".to_string()));
1188    }
1189
1190    #[test]
1191    fn state_key_with() {
1192        let state = State::new();
1193        state.set_key(&TURN_COUNT, 42);
1194        let val = state.with_key(&TURN_COUNT, |v| v.as_u64().unwrap());
1195        assert_eq!(val, Some(42));
1196    }
1197
1198    #[test]
1199    fn state_key_interop_with_raw() {
1200        let state = State::new();
1201        state.set_key(&TURN_COUNT, 10);
1202        // Can also read via raw key
1203        assert_eq!(state.get::<u32>("session:turn_count"), Some(10));
1204    }
1205}