atomr_streams/
stream_ref.rs1use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::Arc;
18
19use futures::stream::StreamExt;
20use tokio::sync::mpsc;
21
22use crate::source::Source;
23
24pub struct SourceRefHandle<T: Send + 'static> {
27 pub id: u64,
29 receiver: parking_lot::Mutex<Option<mpsc::Receiver<T>>>,
30}
31
32impl<T: Send + 'static> SourceRefHandle<T> {
33 pub fn advertise(source: Source<T>, buffer: usize) -> Self {
36 let id = next_ref_id();
37 let buffer = buffer.max(1);
38 let (tx, rx) = mpsc::channel::<T>(buffer);
39 let mut inner = source.into_boxed();
40 tokio::spawn(async move {
41 while let Some(item) = inner.next().await {
42 if tx.send(item).await.is_err() {
43 return;
44 }
45 }
46 });
47 Self { id, receiver: parking_lot::Mutex::new(Some(rx)) }
48 }
49
50 pub fn take_source(&self) -> Source<T> {
53 match self.receiver.lock().take() {
54 Some(rx) => Source { inner: rx_to_stream(rx).boxed() },
55 None => Source::empty(),
56 }
57 }
58}
59
60fn rx_to_stream<T: Send + 'static>(rx: mpsc::Receiver<T>) -> futures::stream::BoxStream<'static, T> {
61 futures::stream::unfold(rx, |mut rx| async move { rx.recv().await.map(|item| (item, rx)) }).boxed()
62}
63
64pub struct SinkRefHandle<T: Send + 'static> {
68 pub id: u64,
69 sender: mpsc::Sender<T>,
70 receiver: parking_lot::Mutex<Option<mpsc::Receiver<T>>>,
71}
72
73impl<T: Send + 'static> SinkRefHandle<T> {
74 pub fn new(buffer: usize) -> Self {
75 let buffer = buffer.max(1);
76 let (tx, rx) = mpsc::channel::<T>(buffer);
77 Self { id: next_ref_id(), sender: tx, receiver: parking_lot::Mutex::new(Some(rx)) }
78 }
79
80 pub fn attach(&self, source: Source<T>) {
83 let tx = self.sender.clone();
84 let mut inner = source.into_boxed();
85 tokio::spawn(async move {
86 while let Some(item) = inner.next().await {
87 if tx.send(item).await.is_err() {
88 return;
89 }
90 }
91 });
92 }
93
94 pub fn take_source(&self) -> Source<T> {
97 match self.receiver.lock().take() {
98 Some(rx) => Source { inner: rx_to_stream(rx).boxed() },
99 None => Source::empty(),
100 }
101 }
102}
103
104fn next_ref_id() -> u64 {
105 static NEXT: AtomicU64 = AtomicU64::new(1);
106 NEXT.fetch_add(1, Ordering::Relaxed)
107}
108
109pub type SourceRef<T> = Arc<SourceRefHandle<T>>;
111pub type SinkRef<T> = Arc<SinkRefHandle<T>>;
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use crate::sink::Sink;
117 use std::time::Duration;
118
119 #[tokio::test]
120 async fn source_ref_round_trips_elements() {
121 let s = Source::from_iter(vec![1, 2, 3, 4]);
122 let handle: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(s, 16));
123 let consumed = Sink::collect(handle.take_source()).await;
124 assert_eq!(consumed, vec![1, 2, 3, 4]);
125 }
126
127 #[tokio::test]
128 async fn source_ref_take_twice_yields_empty_second() {
129 let s = Source::from_iter(vec![1]);
130 let handle: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(s, 1));
131 let _ = handle.take_source();
132 let v = tokio::time::timeout(Duration::from_millis(20), Sink::collect(handle.take_source()))
133 .await
134 .unwrap_or_default();
135 assert!(v.is_empty());
136 }
137
138 #[tokio::test]
139 async fn sink_ref_aggregates_attached_sources() {
140 let sink: SinkRef<i32> = Arc::new(SinkRefHandle::new(16));
141 sink.attach(Source::from_iter(vec![1, 2, 3]));
142 sink.attach(Source::from_iter(vec![10, 20]));
143 let merged = sink.take_source();
144 drop(sink);
147 let mut got = Sink::collect(merged).await;
148 got.sort();
149 assert_eq!(got, vec![1, 2, 3, 10, 20]);
150 }
151
152 #[tokio::test]
153 async fn ref_ids_are_unique_per_node() {
154 let s1: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(Source::from_iter(vec![1]), 1));
155 let s2: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(Source::from_iter(vec![1]), 1));
156 assert_ne!(s1.id, s2.id);
157 }
158}