1use serde::{Deserialize, Serialize};
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10pub type State = HashMap<String, Value>;
12
13#[derive(Clone)]
15pub enum Reducer {
16 Overwrite,
18 Append,
20 Sum,
22 Custom(Arc<dyn Fn(Value, Value) -> Value + Send + Sync>),
24}
25
26#[allow(clippy::derivable_impls)]
28impl Default for Reducer {
29 fn default() -> Self {
30 Self::Overwrite
31 }
32}
33
34impl std::fmt::Debug for Reducer {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::Overwrite => write!(f, "Overwrite"),
38 Self::Append => write!(f, "Append"),
39 Self::Sum => write!(f, "Sum"),
40 Self::Custom(_) => write!(f, "Custom"),
41 }
42 }
43}
44
45#[derive(Clone)]
47pub struct Channel {
48 pub name: String,
50 pub reducer: Reducer,
52 pub default: Option<Value>,
54}
55
56impl Channel {
57 pub fn new(name: &str) -> Self {
59 Self { name: name.to_string(), reducer: Reducer::Overwrite, default: None }
60 }
61
62 pub fn list(name: &str) -> Self {
64 Self { name: name.to_string(), reducer: Reducer::Append, default: Some(json!([])) }
65 }
66
67 pub fn counter(name: &str) -> Self {
69 Self { name: name.to_string(), reducer: Reducer::Sum, default: Some(json!(0)) }
70 }
71
72 pub fn with_reducer(mut self, reducer: Reducer) -> Self {
74 self.reducer = reducer;
75 self
76 }
77
78 pub fn with_default(mut self, default: Value) -> Self {
80 self.default = Some(default);
81 self
82 }
83}
84
85#[derive(Clone, Default)]
87pub struct StateSchema {
88 pub channels: HashMap<String, Channel>,
90}
91
92impl StateSchema {
93 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn builder() -> StateSchemaBuilder {
100 StateSchemaBuilder::default()
101 }
102
103 pub fn simple(channels: &[&str]) -> Self {
105 let mut schema = Self::new();
106 for name in channels {
107 schema.channels.insert((*name).to_string(), Channel::new(name));
108 }
109 schema
110 }
111
112 pub fn get_reducer(&self, channel: &str) -> &Reducer {
114 self.channels.get(channel).map(|c| &c.reducer).unwrap_or(&Reducer::Overwrite)
115 }
116
117 pub fn get_default(&self, channel: &str) -> Option<&Value> {
119 self.channels.get(channel).and_then(|c| c.default.as_ref())
120 }
121
122 pub fn apply_update(&self, state: &mut State, key: &str, value: Value) {
124 let reducer = self.get_reducer(key);
125 let current = state.get(key).cloned().unwrap_or(Value::Null);
126
127 let new_value = match reducer {
128 Reducer::Overwrite => value,
129 Reducer::Append => {
130 let mut arr = match current {
131 Value::Array(a) => a,
132 Value::Null => vec![],
133 _ => vec![current],
134 };
135 match value {
136 Value::Array(items) => arr.extend(items),
137 _ => arr.push(value),
138 }
139 Value::Array(arr)
140 }
141 Reducer::Sum => {
142 let current_num = current.as_f64().unwrap_or(0.0);
143 let add_num = value.as_f64().unwrap_or(0.0);
144 json!(current_num + add_num)
145 }
146 Reducer::Custom(f) => f(current, value),
147 };
148
149 state.insert(key.to_string(), new_value);
150 }
151
152 pub fn initialize_state(&self) -> State {
154 let mut state = State::new();
155 for (name, channel) in &self.channels {
156 if let Some(default) = &channel.default {
157 state.insert(name.clone(), default.clone());
158 }
159 }
160 state
161 }
162}
163
164#[derive(Default)]
166pub struct StateSchemaBuilder {
167 channels: HashMap<String, Channel>,
168}
169
170impl StateSchemaBuilder {
171 pub fn channel(mut self, name: &str) -> Self {
173 self.channels.insert(name.to_string(), Channel::new(name));
174 self
175 }
176
177 pub fn list_channel(mut self, name: &str) -> Self {
179 self.channels.insert(name.to_string(), Channel::list(name));
180 self
181 }
182
183 pub fn counter_channel(mut self, name: &str) -> Self {
185 self.channels.insert(name.to_string(), Channel::counter(name));
186 self
187 }
188
189 pub fn channel_with_reducer(mut self, name: &str, reducer: Reducer) -> Self {
191 self.channels.insert(name.to_string(), Channel::new(name).with_reducer(reducer));
192 self
193 }
194
195 pub fn channel_with_default(mut self, name: &str, default: Value) -> Self {
197 self.channels.insert(name.to_string(), Channel::new(name).with_default(default));
198 self
199 }
200
201 pub fn build(self) -> StateSchema {
203 StateSchema { channels: self.channels }
204 }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct Checkpoint {
210 pub thread_id: String,
212 pub checkpoint_id: String,
214 pub state: State,
216 pub step: usize,
218 pub pending_nodes: Vec<String>,
220 pub metadata: HashMap<String, Value>,
222 pub created_at: chrono::DateTime<chrono::Utc>,
224}
225
226impl Checkpoint {
227 pub fn new(thread_id: &str, state: State, step: usize, pending_nodes: Vec<String>) -> Self {
229 Self {
230 thread_id: thread_id.to_string(),
231 checkpoint_id: uuid::Uuid::new_v4().to_string(),
232 state,
233 step,
234 pending_nodes,
235 metadata: HashMap::new(),
236 created_at: chrono::Utc::now(),
237 }
238 }
239
240 pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
242 self.metadata.insert(key.to_string(), value);
243 self
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_overwrite_reducer() {
253 let schema = StateSchema::simple(&["value"]);
254 let mut state = State::new();
255
256 schema.apply_update(&mut state, "value", json!(1));
257 assert_eq!(state.get("value"), Some(&json!(1)));
258
259 schema.apply_update(&mut state, "value", json!(2));
260 assert_eq!(state.get("value"), Some(&json!(2)));
261 }
262
263 #[test]
264 fn test_append_reducer() {
265 let schema = StateSchema::builder().list_channel("messages").build();
266 let mut state = schema.initialize_state();
267
268 schema.apply_update(&mut state, "messages", json!({"role": "user", "content": "hi"}));
269 assert_eq!(state.get("messages"), Some(&json!([{"role": "user", "content": "hi"}])));
270
271 schema.apply_update(
272 &mut state,
273 "messages",
274 json!([{"role": "assistant", "content": "hello"}]),
275 );
276 assert_eq!(
277 state.get("messages"),
278 Some(&json!([
279 {"role": "user", "content": "hi"},
280 {"role": "assistant", "content": "hello"}
281 ]))
282 );
283 }
284
285 #[test]
286 fn test_sum_reducer() {
287 let schema = StateSchema::builder().counter_channel("count").build();
288 let mut state = schema.initialize_state();
289
290 assert_eq!(state.get("count"), Some(&json!(0)));
291
292 schema.apply_update(&mut state, "count", json!(5));
293 assert_eq!(state.get("count"), Some(&json!(5.0)));
294
295 schema.apply_update(&mut state, "count", json!(3));
296 assert_eq!(state.get("count"), Some(&json!(8.0)));
297 }
298
299 #[test]
300 fn test_custom_reducer() {
301 let schema = StateSchema::builder()
302 .channel_with_reducer(
303 "max",
304 Reducer::Custom(Arc::new(|a, b| {
305 let a_num = a.as_f64().unwrap_or(f64::MIN);
306 let b_num = b.as_f64().unwrap_or(f64::MIN);
307 json!(a_num.max(b_num))
308 })),
309 )
310 .build();
311 let mut state = State::new();
312
313 schema.apply_update(&mut state, "max", json!(5));
314 schema.apply_update(&mut state, "max", json!(3));
315 schema.apply_update(&mut state, "max", json!(8));
316 assert_eq!(state.get("max"), Some(&json!(8.0)));
317 }
318}