Skip to main content

lellm_graph/
state.rs

1//! State 和执行结果。
2//!
3//! 包含 Graph 共享状态的核心类型(从 lellm-runtime 合并)和 Graph 特有的执行结果类型。
4
5use std::collections::HashMap;
6use std::time::{Duration, Instant};
7
8use serde_json::Value;
9
10// ─── State 类型 ─────────────────────────────────────────────────
11
12/// Graph 共享状态。
13pub type State = HashMap<String, Value>;
14
15// ─── StateError ─────────────────────────────────────────────────
16
17/// State 操作错误。
18#[derive(Debug, thiserror::Error)]
19pub enum StateError {
20    /// Key 不存在
21    #[error("state key '{0}' is missing")]
22    MissingKey(String),
23
24    /// 反序列化失败
25    #[error("failed to deserialize state key '{0}': {1}")]
26    Deserialize(String, String),
27
28    /// Reducer 合并失败
29    #[error("reducer conflict on key '{0}': {1}")]
30    ReducerConflict(String, String),
31
32    /// Delta 应用失败(类型不匹配等)
33    #[error("failed to apply delta on key '{0}': {1}")]
34    DeltaApply(String, String),
35
36    /// 并行状态冲突
37    #[error("state conflict on key '{key}': concurrent writers [{}]", writers.join(", "))]
38    StateConflict { key: String, writers: Vec<String> },
39}
40
41// ─── StateExt ───────────────────────────────────────────────────
42
43/// State 扩展方法 trait。
44///
45/// 为 `State`(`HashMap<String, Value>`)提供类型安全的读写方法。
46pub trait StateExt {
47    fn get_str(&self, key: &str) -> Option<&str>;
48    fn get_bool(&self, key: &str) -> Option<bool>;
49    fn get_u64(&self, key: &str) -> Option<u64>;
50    fn get_i64(&self, key: &str) -> Option<i64>;
51    fn get_f64(&self, key: &str) -> Option<f64>;
52
53    fn get_json<T>(&self, key: &str) -> Result<T, StateError>
54    where
55        T: serde::de::DeserializeOwned;
56
57    fn require<T>(&self, key: &str) -> Result<T, StateError>
58    where
59        T: serde::de::DeserializeOwned;
60
61    fn set<T>(&mut self, key: impl Into<String>, value: T)
62    where
63        T: serde::Serialize;
64
65    fn remove(&mut self, key: &str) -> Option<serde_json::Value>;
66    fn contains(&self, key: &str) -> bool;
67
68    fn reduce(
69        &mut self,
70        key: &str,
71        value: serde_json::Value,
72        reducer: &StateReducer,
73    ) -> Result<(), String>;
74
75    fn append_array(&mut self, key: &str, items: serde_json::Value) -> Result<(), String>;
76}
77
78impl StateExt for State {
79    fn get_str(&self, key: &str) -> Option<&str> {
80        self.get(key).and_then(|v| v.as_str())
81    }
82
83    fn get_bool(&self, key: &str) -> Option<bool> {
84        self.get(key).and_then(|v| v.as_bool())
85    }
86
87    fn get_u64(&self, key: &str) -> Option<u64> {
88        self.get(key).and_then(|v| v.as_u64())
89    }
90
91    fn get_i64(&self, key: &str) -> Option<i64> {
92        self.get(key).and_then(|v| v.as_i64())
93    }
94
95    fn get_f64(&self, key: &str) -> Option<f64> {
96        self.get(key).and_then(|v| v.as_f64())
97    }
98
99    fn get_json<T>(&self, key: &str) -> Result<T, StateError>
100    where
101        T: serde::de::DeserializeOwned,
102    {
103        let value = self
104            .get(key)
105            .ok_or_else(|| StateError::MissingKey(key.to_string()))?;
106        serde_json::from_value(value.clone())
107            .map_err(|e| StateError::Deserialize(key.to_string(), e.to_string()))
108    }
109
110    fn require<T>(&self, key: &str) -> Result<T, StateError>
111    where
112        T: serde::de::DeserializeOwned,
113    {
114        self.get_json(key)
115    }
116
117    fn set<T>(&mut self, key: impl Into<String>, value: T)
118    where
119        T: serde::Serialize,
120    {
121        let key_str = key.into();
122        let json = match serde_json::to_value(value) {
123            Ok(v) => v,
124            Err(e) => {
125                tracing::warn!(key = %key_str, error = %e, "failed to serialize state value, storing null");
126                serde_json::Value::Null
127            }
128        };
129        HashMap::insert(self, key_str, json);
130    }
131
132    fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
133        HashMap::remove(self, key)
134    }
135
136    fn contains(&self, key: &str) -> bool {
137        self.contains_key(key)
138    }
139
140    fn reduce(
141        &mut self,
142        key: &str,
143        value: serde_json::Value,
144        reducer: &StateReducer,
145    ) -> Result<(), String> {
146        if let Some(existing) = self.get(key) {
147            let merged = reducer(existing, &value)?;
148            self.insert(key.to_string(), merged);
149        } else {
150            self.insert(key.to_string(), value);
151        }
152        Ok(())
153    }
154
155    fn append_array(&mut self, key: &str, items: serde_json::Value) -> Result<(), String> {
156        let new_items = items.as_array().ok_or("append_array expects an array")?;
157        if let Some(existing) = self.get(key) {
158            let mut arr = existing
159                .as_array()
160                .ok_or("append_array: existing value is not an array")?
161                .clone();
162            arr.extend(new_items.iter().cloned());
163            self.insert(key.to_string(), serde_json::Value::Array(arr));
164        } else {
165            self.insert(key.to_string(), items);
166        }
167        Ok(())
168    }
169}
170
171/// State Reducer 类型别名 — 将已有值与新值合并。
172pub type StateReducer = Box<dyn Fn(&Value, &Value) -> Result<Value, String> + Send + Sync>;
173
174/// 内置 Reducer:数组追加。
175pub fn array_reducer(existing: &Value, new: &Value) -> Result<Value, String> {
176    let base = existing
177        .as_array()
178        .ok_or("array_reducer: existing is not an array")?;
179    let items = new
180        .as_array()
181        .ok_or("array_reducer: new value is not an array")?;
182    let mut merged = base.clone();
183    merged.extend(items.iter().cloned());
184    Ok(Value::Array(merged))
185}
186
187// ─── GraphResult ────────────────────────────────────────────────
188
189/// Graph 执行结果。
190#[derive(Debug)]
191pub struct GraphResult {
192    /// 执行追踪 ID(关联本次执行的所有 SpanId)
193    pub trace_id: crate::ids::TraceId,
194    /// 最终状态
195    pub state: State,
196    /// 执行日志
197    pub execution_log: Vec<ExecutionEntry>,
198    /// 执行耗时
199    pub duration: Duration,
200}
201
202/// 单个节点执行记录。
203#[derive(Debug, Clone)]
204pub struct ExecutionEntry {
205    /// 全局步数(第几步)
206    pub step: usize,
207    /// 节点名称
208    pub node_name: String,
209    /// 开始时间
210    pub start_time: Instant,
211    /// 结束时间
212    pub end_time: Instant,
213    /// 是否成功
214    pub success: bool,
215}
216
217impl ExecutionEntry {
218    /// 执行耗时
219    pub fn elapsed(&self) -> Duration {
220        self.end_time.duration_since(self.start_time)
221    }
222}