Skip to main content

juncture_core/state/
trait_.rs

1use crate::error::InvalidUpdateError;
2use std::sync::Arc;
3
4/// Per-field version numbers for state change tracking.
5///
6/// Each element corresponds to one state field, identified by the field's
7/// declaration index (0-based). The Pregel scheduler uses version numbers
8/// to decide which nodes need to re-execute when subscribed fields change.
9///
10/// When `#[derive(State)]` generates the `State` implementation, it also
11/// generates a concrete `FieldVersions` type for the struct.
12#[derive(Clone, Debug, Default, PartialEq, Eq)]
13pub struct FieldVersions(pub Vec<u64>);
14
15impl FieldVersions {
16    /// Create a new `FieldVersions` with `num_fields` entries initialized to zero.
17    #[must_use]
18    pub fn new(num_fields: usize) -> Self {
19        Self(vec![0; num_fields])
20    }
21
22    /// Get the version number for a specific field.
23    #[must_use]
24    pub fn get(&self, field_idx: usize) -> u64 {
25        self.0.get(field_idx).copied().unwrap_or(0)
26    }
27
28    /// Increment the version number for a specific field.
29    pub fn bump(&mut self, field_idx: usize) {
30        if let Some(v) = self.0.get_mut(field_idx) {
31            *v = v.wrapping_add(1);
32        }
33    }
34}
35
36/// State trait for graph state management
37///
38/// All states used in Juncture graphs must implement this trait.
39/// The #[derive(State)] macro automatically generates the implementation.
40pub trait State: Clone + Default + Send + Sync + std::fmt::Debug + 'static {
41    /// Partial update type generated by #[derive(State)]
42    type Update: Default + Clone + Send + Sync + 'static;
43
44    /// Field version tracking type generated by `#[derive(State)]`.
45    type FieldVersions: Default + Clone + Send + Sync + 'static;
46
47    /// Get the current field version numbers.
48    ///
49    /// Returns a snapshot of per-field version counters. The Pregel engine
50    /// uses these to determine which nodes need to re-execute when their
51    /// subscribed fields have new versions since their last execution.
52    ///
53    /// Default returns `Self::FieldVersions::default()` (all zeros).
54    /// `#[derive(State)]` generates a proper implementation.
55    #[must_use]
56    fn field_versions(&self) -> Self::FieldVersions {
57        Self::FieldVersions::default()
58    }
59
60    /// Increment version numbers for the fields marked in `changed`.
61    ///
62    /// Called after each superstep to bump version counters for fields
63    /// that were modified. The Pregel engine uses version comparisons to
64    /// implement reactive scheduling.
65    ///
66    /// Default is a no-op. `#[derive(State)]` generates a proper
67    /// implementation.
68    fn bump_versions(&mut self, _changed: &FieldsChanged) {
69        // no-op default: version tracking is delegated to the engine
70    }
71
72    /// Apply an update to this state, returning which fields changed
73    fn apply(&mut self, update: Self::Update) -> FieldsChanged;
74
75    /// Apply an update to this state, returning which fields changed or an error
76    ///
77    /// Unlike `apply()`, this method returns a structured error when reducer
78    /// constraints are violated (e.g., multiple writers on a replace channel).
79    /// The default implementation delegates to `apply()` for backward compatibility.
80    ///
81    /// # Errors
82    ///
83    /// Returns `InvalidUpdateError` if the update violates reducer constraints,
84    /// such as `InvalidUpdateError::MultipleOverwrite` when multiple nodes
85    /// write to a replace channel in the same superstep.
86    fn try_apply(&mut self, update: Self::Update) -> Result<FieldsChanged, InvalidUpdateError> {
87        Ok(self.apply(update))
88    }
89
90    /// Reset ephemeral fields (called after each superstep)
91    fn reset_ephemeral(&mut self);
92
93    /// Finish a specific field (called when graph execution completes)
94    ///
95    /// This allows channels to finalize their state. For example,
96    /// `LastValueAfterFinishChannel` only makes its value available after
97    /// `finish()` is called.
98    ///
99    /// Default implementation is a no-op for channels that don't need finish semantics.
100    fn finish_field(&mut self, _field_idx: usize) {}
101
102    /// Consume a specific field (called after `apply_writes()` per superstep)
103    ///
104    /// This marks a channel's value as consumed by the framework after writes
105    /// have been applied. For `EphemeralChannel`, this sets the `consumed` flag
106    /// to `true`, signaling that the value has been read. The consumed flag is
107    /// reset on the next `update()`.
108    ///
109    /// Called in `after_tick()` for each field that changed in the superstep.
110    ///
111    /// Default implementation is a no-op for field types that don't need
112    /// consume semantics.
113    fn consume_field(&mut self, _field_idx: usize) {}
114
115    /// Indices of fields that use the `ephemeral` reducer.
116    ///
117    /// Used by the Pregel engine to call `consume_field()` only for fields
118    /// that need consume semantics, avoiding unnecessary work. The proc-macro
119    /// generates this as a static slice from `#[reducer(ephemeral)]` annotations.
120    /// Default returns an empty slice for manually implemented states.
121    #[must_use]
122    fn consume_field_indices() -> &'static [usize] {
123        &[]
124    }
125
126    /// Schema version for migration
127    #[must_use]
128    fn schema_version() -> u32 {
129        1
130    }
131
132    /// Migrate from older schema version
133    #[must_use]
134    fn migrate(_from_version: u32, value: serde_json::Value) -> serde_json::Value {
135        value
136    }
137
138    /// Indices of fields that use the `replace` reducer.
139    ///
140    /// Used by the Pregel engine to detect multiple writers in a single
141    /// superstep. The proc-macro generates this as a static slice.
142    /// Default returns an empty slice for manually implemented states.
143    #[must_use]
144    fn replace_field_indices() -> &'static [usize] {
145        &[]
146    }
147
148    /// Indices of fields that use the `replace_after_finish` reducer.
149    ///
150    /// Used by the Pregel engine to call `finish_field()` only for fields
151    /// that need finish semantics. The proc-macro generates this as a
152    /// static slice. Default returns an empty slice for manually
153    /// implemented states.
154    #[must_use]
155    fn replace_after_finish_field_indices() -> &'static [usize] {
156        &[]
157    }
158
159    /// Check if a specific field is set (Some) in an update.
160    ///
161    /// Provides efficient field-level inspection without serialization,
162    /// used by the Pregel engine for multi-writer conflict detection.
163    /// The proc-macro generates an optimized match-based implementation.
164    /// Default returns `false` for manually implemented states.
165    #[must_use]
166    fn field_is_set(_update: &Self::Update, _field_idx: usize) -> bool {
167        false
168    }
169
170    /// Number of fields in this state type.
171    ///
172    /// Used by `validate_keys()` to verify that all reducer index arrays
173    /// reference valid field positions. The proc-macro generates a constant
174    /// based on the struct's field count. Default returns `0` for manually
175    /// implemented states.
176    #[must_use]
177    fn field_count() -> usize {
178        0
179    }
180
181    /// Names of fields in declaration order.
182    ///
183    /// Used by `validate_keys()` to report which field indices are invalid.
184    /// The proc-macro generates a static slice of string literals matching
185    /// the struct field names. Default returns an empty slice for manually
186    /// implemented states.
187    #[must_use]
188    fn field_names() -> &'static [&'static str] {
189        &[]
190    }
191
192    /// Field indices and snapshot frequencies for `DeltaChannel` fields.
193    ///
194    /// Returns `(field_index, snapshot_frequency)` pairs identifying which
195    /// fields use [`DeltaChannel`](super::channel::DeltaChannel) and their
196    /// configured snapshot interval. The Pregel engine uses this to track
197    /// per-channel delta counters and decide when to persist a full snapshot
198    /// versus an incremental delta.
199    ///
200    /// The proc-macro generates a static slice from `#[delta(frequency = N)]`
201    /// annotations. Default returns an empty slice for manually implemented
202    /// states (no delta channels).
203    #[must_use]
204    fn delta_channel_specs() -> &'static [(usize, usize)] {
205        &[]
206    }
207}
208
209/// Bitmask tracking which fields changed
210///
211/// Uses `u64` to track up to 64 fields. For states with more fields,
212/// enable the "wide-state" feature to use `FixedBitSet` instead.
213#[derive(Clone, Debug, Default)]
214pub struct FieldsChanged(pub u64);
215
216impl FieldsChanged {
217    #[must_use]
218    pub const fn is_empty(&self) -> bool {
219        self.0 == 0
220    }
221
222    #[must_use]
223    pub const fn has_field(&self, index: usize) -> bool {
224        self.0 & (1 << index) != 0
225    }
226
227    #[allow(
228        clippy::missing_const_for_fn,
229        reason = "mutable methods cannot be const"
230    )]
231    pub fn set_field(&mut self, index: usize) {
232        self.0 |= 1 << index;
233    }
234
235    #[allow(
236        clippy::missing_const_for_fn,
237        reason = "mutable methods cannot be const"
238    )]
239    pub fn merge(&mut self, other: &Self) {
240        self.0 |= other.0;
241    }
242}
243
244/// Copy-on-write state wrapper (default state wrapper)
245///
246/// For large states (e.g., long conversation histories), cloning the entire
247/// state for each node spawn is expensive. `CowState` uses `Arc` to share
248/// immutable state and only clones when modified.
249///
250/// This is the DEFAULT state wrapper in Juncture, not just an optimization.
251#[derive(Debug)]
252pub struct CowState<S: State> {
253    /// Shared immutable state
254    shared: Arc<S>,
255    /// Pending local modifications
256    pending: Option<S::Update>,
257}
258
259impl<S: State> CowState<S> {
260    /// Create `CowState` from shared state
261    #[must_use]
262    pub const fn new(state: Arc<S>) -> Self {
263        Self {
264            shared: state,
265            pending: None,
266        }
267    }
268
269    /// Get current state (read-only)
270    pub fn get(&self) -> &S {
271        &self.shared
272    }
273
274    /// Get mutable access to the state, cloning the inner Arc if shared
275    ///
276    /// Uses clone-on-write semantics: if the Arc reference count is greater
277    /// than one, the inner state is cloned before returning a mutable reference.
278    /// This ensures no other `CowState` instances are affected by mutations.
279    pub fn get_mut(&mut self) -> &mut S
280    where
281        S: Clone,
282    {
283        Arc::make_mut(&mut self.shared)
284    }
285
286    /// Apply an update (deferred until commit)
287    ///
288    /// Note: For proper merge semantics, this implementation simply stores
289    /// the latest update. The proc-macro generates more sophisticated
290    /// merge logic for complex update types.
291    pub fn update(&mut self, changes: S::Update) {
292        // Store the new update, replacing any previous pending update
293        // The State trait's apply() method handles proper merging when commit() is called
294        self.pending = Some(changes);
295    }
296
297    /// Commit updates and return new shared state
298    pub fn commit(self) -> Arc<S> {
299        if let Some(pending) = self.pending {
300            let mut state = (*self.shared).clone();
301            state.apply(pending);
302            Arc::new(state)
303        } else {
304            self.shared
305        }
306    }
307
308    /// Commit updates and return new shared state, propagating reducer errors
309    ///
310    /// Unlike `commit()`, this method returns a structured error when reducer
311    /// constraints are violated (e.g., multiple writers on a replace channel).
312    ///
313    /// # Errors
314    ///
315    /// Returns `InvalidUpdateError` if the update violates reducer constraints.
316    pub fn try_commit(self) -> Result<Arc<S>, InvalidUpdateError> {
317        if let Some(pending) = self.pending {
318            let mut state = (*self.shared).clone();
319            let _changed = state.try_apply(pending)?;
320            Ok(Arc::new(state))
321        } else {
322            Ok(self.shared)
323        }
324    }
325}
326
327impl<S: State> Clone for CowState<S> {
328    fn clone(&self) -> Self {
329        Self {
330            shared: Arc::clone(&self.shared),
331            pending: None,
332        }
333    }
334}
335
336impl<S: State> std::ops::Deref for CowState<S> {
337    type Target = S;
338
339    fn deref(&self) -> &Self::Target {
340        &self.shared
341    }
342}
343
344/// Trait for converting input schema into full State
345///
346/// Types implementing this trait can be used as the input type `I` for
347/// [`StateGraph<S, I, O>`](crate::graph::StateGraph). The default blanket
348/// implementation converts `S` into `S` (identity), ensuring full backward
349/// compatibility when `I = S`.
350pub trait IntoState<S: State>: Clone + Send + Sync + 'static {
351    /// Convert `self` into the full state type `S`.
352    fn into_state(self) -> S;
353}
354
355/// Blanket implementation: any `State` type converts to itself (identity).
356impl<S: State> IntoState<S> for S {
357    fn into_state(self) -> S {
358        self
359    }
360}
361
362/// Trait for extracting output schema from full State
363///
364/// Types implementing this trait can be used as the output type `O` for
365/// [`StateGraph<S, I, O>`](crate::graph::StateGraph). The default blanket
366/// implementation extracts `S` from `S` via `Clone`, ensuring full backward
367/// compatibility when `O = S`.
368pub trait FromState<S: State>: Clone + Send + Sync + 'static {
369    /// Extract `Self` from a reference to the full state type `S`.
370    fn from_state(state: &S) -> Self;
371}
372
373/// Blanket implementation: any `State` type extracts from itself via `Clone`.
374impl<S: State> FromState<S> for S {
375    fn from_state(state: &S) -> Self {
376        state.clone()
377    }
378}
379
380// Rust guideline compliant 2026-05-22