Skip to main content

datum_net/
stream_ref.rs

1//! StreamRefs over QUIC.
2//!
3//! `datum-core` owns the protobuf protocol and state machine. This module is
4//! the carrier: it length-prefixes protobuf frames and pumps them over one
5//! reliable, ordered QUIC bidirectional byte stream.
6
7use std::{
8    collections::VecDeque,
9    sync::{Arc, Mutex, mpsc},
10    thread,
11};
12
13use datum::{
14    NotUsed, Sink, Source, SourceRef, StreamCompletion, StreamError, StreamRefFrame, StreamRefId,
15    StreamRefPayload, StreamRefProtoConsumer, StreamRefProtoEndpoint, StreamRefProtoProducer,
16    StreamRefSettings, StreamResult,
17};
18
19use crate::QuicBidirectionalStream;
20
21const FRAME_LEN_BYTES: usize = 4;
22const MAX_STREAM_REF_FRAME_BYTES: usize = 16 * 1024 * 1024;
23
24/// Completion handle for a StreamRefs-over-QUIC carrier.
25#[must_use = "wait for the QUIC StreamRefs carrier to observe completion or failure"]
26pub struct StreamRefQuicHandle {
27    receiver: mpsc::Receiver<StreamResult<NotUsed>>,
28}
29
30impl StreamRefQuicHandle {
31    pub fn wait(self) -> StreamResult<NotUsed> {
32        self.receiver
33            .recv()
34            .unwrap_or(Err(StreamError::AbruptTermination))
35    }
36
37    #[must_use]
38    pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
39        self.receiver.try_recv().ok()
40    }
41}
42
43/// Serves a local `SourceRef` over an accepted or opened QUIC bidi stream.
44pub fn serve_source_ref_over_quic<T>(
45    stream: QuicBidirectionalStream,
46    source_ref: SourceRef<T>,
47    stream_ref_id: StreamRefId,
48    settings: StreamRefSettings,
49) -> StreamResult<StreamRefQuicHandle>
50where
51    T: StreamRefPayload,
52{
53    let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
54    Ok(drive_stream_ref_endpoint(stream, producer))
55}
56
57/// Serves a local `Source` over an accepted or opened QUIC bidi stream.
58pub fn serve_source_over_quic<T, Mat>(
59    stream: QuicBidirectionalStream,
60    source: Source<T, Mat>,
61    stream_ref_id: StreamRefId,
62    settings: StreamRefSettings,
63) -> StreamResult<StreamRefQuicHandle>
64where
65    T: StreamRefPayload,
66    Mat: Send + 'static,
67{
68    let producer = StreamRefProtoProducer::from_source(source, stream_ref_id, settings)?;
69    Ok(drive_stream_ref_endpoint(stream, producer))
70}
71
72/// Creates a local source fed by a remote QUIC StreamRef producer.
73pub fn source_ref_over_quic<T>(
74    stream: QuicBidirectionalStream,
75    stream_ref_id: StreamRefId,
76    settings: StreamRefSettings,
77) -> (Source<T, NotUsed>, StreamRefQuicHandle)
78where
79    T: StreamRefPayload,
80{
81    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
82    let source = consumer.source();
83    let handle = drive_stream_ref_endpoint(stream, consumer);
84    (source, handle)
85}
86
87/// Serves a local `SinkRef` receiver over an accepted or opened QUIC bidi
88/// stream, returning a [`Source`] of inbound elements.
89///
90/// This is the local/receiver side of the SinkRef-over-QUIC pair: the remote
91/// sender pushes elements into a [`sink_ref_over_quic`](fn.sink_ref_over_quic)
92/// `Sink`, and this side surfaces them as a `Source`. The caller runs the
93/// returned source into a local `Sink` (for example `Sink::collect` or a fold).
94///
95/// The wiring is the transport-agnostic consumer seam
96/// ([`StreamRefProtoConsumer`]); the only difference from
97/// [`source_ref_over_quic`] is naming and intent — both surface inbound remote
98/// elements as a local `Source`, so a SinkRef receiver is symmetric with a
99/// SourceRef receiver over the same carrier.
100pub fn serve_sink_ref_over_quic<T>(
101    stream: QuicBidirectionalStream,
102    stream_ref_id: StreamRefId,
103    settings: StreamRefSettings,
104) -> (Source<T, NotUsed>, StreamRefQuicHandle)
105where
106    T: StreamRefPayload,
107{
108    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
109    let source = consumer.source();
110    let handle = drive_stream_ref_endpoint(stream, consumer);
111    (source, handle)
112}
113
114/// Creates a local `Sink` that sends its incoming elements over QUIC to a
115/// remote `SinkRef` receiver.
116///
117/// This is the remote/sender side of the SinkRef-over-QUIC pair: elements
118/// pushed into the returned [`Sink`] are framed as `SequencedOnNext` and sent
119/// to the [`serve_sink_ref_over_quic`] receiver. The materialized value is a
120/// [`StreamCompletion`] that resolves when the producer reaches its terminal
121/// state (all elements sent and acknowledged, the receiver cancelled/failed,
122/// or the carrier failed).
123///
124/// The wiring is the transport-agnostic producer seam
125/// ([`StreamRefProtoProducer`]) in lazy-input mode: the input stream is
126/// attached when the sink is materialized, and the producer waits (without
127/// spinning) until both the input is attached and the remote has subscribed
128/// with demand.
129pub fn sink_ref_over_quic<T>(
130    stream: QuicBidirectionalStream,
131    stream_ref_id: StreamRefId,
132    settings: StreamRefSettings,
133) -> (Sink<T, StreamCompletion<NotUsed>>, StreamRefQuicHandle)
134where
135    T: StreamRefPayload,
136{
137    let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
138    let sink = producer.sink();
139    let handle = drive_stream_ref_endpoint(stream, producer);
140    (sink, handle)
141}
142
143fn drive_stream_ref_endpoint<E>(stream: QuicBidirectionalStream, endpoint: E) -> StreamRefQuicHandle
144where
145    E: StreamRefProtoEndpoint,
146{
147    let (byte_source, byte_sink) = stream.into_parts();
148    let (sender, receiver) = mpsc::channel();
149
150    let outbound_endpoint = endpoint.clone();
151    let outbound_thread = thread::spawn(move || {
152        let result = outbound_frames(outbound_endpoint.clone())
153            .run_with(byte_sink)
154            .and_then(|completion| completion.wait());
155        if let Err(error) = &result {
156            outbound_endpoint.fail_connection(error.clone());
157        }
158        result
159    });
160
161    let inbound_endpoint = endpoint.clone();
162    let inbound_thread = thread::spawn(move || {
163        let result = inbound_frames(byte_source)
164            .run_with(Sink::foreach_result({
165                let inbound_endpoint = inbound_endpoint.clone();
166                move |frame| inbound_endpoint.handle_frame(frame)
167            }))
168            .and_then(|completion| completion.wait());
169        if let Err(error) = &result {
170            inbound_endpoint.fail_connection(error.clone());
171        }
172        result
173    });
174
175    thread::spawn(move || {
176        let outbound = join_carrier_thread(outbound_thread);
177        let inbound = join_carrier_thread(inbound_thread);
178        let result = match (outbound, inbound) {
179            (Err(error), _) => Err(error),
180            (_, Err(error)) => Err(error),
181            (Ok(()), Ok(())) => Ok(NotUsed),
182        };
183        let _ = sender.send(result);
184    });
185
186    StreamRefQuicHandle { receiver }
187}
188
189fn outbound_frames<E>(endpoint: E) -> Source<Vec<u8>, NotUsed>
190where
191    E: StreamRefProtoEndpoint,
192{
193    Source::unfold_resource(
194        move || Ok(endpoint.clone()),
195        |endpoint| match endpoint.next_frame() {
196            Some(Ok(frame)) => Ok(Some(encode_carrier_frame(frame)?)),
197            Some(Err(error)) => Err(error),
198            None => Ok(None),
199        },
200        |_endpoint| Ok(()),
201    )
202}
203
204fn inbound_frames(byte_source: Source<Vec<u8>, NotUsed>) -> Source<StreamRefFrame, NotUsed> {
205    let decoder = Arc::new(Mutex::new(FrameDecoder::default()));
206    byte_source.map_concat_result(move |chunk| {
207        decoder
208            .lock()
209            .expect("stream ref frame decoder poisoned")
210            .push_chunk(chunk)
211    })
212}
213
214fn encode_carrier_frame(frame: StreamRefFrame) -> StreamResult<Vec<u8>> {
215    let payload = frame.encode_to_vec();
216    let len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
217        max: u32::MAX as u64,
218    })?;
219    let mut bytes = Vec::with_capacity(FRAME_LEN_BYTES + payload.len());
220    bytes.extend(len.to_be_bytes());
221    bytes.extend(payload);
222    Ok(bytes)
223}
224
225#[derive(Default)]
226struct FrameDecoder {
227    buffer: VecDeque<u8>,
228}
229
230impl FrameDecoder {
231    fn push_chunk(&mut self, chunk: Vec<u8>) -> StreamResult<Vec<StreamRefFrame>> {
232        self.buffer.extend(chunk);
233        let mut frames = Vec::new();
234        while let Some(len) = self.peek_len()? {
235            if self.buffer.len() < FRAME_LEN_BYTES + len {
236                break;
237            }
238            self.buffer.drain(..FRAME_LEN_BYTES);
239            let payload = self.buffer.drain(..len).collect::<Vec<_>>();
240            frames.push(StreamRefFrame::decode(&payload)?);
241        }
242        Ok(frames)
243    }
244
245    fn peek_len(&self) -> StreamResult<Option<usize>> {
246        if self.buffer.len() < FRAME_LEN_BYTES {
247            return Ok(None);
248        }
249        let mut len = [0_u8; FRAME_LEN_BYTES];
250        for (target, source) in len.iter_mut().zip(self.buffer.iter().take(FRAME_LEN_BYTES)) {
251            *target = *source;
252        }
253        let len = u32::from_be_bytes(len) as usize;
254        if len > MAX_STREAM_REF_FRAME_BYTES {
255            return Err(StreamError::LimitExceeded {
256                max: MAX_STREAM_REF_FRAME_BYTES as u64,
257            });
258        }
259        Ok(Some(len))
260    }
261}
262
263fn join_carrier_thread(handle: thread::JoinHandle<StreamResult<NotUsed>>) -> StreamResult<()> {
264    match handle.join() {
265        Ok(Ok(NotUsed)) => Ok(()),
266        Ok(Err(error)) => Err(error),
267        Err(_) => Err(StreamError::Failed(
268            "StreamRefs QUIC carrier thread panicked".to_owned(),
269        )),
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn carrier_frame_decoder_reassembles_split_frames() {
279        let frame = StreamRefFrame::new(
280            StreamRefId::from_u128(1),
281            datum::StreamRefMessage::CumulativeDemand { seq_nr: 32 },
282        );
283        let bytes = encode_carrier_frame(frame.clone()).unwrap();
284        let split = bytes.len() / 2;
285        let mut decoder = FrameDecoder::default();
286
287        assert!(
288            decoder
289                .push_chunk(bytes[..split].to_vec())
290                .unwrap()
291                .is_empty()
292        );
293        assert_eq!(
294            decoder.push_chunk(bytes[split..].to_vec()).unwrap(),
295            vec![frame]
296        );
297    }
298}