1use crate::{DocCell, Op, Patch, Path, TireaResult, TrackedPatch};
7use serde_json::Value;
8use std::sync::{Arc, Mutex};
9
10type CollectHook<'a> = Arc<dyn Fn(&Op) -> TireaResult<()> + Send + Sync + 'a>;
11
12pub struct PatchSink<'a> {
23 ops: Option<&'a Mutex<Vec<Op>>>,
24 on_collect: Option<CollectHook<'a>>,
25}
26
27impl<'a> PatchSink<'a> {
28 #[doc(hidden)]
30 pub fn new(ops: &'a Mutex<Vec<Op>>) -> Self {
31 Self {
32 ops: Some(ops),
33 on_collect: None,
34 }
35 }
36
37 #[doc(hidden)]
41 pub fn new_with_hook(ops: &'a Mutex<Vec<Op>>, hook: CollectHook<'a>) -> Self {
42 Self {
43 ops: Some(ops),
44 on_collect: Some(hook),
45 }
46 }
47
48 #[doc(hidden)]
52 pub fn child(&self) -> Self {
53 Self {
54 ops: self.ops,
55 on_collect: self.on_collect.clone(),
56 }
57 }
58
59 #[doc(hidden)]
63 pub fn read_only() -> Self {
64 Self {
65 ops: None,
66 on_collect: None,
67 }
68 }
69
70 #[inline]
72 pub fn collect(&self, op: Op) -> TireaResult<()> {
73 let ops = self.ops.ok_or_else(|| {
74 crate::TireaError::invalid_operation("write attempted on read-only state reference")
75 })?;
76 let mut guard = ops.lock().map_err(|_| {
77 crate::TireaError::invalid_operation("state operation collector mutex poisoned")
78 })?;
79 guard.push(op.clone());
80 drop(guard);
81 if let Some(hook) = &self.on_collect {
82 hook(&op)?;
83 }
84 Ok(())
85 }
86
87 #[doc(hidden)]
89 pub fn inner(&self) -> &'a Mutex<Vec<Op>> {
90 self.ops
91 .expect("PatchSink::inner called on read-only sink (programming error)")
92 }
93}
94
95pub struct StateContext<'a> {
97 doc: &'a DocCell,
98 ops: Mutex<Vec<Op>>,
99}
100
101impl<'a> StateContext<'a> {
102 pub fn new(doc: &'a DocCell) -> Self {
104 Self {
105 doc,
106 ops: Mutex::new(Vec::new()),
107 }
108 }
109
110 pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
112 let base = parse_path(path);
113 let hook: CollectHook<'_> = Arc::new(|op: &Op| self.doc.apply(op));
114 T::state_ref(self.doc, base, PatchSink::new_with_hook(&self.ops, hook))
115 }
116
117 pub fn state_of<T: State>(&self) -> T::Ref<'_> {
122 assert!(
123 !T::PATH.is_empty(),
124 "State type has no bound path; use state::<T>(path) instead"
125 );
126 self.state::<T>(T::PATH)
127 }
128
129 pub fn take_patch(&self) -> Patch {
131 let ops = std::mem::take(&mut *self.ops.lock().unwrap());
132 Patch::with_ops(ops)
133 }
134
135 pub fn take_tracked_patch(&self, source: impl Into<String>) -> TrackedPatch {
137 TrackedPatch::new(self.take_patch()).with_source(source)
138 }
139
140 pub fn has_changes(&self) -> bool {
142 !self.ops.lock().unwrap().is_empty()
143 }
144
145 pub fn ops_count(&self) -> usize {
147 self.ops.lock().unwrap().len()
148 }
149}
150
151pub fn parse_path(path: &str) -> Path {
153 if path.is_empty() {
154 return Path::root();
155 }
156
157 let mut result = Path::root();
158 for segment in path.split('.') {
159 if !segment.is_empty() {
160 result = result.key(segment);
161 }
162 }
163 result
164}
165
166pub trait State: Sized {
191 type Ref<'a>;
193
194 const PATH: &'static str = "";
199
200 fn state_ref<'a>(doc: &'a DocCell, base: Path, sink: PatchSink<'a>) -> Self::Ref<'a>;
208
209 fn from_value(value: &Value) -> TireaResult<Self>;
211
212 fn to_value(&self) -> TireaResult<Value>;
214
215 fn to_patch(&self) -> TireaResult<Patch> {
217 Ok(Patch::with_ops(vec![Op::set(
218 Path::root(),
219 self.to_value()?,
220 )]))
221 }
222}
223
224pub trait StateExt: State {
226 fn at_root<'a>(doc: &'a DocCell, sink: PatchSink<'a>) -> Self::Ref<'a> {
228 Self::state_ref(doc, Path::root(), sink)
229 }
230}
231
232impl<T: State> StateExt for T {}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use serde_json::json;
238
239 #[test]
240 fn test_patch_sink_collect() {
241 let ops = Mutex::new(Vec::new());
242 let sink = PatchSink::new(&ops);
243
244 sink.collect(Op::set(Path::root().key("a"), Value::from(1)))
245 .unwrap();
246 sink.collect(Op::set(Path::root().key("b"), Value::from(2)))
247 .unwrap();
248
249 let collected = ops.lock().unwrap();
250 assert_eq!(collected.len(), 2);
251 }
252
253 #[test]
254 fn test_patch_sink_collect_hook() {
255 let ops = Mutex::new(Vec::new());
256 let seen = Arc::new(Mutex::new(Vec::new()));
257 let seen_hook = seen.clone();
258 let hook = Arc::new(move |op: &Op| {
259 seen_hook.lock().unwrap().push(format!("{:?}", op));
260 Ok(())
261 });
262 let sink = PatchSink::new_with_hook(&ops, hook);
263
264 sink.collect(Op::set(Path::root().key("a"), Value::from(1)))
265 .unwrap();
266 sink.collect(Op::delete(Path::root().key("b"))).unwrap();
267
268 let collected = ops.lock().unwrap();
269 assert_eq!(collected.len(), 2);
270 assert_eq!(seen.lock().unwrap().len(), 2);
271 }
272
273 #[test]
274 fn test_patch_sink_child_preserves_collect_and_hook() {
275 let ops = Mutex::new(Vec::new());
276 let seen = Arc::new(Mutex::new(Vec::new()));
277 let seen_hook = seen.clone();
278 let hook = Arc::new(move |op: &Op| {
279 seen_hook.lock().unwrap().push(format!("{:?}", op));
280 Ok(())
281 });
282 let sink = PatchSink::new_with_hook(&ops, hook);
283 let child = sink.child();
284
285 child
286 .collect(Op::set(Path::root().key("nested"), Value::from(1)))
287 .unwrap();
288
289 assert_eq!(ops.lock().unwrap().len(), 1);
290 assert_eq!(seen.lock().unwrap().len(), 1);
291 }
292
293 #[test]
294 fn test_patch_sink_read_only_child_collect_errors() {
295 let sink = PatchSink::read_only();
296 let child = sink.child();
297 let err = child
298 .collect(Op::set(Path::root().key("x"), Value::from(1)))
299 .unwrap_err();
300 assert!(matches!(err, crate::TireaError::InvalidOperation { .. }));
301 }
302
303 #[test]
304 fn test_patch_sink_read_only_collect_errors() {
305 let sink = PatchSink::read_only();
306 let err = sink
307 .collect(Op::set(Path::root().key("x"), Value::from(1)))
308 .unwrap_err();
309 assert!(matches!(err, crate::TireaError::InvalidOperation { .. }));
310 }
311
312 #[test]
313 #[should_panic(expected = "read-only sink")]
314 fn test_patch_sink_read_only_inner_panics() {
315 let sink = PatchSink::read_only();
316 let _ = sink.inner();
317 }
318
319 #[test]
320 fn test_parse_path_empty() {
321 let path = parse_path("");
322 assert!(path.is_empty());
323 }
324
325 #[test]
326 fn test_parse_path_nested() {
327 let path = parse_path("tool_calls.call_123.data");
328 assert_eq!(path.to_string(), "$.tool_calls.call_123.data");
329 }
330
331 #[test]
332 fn test_state_context_collects_ops() {
333 struct Counter;
334
335 struct CounterRef<'a> {
336 base: Path,
337 sink: PatchSink<'a>,
338 }
339
340 impl<'a> CounterRef<'a> {
341 fn set_value(&self, value: i64) -> TireaResult<()> {
342 self.sink
343 .collect(Op::set(self.base.clone().key("value"), Value::from(value)))
344 }
345 }
346
347 impl State for Counter {
348 type Ref<'a> = CounterRef<'a>;
349
350 fn state_ref<'a>(_: &'a DocCell, base: Path, sink: PatchSink<'a>) -> Self::Ref<'a> {
351 CounterRef { base, sink }
352 }
353
354 fn from_value(_: &Value) -> TireaResult<Self> {
355 Ok(Counter)
356 }
357
358 fn to_value(&self) -> TireaResult<Value> {
359 Ok(Value::Null)
360 }
361 }
362
363 let doc = DocCell::new(json!({"counter": {"value": 1}}));
364 let ctx = StateContext::new(&doc);
365 let counter = ctx.state::<Counter>("counter");
366 counter.set_value(2).unwrap();
367
368 assert!(ctx.has_changes());
369 assert_eq!(ctx.ops_count(), 1);
370 assert_eq!(ctx.take_patch().len(), 1);
371 }
372}