Skip to main content

lellm_graph/
statekey.rs

1//! StateKey<T> — 编译期类型安全的 State 键。
2//!
3//! 从 lellm-runtime 合并到 lellm-graph,加上内置的常用 StateKey 常量。
4
5use serde::{Serialize, de::DeserializeOwned};
6
7use crate::delta::Reducer;
8use crate::state::{State, StateError};
9
10/// 编译期类型安全的 State 键。
11#[derive(Debug)]
12pub struct StateKey<T> {
13    name: &'static str,
14    reducer: Reducer,
15    _marker: std::marker::PhantomData<T>,
16}
17
18impl<T> StateKey<T> {
19    pub const fn new(name: &'static str, reducer: Reducer) -> Self {
20        Self {
21            name,
22            reducer,
23            _marker: std::marker::PhantomData,
24        }
25    }
26
27    pub const fn append(name: &'static str) -> Self {
28        Self::new(name, Reducer::Append)
29    }
30
31    pub const fn sum(name: &'static str) -> Self {
32        Self::new(name, Reducer::Sum)
33    }
34
35    pub const fn replace(name: &'static str) -> Self {
36        Self::new(name, Reducer::Replace)
37    }
38
39    pub const fn merge_object(name: &'static str) -> Self {
40        Self::new(name, Reducer::MergeObject)
41    }
42
43    pub const fn max(name: &'static str) -> Self {
44        Self::new(name, Reducer::Max)
45    }
46
47    pub const fn min(name: &'static str) -> Self {
48        Self::new(name, Reducer::Min)
49    }
50
51    pub const fn error(name: &'static str) -> Self {
52        Self::new(name, Reducer::Error)
53    }
54
55    pub fn name(&self) -> &str {
56        self.name
57    }
58
59    pub fn reducer(&self) -> &Reducer {
60        &self.reducer
61    }
62}
63
64/// StateKey 专用的 State 扩展方法。
65pub trait StateKeyExt {
66    fn set_sk<T>(&mut self, key: &StateKey<T>, value: T)
67    where
68        T: Serialize;
69
70    fn get_sk<T>(&self, key: &StateKey<T>) -> Option<T>
71    where
72        T: DeserializeOwned;
73
74    fn require_sk<T>(&self, key: &StateKey<T>) -> Result<T, StateError>
75    where
76        T: DeserializeOwned;
77
78    fn contains_sk<T>(&self, key: &StateKey<T>) -> bool;
79
80    fn remove_sk<T>(&mut self, key: &StateKey<T>) -> Option<serde_json::Value>;
81}
82
83impl StateKeyExt for State {
84    fn set_sk<T>(&mut self, key: &StateKey<T>, value: T)
85    where
86        T: Serialize,
87    {
88        let key_str = key.name().to_string();
89        let json = match serde_json::to_value(value) {
90            Ok(v) => v,
91            Err(e) => {
92                tracing::warn!(key = %key_str, error = %e, "failed to serialize state value, storing null");
93                serde_json::Value::Null
94            }
95        };
96        self.insert(key_str, json);
97    }
98
99    fn get_sk<T>(&self, key: &StateKey<T>) -> Option<T>
100    where
101        T: DeserializeOwned,
102    {
103        self.get(key.name())
104            .and_then(|v| serde_json::from_value(v.clone()).ok())
105    }
106
107    fn require_sk<T>(&self, key: &StateKey<T>) -> Result<T, StateError>
108    where
109        T: DeserializeOwned,
110    {
111        let value = self
112            .get(key.name())
113            .ok_or_else(|| StateError::MissingKey(key.name().to_string()))?;
114        serde_json::from_value(value.clone())
115            .map_err(|e| StateError::Deserialize(key.name().to_string(), e.to_string()))
116    }
117
118    fn contains_sk<T>(&self, key: &StateKey<T>) -> bool {
119        self.contains_key(key.name())
120    }
121
122    fn remove_sk<T>(&mut self, key: &StateKey<T>) -> Option<serde_json::Value> {
123        self.remove(key.name())
124    }
125}
126
127// ─── 内置常用 StateKey 常量 ───────────────────────────────────
128
129/// 消息列表 — Graph 中最通用的 State key。
130pub static SK_MESSAGES: StateKey<Vec<serde_json::Value>> =
131    StateKey::new("messages", Reducer::Append);
132
133/// 通用计数 — 循环计数器等场景。
134pub static SK_COUNT: StateKey<u64> = StateKey::new("count", Reducer::Sum);
135
136/// 执行步骤记录 — Barrier 多轮审批等场景。
137pub static SK_STEPS: StateKey<Vec<String>> = StateKey::new("steps", Reducer::Append);
138
139// ─── Agent 核心状态键(v0.3.1)─────────────────────────────────
140
141/// Agent 迭代轮次。
142pub static SK_ITERATIONS: StateKey<u32> = StateKey::replace("iterations");
143
144/// 当前轮待执行的工具调用(每轮清空,非历史累计)。
145pub static SK_PENDING_TOOL_CALLS: StateKey<Vec<serde_json::Value>> =
146    StateKey::replace("pending_tool_calls");
147
148/// 累计工具调用总数(整个 Agent Run)。
149pub static SK_TOTAL_TOOL_CALLS: StateKey<usize> = StateKey::sum("total_tool_calls");
150
151/// 累计输出 Token 数(Text,不含 Thinking)。
152pub static SK_OUTPUT_TOKENS: StateKey<usize> = StateKey::sum("output_tokens");
153
154/// 累计推理 Token 数(Thinking,不含 Text)。
155pub static SK_REASONING_TOKENS: StateKey<usize> = StateKey::sum("reasoning_tokens");