Skip to main content

atomr_streams/
stream_ref.rs

1//! `SourceRef[T]` / `SinkRef[T]` — handles to streams that can cross
2//! process boundaries.
3//!
4//! Phase 12.9 of `docs/full-port-plan.md`. Akka.NET parity:
5//! `Akka.Streams.StreamRefs.{SourceRef, SinkRef}`. The wire-level
6//! transport (sequence numbers, demand windows, retransmission) is
7//! a follow-on; this module ships the in-process scaffolding that
8//! lets a `Source<T>` be advertised over an mpsc channel and pulled
9//! by a remote attacher.
10//!
11//! For Phase 5.D / Phase 6.D's wire integration, the channel handles
12//! get serialized as `RemoteEnvelope`s; both ends use the same
13//! `SourceRefHandle` shape so the local-only and cross-process
14//! flavours share an API.
15
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::Arc;
18
19use futures::stream::StreamExt;
20use tokio::sync::mpsc;
21
22use crate::source::Source;
23
24/// Producer-side advertisement of a `Source<T>`. The owner pumps
25/// elements; consumers subscribe via [`SourceRefHandle::take_source`].
26pub struct SourceRefHandle<T: Send + 'static> {
27    /// Globally-unique stream ref id (unique per node).
28    pub id: u64,
29    receiver: parking_lot::Mutex<Option<mpsc::Receiver<T>>>,
30}
31
32impl<T: Send + 'static> SourceRefHandle<T> {
33    /// Advertise `source` as a stream ref. Returns the handle the
34    /// caller serializes/sends to the consumer side.
35    pub fn advertise(source: Source<T>, buffer: usize) -> Self {
36        let id = next_ref_id();
37        let buffer = buffer.max(1);
38        let (tx, rx) = mpsc::channel::<T>(buffer);
39        let mut inner = source.into_boxed();
40        tokio::spawn(async move {
41            while let Some(item) = inner.next().await {
42                if tx.send(item).await.is_err() {
43                    return;
44                }
45            }
46        });
47        Self { id, receiver: parking_lot::Mutex::new(Some(rx)) }
48    }
49
50    /// Take the consumer source. Calling more than once yields
51    /// `Source::empty()` (the receiver only exists once).
52    pub fn take_source(&self) -> Source<T> {
53        match self.receiver.lock().take() {
54            Some(rx) => Source { inner: rx_to_stream(rx).boxed() },
55            None => Source::empty(),
56        }
57    }
58}
59
60fn rx_to_stream<T: Send + 'static>(rx: mpsc::Receiver<T>) -> futures::stream::BoxStream<'static, T> {
61    futures::stream::unfold(rx, |mut rx| async move { rx.recv().await.map(|item| (item, rx)) }).boxed()
62}
63
64/// Consumer-side advertisement of a `Sink<T>`. The producer attaches
65/// a source via [`SinkRefHandle::attach`] which then pumps into the
66/// receiver-owned stream.
67pub struct SinkRefHandle<T: Send + 'static> {
68    pub id: u64,
69    sender: mpsc::Sender<T>,
70    receiver: parking_lot::Mutex<Option<mpsc::Receiver<T>>>,
71}
72
73impl<T: Send + 'static> SinkRefHandle<T> {
74    pub fn new(buffer: usize) -> Self {
75        let buffer = buffer.max(1);
76        let (tx, rx) = mpsc::channel::<T>(buffer);
77        Self { id: next_ref_id(), sender: tx, receiver: parking_lot::Mutex::new(Some(rx)) }
78    }
79
80    /// Producer-side: attach `source` so its elements drain into the
81    /// sink. Multiple attaches are merged.
82    pub fn attach(&self, source: Source<T>) {
83        let tx = self.sender.clone();
84        let mut inner = source.into_boxed();
85        tokio::spawn(async move {
86            while let Some(item) = inner.next().await {
87                if tx.send(item).await.is_err() {
88                    return;
89                }
90            }
91        });
92    }
93
94    /// Consumer-side: take the source that drains every attached
95    /// producer.
96    pub fn take_source(&self) -> Source<T> {
97        match self.receiver.lock().take() {
98            Some(rx) => Source { inner: rx_to_stream(rx).boxed() },
99            None => Source::empty(),
100        }
101    }
102}
103
104fn next_ref_id() -> u64 {
105    static NEXT: AtomicU64 = AtomicU64::new(1);
106    NEXT.fetch_add(1, Ordering::Relaxed)
107}
108
109// `Arc` re-export so callers can pass handles between actors.
110pub type SourceRef<T> = Arc<SourceRefHandle<T>>;
111pub type SinkRef<T> = Arc<SinkRefHandle<T>>;
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::sink::Sink;
117    use std::time::Duration;
118
119    #[tokio::test]
120    async fn source_ref_round_trips_elements() {
121        let s = Source::from_iter(vec![1, 2, 3, 4]);
122        let handle: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(s, 16));
123        let consumed = Sink::collect(handle.take_source()).await;
124        assert_eq!(consumed, vec![1, 2, 3, 4]);
125    }
126
127    #[tokio::test]
128    async fn source_ref_take_twice_yields_empty_second() {
129        let s = Source::from_iter(vec![1]);
130        let handle: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(s, 1));
131        let _ = handle.take_source();
132        let v = tokio::time::timeout(Duration::from_millis(20), Sink::collect(handle.take_source()))
133            .await
134            .unwrap_or_default();
135        assert!(v.is_empty());
136    }
137
138    #[tokio::test]
139    async fn sink_ref_aggregates_attached_sources() {
140        let sink: SinkRef<i32> = Arc::new(SinkRefHandle::new(16));
141        sink.attach(Source::from_iter(vec![1, 2, 3]));
142        sink.attach(Source::from_iter(vec![10, 20]));
143        let merged = sink.take_source();
144        // Drop the handle so its retained sender is released — without
145        // this the merged source never sees `Closed` and we'd hang.
146        drop(sink);
147        let mut got = Sink::collect(merged).await;
148        got.sort();
149        assert_eq!(got, vec![1, 2, 3, 10, 20]);
150    }
151
152    #[tokio::test]
153    async fn ref_ids_are_unique_per_node() {
154        let s1: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(Source::from_iter(vec![1]), 1));
155        let s2: SourceRef<i32> = Arc::new(SourceRefHandle::advertise(Source::from_iter(vec![1]), 1));
156        assert_ne!(s1.id, s2.id);
157    }
158}