Skip to main content

lellm_graph/
statekey.rs

1//! StateKey<T> — 编译期类型安全的 State 键。
2//!
3//! 从 lellm-runtime 合并到 lellm-graph,加上内置的常用 StateKey 常量。
4
5use serde::{Serialize, de::DeserializeOwned};
6use serde_json::Value;
7
8use crate::state::{State, StateError};
9
10// ─── Reducer ────────────────────────────────────────────────────
11
12/// Reducer 枚举 — 描述"这个 key 允许怎么合并"。
13#[allow(unpredictable_function_pointer_comparisons)]
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum Reducer {
16    /// 冲突即报错
17    Error,
18    /// 最后写入者胜
19    Replace,
20    /// 数组追加
21    Append,
22    /// 对象浅合并
23    MergeObject,
24    /// 数值求和
25    Sum,
26    /// 取最大值
27    Max,
28    /// 取最小值
29    Min,
30    /// 自定义合并函数
31    Custom(fn(&Value, &Value) -> Result<Value, String>),
32}
33
34impl std::fmt::Display for Reducer {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            Reducer::Error => write!(f, "Error"),
38            Reducer::Replace => write!(f, "Replace"),
39            Reducer::Append => write!(f, "Append"),
40            Reducer::MergeObject => write!(f, "MergeObject"),
41            Reducer::Sum => write!(f, "Sum"),
42            Reducer::Max => write!(f, "Max"),
43            Reducer::Min => write!(f, "Min"),
44            Reducer::Custom(_) => write!(f, "Custom"),
45        }
46    }
47}
48
49/// 编译期类型安全的 State 键。
50#[derive(Debug)]
51pub struct StateKey<T> {
52    name: &'static str,
53    reducer: Reducer,
54    _marker: std::marker::PhantomData<T>,
55}
56
57impl<T> StateKey<T> {
58    pub const fn new(name: &'static str, reducer: Reducer) -> Self {
59        Self {
60            name,
61            reducer,
62            _marker: std::marker::PhantomData,
63        }
64    }
65
66    pub const fn append(name: &'static str) -> Self {
67        Self::new(name, Reducer::Append)
68    }
69
70    pub const fn sum(name: &'static str) -> Self {
71        Self::new(name, Reducer::Sum)
72    }
73
74    pub const fn replace(name: &'static str) -> Self {
75        Self::new(name, Reducer::Replace)
76    }
77
78    pub const fn merge_object(name: &'static str) -> Self {
79        Self::new(name, Reducer::MergeObject)
80    }
81
82    pub const fn max(name: &'static str) -> Self {
83        Self::new(name, Reducer::Max)
84    }
85
86    pub const fn min(name: &'static str) -> Self {
87        Self::new(name, Reducer::Min)
88    }
89
90    pub const fn error(name: &'static str) -> Self {
91        Self::new(name, Reducer::Error)
92    }
93
94    pub fn name(&self) -> &str {
95        self.name
96    }
97
98    pub fn reducer(&self) -> &Reducer {
99        &self.reducer
100    }
101}
102
103/// StateKey 专用的 State 扩展方法。
104pub trait StateKeyExt {
105    fn set_sk<T>(&mut self, key: &StateKey<T>, value: T)
106    where
107        T: Serialize;
108
109    fn get_sk<T>(&self, key: &StateKey<T>) -> Option<T>
110    where
111        T: DeserializeOwned;
112
113    fn require_sk<T>(&self, key: &StateKey<T>) -> Result<T, StateError>
114    where
115        T: DeserializeOwned;
116
117    fn contains_sk<T>(&self, key: &StateKey<T>) -> bool;
118
119    fn remove_sk<T>(&mut self, key: &StateKey<T>) -> Option<serde_json::Value>;
120}
121
122impl StateKeyExt for State {
123    fn set_sk<T>(&mut self, key: &StateKey<T>, value: T)
124    where
125        T: Serialize,
126    {
127        let key_str = key.name().to_string();
128        let json = match serde_json::to_value(value) {
129            Ok(v) => v,
130            Err(e) => {
131                tracing::warn!(key = %key_str, error = %e, "failed to serialize state value, storing null");
132                serde_json::Value::Null
133            }
134        };
135        self.insert(key_str, json);
136    }
137
138    fn get_sk<T>(&self, key: &StateKey<T>) -> Option<T>
139    where
140        T: DeserializeOwned,
141    {
142        self.get(key.name())
143            .and_then(|v| serde_json::from_value(v.clone()).ok())
144    }
145
146    fn require_sk<T>(&self, key: &StateKey<T>) -> Result<T, StateError>
147    where
148        T: DeserializeOwned,
149    {
150        let value = self
151            .get(key.name())
152            .ok_or_else(|| StateError::MissingKey(key.name().to_string()))?;
153        serde_json::from_value(value.clone())
154            .map_err(|e| StateError::Deserialize(key.name().to_string(), e.to_string()))
155    }
156
157    fn contains_sk<T>(&self, key: &StateKey<T>) -> bool {
158        self.contains_key(key.name())
159    }
160
161    fn remove_sk<T>(&mut self, key: &StateKey<T>) -> Option<serde_json::Value> {
162        self.remove(key.name())
163    }
164}
165
166// ─── 内置常用 StateKey 常量 ───────────────────────────────────
167
168/// 消息列表 — Graph 中最通用的 State key。
169pub static SK_MESSAGES: StateKey<Vec<serde_json::Value>> =
170    StateKey::new("messages", Reducer::Append);
171
172/// 通用计数 — 循环计数器等场景。
173pub static SK_COUNT: StateKey<u64> = StateKey::new("count", Reducer::Sum);
174
175/// 执行步骤记录 — Barrier 多轮审批等场景。
176pub static SK_STEPS: StateKey<Vec<String>> = StateKey::new("steps", Reducer::Append);
177
178// ─── Agent 核心状态键(v0.3.1)─────────────────────────────────
179
180/// Agent 迭代轮次。
181pub static SK_ITERATIONS: StateKey<u32> = StateKey::replace("iterations");
182
183/// 当前轮待执行的工具调用(每轮清空,非历史累计)。
184pub static SK_PENDING_TOOL_CALLS: StateKey<Vec<serde_json::Value>> =
185    StateKey::replace("pending_tool_calls");
186
187/// 累计工具调用总数(整个 Agent Run)。
188pub static SK_TOTAL_TOOL_CALLS: StateKey<usize> = StateKey::sum("total_tool_calls");
189
190/// 累计输出 Token 数(Text,不含 Thinking)。
191pub static SK_OUTPUT_TOKENS: StateKey<usize> = StateKey::sum("output_tokens");
192
193/// 累计推理 Token 数(Thinking,不含 Text)。
194pub static SK_REASONING_TOKENS: StateKey<usize> = StateKey::sum("reasoning_tokens");