lellm_graph/state/
statekey.rs1use serde::{Serialize, de::DeserializeOwned};
6use serde_json::Value;
7
8use crate::state::{State, StateError};
9
10#[allow(unpredictable_function_pointer_comparisons)]
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum Reducer {
16 Error,
18 Replace,
20 Append,
22 MergeObject,
24 Sum,
26 Max,
28 Min,
30 Custom(fn(&Value, &Value) -> Result<Value, String>),
32}
33
34impl std::fmt::Display for Reducer {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Reducer::Error => write!(f, "Error"),
38 Reducer::Replace => write!(f, "Replace"),
39 Reducer::Append => write!(f, "Append"),
40 Reducer::MergeObject => write!(f, "MergeObject"),
41 Reducer::Sum => write!(f, "Sum"),
42 Reducer::Max => write!(f, "Max"),
43 Reducer::Min => write!(f, "Min"),
44 Reducer::Custom(_) => write!(f, "Custom"),
45 }
46 }
47}
48
49#[derive(Debug)]
51pub struct StateKey<T> {
52 name: &'static str,
53 reducer: Reducer,
54 _marker: std::marker::PhantomData<T>,
55}
56
57impl<T> StateKey<T> {
58 pub const fn new(name: &'static str, reducer: Reducer) -> Self {
59 Self {
60 name,
61 reducer,
62 _marker: std::marker::PhantomData,
63 }
64 }
65
66 pub const fn append(name: &'static str) -> Self {
67 Self::new(name, Reducer::Append)
68 }
69
70 pub const fn sum(name: &'static str) -> Self {
71 Self::new(name, Reducer::Sum)
72 }
73
74 pub const fn replace(name: &'static str) -> Self {
75 Self::new(name, Reducer::Replace)
76 }
77
78 pub const fn merge_object(name: &'static str) -> Self {
79 Self::new(name, Reducer::MergeObject)
80 }
81
82 pub const fn max(name: &'static str) -> Self {
83 Self::new(name, Reducer::Max)
84 }
85
86 pub const fn min(name: &'static str) -> Self {
87 Self::new(name, Reducer::Min)
88 }
89
90 pub const fn error(name: &'static str) -> Self {
91 Self::new(name, Reducer::Error)
92 }
93
94 pub fn name(&self) -> &str {
95 self.name
96 }
97
98 pub fn reducer(&self) -> &Reducer {
99 &self.reducer
100 }
101}
102
103pub trait StateKeyExt {
105 fn set_sk<T>(&mut self, key: &StateKey<T>, value: T)
106 where
107 T: Serialize;
108
109 fn get_sk<T>(&self, key: &StateKey<T>) -> Option<T>
110 where
111 T: DeserializeOwned;
112
113 fn require_sk<T>(&self, key: &StateKey<T>) -> Result<T, StateError>
114 where
115 T: DeserializeOwned;
116
117 fn contains_sk<T>(&self, key: &StateKey<T>) -> bool;
118
119 fn remove_sk<T>(&mut self, key: &StateKey<T>) -> Option<serde_json::Value>;
120}
121
122impl StateKeyExt for State {
123 fn set_sk<T>(&mut self, key: &StateKey<T>, value: T)
124 where
125 T: Serialize,
126 {
127 let key_str = key.name().to_string();
128 let json = match serde_json::to_value(value) {
129 Ok(v) => v,
130 Err(e) => {
131 tracing::warn!(key = %key_str, error = %e, "failed to serialize state value, storing null");
132 serde_json::Value::Null
133 }
134 };
135 self.insert(key_str, json);
136 }
137
138 fn get_sk<T>(&self, key: &StateKey<T>) -> Option<T>
139 where
140 T: DeserializeOwned,
141 {
142 self.get(key.name())
143 .and_then(|v| serde_json::from_value(v.clone()).ok())
144 }
145
146 fn require_sk<T>(&self, key: &StateKey<T>) -> Result<T, StateError>
147 where
148 T: DeserializeOwned,
149 {
150 let value = self
151 .get(key.name())
152 .ok_or_else(|| StateError::MissingKey(key.name().to_string()))?;
153 serde_json::from_value(value.clone())
154 .map_err(|e| StateError::Deserialize(key.name().to_string(), e.to_string()))
155 }
156
157 fn contains_sk<T>(&self, key: &StateKey<T>) -> bool {
158 self.contains_key(key.name())
159 }
160
161 fn remove_sk<T>(&mut self, key: &StateKey<T>) -> Option<serde_json::Value> {
162 self.remove(key.name())
163 }
164}
165
166pub static SK_MESSAGES: StateKey<Vec<serde_json::Value>> =
170 StateKey::new("messages", Reducer::Append);
171
172pub static SK_COUNT: StateKey<u64> = StateKey::new("count", Reducer::Sum);
174
175pub static SK_STEPS: StateKey<Vec<String>> = StateKey::new("steps", Reducer::Append);
177
178pub static SK_ITERATIONS: StateKey<u32> = StateKey::replace("iterations");
182
183pub static SK_PENDING_TOOL_CALLS: StateKey<Vec<serde_json::Value>> =
185 StateKey::replace("pending_tool_calls");
186
187pub static SK_TOTAL_TOOL_CALLS: StateKey<usize> = StateKey::sum("total_tool_calls");
189
190pub static SK_OUTPUT_TOKENS: StateKey<usize> = StateKey::sum("output_tokens");
192
193pub static SK_REASONING_TOKENS: StateKey<usize> = StateKey::sum("reasoning_tokens");