Skip to main content

langgraph_core_rs/channels/
topic.rs

1use parking_lot::RwLock;
2use serde_json::Value as JsonValue;
3use langgraph_checkpoint::error::ChannelError;
4use super::base::Channel;
5
6/// PubSub topic channel.
7///
8/// Can accumulate values (`accumulate=true`) or clear each step.
9/// Used for the TASKS channel (holds Send objects).
10pub struct Topic {
11    key: String,
12    values: RwLock<Vec<JsonValue>>,
13    accumulate: bool,
14}
15
16impl Topic {
17    pub fn new(key: impl Into<String>, accumulate: bool) -> Self {
18        Self {
19            key: key.into(),
20            values: RwLock::new(Vec::new()),
21            accumulate,
22        }
23    }
24}
25
26impl Channel for Topic {
27    fn checkpoint(&self) -> Option<JsonValue> {
28        let vals = self.values.read();
29        if vals.is_empty() {
30            None
31        } else {
32            Some(JsonValue::Array(vals.clone()))
33        }
34    }
35
36    fn from_checkpoint(&self, checkpoint: Option<&JsonValue>) -> Box<dyn Channel> {
37        let values = match checkpoint {
38            Some(JsonValue::Array(arr)) => arr.clone(),
39            Some(other) => vec![other.clone()],
40            None => Vec::new(),
41        };
42        Box::new(Self {
43            key: self.key.clone(),
44            values: RwLock::new(values),
45            accumulate: self.accumulate,
46        })
47    }
48
49    fn update(&self, values: &[JsonValue]) -> Result<bool, ChannelError> {
50        if values.is_empty() {
51            return Ok(false);
52        }
53        let mut guard = self.values.write();
54        for val in values {
55            match val {
56                JsonValue::Array(arr) => guard.extend(arr.iter().cloned()),
57                other => guard.push(other.clone()),
58            }
59        }
60        Ok(true)
61    }
62
63    fn get(&self) -> Result<JsonValue, ChannelError> {
64        let vals = self.values.read();
65        if vals.is_empty() {
66            Err(ChannelError::EmptyChannel)
67        } else {
68            Ok(JsonValue::Array(vals.clone()))
69        }
70    }
71
72    fn consume(&self) -> bool {
73        if !self.accumulate {
74            let changed = !self.values.read().is_empty();
75            self.values.write().clear();
76            changed
77        } else {
78            false
79        }
80    }
81
82    fn is_available(&self) -> bool {
83        !self.values.read().is_empty()
84    }
85
86    fn clone_channel(&self) -> Box<dyn Channel> {
87        Box::new(Self {
88            key: self.key.clone(),
89            values: RwLock::new(self.values.read().clone()),
90            accumulate: self.accumulate,
91        })
92    }
93
94    fn name(&self) -> &str {
95        &self.key
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[test]
104    fn test_topic_accumulate() {
105        let ch = Topic::new("tasks", true);
106        ch.update(&[serde_json::json!("a")]).unwrap();
107        ch.update(&[serde_json::json!("b")]).unwrap();
108        assert_eq!(ch.get().unwrap(), serde_json::json!(["a", "b"]));
109        // consume doesn't clear when accumulate=true
110        ch.consume();
111        assert!(ch.is_available());
112    }
113
114    #[test]
115    fn test_topic_no_accumulate() {
116        let ch = Topic::new("tasks", false);
117        ch.update(&[serde_json::json!("a")]).unwrap();
118        ch.update(&[serde_json::json!("b")]).unwrap();
119        assert_eq!(ch.get().unwrap(), serde_json::json!(["a", "b"]));
120        // consume clears when accumulate=false
121        ch.consume();
122        assert!(!ch.is_available());
123    }
124
125    #[test]
126    fn test_topic_array_update() {
127        let ch = Topic::new("tasks", true);
128        ch.update(&[serde_json::json!(["a", "b"]), serde_json::json!(["c"])]).unwrap();
129        assert_eq!(ch.get().unwrap(), serde_json::json!(["a", "b", "c"]));
130    }
131}