1use std::borrow::Cow;
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::state::StateError;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum DeltaSource {
15 Node { node_id: String },
17 Hook { node_id: String, hook_name: String },
19 ReducerMerge,
21 ResumeReplay,
23}
24
25impl std::fmt::Display for DeltaSource {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 match self {
28 DeltaSource::Node { node_id } => write!(f, "node:{}", node_id),
29 DeltaSource::Hook { node_id, hook_name } => {
30 write!(f, "hook:{}:{}", node_id, hook_name)
31 }
32 DeltaSource::ReducerMerge => write!(f, "reducer_merge"),
33 DeltaSource::ResumeReplay => write!(f, "resume_replay"),
34 }
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct StateDelta {
41 pub key: Cow<'static, str>,
42 pub op: DeltaOp,
43 pub value: Value,
44 pub source: DeltaSource,
45}
46
47impl StateDelta {
48 pub fn put(key: impl Into<String>, value: Value) -> Self {
49 Self {
50 key: Cow::Owned(key.into()),
51 op: DeltaOp::Put,
52 value,
53 source: DeltaSource::Node {
54 node_id: String::new(),
55 },
56 }
57 }
58
59 pub fn delete(key: impl Into<String>) -> Self {
60 Self {
61 key: Cow::Owned(key.into()),
62 op: DeltaOp::Delete,
63 value: Value::Null,
64 source: DeltaSource::Node {
65 node_id: String::new(),
66 },
67 }
68 }
69
70 pub fn put_with_source(key: impl Into<String>, value: Value, source: DeltaSource) -> Self {
71 Self {
72 key: Cow::Owned(key.into()),
73 op: DeltaOp::Put,
74 value,
75 source,
76 }
77 }
78
79 pub fn delete_with_source(key: impl Into<String>, source: DeltaSource) -> Self {
80 Self {
81 key: Cow::Owned(key.into()),
82 op: DeltaOp::Delete,
83 value: Value::Null,
84 source,
85 }
86 }
87
88 pub fn with_writer(mut self, writer: impl Into<String>) -> Self {
89 self.source = DeltaSource::Node {
90 node_id: writer.into(),
91 };
92 self
93 }
94
95 pub fn with_source(mut self, source: DeltaSource) -> Self {
96 self.source = source;
97 self
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
103pub enum DeltaOp {
104 Put,
106 Delete,
108}
109
110#[allow(unpredictable_function_pointer_comparisons)]
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum Reducer {
114 Error,
116 Replace,
118 Append,
120 MergeObject,
122 Sum,
124 Max,
126 Min,
128 Custom(fn(&Value, &Value) -> Result<Value, String>),
130}
131
132impl std::fmt::Display for Reducer {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 match self {
135 Reducer::Error => write!(f, "Error"),
136 Reducer::Replace => write!(f, "Replace"),
137 Reducer::Append => write!(f, "Append"),
138 Reducer::MergeObject => write!(f, "MergeObject"),
139 Reducer::Sum => write!(f, "Sum"),
140 Reducer::Max => write!(f, "Max"),
141 Reducer::Min => write!(f, "Min"),
142 Reducer::Custom(_) => write!(f, "Custom"),
143 }
144 }
145}
146
147type CustomReducerFn = Box<dyn Fn(&Value, &Value) -> Result<Value, String> + Send + Sync>;
149
150#[derive(Default)]
152pub struct ReducerRegistry {
153 reducers: std::collections::HashMap<String, Reducer>,
154 custom_reducers: std::collections::HashMap<String, CustomReducerFn>,
155}
156
157impl std::fmt::Debug for ReducerRegistry {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("ReducerRegistry")
160 .field("reducers", &self.reducers)
161 .field(
162 "custom_reducers",
163 &format!("{} entries", self.custom_reducers.len()),
164 )
165 .finish()
166 }
167}
168
169impl ReducerRegistry {
170 pub fn new() -> Self {
171 Self::default()
172 }
173
174 pub fn register(&mut self, key: &str, reducer: Reducer) {
175 self.reducers.insert(key.to_string(), reducer);
176 }
177
178 pub fn register_custom(
179 &mut self,
180 key: &str,
181 f: impl Fn(&Value, &Value) -> Result<Value, String> + Send + Sync + 'static,
182 ) {
183 self.custom_reducers.insert(key.to_string(), Box::new(f));
184 }
185
186 pub fn get(&self, key: &str) -> &Reducer {
187 self.reducers.get(key).unwrap_or(&Reducer::Error)
188 }
189
190 pub fn apply_custom(
191 &self,
192 key: &str,
193 existing: &Value,
194 new_val: &Value,
195 ) -> Result<Option<Value>, String> {
196 if let Some(f) = self.custom_reducers.get(key) {
197 Ok(Some(f(existing, new_val)?))
198 } else {
199 Ok(None)
200 }
201 }
202
203 pub fn apply_delta(
204 &self,
205 state: &mut std::collections::HashMap<String, Value>,
206 delta: &StateDelta,
207 ) -> Result<(), StateError> {
208 match delta.op {
209 DeltaOp::Put => {
210 state.insert(delta.key.to_string(), delta.value.clone());
211 }
212 DeltaOp::Delete => {
213 state.remove(delta.key.as_ref());
214 }
215 }
216 Ok(())
217 }
218
219 pub fn merge_deltas(
220 &self,
221 state: &mut std::collections::HashMap<String, Value>,
222 deltas: &[StateDelta],
223 ) -> Result<(), StateError> {
224 let mut grouped: std::collections::HashMap<&str, Vec<&StateDelta>> =
225 std::collections::HashMap::new();
226 for delta in deltas {
227 grouped.entry(&delta.key).or_default().push(delta);
228 }
229
230 for (key, key_deltas) in grouped {
231 if key_deltas.len() > 1 {
232 self.merge_by_reducer(state, key, &key_deltas, self.get(key))?;
233 } else if let Some(delta) = key_deltas.first() {
234 self.apply_delta(state, delta)?;
235 }
236 }
237
238 Ok(())
239 }
240
241 fn merge_by_reducer(
242 &self,
243 state: &mut std::collections::HashMap<String, Value>,
244 key: &str,
245 key_deltas: &[&StateDelta],
246 reducer: &Reducer,
247 ) -> Result<(), StateError> {
248 match reducer {
249 Reducer::Error => {
250 let writers: Vec<String> =
251 key_deltas.iter().map(|d| d.source.to_string()).collect();
252 Err(StateError::StateConflict {
253 key: key.to_string(),
254 writers,
255 })
256 }
257 Reducer::Replace => {
258 if let Some(last) = key_deltas.last() {
259 state.insert(key.to_string(), last.value.clone());
260 }
261 Ok(())
262 }
263 Reducer::Append => {
264 let mut all_items = Vec::new();
265 for d in key_deltas {
266 if let Some(arr) = d.value.as_array() {
267 all_items.extend(arr.iter().cloned());
268 }
269 }
270 if let Some(existing) = state.get(key).and_then(|v| v.as_array()) {
271 let mut merged = existing.clone();
272 merged.extend(all_items);
273 state.insert(key.to_string(), Value::Array(merged));
274 } else if !all_items.is_empty() {
275 state.insert(key.to_string(), Value::Array(all_items));
276 }
277 Ok(())
278 }
279 Reducer::MergeObject => {
280 let mut merged = state
281 .get(key)
282 .and_then(|v| v.as_object().cloned())
283 .unwrap_or_default();
284 for d in key_deltas {
285 if let Some(obj) = d.value.as_object() {
286 for (k, v) in obj {
287 merged.insert(k.clone(), v.clone());
288 }
289 }
290 }
291 state.insert(key.to_string(), Value::Object(merged));
292 Ok(())
293 }
294 Reducer::Sum | Reducer::Max | Reducer::Min => {
295 let existing_val = state.get(key).and_then(|v| v.as_f64()).unwrap_or(0.0);
296 let values: Vec<f64> = key_deltas.iter().filter_map(|d| d.value.as_f64()).collect();
297
298 let result = if values.is_empty() {
299 existing_val
300 } else {
301 let sum: f64 = values.iter().sum();
302 match reducer {
303 Reducer::Sum => existing_val + sum,
304 Reducer::Max => existing_val.max(
305 *values
306 .iter()
307 .max_by(|a, b| a.partial_cmp(b).unwrap())
308 .unwrap(),
309 ),
310 Reducer::Min => existing_val.min(
311 *values
312 .iter()
313 .min_by(|a, b| a.partial_cmp(b).unwrap())
314 .unwrap(),
315 ),
316 _ => unreachable!(),
317 }
318 };
319 state.insert(key.to_string(), Value::from(result));
320 Ok(())
321 }
322 Reducer::Custom(f) => {
323 let mut current = state.get(key).cloned().unwrap_or(Value::Null);
324 for d in key_deltas {
325 current = f(¤t, &d.value)
326 .map_err(|e| StateError::ReducerConflict(key.to_string(), e))?;
327 }
328 state.insert(key.to_string(), current);
329 Ok(())
330 }
331 }
332 }
333}