Skip to main content

lellm_graph/
state.rs

1//! State 和执行结果。
2//!
3//! 包含 Graph 共享状态的核心类型(从 lellm-runtime 合并)和 Graph 特有的执行结果类型。
4//!
5//! v0.4+: `State` 从 type alias 改为 struct wrapper,以便实现 `WorkflowState` trait。
6//! 通过 `Deref`/`DerefMut` 保持对 `HashMap<String, Value>` 的完全兼容。
7
8use std::collections::HashMap;
9use std::ops::{Deref, DerefMut};
10use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
11
12use serde_json::Value;
13
14// ─── State 类型 ─────────────────────────────────────────────────
15
16/// Graph 共享状态 — struct wrapper,支持 `WorkflowState` trait。
17///
18/// 通过 `Deref`/`DerefMut` 完全兼容 `HashMap<String, Value>` API。
19/// 所有现有代码无需修改。
20#[derive(Debug, Clone, Default)]
21pub struct State {
22    inner: HashMap<String, Value>,
23}
24
25/// 手动实现 Serialize/Deserialize — 序列化底层 HashMap,保持兼容。
26impl serde::Serialize for State {
27    fn serialize<SER: serde::Serializer>(&self, serializer: SER) -> Result<SER::Ok, SER::Error> {
28        self.inner.serialize(serializer)
29    }
30}
31
32impl<'de> serde::Deserialize<'de> for State {
33    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
34        let map = HashMap::deserialize(deserializer)?;
35        Ok(State { inner: map })
36    }
37}
38
39impl State {
40    /// 创建空状态。
41    pub fn new() -> Self {
42        Self {
43            inner: HashMap::new(),
44        }
45    }
46}
47
48impl Deref for State {
49    type Target = HashMap<String, Value>;
50
51    fn deref(&self) -> &Self::Target {
52        &self.inner
53    }
54}
55
56impl DerefMut for State {
57    fn deref_mut(&mut self) -> &mut Self::Target {
58        &mut self.inner
59    }
60}
61
62impl From<HashMap<String, Value>> for State {
63    fn from(map: HashMap<String, Value>) -> Self {
64        Self { inner: map }
65    }
66}
67
68impl From<State> for HashMap<String, Value> {
69    fn from(state: State) -> Self {
70        state.inner
71    }
72}
73
74// ─── WorkflowState for State ────────────────────────────────────
75
76/// State 的 Mutation — HashMap 级别的变更。
77#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
78pub enum StateMutation {
79    /// 设置 key-value
80    Put(String, Value),
81    /// 删除 key
82    Delete(String),
83}
84
85impl crate::workflow_state::StateMutation<State> for StateMutation {
86    fn apply(self, state: &mut State) {
87        match self {
88            StateMutation::Put(key, value) => {
89                state.insert(key, value);
90            }
91            StateMutation::Delete(key) => {
92                state.remove(&key);
93            }
94        }
95    }
96}
97
98impl crate::workflow_state::WorkflowState for State {
99    /// State 本身就是可序列化的 Checkpoint(向后兼容)。
100    type Checkpoint = State;
101    type Mutation = StateMutation;
102
103    fn snapshot(&self) -> State {
104        self.clone()
105    }
106
107    fn restore(checkpoint: State) -> Self {
108        checkpoint
109    }
110}
111
112/// State 的默认合并策略 — 逐 key 合并,后续分支覆盖同 key。
113#[derive(Clone)]
114pub struct StateMerge;
115
116impl crate::workflow_state::MergeStrategy<State> for StateMerge {
117    fn merge(branches: Vec<State>) -> Result<State, crate::workflow_state::WorkflowError> {
118        let mut merged: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
119        for state in branches {
120            merged.extend(state.inner);
121        }
122        Ok(State {
123            inner: merged.into_iter().collect(),
124        })
125    }
126
127    fn default_instance() -> Self {
128        StateMerge
129    }
130}
131
132// ─── StateError ─────────────────────────────────────────────────
133
134/// State 操作错误。
135#[derive(Debug, thiserror::Error)]
136pub enum StateError {
137    /// Key 不存在
138    #[error("state key '{0}' is missing")]
139    MissingKey(String),
140
141    /// 反序列化失败
142    #[error("failed to deserialize state key '{0}': {1}")]
143    Deserialize(String, String),
144
145    /// Reducer 合并失败
146    #[error("reducer conflict on key '{0}': {1}")]
147    ReducerConflict(String, String),
148
149    /// Delta 应用失败(类型不匹配等)
150    #[error("failed to apply delta on key '{0}': {1}")]
151    DeltaApply(String, String),
152
153    /// 并行状态冲突
154    #[error("state conflict on key '{key}': concurrent writers [{}]", writers.join(", "))]
155    StateConflict { key: String, writers: Vec<String> },
156}
157
158// ─── StateExt ───────────────────────────────────────────────────
159
160/// State 扩展方法 trait。
161///
162/// 为 `State` 提供类型安全的读写方法。
163pub trait StateExt {
164    fn get_str(&self, key: &str) -> Option<&str>;
165    fn get_bool(&self, key: &str) -> Option<bool>;
166    fn get_u64(&self, key: &str) -> Option<u64>;
167    fn get_i64(&self, key: &str) -> Option<i64>;
168    fn get_f64(&self, key: &str) -> Option<f64>;
169
170    fn get_json<T>(&self, key: &str) -> Result<T, StateError>
171    where
172        T: serde::de::DeserializeOwned;
173
174    fn require<T>(&self, key: &str) -> Result<T, StateError>
175    where
176        T: serde::de::DeserializeOwned;
177
178    fn set<T>(&mut self, key: impl Into<String>, value: T)
179    where
180        T: serde::Serialize;
181
182    fn remove(&mut self, key: &str) -> Option<serde_json::Value>;
183    fn contains(&self, key: &str) -> bool;
184
185    fn reduce(
186        &mut self,
187        key: &str,
188        value: serde_json::Value,
189        reducer: &StateReducer,
190    ) -> Result<(), String>;
191
192    fn append_array(&mut self, key: &str, items: serde_json::Value) -> Result<(), String>;
193}
194
195impl StateExt for State {
196    fn get_str(&self, key: &str) -> Option<&str> {
197        self.inner.get(key).and_then(|v| v.as_str())
198    }
199
200    fn get_bool(&self, key: &str) -> Option<bool> {
201        self.inner.get(key).and_then(|v| v.as_bool())
202    }
203
204    fn get_u64(&self, key: &str) -> Option<u64> {
205        self.inner.get(key).and_then(|v| v.as_u64())
206    }
207
208    fn get_i64(&self, key: &str) -> Option<i64> {
209        self.inner.get(key).and_then(|v| v.as_i64())
210    }
211
212    fn get_f64(&self, key: &str) -> Option<f64> {
213        self.inner.get(key).and_then(|v| v.as_f64())
214    }
215
216    fn get_json<T>(&self, key: &str) -> Result<T, StateError>
217    where
218        T: serde::de::DeserializeOwned,
219    {
220        let value = self
221            .inner
222            .get(key)
223            .ok_or_else(|| StateError::MissingKey(key.to_string()))?;
224        serde_json::from_value(value.clone())
225            .map_err(|e| StateError::Deserialize(key.to_string(), e.to_string()))
226    }
227
228    fn require<T>(&self, key: &str) -> Result<T, StateError>
229    where
230        T: serde::de::DeserializeOwned,
231    {
232        self.get_json(key)
233    }
234
235    fn set<T>(&mut self, key: impl Into<String>, value: T)
236    where
237        T: serde::Serialize,
238    {
239        let key_str = key.into();
240        let json = match serde_json::to_value(value) {
241            Ok(v) => v,
242            Err(e) => {
243                tracing::warn!(key = %key_str, error = %e, "failed to serialize state value, storing null");
244                serde_json::Value::Null
245            }
246        };
247        self.inner.insert(key_str, json);
248    }
249
250    fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
251        self.inner.remove(key)
252    }
253
254    fn contains(&self, key: &str) -> bool {
255        self.inner.contains_key(key)
256    }
257
258    fn reduce(
259        &mut self,
260        key: &str,
261        value: serde_json::Value,
262        reducer: &StateReducer,
263    ) -> Result<(), String> {
264        if let Some(existing) = self.inner.get(key) {
265            let merged = reducer(existing, &value)?;
266            self.inner.insert(key.to_string(), merged);
267        } else {
268            self.inner.insert(key.to_string(), value);
269        }
270        Ok(())
271    }
272
273    fn append_array(&mut self, key: &str, items: serde_json::Value) -> Result<(), String> {
274        let new_items = items.as_array().ok_or("append_array expects an array")?;
275        if let Some(existing) = self.inner.get(key) {
276            let mut arr = existing
277                .as_array()
278                .ok_or("append_array: existing value is not an array")?
279                .clone();
280            arr.extend(new_items.iter().cloned());
281            self.inner
282                .insert(key.to_string(), serde_json::Value::Array(arr));
283        } else {
284            self.inner.insert(key.to_string(), items);
285        }
286        Ok(())
287    }
288}
289
290/// State Reducer 类型别名 — 将已有值与新值合并。
291pub type StateReducer = Box<dyn Fn(&Value, &Value) -> Result<Value, String> + Send + Sync>;
292
293/// 内置 Reducer:数组追加。
294pub fn array_reducer(existing: &Value, new: &Value) -> Result<Value, String> {
295    let base = existing
296        .as_array()
297        .ok_or("array_reducer: existing is not an array")?;
298    let items = new
299        .as_array()
300        .ok_or("array_reducer: new value is not an array")?;
301    let mut merged = base.clone();
302    merged.extend(items.iter().cloned());
303    Ok(Value::Array(merged))
304}
305
306// ─── GraphResult ────────────────────────────────────────────────
307
308/// Graph 执行结果。
309///
310/// # 泛型
311///
312/// - `S` — 类型化状态(默认 `State` = HashMap,向后兼容)
313#[derive(Debug)]
314pub struct GraphResult<S: crate::workflow_state::WorkflowState = State> {
315    /// 执行追踪 ID(关联本次执行的所有 SpanId)
316    pub trace_id: crate::ids::TraceId,
317    /// 最终状态
318    pub state: S,
319    /// 执行日志
320    pub execution_log: Vec<ExecutionEntry>,
321    /// 执行耗时
322    pub duration: Duration,
323    /// 执行追踪(可选,启用 TraceSink 时填充)
324    pub trace: Option<crate::trace::ExecutionTrace<S::Mutation>>,
325}
326
327/// 单个节点执行记录。
328///
329/// 运行时使用 `Instant` 精确计时,序列化时转换为 ISO-8601 字符串。
330#[derive(Debug, Clone)]
331pub struct ExecutionEntry {
332    /// 全局步数(第几步)
333    pub step: usize,
334    /// 节点名称
335    pub node_name: String,
336    /// 开始时间
337    pub start_time: Instant,
338    /// 结束时间
339    pub end_time: Instant,
340    /// 是否成功
341    pub success: bool,
342    /// 错误信息(失败时)
343    pub error: Option<String>,
344}
345
346impl ExecutionEntry {
347    /// 执行耗时
348    pub fn elapsed(&self) -> Duration {
349        self.end_time.duration_since(self.start_time)
350    }
351
352    /// 序列化为 JSON Value(Instant → ISO-8601 字符串)。
353    /// 供 Checkpoint 持久化使用。
354    pub fn to_json_value(&self) -> serde_json::Value {
355        serde_json::json!({
356            "step": self.step,
357            "node_name": self.node_name,
358            "start_time": instant_to_iso(&self.start_time),
359            "end_time": instant_to_iso(&self.end_time),
360            "success": self.success,
361            "error": self.error,
362        })
363    }
364}
365
366/// 将 Instant 转换为 ISO-8601 时间戳字符串。
367/// 使用 UNIX_EPOCH 近似计算,不依赖 chrono。
368fn instant_to_iso(instant: &Instant) -> String {
369    let now_secs = SystemTime::now()
370        .duration_since(UNIX_EPOCH)
371        .unwrap_or_default()
372        .as_secs();
373    let elapsed_secs = instant.elapsed().as_secs();
374    let secs = now_secs.saturating_sub(elapsed_secs);
375    format!(
376        "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
377        ((secs / 86400 / 365) + 1970) as u16,
378        ((secs / 86400 % 365) / 30 + 1) as u8,
379        (secs / 86400 % 30 + 1) as u8,
380        (secs % 86400 / 3600) as u8,
381        (secs % 3600 / 60) as u8,
382        (secs % 60) as u8
383    )
384}