atomr_streams/
stream_ref.rs1use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16
17use futures::stream::StreamExt;
18use tokio::sync::mpsc;
19
20use crate::source::Source;
21
22pub struct SourceRefHandle<T: Send + 'static> {
25 pub id: u64,
27 receiver: parking_lot::Mutex<Option<mpsc::Receiver<T>>>,
28}
29
30impl<T: Send + 'static> SourceRefHandle<T> {
31 pub fn advertise(source: Source<T>, buffer: usize) -> Self {
34 let id = next_ref_id();
35 let buffer = buffer.max(1);
36 let (tx, rx) = mpsc::channel::<T>(buffer);
37 let mut inner = source.into_boxed();
38 tokio::spawn(async move {
39 while let Some(item) = inner.next().await {
40 if tx.send(item).await.is_err() {
41 return;
42 }
43 }
44 });
45 Self { id, receiver: parking_lot::Mutex::new(Some(rx)) }
46 }
47
48 pub fn take_source(&self) -> Source<T> {
51 match self.receiver.lock().take() {
52 Some(rx) => Source { inner: rx_to_stream(rx).boxed() },
53 None => Source::empty(),
54 }
55 }
56}
57
58fn rx_to_stream<T: Send + 'static>(rx: mpsc::Receiver<T>) -> futures::stream::BoxStream<'static, T> {
59 futures::stream::unfold(rx, |mut rx| async move { rx.recv().await.map(|item| (item, rx)) }).boxed()
60}
61
62pub struct SinkRefHandle<T: Send + 'static> {
66 pub id: u64,
67 sender: mpsc::Sender<T>,
68 receiver: parking_lot::Mutex<Option<mpsc::Receiver<T>>>,
69}
70
71impl<T: Send + 'static> SinkRefHandle<T> {
72 pub fn new(buffer: usize) -> Self {
73 let buffer = buffer.max(1);
74 let (tx, rx) = mpsc::channel::<T>(buffer);
75 Self { id: next_ref_id(), sender: tx, receiver: parking_lot::Mutex::new(Some(rx)) }
76 }
77
78 pub fn attach(&self, source: Source<T>) {
81 let tx = self.sender.clone();
82 let mut inner = source.into_boxed();
83 tokio::spawn(async move {
84 while let Some(item) = inner.next().await {
85 if tx.send(item).await.is_err() {
86 return;
87 }
88 }
89 });
90 }
91
92 pub fn take_source(&self) -> Source<T> {
95 match self.receiver.lock().take() {
96 Some(rx) => Source { inner: rx_to_stream(rx).boxed() },
97 None => Source::empty(),
98 }
99 }
100}
101
102fn next_ref_id() -> u64 {
103 static NEXT: AtomicU64 = AtomicU64::new(1);
104 NEXT.fetch_add(1, Ordering::Relaxed)
105}
106
107pub type SourceRef<T> = Arc<SourceRefHandle<T>>;
109pub type SinkRef<T> = Arc<SinkRefHandle<T>>;
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::sink::Sink;
115 use std::time::Duration;
116
117 #[tokio::test]
118 async fn source_ref_round_trips_elements() {
119 let s = Source::from_iter(vec![1, 2, 3, 4]);
120 let handle: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(s, 16));
121 let consumed = Sink::collect(handle.take_source()).await;
122 assert_eq!(consumed, vec![1, 2, 3, 4]);
123 }
124
125 #[tokio::test]
126 async fn source_ref_take_twice_yields_empty_second() {
127 let s = Source::from_iter(vec![1]);
128 let handle: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(s, 1));
129 let _ = handle.take_source();
130 let v = tokio::time::timeout(Duration::from_millis(20), Sink::collect(handle.take_source()))
131 .await
132 .unwrap_or_default();
133 assert!(v.is_empty());
134 }
135
136 #[tokio::test]
137 async fn sink_ref_aggregates_attached_sources() {
138 let sink: SinkRef<i32> = Arc::new(SinkRefHandle::new(16));
139 sink.attach(Source::from_iter(vec![1, 2, 3]));
140 sink.attach(Source::from_iter(vec![10, 20]));
141 let merged = sink.take_source();
142 drop(sink);
145 let mut got = Sink::collect(merged).await;
146 got.sort();
147 assert_eq!(got, vec![1, 2, 3, 10, 20]);
148 }
149
150 #[tokio::test]
151 async fn ref_ids_are_unique_per_node() {
152 let s1: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(Source::from_iter(vec![1]), 1));
153 let s2: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(Source::from_iter(vec![1]), 1));
154 assert_ne!(s1.id, s2.id);
155 }
156}