Skip to main content

langgraph_core_rs/channels/
binop.rs

1use parking_lot::RwLock;
2use serde_json::Value as JsonValue;
3use langgraph_checkpoint::error::ChannelError;
4use super::base::Channel;
5
6/// Reducer function type: (current, update) -> new
7pub type ReducerFn = fn(&JsonValue, &JsonValue) -> JsonValue;
8
9/// Applies a binary operator to accumulate values.
10///
11/// Created when a state key uses a reducer (e.g., Annotated[list, add_messages]).
12/// Supports Overwrite to bypass the reducer.
13pub struct BinaryOperatorAggregate {
14    key: String,
15    value: RwLock<Option<JsonValue>>,
16    reducer: ReducerFn,
17}
18
19impl BinaryOperatorAggregate {
20    pub fn new(key: impl Into<String>, reducer: ReducerFn) -> Self {
21        Self {
22            key: key.into(),
23            value: RwLock::new(None),
24            reducer,
25        }
26    }
27}
28
29impl Channel for BinaryOperatorAggregate {
30    fn checkpoint(&self) -> Option<JsonValue> {
31        self.value.read().clone()
32    }
33
34    fn from_checkpoint(&self, checkpoint: Option<&JsonValue>) -> Box<dyn Channel> {
35        Box::new(Self {
36            key: self.key.clone(),
37            value: RwLock::new(checkpoint.cloned()),
38            reducer: self.reducer,
39        })
40    }
41
42    fn update(&self, values: &[JsonValue]) -> Result<bool, ChannelError> {
43        if values.is_empty() {
44            return Ok(false);
45        }
46
47        let mut guard = self.value.write();
48        let mut seen_overwrite = false;
49
50        for val in values {
51            // Check for Overwrite pattern: {"__overwrite__": value}
52            if let Some(obj) = val.as_object() {
53                if let Some(overwrite_val) = obj.get("__overwrite__") {
54                    if seen_overwrite {
55                        return Err(ChannelError::InvalidUpdate(
56                            "Received multiple Overwrite values in a single update".to_string(),
57                        ));
58                    }
59                    *guard = Some(overwrite_val.clone());
60                    seen_overwrite = true;
61                    continue;
62                }
63            }
64
65            // If we've seen an Overwrite, skip non-Overwrite values
66            if seen_overwrite {
67                continue;
68            }
69
70            match guard.as_ref() {
71                Some(current) => {
72                    let new_val = (self.reducer)(current, val);
73                    *guard = Some(new_val);
74                }
75                None => {
76                    *guard = Some(val.clone());
77                }
78            }
79        }
80        Ok(true)
81    }
82
83    fn get(&self) -> Result<JsonValue, ChannelError> {
84        self.value
85            .read()
86            .clone()
87            .ok_or(ChannelError::EmptyChannel)
88    }
89
90    fn is_available(&self) -> bool {
91        self.value.read().is_some()
92    }
93
94    fn clone_channel(&self) -> Box<dyn Channel> {
95        Box::new(Self {
96            key: self.key.clone(),
97            value: RwLock::new(self.value.read().clone()),
98            reducer: self.reducer,
99        })
100    }
101
102    fn name(&self) -> &str {
103        &self.key
104    }
105}
106
107/// Common reducer: append arrays
108pub fn append_reducer(current: &JsonValue, update: &JsonValue) -> JsonValue {
109    let mut result = match current {
110        JsonValue::Array(arr) => arr.clone(),
111        other => vec![other.clone()],
112    };
113    match update {
114        JsonValue::Array(arr) => result.extend(arr.iter().cloned()),
115        other => result.push(other.clone()),
116    }
117    JsonValue::Array(result)
118}
119
120/// Common reducer: merge objects
121pub fn merge_reducer(current: &JsonValue, update: &JsonValue) -> JsonValue {
122    match (current, update) {
123        (JsonValue::Object(curr), JsonValue::Object(upd)) => {
124            let mut merged = curr.clone();
125            for (k, v) in upd {
126                merged.insert(k.clone(), v.clone());
127            }
128            JsonValue::Object(merged)
129        }
130        _ => update.clone(),
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_append_reducer() {
140        let ch = BinaryOperatorAggregate::new("items", append_reducer);
141        ch.update(&[serde_json::json!([1, 2])]).unwrap();
142        ch.update(&[serde_json::json!([3, 4])]).unwrap();
143        assert_eq!(ch.get().unwrap(), serde_json::json!([1, 2, 3, 4]));
144    }
145
146    #[test]
147    fn test_merge_reducer() {
148        let ch = BinaryOperatorAggregate::new("state", merge_reducer);
149        ch.update(&[serde_json::json!({"a": 1})]).unwrap();
150        ch.update(&[serde_json::json!({"b": 2})]).unwrap();
151        assert_eq!(ch.get().unwrap(), serde_json::json!({"a": 1, "b": 2}));
152    }
153
154    #[test]
155    fn test_overwrite() {
156        let ch = BinaryOperatorAggregate::new("items", append_reducer);
157        ch.update(&[serde_json::json!([1, 2])]).unwrap();
158        ch.update(&[serde_json::json!({"__overwrite__": [99]})]).unwrap();
159        assert_eq!(ch.get().unwrap(), serde_json::json!([99]));
160    }
161
162    #[test]
163    fn test_checkpoint_restore() {
164        let ch = BinaryOperatorAggregate::new("items", append_reducer);
165        ch.update(&[serde_json::json!([1, 2])]).unwrap();
166
167        let cp = ch.checkpoint();
168        let restored = ch.from_checkpoint(cp.as_ref());
169        assert_eq!(restored.get().unwrap(), serde_json::json!([1, 2]));
170
171        restored.update(&[serde_json::json!([3])]).unwrap();
172        assert_eq!(restored.get().unwrap(), serde_json::json!([1, 2, 3]));
173    }
174}