rust_langgraph/channels/
topic.rs1use super::BaseChannel;
6use crate::errors::Result;
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9use std::marker::PhantomData;
10
11#[derive(Debug, Clone)]
32pub struct Topic<T> {
33 values: Vec<T>,
34 _phantom: PhantomData<T>,
35}
36
37impl<T> Topic<T> {
38 pub fn new() -> Self {
40 Self {
41 values: Vec::new(),
42 _phantom: PhantomData,
43 }
44 }
45
46 pub fn with_values(values: Vec<T>) -> Self {
48 Self {
49 values,
50 _phantom: PhantomData,
51 }
52 }
53
54 pub fn len(&self) -> usize {
56 self.values.len()
57 }
58
59 pub fn is_empty(&self) -> bool {
61 self.values.is_empty()
62 }
63}
64
65impl<T> Default for Topic<T> {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl<T> BaseChannel for Topic<T>
72where
73 T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Debug + 'static,
74{
75 fn get(&self) -> Result<Option<serde_json::Value>> {
76 if self.values.is_empty() {
77 Ok(None)
78 } else {
79 Ok(Some(serde_json::to_value(&self.values)?))
80 }
81 }
82
83 fn update(&mut self, values: Vec<serde_json::Value>) -> Result<()> {
84 for value in values {
85 let typed_value: T = serde_json::from_value(value)?;
86 self.values.push(typed_value);
87 }
88 Ok(())
89 }
90
91 fn checkpoint(&self) -> Result<serde_json::Value> {
92 serde_json::to_value(&self.values).map_err(Into::into)
93 }
94
95 fn from_checkpoint(data: serde_json::Value) -> Result<Box<dyn BaseChannel>> {
96 let values: Vec<T> = serde_json::from_value(data)?;
97 Ok(Box::new(Self::with_values(values)))
98 }
99
100 fn type_name(&self) -> &'static str {
101 "Topic"
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 #[test]
110 fn test_topic_basic() {
111 let mut channel = Topic::<i32>::new();
112 assert!(channel.get().unwrap().is_none());
113 assert_eq!(channel.len(), 0);
114
115 channel.update(vec![serde_json::json!(1)]).unwrap();
116 assert_eq!(channel.len(), 1);
117
118 let values: Vec<i32> = serde_json::from_value(channel.get().unwrap().unwrap()).unwrap();
119 assert_eq!(values, vec![1]);
120 }
121
122 #[test]
123 fn test_topic_accumulation() {
124 let mut channel = Topic::<String>::new();
125
126 channel.update(vec![serde_json::json!("first")]).unwrap();
127 channel.update(vec![serde_json::json!("second")]).unwrap();
128 channel
129 .update(vec![serde_json::json!("third"), serde_json::json!("fourth")])
130 .unwrap();
131
132 let values: Vec<String> =
133 serde_json::from_value(channel.get().unwrap().unwrap()).unwrap();
134 assert_eq!(values, vec!["first", "second", "third", "fourth"]);
135 }
136
137 #[test]
138 fn test_topic_checkpoint() {
139 let mut channel = Topic::<i32>::new();
140 channel
141 .update(vec![serde_json::json!(1), serde_json::json!(2)])
142 .unwrap();
143
144 let checkpoint = channel.checkpoint().unwrap();
145 let restored = Topic::<i32>::from_checkpoint(checkpoint).unwrap();
146
147 let values: Vec<i32> = serde_json::from_value(restored.get().unwrap().unwrap()).unwrap();
148 assert_eq!(values, vec![1, 2]);
149 }
150}