Skip to main content

tirea_state/
state.rs

1//! State trait for typed state access.
2//!
3//! The `State` trait provides a unified interface for typed access to JSON documents.
4//! It is typically implemented via the derive macro `#[derive(State)]`.
5
6use crate::{DocCell, Op, Patch, Path, TireaResult, TrackedPatch};
7use serde_json::Value;
8use std::sync::{Arc, Mutex};
9
10type CollectHook<'a> = Arc<dyn Fn(&Op) -> TireaResult<()> + Send + Sync + 'a>;
11
12/// Collector for patch operations.
13///
14/// `PatchSink` collects operations that will be combined into a `Patch`.
15/// It is used internally by `StateRef` types to automatically collect
16/// all state modifications.
17///
18/// # Thread Safety
19///
20/// `PatchSink` uses a `Mutex` internally to support async contexts.
21/// In single-threaded usage, the lock overhead is minimal.
22pub struct PatchSink<'a> {
23    ops: Option<&'a Mutex<Vec<Op>>>,
24    on_collect: Option<CollectHook<'a>>,
25}
26
27impl<'a> PatchSink<'a> {
28    /// Create a new PatchSink wrapping a Mutex.
29    #[doc(hidden)]
30    pub fn new(ops: &'a Mutex<Vec<Op>>) -> Self {
31        Self {
32            ops: Some(ops),
33            on_collect: None,
34        }
35    }
36
37    /// Create a new PatchSink with a collect hook.
38    ///
39    /// The hook is invoked after each operation is collected.
40    #[doc(hidden)]
41    pub fn new_with_hook(ops: &'a Mutex<Vec<Op>>, hook: CollectHook<'a>) -> Self {
42        Self {
43            ops: Some(ops),
44            on_collect: Some(hook),
45        }
46    }
47
48    /// Create a child sink that shares the same collector and hook.
49    ///
50    /// Nested state refs use this so write-through behavior is preserved.
51    #[doc(hidden)]
52    pub fn child(&self) -> Self {
53        Self {
54            ops: self.ops,
55            on_collect: self.on_collect.clone(),
56        }
57    }
58
59    /// Create a read-only PatchSink that errors on collect.
60    ///
61    /// Used for `SealedState::get()` where writes are a programming error.
62    #[doc(hidden)]
63    pub fn read_only() -> Self {
64        Self {
65            ops: None,
66            on_collect: None,
67        }
68    }
69
70    /// Collect an operation.
71    #[inline]
72    pub fn collect(&self, op: Op) -> TireaResult<()> {
73        let ops = self.ops.ok_or_else(|| {
74            crate::TireaError::invalid_operation("write attempted on read-only state reference")
75        })?;
76        let mut guard = ops.lock().map_err(|_| {
77            crate::TireaError::invalid_operation("state operation collector mutex poisoned")
78        })?;
79        guard.push(op.clone());
80        drop(guard);
81        if let Some(hook) = &self.on_collect {
82            hook(&op)?;
83        }
84        Ok(())
85    }
86
87    /// Get the inner Mutex reference (for creating nested PatchSinks).
88    #[doc(hidden)]
89    pub fn inner(&self) -> &'a Mutex<Vec<Op>> {
90        self.ops
91            .expect("PatchSink::inner called on read-only sink (programming error)")
92    }
93}
94
95/// Pure state context with automatic patch collection.
96pub struct StateContext<'a> {
97    doc: &'a DocCell,
98    ops: Mutex<Vec<Op>>,
99}
100
101impl<'a> StateContext<'a> {
102    /// Create a new pure state context.
103    pub fn new(doc: &'a DocCell) -> Self {
104        Self {
105            doc,
106            ops: Mutex::new(Vec::new()),
107        }
108    }
109
110    /// Get a typed state reference at the specified path.
111    pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
112        let base = parse_path(path);
113        let hook: CollectHook<'_> = Arc::new(|op: &Op| self.doc.apply(op));
114        T::state_ref(self.doc, base, PatchSink::new_with_hook(&self.ops, hook))
115    }
116
117    /// Get a typed state reference at the type's canonical path.
118    ///
119    /// Requires `T` to have `#[tirea(path = "...")]` set.
120    /// Panics if `T::PATH` is empty.
121    pub fn state_of<T: State>(&self) -> T::Ref<'_> {
122        assert!(
123            !T::PATH.is_empty(),
124            "State type has no bound path; use state::<T>(path) instead"
125        );
126        self.state::<T>(T::PATH)
127    }
128
129    /// Extract collected operations as a plain patch.
130    pub fn take_patch(&self) -> Patch {
131        let ops = std::mem::take(&mut *self.ops.lock().unwrap());
132        Patch::with_ops(ops)
133    }
134
135    /// Extract collected operations as a tracked patch with a source.
136    pub fn take_tracked_patch(&self, source: impl Into<String>) -> TrackedPatch {
137        TrackedPatch::new(self.take_patch()).with_source(source)
138    }
139
140    /// Check if any operations have been collected.
141    pub fn has_changes(&self) -> bool {
142        !self.ops.lock().unwrap().is_empty()
143    }
144
145    /// Get the number of operations collected.
146    pub fn ops_count(&self) -> usize {
147        self.ops.lock().unwrap().len()
148    }
149}
150
151/// Parse a dot-separated path string into a `Path`.
152pub fn parse_path(path: &str) -> Path {
153    if path.is_empty() {
154        return Path::root();
155    }
156
157    let mut result = Path::root();
158    for segment in path.split('.') {
159        if !segment.is_empty() {
160            result = result.key(segment);
161        }
162    }
163    result
164}
165
166/// Trait for types that can create typed state references.
167///
168/// This trait is typically derived using `#[derive(State)]`.
169/// It provides the interface for creating `StateRef` types that
170/// allow typed read/write access to JSON documents.
171///
172/// # Example
173///
174/// ```ignore
175/// use tirea_state::State;
176/// use tirea_state_derive::State;
177///
178/// #[derive(State)]
179/// struct User {
180///     pub name: String,
181///     pub age: i64,
182/// }
183///
184/// // In a StateContext:
185/// let user = ctx.state::<User>("users.alice");
186/// let name = user.name()?;
187/// user.set_name("Alice");
188/// user.set_age(30);
189/// ```
190pub trait State: Sized {
191    /// The reference type that provides typed access.
192    type Ref<'a>;
193
194    /// Canonical JSON path for this state type.
195    ///
196    /// When set via `#[tirea(path = "...")]`, enables `state_of::<T>()` access
197    /// without an explicit path argument. Empty string means no bound path.
198    const PATH: &'static str = "";
199
200    /// Create a state reference at the specified path.
201    ///
202    /// # Arguments
203    ///
204    /// * `doc` - The JSON document to read from
205    /// * `base` - The base path for this state
206    /// * `sink` - The operation collector
207    fn state_ref<'a>(doc: &'a DocCell, base: Path, sink: PatchSink<'a>) -> Self::Ref<'a>;
208
209    /// Deserialize this type from a JSON value.
210    fn from_value(value: &Value) -> TireaResult<Self>;
211
212    /// Serialize this type to a JSON value.
213    fn to_value(&self) -> TireaResult<Value>;
214
215    /// Create a patch that sets this value at the root.
216    fn to_patch(&self) -> TireaResult<Patch> {
217        Ok(Patch::with_ops(vec![Op::set(
218            Path::root(),
219            self.to_value()?,
220        )]))
221    }
222}
223
224/// Extension trait providing convenience methods for State types.
225pub trait StateExt: State {
226    /// Create a state reference at the document root.
227    fn at_root<'a>(doc: &'a DocCell, sink: PatchSink<'a>) -> Self::Ref<'a> {
228        Self::state_ref(doc, Path::root(), sink)
229    }
230}
231
232impl<T: State> StateExt for T {}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use serde_json::json;
238
239    #[test]
240    fn test_patch_sink_collect() {
241        let ops = Mutex::new(Vec::new());
242        let sink = PatchSink::new(&ops);
243
244        sink.collect(Op::set(Path::root().key("a"), Value::from(1)))
245            .unwrap();
246        sink.collect(Op::set(Path::root().key("b"), Value::from(2)))
247            .unwrap();
248
249        let collected = ops.lock().unwrap();
250        assert_eq!(collected.len(), 2);
251    }
252
253    #[test]
254    fn test_patch_sink_collect_hook() {
255        let ops = Mutex::new(Vec::new());
256        let seen = Arc::new(Mutex::new(Vec::new()));
257        let seen_hook = seen.clone();
258        let hook = Arc::new(move |op: &Op| {
259            seen_hook.lock().unwrap().push(format!("{:?}", op));
260            Ok(())
261        });
262        let sink = PatchSink::new_with_hook(&ops, hook);
263
264        sink.collect(Op::set(Path::root().key("a"), Value::from(1)))
265            .unwrap();
266        sink.collect(Op::delete(Path::root().key("b"))).unwrap();
267
268        let collected = ops.lock().unwrap();
269        assert_eq!(collected.len(), 2);
270        assert_eq!(seen.lock().unwrap().len(), 2);
271    }
272
273    #[test]
274    fn test_patch_sink_child_preserves_collect_and_hook() {
275        let ops = Mutex::new(Vec::new());
276        let seen = Arc::new(Mutex::new(Vec::new()));
277        let seen_hook = seen.clone();
278        let hook = Arc::new(move |op: &Op| {
279            seen_hook.lock().unwrap().push(format!("{:?}", op));
280            Ok(())
281        });
282        let sink = PatchSink::new_with_hook(&ops, hook);
283        let child = sink.child();
284
285        child
286            .collect(Op::set(Path::root().key("nested"), Value::from(1)))
287            .unwrap();
288
289        assert_eq!(ops.lock().unwrap().len(), 1);
290        assert_eq!(seen.lock().unwrap().len(), 1);
291    }
292
293    #[test]
294    fn test_patch_sink_read_only_child_collect_errors() {
295        let sink = PatchSink::read_only();
296        let child = sink.child();
297        let err = child
298            .collect(Op::set(Path::root().key("x"), Value::from(1)))
299            .unwrap_err();
300        assert!(matches!(err, crate::TireaError::InvalidOperation { .. }));
301    }
302
303    #[test]
304    fn test_patch_sink_read_only_collect_errors() {
305        let sink = PatchSink::read_only();
306        let err = sink
307            .collect(Op::set(Path::root().key("x"), Value::from(1)))
308            .unwrap_err();
309        assert!(matches!(err, crate::TireaError::InvalidOperation { .. }));
310    }
311
312    #[test]
313    #[should_panic(expected = "read-only sink")]
314    fn test_patch_sink_read_only_inner_panics() {
315        let sink = PatchSink::read_only();
316        let _ = sink.inner();
317    }
318
319    #[test]
320    fn test_parse_path_empty() {
321        let path = parse_path("");
322        assert!(path.is_empty());
323    }
324
325    #[test]
326    fn test_parse_path_nested() {
327        let path = parse_path("tool_calls.call_123.data");
328        assert_eq!(path.to_string(), "$.tool_calls.call_123.data");
329    }
330
331    #[test]
332    fn test_state_context_collects_ops() {
333        struct Counter;
334
335        struct CounterRef<'a> {
336            base: Path,
337            sink: PatchSink<'a>,
338        }
339
340        impl<'a> CounterRef<'a> {
341            fn set_value(&self, value: i64) -> TireaResult<()> {
342                self.sink
343                    .collect(Op::set(self.base.clone().key("value"), Value::from(value)))
344            }
345        }
346
347        impl State for Counter {
348            type Ref<'a> = CounterRef<'a>;
349
350            fn state_ref<'a>(_: &'a DocCell, base: Path, sink: PatchSink<'a>) -> Self::Ref<'a> {
351                CounterRef { base, sink }
352            }
353
354            fn from_value(_: &Value) -> TireaResult<Self> {
355                Ok(Counter)
356            }
357
358            fn to_value(&self) -> TireaResult<Value> {
359                Ok(Value::Null)
360            }
361        }
362
363        let doc = DocCell::new(json!({"counter": {"value": 1}}));
364        let ctx = StateContext::new(&doc);
365        let counter = ctx.state::<Counter>("counter");
366        counter.set_value(2).unwrap();
367
368        assert!(ctx.has_changes());
369        assert_eq!(ctx.ops_count(), 1);
370        assert_eq!(ctx.take_patch().len(), 1);
371    }
372}