Skip to main content

lellm_graph/
delta.rs

1//! StateDelta + Reducer — 键级状态增量与合并策略。
2//!
3//! 从 lellm-runtime 合并到 lellm-graph。
4
5use std::borrow::Cow;
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::state::StateError;
11
12/// Delta 来源 — 追踪谁产生了这个修改。
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub enum DeltaSource {
15    /// 节点执行产生
16    Node { node_id: String },
17    /// Agent Hook 产生
18    Hook { node_id: String, hook_name: String },
19    /// Reducer 合并产生
20    ReducerMerge,
21    /// 恢复时重放
22    ResumeReplay,
23}
24
25impl std::fmt::Display for DeltaSource {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            DeltaSource::Node { node_id } => write!(f, "node:{}", node_id),
29            DeltaSource::Hook { node_id, hook_name } => {
30                write!(f, "hook:{}:{}", node_id, hook_name)
31            }
32            DeltaSource::ReducerMerge => write!(f, "reducer_merge"),
33            DeltaSource::ResumeReplay => write!(f, "resume_replay"),
34        }
35    }
36}
37
38/// 状态增量 — 节点对 State 的修改意图。
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct StateDelta {
41    pub key: Cow<'static, str>,
42    pub op: DeltaOp,
43    pub value: Value,
44    pub source: DeltaSource,
45}
46
47impl StateDelta {
48    pub fn put(key: impl Into<String>, value: Value) -> Self {
49        Self {
50            key: Cow::Owned(key.into()),
51            op: DeltaOp::Put,
52            value,
53            source: DeltaSource::Node {
54                node_id: String::new(),
55            },
56        }
57    }
58
59    pub fn delete(key: impl Into<String>) -> Self {
60        Self {
61            key: Cow::Owned(key.into()),
62            op: DeltaOp::Delete,
63            value: Value::Null,
64            source: DeltaSource::Node {
65                node_id: String::new(),
66            },
67        }
68    }
69
70    pub fn put_with_source(key: impl Into<String>, value: Value, source: DeltaSource) -> Self {
71        Self {
72            key: Cow::Owned(key.into()),
73            op: DeltaOp::Put,
74            value,
75            source,
76        }
77    }
78
79    pub fn delete_with_source(key: impl Into<String>, source: DeltaSource) -> Self {
80        Self {
81            key: Cow::Owned(key.into()),
82            op: DeltaOp::Delete,
83            value: Value::Null,
84            source,
85        }
86    }
87
88    pub fn with_writer(mut self, writer: impl Into<String>) -> Self {
89        self.source = DeltaSource::Node {
90            node_id: writer.into(),
91        };
92        self
93    }
94
95    pub fn with_source(mut self, source: DeltaSource) -> Self {
96        self.source = source;
97        self
98    }
99}
100
101/// Delta 操作类型。
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
103pub enum DeltaOp {
104    /// 覆盖写入
105    Put,
106    /// 删除
107    Delete,
108}
109
110/// Reducer 枚举 — 描述"这个 key 允许怎么合并"。
111#[allow(unpredictable_function_pointer_comparisons)]
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum Reducer {
114    /// 冲突即报错
115    Error,
116    /// 最后写入者胜
117    Replace,
118    /// 数组追加
119    Append,
120    /// 对象浅合并
121    MergeObject,
122    /// 数值求和
123    Sum,
124    /// 取最大值
125    Max,
126    /// 取最小值
127    Min,
128    /// 自定义合并函数
129    Custom(fn(&Value, &Value) -> Result<Value, String>),
130}
131
132impl std::fmt::Display for Reducer {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        match self {
135            Reducer::Error => write!(f, "Error"),
136            Reducer::Replace => write!(f, "Replace"),
137            Reducer::Append => write!(f, "Append"),
138            Reducer::MergeObject => write!(f, "MergeObject"),
139            Reducer::Sum => write!(f, "Sum"),
140            Reducer::Max => write!(f, "Max"),
141            Reducer::Min => write!(f, "Min"),
142            Reducer::Custom(_) => write!(f, "Custom"),
143        }
144    }
145}
146
147/// 自定义 Reducer 闭包类型。
148type CustomReducerFn = Box<dyn Fn(&Value, &Value) -> Result<Value, String> + Send + Sync>;
149
150/// Reducer 注册表 — 管理每个 key 的合并策略。
151#[derive(Default)]
152pub struct ReducerRegistry {
153    reducers: std::collections::HashMap<String, Reducer>,
154    custom_reducers: std::collections::HashMap<String, CustomReducerFn>,
155}
156
157impl std::fmt::Debug for ReducerRegistry {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("ReducerRegistry")
160            .field("reducers", &self.reducers)
161            .field(
162                "custom_reducers",
163                &format!("{} entries", self.custom_reducers.len()),
164            )
165            .finish()
166    }
167}
168
169impl ReducerRegistry {
170    pub fn new() -> Self {
171        Self::default()
172    }
173
174    pub fn register(&mut self, key: &str, reducer: Reducer) {
175        self.reducers.insert(key.to_string(), reducer);
176    }
177
178    pub fn register_custom(
179        &mut self,
180        key: &str,
181        f: impl Fn(&Value, &Value) -> Result<Value, String> + Send + Sync + 'static,
182    ) {
183        self.custom_reducers.insert(key.to_string(), Box::new(f));
184    }
185
186    pub fn get(&self, key: &str) -> &Reducer {
187        self.reducers.get(key).unwrap_or(&Reducer::Error)
188    }
189
190    pub fn apply_custom(
191        &self,
192        key: &str,
193        existing: &Value,
194        new_val: &Value,
195    ) -> Result<Option<Value>, String> {
196        if let Some(f) = self.custom_reducers.get(key) {
197            Ok(Some(f(existing, new_val)?))
198        } else {
199            Ok(None)
200        }
201    }
202
203    pub fn apply_delta(
204        &self,
205        state: &mut std::collections::HashMap<String, Value>,
206        delta: &StateDelta,
207    ) -> Result<(), StateError> {
208        match delta.op {
209            DeltaOp::Put => {
210                state.insert(delta.key.to_string(), delta.value.clone());
211            }
212            DeltaOp::Delete => {
213                state.remove(delta.key.as_ref());
214            }
215        }
216        Ok(())
217    }
218
219    pub fn merge_deltas(
220        &self,
221        state: &mut std::collections::HashMap<String, Value>,
222        deltas: &[StateDelta],
223    ) -> Result<(), StateError> {
224        let mut grouped: std::collections::HashMap<&str, Vec<&StateDelta>> =
225            std::collections::HashMap::new();
226        for delta in deltas {
227            grouped.entry(&delta.key).or_default().push(delta);
228        }
229
230        for (key, key_deltas) in grouped {
231            if key_deltas.len() > 1 {
232                self.merge_by_reducer(state, key, &key_deltas, self.get(key))?;
233            } else if let Some(delta) = key_deltas.first() {
234                self.apply_delta(state, delta)?;
235            }
236        }
237
238        Ok(())
239    }
240
241    fn merge_by_reducer(
242        &self,
243        state: &mut std::collections::HashMap<String, Value>,
244        key: &str,
245        key_deltas: &[&StateDelta],
246        reducer: &Reducer,
247    ) -> Result<(), StateError> {
248        match reducer {
249            Reducer::Error => {
250                let writers: Vec<String> =
251                    key_deltas.iter().map(|d| d.source.to_string()).collect();
252                Err(StateError::StateConflict {
253                    key: key.to_string(),
254                    writers,
255                })
256            }
257            Reducer::Replace => {
258                if let Some(last) = key_deltas.last() {
259                    state.insert(key.to_string(), last.value.clone());
260                }
261                Ok(())
262            }
263            Reducer::Append => {
264                let mut all_items = Vec::new();
265                for d in key_deltas {
266                    if let Some(arr) = d.value.as_array() {
267                        all_items.extend(arr.iter().cloned());
268                    }
269                }
270                if let Some(existing) = state.get(key).and_then(|v| v.as_array()) {
271                    let mut merged = existing.clone();
272                    merged.extend(all_items);
273                    state.insert(key.to_string(), Value::Array(merged));
274                } else if !all_items.is_empty() {
275                    state.insert(key.to_string(), Value::Array(all_items));
276                }
277                Ok(())
278            }
279            Reducer::MergeObject => {
280                let mut merged = state
281                    .get(key)
282                    .and_then(|v| v.as_object().cloned())
283                    .unwrap_or_default();
284                for d in key_deltas {
285                    if let Some(obj) = d.value.as_object() {
286                        for (k, v) in obj {
287                            merged.insert(k.clone(), v.clone());
288                        }
289                    }
290                }
291                state.insert(key.to_string(), Value::Object(merged));
292                Ok(())
293            }
294            Reducer::Sum | Reducer::Max | Reducer::Min => {
295                let existing_val = state.get(key).and_then(|v| v.as_f64()).unwrap_or(0.0);
296                let values: Vec<f64> = key_deltas.iter().filter_map(|d| d.value.as_f64()).collect();
297
298                let result = if values.is_empty() {
299                    existing_val
300                } else {
301                    let sum: f64 = values.iter().sum();
302                    match reducer {
303                        Reducer::Sum => existing_val + sum,
304                        Reducer::Max => existing_val.max(
305                            *values
306                                .iter()
307                                .max_by(|a, b| a.partial_cmp(b).unwrap())
308                                .unwrap(),
309                        ),
310                        Reducer::Min => existing_val.min(
311                            *values
312                                .iter()
313                                .min_by(|a, b| a.partial_cmp(b).unwrap())
314                                .unwrap(),
315                        ),
316                        _ => unreachable!(),
317                    }
318                };
319                state.insert(key.to_string(), Value::from(result));
320                Ok(())
321            }
322            Reducer::Custom(f) => {
323                let mut current = state.get(key).cloned().unwrap_or(Value::Null);
324                for d in key_deltas {
325                    current = f(&current, &d.value)
326                        .map_err(|e| StateError::ReducerConflict(key.to_string(), e))?;
327                }
328                state.insert(key.to_string(), current);
329                Ok(())
330            }
331        }
332    }
333}