langgraph_core_rs/channels/
binop.rs1use parking_lot::RwLock;
2use serde_json::Value as JsonValue;
3use langgraph_checkpoint::error::ChannelError;
4use super::base::Channel;
5
6pub type ReducerFn = fn(&JsonValue, &JsonValue) -> JsonValue;
8
9pub struct BinaryOperatorAggregate {
14 key: String,
15 value: RwLock<Option<JsonValue>>,
16 reducer: ReducerFn,
17}
18
19impl BinaryOperatorAggregate {
20 pub fn new(key: impl Into<String>, reducer: ReducerFn) -> Self {
21 Self {
22 key: key.into(),
23 value: RwLock::new(None),
24 reducer,
25 }
26 }
27}
28
29impl Channel for BinaryOperatorAggregate {
30 fn checkpoint(&self) -> Option<JsonValue> {
31 self.value.read().clone()
32 }
33
34 fn from_checkpoint(&self, checkpoint: Option<&JsonValue>) -> Box<dyn Channel> {
35 Box::new(Self {
36 key: self.key.clone(),
37 value: RwLock::new(checkpoint.cloned()),
38 reducer: self.reducer,
39 })
40 }
41
42 fn update(&self, values: &[JsonValue]) -> Result<bool, ChannelError> {
43 if values.is_empty() {
44 return Ok(false);
45 }
46
47 let mut guard = self.value.write();
48 let mut seen_overwrite = false;
49
50 for val in values {
51 if let Some(obj) = val.as_object() {
53 if let Some(overwrite_val) = obj.get("__overwrite__") {
54 if seen_overwrite {
55 return Err(ChannelError::InvalidUpdate(
56 "Received multiple Overwrite values in a single update".to_string(),
57 ));
58 }
59 *guard = Some(overwrite_val.clone());
60 seen_overwrite = true;
61 continue;
62 }
63 }
64
65 if seen_overwrite {
67 continue;
68 }
69
70 match guard.as_ref() {
71 Some(current) => {
72 let new_val = (self.reducer)(current, val);
73 *guard = Some(new_val);
74 }
75 None => {
76 *guard = Some(val.clone());
77 }
78 }
79 }
80 Ok(true)
81 }
82
83 fn get(&self) -> Result<JsonValue, ChannelError> {
84 self.value
85 .read()
86 .clone()
87 .ok_or(ChannelError::EmptyChannel)
88 }
89
90 fn is_available(&self) -> bool {
91 self.value.read().is_some()
92 }
93
94 fn clone_channel(&self) -> Box<dyn Channel> {
95 Box::new(Self {
96 key: self.key.clone(),
97 value: RwLock::new(self.value.read().clone()),
98 reducer: self.reducer,
99 })
100 }
101
102 fn name(&self) -> &str {
103 &self.key
104 }
105}
106
107pub fn append_reducer(current: &JsonValue, update: &JsonValue) -> JsonValue {
109 let mut result = match current {
110 JsonValue::Array(arr) => arr.clone(),
111 other => vec![other.clone()],
112 };
113 match update {
114 JsonValue::Array(arr) => result.extend(arr.iter().cloned()),
115 other => result.push(other.clone()),
116 }
117 JsonValue::Array(result)
118}
119
120pub fn merge_reducer(current: &JsonValue, update: &JsonValue) -> JsonValue {
122 match (current, update) {
123 (JsonValue::Object(curr), JsonValue::Object(upd)) => {
124 let mut merged = curr.clone();
125 for (k, v) in upd {
126 merged.insert(k.clone(), v.clone());
127 }
128 JsonValue::Object(merged)
129 }
130 _ => update.clone(),
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_append_reducer() {
140 let ch = BinaryOperatorAggregate::new("items", append_reducer);
141 ch.update(&[serde_json::json!([1, 2])]).unwrap();
142 ch.update(&[serde_json::json!([3, 4])]).unwrap();
143 assert_eq!(ch.get().unwrap(), serde_json::json!([1, 2, 3, 4]));
144 }
145
146 #[test]
147 fn test_merge_reducer() {
148 let ch = BinaryOperatorAggregate::new("state", merge_reducer);
149 ch.update(&[serde_json::json!({"a": 1})]).unwrap();
150 ch.update(&[serde_json::json!({"b": 2})]).unwrap();
151 assert_eq!(ch.get().unwrap(), serde_json::json!({"a": 1, "b": 2}));
152 }
153
154 #[test]
155 fn test_overwrite() {
156 let ch = BinaryOperatorAggregate::new("items", append_reducer);
157 ch.update(&[serde_json::json!([1, 2])]).unwrap();
158 ch.update(&[serde_json::json!({"__overwrite__": [99]})]).unwrap();
159 assert_eq!(ch.get().unwrap(), serde_json::json!([99]));
160 }
161
162 #[test]
163 fn test_checkpoint_restore() {
164 let ch = BinaryOperatorAggregate::new("items", append_reducer);
165 ch.update(&[serde_json::json!([1, 2])]).unwrap();
166
167 let cp = ch.checkpoint();
168 let restored = ch.from_checkpoint(cp.as_ref());
169 assert_eq!(restored.get().unwrap(), serde_json::json!([1, 2]));
170
171 restored.update(&[serde_json::json!([3])]).unwrap();
172 assert_eq!(restored.get().unwrap(), serde_json::json!([1, 2, 3]));
173 }
174}