adk_graph/
state.rs

1//! State management for graph execution
2//!
3//! Provides typed state with reducers for controlling how updates are merged.
4
5use serde::{Deserialize, Serialize};
6use serde_json::{json, Value};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10/// Graph state - a map of channel names to values
11pub type State = HashMap<String, Value>;
12
13/// Reducer determines how state updates are merged
14#[derive(Clone)]
15pub enum Reducer {
16    /// Replace the value entirely (default)
17    Overwrite,
18    /// Append to a list
19    Append,
20    /// Sum numeric values
21    Sum,
22    /// Custom merge function
23    Custom(Arc<dyn Fn(Value, Value) -> Value + Send + Sync>),
24}
25
26// Cannot derive Default because of the Custom variant with Arc<dyn Fn>
27#[allow(clippy::derivable_impls)]
28impl Default for Reducer {
29    fn default() -> Self {
30        Self::Overwrite
31    }
32}
33
34impl std::fmt::Debug for Reducer {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            Self::Overwrite => write!(f, "Overwrite"),
38            Self::Append => write!(f, "Append"),
39            Self::Sum => write!(f, "Sum"),
40            Self::Custom(_) => write!(f, "Custom"),
41        }
42    }
43}
44
45/// Channel definition for a state field
46#[derive(Clone)]
47pub struct Channel {
48    /// Channel name
49    pub name: String,
50    /// Reducer for merging updates
51    pub reducer: Reducer,
52    /// Default value
53    pub default: Option<Value>,
54}
55
56impl Channel {
57    /// Create a new channel with overwrite semantics
58    pub fn new(name: &str) -> Self {
59        Self { name: name.to_string(), reducer: Reducer::Overwrite, default: None }
60    }
61
62    /// Create a list channel with append semantics
63    pub fn list(name: &str) -> Self {
64        Self { name: name.to_string(), reducer: Reducer::Append, default: Some(json!([])) }
65    }
66
67    /// Create a counter channel with sum semantics
68    pub fn counter(name: &str) -> Self {
69        Self { name: name.to_string(), reducer: Reducer::Sum, default: Some(json!(0)) }
70    }
71
72    /// Set the reducer
73    pub fn with_reducer(mut self, reducer: Reducer) -> Self {
74        self.reducer = reducer;
75        self
76    }
77
78    /// Set the default value
79    pub fn with_default(mut self, default: Value) -> Self {
80        self.default = Some(default);
81        self
82    }
83}
84
85/// State schema defines channels and their reducers
86#[derive(Clone, Default)]
87pub struct StateSchema {
88    /// Channel definitions
89    pub channels: HashMap<String, Channel>,
90}
91
92impl StateSchema {
93    /// Create a new empty schema
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    /// Create a schema builder
99    pub fn builder() -> StateSchemaBuilder {
100        StateSchemaBuilder::default()
101    }
102
103    /// Create a simple schema with just channel names (all overwrite)
104    pub fn simple(channels: &[&str]) -> Self {
105        let mut schema = Self::new();
106        for name in channels {
107            schema.channels.insert((*name).to_string(), Channel::new(name));
108        }
109        schema
110    }
111
112    /// Get the reducer for a channel
113    pub fn get_reducer(&self, channel: &str) -> &Reducer {
114        self.channels.get(channel).map(|c| &c.reducer).unwrap_or(&Reducer::Overwrite)
115    }
116
117    /// Get the default value for a channel
118    pub fn get_default(&self, channel: &str) -> Option<&Value> {
119        self.channels.get(channel).and_then(|c| c.default.as_ref())
120    }
121
122    /// Apply an update to state using the appropriate reducer
123    pub fn apply_update(&self, state: &mut State, key: &str, value: Value) {
124        let reducer = self.get_reducer(key);
125        let current = state.get(key).cloned().unwrap_or(Value::Null);
126
127        let new_value = match reducer {
128            Reducer::Overwrite => value,
129            Reducer::Append => {
130                let mut arr = match current {
131                    Value::Array(a) => a,
132                    Value::Null => vec![],
133                    _ => vec![current],
134                };
135                match value {
136                    Value::Array(items) => arr.extend(items),
137                    _ => arr.push(value),
138                }
139                Value::Array(arr)
140            }
141            Reducer::Sum => {
142                let current_num = current.as_f64().unwrap_or(0.0);
143                let add_num = value.as_f64().unwrap_or(0.0);
144                json!(current_num + add_num)
145            }
146            Reducer::Custom(f) => f(current, value),
147        };
148
149        state.insert(key.to_string(), new_value);
150    }
151
152    /// Initialize state with default values
153    pub fn initialize_state(&self) -> State {
154        let mut state = State::new();
155        for (name, channel) in &self.channels {
156            if let Some(default) = &channel.default {
157                state.insert(name.clone(), default.clone());
158            }
159        }
160        state
161    }
162}
163
164/// Builder for StateSchema
165#[derive(Default)]
166pub struct StateSchemaBuilder {
167    channels: HashMap<String, Channel>,
168}
169
170impl StateSchemaBuilder {
171    /// Add a channel with overwrite semantics
172    pub fn channel(mut self, name: &str) -> Self {
173        self.channels.insert(name.to_string(), Channel::new(name));
174        self
175    }
176
177    /// Add a channel with append semantics (for lists)
178    pub fn list_channel(mut self, name: &str) -> Self {
179        self.channels.insert(name.to_string(), Channel::list(name));
180        self
181    }
182
183    /// Add a counter channel with sum semantics
184    pub fn counter_channel(mut self, name: &str) -> Self {
185        self.channels.insert(name.to_string(), Channel::counter(name));
186        self
187    }
188
189    /// Add a channel with custom reducer
190    pub fn channel_with_reducer(mut self, name: &str, reducer: Reducer) -> Self {
191        self.channels.insert(name.to_string(), Channel::new(name).with_reducer(reducer));
192        self
193    }
194
195    /// Add a channel with default value
196    pub fn channel_with_default(mut self, name: &str, default: Value) -> Self {
197        self.channels.insert(name.to_string(), Channel::new(name).with_default(default));
198        self
199    }
200
201    /// Build the schema
202    pub fn build(self) -> StateSchema {
203        StateSchema { channels: self.channels }
204    }
205}
206
207/// Checkpoint data structure for persistence
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct Checkpoint {
210    /// Thread identifier
211    pub thread_id: String,
212    /// Unique checkpoint ID
213    pub checkpoint_id: String,
214    /// State at this checkpoint
215    pub state: State,
216    /// Step number
217    pub step: usize,
218    /// Nodes pending execution
219    pub pending_nodes: Vec<String>,
220    /// Additional metadata
221    pub metadata: HashMap<String, Value>,
222    /// Creation timestamp
223    pub created_at: chrono::DateTime<chrono::Utc>,
224}
225
226impl Checkpoint {
227    /// Create a new checkpoint
228    pub fn new(thread_id: &str, state: State, step: usize, pending_nodes: Vec<String>) -> Self {
229        Self {
230            thread_id: thread_id.to_string(),
231            checkpoint_id: uuid::Uuid::new_v4().to_string(),
232            state,
233            step,
234            pending_nodes,
235            metadata: HashMap::new(),
236            created_at: chrono::Utc::now(),
237        }
238    }
239
240    /// Add metadata to the checkpoint
241    pub fn with_metadata(mut self, key: &str, value: Value) -> Self {
242        self.metadata.insert(key.to_string(), value);
243        self
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_overwrite_reducer() {
253        let schema = StateSchema::simple(&["value"]);
254        let mut state = State::new();
255
256        schema.apply_update(&mut state, "value", json!(1));
257        assert_eq!(state.get("value"), Some(&json!(1)));
258
259        schema.apply_update(&mut state, "value", json!(2));
260        assert_eq!(state.get("value"), Some(&json!(2)));
261    }
262
263    #[test]
264    fn test_append_reducer() {
265        let schema = StateSchema::builder().list_channel("messages").build();
266        let mut state = schema.initialize_state();
267
268        schema.apply_update(&mut state, "messages", json!({"role": "user", "content": "hi"}));
269        assert_eq!(state.get("messages"), Some(&json!([{"role": "user", "content": "hi"}])));
270
271        schema.apply_update(
272            &mut state,
273            "messages",
274            json!([{"role": "assistant", "content": "hello"}]),
275        );
276        assert_eq!(
277            state.get("messages"),
278            Some(&json!([
279                {"role": "user", "content": "hi"},
280                {"role": "assistant", "content": "hello"}
281            ]))
282        );
283    }
284
285    #[test]
286    fn test_sum_reducer() {
287        let schema = StateSchema::builder().counter_channel("count").build();
288        let mut state = schema.initialize_state();
289
290        assert_eq!(state.get("count"), Some(&json!(0)));
291
292        schema.apply_update(&mut state, "count", json!(5));
293        assert_eq!(state.get("count"), Some(&json!(5.0)));
294
295        schema.apply_update(&mut state, "count", json!(3));
296        assert_eq!(state.get("count"), Some(&json!(8.0)));
297    }
298
299    #[test]
300    fn test_custom_reducer() {
301        let schema = StateSchema::builder()
302            .channel_with_reducer(
303                "max",
304                Reducer::Custom(Arc::new(|a, b| {
305                    let a_num = a.as_f64().unwrap_or(f64::MIN);
306                    let b_num = b.as_f64().unwrap_or(f64::MIN);
307                    json!(a_num.max(b_num))
308                })),
309            )
310            .build();
311        let mut state = State::new();
312
313        schema.apply_update(&mut state, "max", json!(5));
314        schema.apply_update(&mut state, "max", json!(3));
315        schema.apply_update(&mut state, "max", json!(8));
316        assert_eq!(state.get("max"), Some(&json!(8.0)));
317    }
318}