agent_chain_core/tracers/
memory_stream.rs1use std::sync::Arc;
8use tokio::sync::mpsc;
9
10#[derive(Debug)]
12pub struct SendStream<T> {
13 sender: mpsc::UnboundedSender<Option<T>>,
14}
15
16impl<T> SendStream<T> {
17 pub async fn send(&self, item: T) -> Result<(), mpsc::error::SendError<Option<T>>> {
27 self.send_nowait(item)
28 }
29
30 pub fn send_nowait(&self, item: T) -> Result<(), mpsc::error::SendError<Option<T>>> {
40 self.sender.send(Some(item))
41 }
42
43 pub async fn aclose(&self) -> Result<(), mpsc::error::SendError<Option<T>>> {
45 self.close()
46 }
47
48 pub fn close(&self) -> Result<(), mpsc::error::SendError<Option<T>>> {
50 self.sender.send(None)
51 }
52}
53
54impl<T> Clone for SendStream<T> {
55 fn clone(&self) -> Self {
56 Self {
57 sender: self.sender.clone(),
58 }
59 }
60}
61
62#[derive(Debug)]
64pub struct ReceiveStream<T> {
65 receiver: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<Option<T>>>>,
66 is_closed: Arc<std::sync::atomic::AtomicBool>,
67}
68
69impl<T> ReceiveStream<T> {
70 pub fn is_closed(&self) -> bool {
72 self.is_closed.load(std::sync::atomic::Ordering::SeqCst)
73 }
74}
75
76impl<T: Send + 'static> ReceiveStream<T> {
77 pub fn into_stream(self) -> impl futures::Stream<Item = T> {
79 futures::stream::unfold(self, |state| async move {
80 if state.is_closed() {
81 return None;
82 }
83
84 let mut receiver = state.receiver.lock().await;
85 match receiver.recv().await {
86 Some(Some(item)) => {
87 drop(receiver);
88 Some((item, state))
89 }
90 Some(None) | None => {
91 state
92 .is_closed
93 .store(true, std::sync::atomic::Ordering::SeqCst);
94 None
95 }
96 }
97 })
98 }
99}
100
101#[derive(Debug)]
106pub struct MemoryStream<T> {
107 sender: mpsc::UnboundedSender<Option<T>>,
108 receiver: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<Option<T>>>>,
109}
110
111impl<T> MemoryStream<T> {
112 pub fn new() -> Self {
114 let (sender, receiver) = mpsc::unbounded_channel();
115 Self {
116 sender,
117 receiver: Arc::new(tokio::sync::Mutex::new(receiver)),
118 }
119 }
120
121 pub fn get_send_stream(&self) -> SendStream<T> {
123 SendStream {
124 sender: self.sender.clone(),
125 }
126 }
127
128 pub fn get_receive_stream(&self) -> ReceiveStream<T> {
130 ReceiveStream {
131 receiver: self.receiver.clone(),
132 is_closed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
133 }
134 }
135}
136
137impl<T> Default for MemoryStream<T> {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143#[derive(Debug)]
145pub struct BoundedMemoryStream<T> {
146 sender: mpsc::Sender<Option<T>>,
147 receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<Option<T>>>>,
148}
149
150impl<T> BoundedMemoryStream<T> {
151 pub fn new(capacity: usize) -> Self {
157 let (sender, receiver) = mpsc::channel(capacity);
158 Self {
159 sender,
160 receiver: Arc::new(tokio::sync::Mutex::new(receiver)),
161 }
162 }
163
164 pub fn get_send_stream(&self) -> BoundedSendStream<T> {
166 BoundedSendStream {
167 sender: self.sender.clone(),
168 }
169 }
170
171 pub fn get_receive_stream(&self) -> BoundedReceiveStream<T> {
173 BoundedReceiveStream {
174 receiver: self.receiver.clone(),
175 is_closed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
176 }
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct BoundedSendStream<T> {
183 sender: mpsc::Sender<Option<T>>,
184}
185
186impl<T> BoundedSendStream<T> {
187 pub async fn send(&self, item: T) -> Result<(), mpsc::error::SendError<Option<T>>> {
193 self.sender.send(Some(item)).await
194 }
195
196 pub fn try_send(&self, item: T) -> Result<(), mpsc::error::TrySendError<Option<T>>> {
198 self.sender.try_send(Some(item))
199 }
200
201 pub async fn close(&self) -> Result<(), mpsc::error::SendError<Option<T>>> {
203 self.sender.send(None).await
204 }
205}
206
207#[derive(Debug)]
209pub struct BoundedReceiveStream<T> {
210 receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<Option<T>>>>,
211 is_closed: Arc<std::sync::atomic::AtomicBool>,
212}
213
214impl<T> BoundedReceiveStream<T> {
215 pub fn is_closed(&self) -> bool {
217 self.is_closed.load(std::sync::atomic::Ordering::SeqCst)
218 }
219}
220
221impl<T: Send + 'static> BoundedReceiveStream<T> {
222 pub fn into_stream(self) -> impl futures::Stream<Item = T> {
224 futures::stream::unfold(self, |state| async move {
225 if state.is_closed() {
226 return None;
227 }
228
229 let mut receiver = state.receiver.lock().await;
230 match receiver.recv().await {
231 Some(Some(item)) => {
232 drop(receiver);
233 Some((item, state))
234 }
235 Some(None) | None => {
236 state
237 .is_closed
238 .store(true, std::sync::atomic::Ordering::SeqCst);
239 None
240 }
241 }
242 })
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use futures::StreamExt;
250 use std::pin::pin;
251
252 #[tokio::test]
253 async fn test_memory_stream_basic() {
254 let stream = MemoryStream::<i32>::new();
255 let sender = stream.get_send_stream();
256 let receiver = stream.get_receive_stream();
257
258 sender.send_nowait(1).unwrap();
259 sender.send_nowait(2).unwrap();
260 sender.send_nowait(3).unwrap();
261 sender.close().unwrap();
262
263 let mut results = Vec::new();
264 let mut stream = pin!(receiver.into_stream());
265 while let Some(item) = stream.next().await {
266 results.push(item);
267 }
268
269 assert_eq!(results, vec![1, 2, 3]);
270 }
271
272 #[tokio::test]
273 async fn test_memory_stream_async_send() {
274 let stream = MemoryStream::<String>::new();
275 let sender = stream.get_send_stream();
276 let receiver = stream.get_receive_stream();
277
278 sender.send("hello".to_string()).await.unwrap();
279 sender.send("world".to_string()).await.unwrap();
280 sender.aclose().await.unwrap();
281
282 let mut results = Vec::new();
283 let mut stream = pin!(receiver.into_stream());
284 while let Some(item) = stream.next().await {
285 results.push(item);
286 }
287
288 assert_eq!(results, vec!["hello".to_string(), "world".to_string()]);
289 }
290
291 #[tokio::test]
292 async fn test_bounded_memory_stream() {
293 let stream = BoundedMemoryStream::<i32>::new(10);
294 let sender = stream.get_send_stream();
295 let receiver = stream.get_receive_stream();
296
297 sender.send(1).await.unwrap();
298 sender.send(2).await.unwrap();
299 sender.close().await.unwrap();
300
301 let mut results = Vec::new();
302 let mut stream = pin!(receiver.into_stream());
303 while let Some(item) = stream.next().await {
304 results.push(item);
305 }
306
307 assert_eq!(results, vec![1, 2]);
308 }
309
310 #[tokio::test]
311 async fn test_send_stream_clone() {
312 let stream = MemoryStream::<i32>::new();
313 let sender1 = stream.get_send_stream();
314 let sender2 = sender1.clone();
315 let receiver = stream.get_receive_stream();
316
317 sender1.send_nowait(1).unwrap();
318 sender2.send_nowait(2).unwrap();
319 sender1.close().unwrap();
320
321 let mut results = Vec::new();
322 let mut stream = pin!(receiver.into_stream());
323 while let Some(item) = stream.next().await {
324 results.push(item);
325 }
326
327 assert_eq!(results, vec![1, 2]);
328 }
329}