agent_chain_core/tracers/
memory_stream.rs

1//! Memory stream for communication between async tasks.
2//!
3//! This module provides a way to communicate between two async tasks using channels.
4//! The writer and reader can be in the same task or different tasks.
5//! Mirrors `langchain_core.tracers.memory_stream`.
6
7use std::sync::Arc;
8use tokio::sync::mpsc;
9
10/// A sender for the memory stream.
11#[derive(Debug)]
12pub struct SendStream<T> {
13    sender: mpsc::UnboundedSender<Option<T>>,
14}
15
16impl<T> SendStream<T> {
17    /// Send an item to the stream.
18    ///
19    /// # Arguments
20    ///
21    /// * `item` - The item to send.
22    ///
23    /// # Returns
24    ///
25    /// `Ok(())` if the item was sent successfully, `Err` if the receiver was dropped.
26    pub async fn send(&self, item: T) -> Result<(), mpsc::error::SendError<Option<T>>> {
27        self.send_nowait(item)
28    }
29
30    /// Send an item to the stream without waiting.
31    ///
32    /// # Arguments
33    ///
34    /// * `item` - The item to send.
35    ///
36    /// # Returns
37    ///
38    /// `Ok(())` if the item was sent successfully, `Err` if the receiver was dropped.
39    pub fn send_nowait(&self, item: T) -> Result<(), mpsc::error::SendError<Option<T>>> {
40        self.sender.send(Some(item))
41    }
42
43    /// Close the stream.
44    pub async fn aclose(&self) -> Result<(), mpsc::error::SendError<Option<T>>> {
45        self.close()
46    }
47
48    /// Close the stream.
49    pub fn close(&self) -> Result<(), mpsc::error::SendError<Option<T>>> {
50        self.sender.send(None)
51    }
52}
53
54impl<T> Clone for SendStream<T> {
55    fn clone(&self) -> Self {
56        Self {
57            sender: self.sender.clone(),
58        }
59    }
60}
61
62/// A receiver for the memory stream.
63#[derive(Debug)]
64pub struct ReceiveStream<T> {
65    receiver: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<Option<T>>>>,
66    is_closed: Arc<std::sync::atomic::AtomicBool>,
67}
68
69impl<T> ReceiveStream<T> {
70    /// Check if the stream is closed.
71    pub fn is_closed(&self) -> bool {
72        self.is_closed.load(std::sync::atomic::Ordering::SeqCst)
73    }
74}
75
76impl<T: Send + 'static> ReceiveStream<T> {
77    /// Create an async iterator over the stream.
78    pub fn into_stream(self) -> impl futures::Stream<Item = T> {
79        futures::stream::unfold(self, |state| async move {
80            if state.is_closed() {
81                return None;
82            }
83
84            let mut receiver = state.receiver.lock().await;
85            match receiver.recv().await {
86                Some(Some(item)) => {
87                    drop(receiver);
88                    Some((item, state))
89                }
90                Some(None) | None => {
91                    state
92                        .is_closed
93                        .store(true, std::sync::atomic::Ordering::SeqCst);
94                    None
95                }
96            }
97        })
98    }
99}
100
101/// A memory stream for communication between async tasks.
102///
103/// This stream uses unbounded channels to communicate between tasks.
104/// It is designed for single producer, single consumer scenarios.
105#[derive(Debug)]
106pub struct MemoryStream<T> {
107    sender: mpsc::UnboundedSender<Option<T>>,
108    receiver: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<Option<T>>>>,
109}
110
111impl<T> MemoryStream<T> {
112    /// Create a new memory stream.
113    pub fn new() -> Self {
114        let (sender, receiver) = mpsc::unbounded_channel();
115        Self {
116            sender,
117            receiver: Arc::new(tokio::sync::Mutex::new(receiver)),
118        }
119    }
120
121    /// Get a sender for the stream.
122    pub fn get_send_stream(&self) -> SendStream<T> {
123        SendStream {
124            sender: self.sender.clone(),
125        }
126    }
127
128    /// Get a receiver for the stream.
129    pub fn get_receive_stream(&self) -> ReceiveStream<T> {
130        ReceiveStream {
131            receiver: self.receiver.clone(),
132            is_closed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
133        }
134    }
135}
136
137impl<T> Default for MemoryStream<T> {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143/// A bounded memory stream with a maximum capacity.
144#[derive(Debug)]
145pub struct BoundedMemoryStream<T> {
146    sender: mpsc::Sender<Option<T>>,
147    receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<Option<T>>>>,
148}
149
150impl<T> BoundedMemoryStream<T> {
151    /// Create a new bounded memory stream.
152    ///
153    /// # Arguments
154    ///
155    /// * `capacity` - The maximum number of items the stream can hold.
156    pub fn new(capacity: usize) -> Self {
157        let (sender, receiver) = mpsc::channel(capacity);
158        Self {
159            sender,
160            receiver: Arc::new(tokio::sync::Mutex::new(receiver)),
161        }
162    }
163
164    /// Get a sender for the stream.
165    pub fn get_send_stream(&self) -> BoundedSendStream<T> {
166        BoundedSendStream {
167            sender: self.sender.clone(),
168        }
169    }
170
171    /// Get a receiver for the stream.
172    pub fn get_receive_stream(&self) -> BoundedReceiveStream<T> {
173        BoundedReceiveStream {
174            receiver: self.receiver.clone(),
175            is_closed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
176        }
177    }
178}
179
180/// A bounded sender for the memory stream.
181#[derive(Debug, Clone)]
182pub struct BoundedSendStream<T> {
183    sender: mpsc::Sender<Option<T>>,
184}
185
186impl<T> BoundedSendStream<T> {
187    /// Send an item to the stream.
188    ///
189    /// # Arguments
190    ///
191    /// * `item` - The item to send.
192    pub async fn send(&self, item: T) -> Result<(), mpsc::error::SendError<Option<T>>> {
193        self.sender.send(Some(item)).await
194    }
195
196    /// Try to send an item without waiting.
197    pub fn try_send(&self, item: T) -> Result<(), mpsc::error::TrySendError<Option<T>>> {
198        self.sender.try_send(Some(item))
199    }
200
201    /// Close the stream.
202    pub async fn close(&self) -> Result<(), mpsc::error::SendError<Option<T>>> {
203        self.sender.send(None).await
204    }
205}
206
207/// A bounded receiver for the memory stream.
208#[derive(Debug)]
209pub struct BoundedReceiveStream<T> {
210    receiver: Arc<tokio::sync::Mutex<mpsc::Receiver<Option<T>>>>,
211    is_closed: Arc<std::sync::atomic::AtomicBool>,
212}
213
214impl<T> BoundedReceiveStream<T> {
215    /// Check if the stream is closed.
216    pub fn is_closed(&self) -> bool {
217        self.is_closed.load(std::sync::atomic::Ordering::SeqCst)
218    }
219}
220
221impl<T: Send + 'static> BoundedReceiveStream<T> {
222    /// Create an async iterator over the stream.
223    pub fn into_stream(self) -> impl futures::Stream<Item = T> {
224        futures::stream::unfold(self, |state| async move {
225            if state.is_closed() {
226                return None;
227            }
228
229            let mut receiver = state.receiver.lock().await;
230            match receiver.recv().await {
231                Some(Some(item)) => {
232                    drop(receiver);
233                    Some((item, state))
234                }
235                Some(None) | None => {
236                    state
237                        .is_closed
238                        .store(true, std::sync::atomic::Ordering::SeqCst);
239                    None
240                }
241            }
242        })
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use futures::StreamExt;
250    use std::pin::pin;
251
252    #[tokio::test]
253    async fn test_memory_stream_basic() {
254        let stream = MemoryStream::<i32>::new();
255        let sender = stream.get_send_stream();
256        let receiver = stream.get_receive_stream();
257
258        sender.send_nowait(1).unwrap();
259        sender.send_nowait(2).unwrap();
260        sender.send_nowait(3).unwrap();
261        sender.close().unwrap();
262
263        let mut results = Vec::new();
264        let mut stream = pin!(receiver.into_stream());
265        while let Some(item) = stream.next().await {
266            results.push(item);
267        }
268
269        assert_eq!(results, vec![1, 2, 3]);
270    }
271
272    #[tokio::test]
273    async fn test_memory_stream_async_send() {
274        let stream = MemoryStream::<String>::new();
275        let sender = stream.get_send_stream();
276        let receiver = stream.get_receive_stream();
277
278        sender.send("hello".to_string()).await.unwrap();
279        sender.send("world".to_string()).await.unwrap();
280        sender.aclose().await.unwrap();
281
282        let mut results = Vec::new();
283        let mut stream = pin!(receiver.into_stream());
284        while let Some(item) = stream.next().await {
285            results.push(item);
286        }
287
288        assert_eq!(results, vec!["hello".to_string(), "world".to_string()]);
289    }
290
291    #[tokio::test]
292    async fn test_bounded_memory_stream() {
293        let stream = BoundedMemoryStream::<i32>::new(10);
294        let sender = stream.get_send_stream();
295        let receiver = stream.get_receive_stream();
296
297        sender.send(1).await.unwrap();
298        sender.send(2).await.unwrap();
299        sender.close().await.unwrap();
300
301        let mut results = Vec::new();
302        let mut stream = pin!(receiver.into_stream());
303        while let Some(item) = stream.next().await {
304            results.push(item);
305        }
306
307        assert_eq!(results, vec![1, 2]);
308    }
309
310    #[tokio::test]
311    async fn test_send_stream_clone() {
312        let stream = MemoryStream::<i32>::new();
313        let sender1 = stream.get_send_stream();
314        let sender2 = sender1.clone();
315        let receiver = stream.get_receive_stream();
316
317        sender1.send_nowait(1).unwrap();
318        sender2.send_nowait(2).unwrap();
319        sender1.close().unwrap();
320
321        let mut results = Vec::new();
322        let mut stream = pin!(receiver.into_stream());
323        while let Some(item) = stream.next().await {
324            results.push(item);
325        }
326
327        assert_eq!(results, vec![1, 2]);
328    }
329}