Skip to main content

lellm_graph/
state.rs

1//! State 和执行结果。
2//!
3//! 提供扁平 KV 状态管理,以及显式的 Reducer 合并机制(P1)。
4
5use std::collections::HashMap;
6use std::time::{Duration, Instant};
7
8/// Graph 共享状态。
9pub type State = HashMap<String, serde_json::Value>;
10
11/// State 操作错误。
12#[derive(Debug, thiserror::Error)]
13pub enum StateError {
14    /// Key 不存在
15    #[error("state key '{0}' is missing")]
16    MissingKey(String),
17
18    /// 反序列化失败
19    #[error("failed to deserialize state key '{0}': {1}")]
20    Deserialize(String, String),
21}
22
23/// State Reducer 类型别名 — 将已有值与新值合并。
24///
25/// 类似于 LangGraph 的 `operator.add`,但保持显式:
26/// ```rust,ignore
27/// // 追加消息列表
28/// state.reduce("messages", new_msgs, |existing, new| {
29///     let mut msgs: Vec<Value> = serde_json::from_value(existing.clone())?;
30///     let additions: Vec<Value> = serde_json::from_value(new.clone())?;
31///     msgs.extend(additions);
32///     Ok(serde_json::to_value(msgs)?)
33/// });
34/// ```
35pub type StateReducer = Box<
36    dyn Fn(&serde_json::Value, &serde_json::Value) -> Result<serde_json::Value, String>
37        + Send
38        + Sync,
39>;
40
41/// State 扩展方法 — 通过 Trait 为 HashMap 添加强类型访问与 Reducer 能力。
42pub trait StateExt {
43    // ─── 强类型 Getter ────────────────────────────────────────
44
45    /// 获取 String 值。
46    fn get_str(&self, key: &str) -> Option<&str>;
47
48    /// 获取 bool 值。
49    fn get_bool(&self, key: &str) -> Option<bool>;
50
51    /// 获取 u64 值。
52    fn get_u64(&self, key: &str) -> Option<u64>;
53
54    /// 获取 i64 值。
55    fn get_i64(&self, key: &str) -> Option<i64>;
56
57    /// 获取 f64 值。
58    fn get_f64(&self, key: &str) -> Option<f64>;
59
60    /// 反序列化为强类型。
61    fn get_json<T>(&self, key: &str) -> Result<T, StateError>
62    where
63        T: serde::de::DeserializeOwned;
64
65    /// 获取并反序列化为强类型。key 不存在时返回错误。
66    fn require<T>(&self, key: &str) -> Result<T, StateError>
67    where
68        T: serde::de::DeserializeOwned;
69
70    /// 设置值(自动序列化)。
71    fn set<T>(&mut self, key: impl Into<String>, value: T)
72    where
73        T: serde::Serialize;
74
75    /// 移除并返回值。
76    fn remove(&mut self, key: &str) -> Option<serde_json::Value>;
77
78    /// 检查 key 是否存在。
79    fn contains(&self, key: &str) -> bool;
80
81    // ─── Reducer ──────────────────────────────────────────────
82
83    /// 使用 Reducer 合并值到指定 key。
84    fn reduce(
85        &mut self,
86        key: &str,
87        value: serde_json::Value,
88        reducer: &StateReducer,
89    ) -> Result<(), String>;
90
91    /// 追加模式 — 内置的数组追加 Reducer。
92    fn append_array(&mut self, key: &str, items: serde_json::Value) -> Result<(), String>;
93}
94
95impl StateExt for State {
96    fn get_str(&self, key: &str) -> Option<&str> {
97        self.get(key).and_then(|v| v.as_str())
98    }
99
100    fn get_bool(&self, key: &str) -> Option<bool> {
101        self.get(key).and_then(|v| v.as_bool())
102    }
103
104    fn get_u64(&self, key: &str) -> Option<u64> {
105        self.get(key).and_then(|v| v.as_u64())
106    }
107
108    fn get_i64(&self, key: &str) -> Option<i64> {
109        self.get(key).and_then(|v| v.as_i64())
110    }
111
112    fn get_f64(&self, key: &str) -> Option<f64> {
113        self.get(key).and_then(|v| v.as_f64())
114    }
115
116    fn get_json<T>(&self, key: &str) -> Result<T, StateError>
117    where
118        T: serde::de::DeserializeOwned,
119    {
120        let value = self
121            .get(key)
122            .ok_or_else(|| StateError::MissingKey(key.to_string()))?;
123        serde_json::from_value(value.clone())
124            .map_err(|e| StateError::Deserialize(key.to_string(), e.to_string()))
125    }
126
127    fn require<T>(&self, key: &str) -> Result<T, StateError>
128    where
129        T: serde::de::DeserializeOwned,
130    {
131        self.get_json(key)
132    }
133
134    fn set<T>(&mut self, key: impl Into<String>, value: T)
135    where
136        T: serde::Serialize,
137    {
138        let json = serde_json::to_value(value).unwrap_or(serde_json::Value::Null);
139        HashMap::insert(self, key.into(), json);
140    }
141
142    fn remove(&mut self, key: &str) -> Option<serde_json::Value> {
143        HashMap::remove(self, key)
144    }
145
146    fn contains(&self, key: &str) -> bool {
147        self.contains_key(key)
148    }
149
150    fn reduce(
151        &mut self,
152        key: &str,
153        value: serde_json::Value,
154        reducer: &StateReducer,
155    ) -> Result<(), String> {
156        if let Some(existing) = self.get(key) {
157            let merged = reducer(existing, &value)?;
158            self.insert(key.to_string(), merged);
159        } else {
160            self.insert(key.to_string(), value);
161        }
162        Ok(())
163    }
164
165    fn append_array(&mut self, key: &str, items: serde_json::Value) -> Result<(), String> {
166        let new_items = items.as_array().ok_or("append_array expects an array")?;
167        if let Some(existing) = self.get(key) {
168            let mut arr = existing
169                .as_array()
170                .ok_or("append_array: existing value is not an array")?
171                .clone();
172            arr.extend(new_items.iter().cloned());
173            self.insert(key.to_string(), serde_json::Value::Array(arr));
174        } else {
175            self.insert(key.to_string(), items);
176        }
177        Ok(())
178    }
179}
180
181/// 内置 Reducer:数组追加(类似 LangGraph 的 `operator.add` for lists)。
182///
183/// ```rust,ignore
184/// use lellm_graph::{State, StateExt, array_reducer};
185/// let mut state = State::new();
186/// state.insert("items", json!([1, 2]));
187/// state.reduce("items", json!([3, 4]), &array_reducer())?;
188/// // state["items"] == [1, 2, 3, 4]
189/// ```
190pub fn array_reducer() -> StateReducer {
191    Box::new(|existing: &serde_json::Value, new: &serde_json::Value| {
192        let mut arr = existing
193            .as_array()
194            .ok_or("existing value is not an array")?
195            .clone();
196        let additions = new.as_array().ok_or("new value is not an array")?;
197        arr.extend(additions.iter().cloned());
198        Ok(serde_json::Value::Array(arr))
199    })
200}
201
202/// Graph 执行结果。
203#[derive(Debug)]
204pub struct GraphResult {
205    /// 最终状态
206    pub state: State,
207    /// 执行日志
208    pub execution_log: Vec<ExecutionEntry>,
209    /// 执行耗时
210    pub duration: Duration,
211}
212
213/// 单个节点执行记录。
214#[derive(Debug, Clone)]
215pub struct ExecutionEntry {
216    /// 节点名称
217    pub node_name: String,
218    /// 开始时间
219    pub start_time: Instant,
220    /// 结束时间
221    pub end_time: Instant,
222    /// 是否成功
223    pub success: bool,
224}
225
226impl ExecutionEntry {
227    /// 执行耗时
228    pub fn elapsed(&self) -> Duration {
229        self.end_time.duration_since(self.start_time)
230    }
231}