Skip to main content

datum_net/
stream_ref.rs

1//! Remote StreamRefs carriers.
2//!
3//! `datum-core` owns the protobuf protocol and state machine. This module is
4//! the carrier layer: it length-prefixes protobuf frames and pumps them over
5//! reliable, ordered bidirectional byte streams such as QUIC or plaintext TCP.
6
7use std::{
8    collections::VecDeque,
9    future::Future,
10    net::SocketAddr,
11    sync::{
12        Arc, Mutex, OnceLock,
13        atomic::{AtomicBool, Ordering},
14        mpsc,
15    },
16    thread,
17};
18
19use bytes::{Buf, BytesMut};
20use datum::{
21    NotUsed, Sink, Source, SourceRef, StreamCompletion, StreamError, StreamRefFrame, StreamRefId,
22    StreamRefMessage, StreamRefOutbound, StreamRefPayload, StreamRefPayloadBatch,
23    StreamRefProtoConsumer, StreamRefProtoEndpoint, StreamRefProtoProducer, StreamRefSettings,
24    StreamResult,
25    actor::stream_ref_proto::{StreamRefOutboundPoll, StreamRefProtoEndpointWake},
26};
27use tokio::{
28    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
29    net::{TcpListener, TcpStream, ToSocketAddrs},
30    runtime::{Handle, Runtime},
31    sync::mpsc as tokio_mpsc,
32};
33
34use crate::QuicBidirectionalStream;
35
36const FRAME_LEN_BYTES: usize = 4;
37const MAX_STREAM_REF_FRAME_BYTES: usize = 16 * 1024 * 1024;
38const STREAM_REF_TCP_CHUNK_SIZE: usize = 8192;
39const STREAM_REF_OUTBOUND_BATCH_FRAMES: usize = 64;
40
41// Carrier wire v1 keeps protobuf for control frames. A high bit in the
42// length-prefix marks compact SequencedOnNext batches: version, kind,
43// stream_ref_id, first seqNr, count, then length-prefixed payloads.
44const COMPACT_FRAME_FLAG: u32 = 0x8000_0000;
45const COMPACT_FRAME_LEN_MASK: u32 = 0x7fff_ffff;
46const COMPACT_FRAME_VERSION: u8 = 1;
47const COMPACT_SEQUENCED_ON_NEXT_BATCH: u8 = 1;
48const COMPACT_BATCH_HEADER_BYTES: usize = 1 + 1 + 16 + 8 + 2;
49const COMPACT_BATCH_ELEMENT_LEN_BYTES: usize = 4;
50
51#[derive(Clone, Copy)]
52struct CarrierReadMode {
53    chunk_size: usize,
54    emit_available: bool,
55    fail_on_eof: bool,
56}
57
58impl CarrierReadMode {
59    fn new(chunk_size: usize, emit_available: bool, fail_on_eof: bool) -> Self {
60        assert!(chunk_size > 0, "chunk size must be greater than zero");
61        Self {
62            chunk_size,
63            emit_available,
64            fail_on_eof,
65        }
66    }
67}
68
69/// Counts selected StreamRefs protocol messages successfully written by a
70/// carrier endpoint.
71#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
72pub struct StreamRefProtocolMessageCounts {
73    pub cumulative_demand: u64,
74    pub sequenced_on_next: u64,
75    pub ack: u64,
76}
77
78/// Shared collector for StreamRefs protocol message-count diagnostics.
79///
80/// The carrier records counts only after the encoded frame has been written
81/// successfully. This is intentionally transport-level instrumentation; it does
82/// not change the protobuf protocol or wire format.
83#[derive(Clone, Default)]
84pub struct StreamRefProtocolDiagnostics {
85    counts: Arc<Mutex<StreamRefProtocolMessageCounts>>,
86}
87
88impl StreamRefProtocolDiagnostics {
89    #[must_use]
90    pub fn new() -> Self {
91        Self::default()
92    }
93
94    #[must_use]
95    pub fn snapshot(&self) -> StreamRefProtocolMessageCounts {
96        *self
97            .counts
98            .lock()
99            .expect("stream ref protocol diagnostics poisoned")
100    }
101
102    fn record_written_outbound(&self, outbound: &StreamRefOutbound) {
103        self.record_counts(outbound_counts(outbound));
104    }
105
106    fn record_counts(&self, delta: StreamRefProtocolMessageCounts) {
107        if delta == StreamRefProtocolMessageCounts::default() {
108            return;
109        }
110        let mut counts = self
111            .counts
112            .lock()
113            .expect("stream ref protocol diagnostics poisoned");
114        counts.cumulative_demand = counts
115            .cumulative_demand
116            .saturating_add(delta.cumulative_demand);
117        counts.sequenced_on_next = counts
118            .sequenced_on_next
119            .saturating_add(delta.sequenced_on_next);
120        counts.ack = counts.ack.saturating_add(delta.ack);
121    }
122}
123
124fn outbound_counts(outbound: &StreamRefOutbound) -> StreamRefProtocolMessageCounts {
125    let mut counts = StreamRefProtocolMessageCounts::default();
126    match outbound {
127        StreamRefOutbound::Frame(frame) => match &frame.message {
128            StreamRefMessage::CumulativeDemand { .. } => {
129                counts.cumulative_demand = 1;
130            }
131            StreamRefMessage::SequencedOnNext { .. } => {
132                counts.sequenced_on_next = 1;
133            }
134            StreamRefMessage::Ack => {
135                counts.ack = 1;
136            }
137            StreamRefMessage::OnSubscribeHandshake
138            | StreamRefMessage::RemoteStreamCompleted { .. }
139            | StreamRefMessage::RemoteStreamFailure { .. } => {}
140        },
141        StreamRefOutbound::SequencedBatch(batch) => {
142            counts.sequenced_on_next = batch.count() as u64;
143        }
144    }
145    counts
146}
147
148#[derive(Clone, Copy)]
149struct PendingDiagnostic {
150    remaining: usize,
151    counts: StreamRefProtocolMessageCounts,
152}
153
154/// Completion handle for a StreamRefs-over-QUIC carrier.
155#[must_use = "wait for the QUIC StreamRefs carrier to observe completion or failure"]
156pub struct StreamRefQuicHandle {
157    receiver: mpsc::Receiver<StreamResult<NotUsed>>,
158}
159
160impl StreamRefQuicHandle {
161    pub fn wait(self) -> StreamResult<NotUsed> {
162        self.receiver
163            .recv()
164            .unwrap_or(Err(StreamError::AbruptTermination))
165    }
166
167    #[must_use]
168    pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
169        self.receiver.try_recv().ok()
170    }
171}
172
173/// Local TCP listener binding used by StreamRefs-over-TCP producer endpoints.
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct StreamRefTcpBinding {
176    local_addr: SocketAddr,
177}
178
179impl StreamRefTcpBinding {
180    #[must_use]
181    pub fn local_addr(&self) -> SocketAddr {
182        self.local_addr
183    }
184}
185
186/// Completion handle for a StreamRefs-over-TCP carrier.
187#[must_use = "wait for the TCP StreamRefs carrier to observe completion or failure"]
188pub struct StreamRefTcpHandle {
189    receiver: mpsc::Receiver<StreamResult<NotUsed>>,
190}
191
192impl StreamRefTcpHandle {
193    pub fn wait(self) -> StreamResult<NotUsed> {
194        self.receiver
195            .recv()
196            .unwrap_or(Err(StreamError::AbruptTermination))
197    }
198
199    #[must_use]
200    pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
201        self.receiver.try_recv().ok()
202    }
203}
204
205/// Serves a local `SourceRef` over an accepted or opened QUIC bidi stream.
206pub fn serve_source_ref_over_quic<T>(
207    stream: QuicBidirectionalStream,
208    source_ref: SourceRef<T>,
209    stream_ref_id: StreamRefId,
210    settings: StreamRefSettings,
211) -> StreamResult<StreamRefQuicHandle>
212where
213    T: StreamRefPayload,
214{
215    let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
216    Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
217}
218
219/// Serves a local `Source` over an accepted or opened QUIC bidi stream.
220pub fn serve_source_over_quic<T, Mat>(
221    stream: QuicBidirectionalStream,
222    source: Source<T, Mat>,
223    stream_ref_id: StreamRefId,
224    settings: StreamRefSettings,
225) -> StreamResult<StreamRefQuicHandle>
226where
227    T: StreamRefPayload,
228    Mat: Send + 'static,
229{
230    let producer = StreamRefProtoProducer::from_source(source, stream_ref_id, settings)?;
231    Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
232}
233
234/// Creates a local source fed by a remote QUIC StreamRef producer.
235pub fn source_ref_over_quic<T>(
236    stream: QuicBidirectionalStream,
237    stream_ref_id: StreamRefId,
238    settings: StreamRefSettings,
239) -> (Source<T, NotUsed>, StreamRefQuicHandle)
240where
241    T: StreamRefPayload,
242{
243    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
244    let source = consumer.source();
245    let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
246    (source, handle)
247}
248
249/// Serves a local `SinkRef` receiver over an accepted or opened QUIC bidi
250/// stream, returning a [`Source`] of inbound elements.
251///
252/// This is the local/receiver side of the SinkRef-over-QUIC pair: the remote
253/// sender pushes elements into a [`sink_ref_over_quic`](fn.sink_ref_over_quic)
254/// `Sink`, and this side surfaces them as a `Source`. The caller runs the
255/// returned source into a local `Sink` (for example `Sink::collect` or a fold).
256pub fn serve_sink_ref_over_quic<T>(
257    stream: QuicBidirectionalStream,
258    stream_ref_id: StreamRefId,
259    settings: StreamRefSettings,
260) -> (Source<T, NotUsed>, StreamRefQuicHandle)
261where
262    T: StreamRefPayload,
263{
264    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
265    let source = consumer.source();
266    let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
267    (source, handle)
268}
269
270/// Creates a local `Sink` that sends its incoming elements over QUIC to a
271/// remote `SinkRef` receiver.
272pub fn sink_ref_over_quic<T>(
273    stream: QuicBidirectionalStream,
274    stream_ref_id: StreamRefId,
275    settings: StreamRefSettings,
276) -> (Sink<T, StreamCompletion<NotUsed>>, StreamRefQuicHandle)
277where
278    T: StreamRefPayload,
279{
280    let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
281    let sink = producer.sink();
282    let handle = drive_stream_ref_endpoint_over_quic(stream, producer, None);
283    (sink, handle)
284}
285
286/// Serves a local `SourceRef` over a one-shot plaintext TCP listener.
287///
288/// The listener binds immediately, accepts one connection, then runs the same
289/// StreamRefs protocol used by the QUIC carrier. The remote receiver should use
290/// [`source_ref_over_tcp`] to open the TCP connection and send the initial
291/// subscribe+demand frames.
292pub fn serve_source_ref_over_tcp<T, A>(
293    addr: A,
294    source_ref: SourceRef<T>,
295    stream_ref_id: StreamRefId,
296    settings: StreamRefSettings,
297) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
298where
299    T: StreamRefPayload,
300    A: ToSocketAddrs + Send + 'static,
301{
302    serve_source_ref_over_tcp_with_diagnostics(addr, source_ref, stream_ref_id, settings, None)
303}
304
305pub fn serve_source_ref_over_tcp_with_diagnostics<T, A>(
306    addr: A,
307    source_ref: SourceRef<T>,
308    stream_ref_id: StreamRefId,
309    settings: StreamRefSettings,
310    diagnostics: Option<StreamRefProtocolDiagnostics>,
311) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
312where
313    T: StreamRefPayload,
314    A: ToSocketAddrs + Send + 'static,
315{
316    let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
317    let (listener, binding, handle) = bind_tcp_listener(addr)?;
318    Ok((
319        binding,
320        drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics),
321    ))
322}
323
324/// Serves a local `SourceRef` over an already-connected Tokio TCP stream.
325///
326/// This is the stream-shaped counterpart to [`serve_source_ref_over_tcp`],
327/// intended for callers that own listener lifecycle separately, such as a
328/// benchmark server that reuses one bound listener across operations.
329pub fn serve_source_ref_over_tcp_stream<T>(
330    stream: TcpStream,
331    source_ref: SourceRef<T>,
332    stream_ref_id: StreamRefId,
333    settings: StreamRefSettings,
334) -> StreamResult<StreamRefTcpHandle>
335where
336    T: StreamRefPayload,
337{
338    serve_source_ref_over_tcp_stream_with_diagnostics(
339        stream,
340        source_ref,
341        stream_ref_id,
342        settings,
343        None,
344    )
345}
346
347pub fn serve_source_ref_over_tcp_stream_with_diagnostics<T>(
348    stream: TcpStream,
349    source_ref: SourceRef<T>,
350    stream_ref_id: StreamRefId,
351    settings: StreamRefSettings,
352    diagnostics: Option<StreamRefProtocolDiagnostics>,
353) -> StreamResult<StreamRefTcpHandle>
354where
355    T: StreamRefPayload,
356{
357    let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
358    let handle = current_tokio_handle()?;
359    Ok(drive_stream_ref_endpoint_over_tcp_stream(
360        stream,
361        handle,
362        producer,
363        diagnostics,
364    ))
365}
366
367/// Creates a local source fed by a remote plaintext TCP StreamRef producer.
368///
369/// This receiver side opens the TCP connection and, once the returned source is
370/// materialized, sends the initial handshake and cumulative demand.
371pub fn source_ref_over_tcp<T, A>(
372    addr: A,
373    stream_ref_id: StreamRefId,
374    settings: StreamRefSettings,
375) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
376where
377    T: StreamRefPayload,
378    A: ToSocketAddrs + Send + 'static,
379{
380    source_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
381}
382
383pub fn source_ref_over_tcp_with_diagnostics<T, A>(
384    addr: A,
385    stream_ref_id: StreamRefId,
386    settings: StreamRefSettings,
387    diagnostics: Option<StreamRefProtocolDiagnostics>,
388) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
389where
390    T: StreamRefPayload,
391    A: ToSocketAddrs + Send + 'static,
392{
393    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
394    let source = consumer.source();
395    let (stream, handle) = connect_tcp_stream(addr)?;
396    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
397    Ok((source, handle))
398}
399
400/// Creates a local source fed by a remote producer over an already-connected
401/// Tokio TCP stream.
402pub fn source_ref_over_tcp_stream<T>(
403    stream: TcpStream,
404    stream_ref_id: StreamRefId,
405    settings: StreamRefSettings,
406) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
407where
408    T: StreamRefPayload,
409{
410    source_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
411}
412
413pub fn source_ref_over_tcp_stream_with_diagnostics<T>(
414    stream: TcpStream,
415    stream_ref_id: StreamRefId,
416    settings: StreamRefSettings,
417    diagnostics: Option<StreamRefProtocolDiagnostics>,
418) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
419where
420    T: StreamRefPayload,
421{
422    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
423    let source = consumer.source();
424    let handle = current_tokio_handle()?;
425    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
426    Ok((source, handle))
427}
428
429/// Serves a remote `SinkRef` receiver over plaintext TCP.
430///
431/// This receiver side opens the TCP connection to a sender created with
432/// [`sink_ref_over_tcp`] and, once the returned source is materialized, sends
433/// the initial handshake and cumulative demand.
434pub fn serve_sink_ref_over_tcp<T, A>(
435    addr: A,
436    stream_ref_id: StreamRefId,
437    settings: StreamRefSettings,
438) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
439where
440    T: StreamRefPayload,
441    A: ToSocketAddrs + Send + 'static,
442{
443    serve_sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
444}
445
446pub fn serve_sink_ref_over_tcp_with_diagnostics<T, A>(
447    addr: A,
448    stream_ref_id: StreamRefId,
449    settings: StreamRefSettings,
450    diagnostics: Option<StreamRefProtocolDiagnostics>,
451) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
452where
453    T: StreamRefPayload,
454    A: ToSocketAddrs + Send + 'static,
455{
456    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
457    let source = consumer.source();
458    let (stream, handle) = connect_tcp_stream(addr)?;
459    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
460    Ok((source, handle))
461}
462
463/// Serves a remote `SinkRef` receiver over an already-connected Tokio TCP
464/// stream.
465pub fn serve_sink_ref_over_tcp_stream<T>(
466    stream: TcpStream,
467    stream_ref_id: StreamRefId,
468    settings: StreamRefSettings,
469) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
470where
471    T: StreamRefPayload,
472{
473    serve_sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
474}
475
476pub fn serve_sink_ref_over_tcp_stream_with_diagnostics<T>(
477    stream: TcpStream,
478    stream_ref_id: StreamRefId,
479    settings: StreamRefSettings,
480    diagnostics: Option<StreamRefProtocolDiagnostics>,
481) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
482where
483    T: StreamRefPayload,
484{
485    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
486    let source = consumer.source();
487    let handle = current_tokio_handle()?;
488    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
489    Ok((source, handle))
490}
491
492/// Creates a local `Sink` whose incoming elements are sent over a one-shot
493/// plaintext TCP listener to a remote `SinkRef` receiver.
494///
495/// The producer/listener side waits for the receiver to open the TCP
496/// connection. This mirrors the QUIC SinkRef direction where the receiver
497/// opens the bidi stream so its handshake+demand establish the stream before
498/// the producer has elements to send.
499pub fn sink_ref_over_tcp<T, A>(
500    addr: A,
501    stream_ref_id: StreamRefId,
502    settings: StreamRefSettings,
503) -> StreamResult<(
504    Sink<T, StreamCompletion<NotUsed>>,
505    StreamRefTcpBinding,
506    StreamRefTcpHandle,
507)>
508where
509    T: StreamRefPayload,
510    A: ToSocketAddrs + Send + 'static,
511{
512    sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
513}
514
515pub fn sink_ref_over_tcp_with_diagnostics<T, A>(
516    addr: A,
517    stream_ref_id: StreamRefId,
518    settings: StreamRefSettings,
519    diagnostics: Option<StreamRefProtocolDiagnostics>,
520) -> StreamResult<(
521    Sink<T, StreamCompletion<NotUsed>>,
522    StreamRefTcpBinding,
523    StreamRefTcpHandle,
524)>
525where
526    T: StreamRefPayload,
527    A: ToSocketAddrs + Send + 'static,
528{
529    let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
530    let sink = producer.sink();
531    let (listener, binding, handle) = bind_tcp_listener(addr)?;
532    let handle =
533        drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics);
534    Ok((sink, binding, handle))
535}
536
537/// Creates a local `Sink` whose incoming elements are sent over an
538/// already-connected Tokio TCP stream to a remote `SinkRef` receiver.
539pub fn sink_ref_over_tcp_stream<T>(
540    stream: TcpStream,
541    stream_ref_id: StreamRefId,
542    settings: StreamRefSettings,
543) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
544where
545    T: StreamRefPayload,
546{
547    sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
548}
549
550pub fn sink_ref_over_tcp_stream_with_diagnostics<T>(
551    stream: TcpStream,
552    stream_ref_id: StreamRefId,
553    settings: StreamRefSettings,
554    diagnostics: Option<StreamRefProtocolDiagnostics>,
555) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
556where
557    T: StreamRefPayload,
558{
559    let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
560    let sink = producer.sink();
561    let handle = current_tokio_handle()?;
562    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, producer, diagnostics);
563    Ok((sink, handle))
564}
565
566fn drive_stream_ref_endpoint_over_quic<E>(
567    stream: QuicBidirectionalStream,
568    endpoint: E,
569    diagnostics: Option<StreamRefProtocolDiagnostics>,
570) -> StreamRefQuicHandle
571where
572    E: StreamRefProtoEndpoint,
573{
574    let (reader, writer, handle, chunk_size, emit_available) = stream.into_stream_ref_parts();
575    let read_mode = CarrierReadMode::new(chunk_size, emit_available, false);
576    StreamRefQuicHandle {
577        receiver: spawn_carrier_thread(move || {
578            run_stream_ref_endpoint_io(reader, writer, endpoint, handle, read_mode, diagnostics)
579        }),
580    }
581}
582
583fn drive_stream_ref_endpoint_over_tcp_listener<E>(
584    listener: TcpListener,
585    handle: Handle,
586    endpoint: E,
587    diagnostics: Option<StreamRefProtocolDiagnostics>,
588) -> StreamRefTcpHandle
589where
590    E: StreamRefProtoEndpointWake,
591{
592    StreamRefTcpHandle {
593        receiver: spawn_tcp_endpoint_task(&handle, async move {
594            let (stream, _) = listener.accept().await.map_err(io_error)?;
595            run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
596        }),
597    }
598}
599
600fn drive_stream_ref_endpoint_over_tcp_stream<E>(
601    stream: TcpStream,
602    handle: Handle,
603    endpoint: E,
604    diagnostics: Option<StreamRefProtocolDiagnostics>,
605) -> StreamRefTcpHandle
606where
607    E: StreamRefProtoEndpointWake,
608{
609    StreamRefTcpHandle {
610        receiver: spawn_tcp_endpoint_task(&handle, async move {
611            run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
612        }),
613    }
614}
615
616fn run_stream_ref_endpoint_io<R, W, E>(
617    reader: R,
618    writer: W,
619    endpoint: E,
620    handle: Handle,
621    read_mode: CarrierReadMode,
622    diagnostics: Option<StreamRefProtocolDiagnostics>,
623) -> StreamResult<NotUsed>
624where
625    R: AsyncRead + Unpin + Send + 'static,
626    W: AsyncWrite + Unpin + Send + 'static,
627    E: StreamRefProtoEndpoint,
628{
629    // Quinn may report a read-side connection loss while a peer tears down the
630    // connection after the StreamRefs protocol has already completed.
631    let outbound_completed = Arc::new(AtomicBool::new(false));
632    let outbound_endpoint = endpoint.clone();
633    let outbound_handle = handle.clone();
634    let outbound_completed_for_writer = Arc::clone(&outbound_completed);
635    let outbound_thread = thread::spawn(move || {
636        let result = run_outbound_frames(
637            writer,
638            outbound_endpoint.clone(),
639            outbound_handle,
640            diagnostics,
641        );
642        match &result {
643            Ok(_) => outbound_completed_for_writer.store(true, Ordering::Release),
644            Err(error) => outbound_endpoint.fail_connection(error.clone()),
645        }
646        result
647    });
648
649    let inbound_endpoint = endpoint.clone();
650    let outbound_completed_for_reader = Arc::clone(&outbound_completed);
651    let inbound_thread = thread::spawn(move || {
652        let result = run_inbound_frames(reader, inbound_endpoint.clone(), handle, read_mode);
653        if let Err(error) = &result {
654            if outbound_completed_for_reader.load(Ordering::Acquire) && is_quic_teardown_loss(error)
655            {
656                return Ok(NotUsed);
657            }
658            inbound_endpoint.fail_connection(error.clone());
659        }
660        result
661    });
662
663    let outbound = join_carrier_thread(outbound_thread);
664    let inbound = join_carrier_thread(inbound_thread);
665    match (outbound, inbound) {
666        (Err(error), _) => Err(error),
667        (_, Err(error)) => Err(error),
668        (Ok(()), Ok(())) => Ok(NotUsed),
669    }
670}
671
672async fn run_stream_ref_endpoint_tcp_task<E>(
673    stream: TcpStream,
674    endpoint: E,
675    diagnostics: Option<StreamRefProtocolDiagnostics>,
676) -> StreamResult<NotUsed>
677where
678    E: StreamRefProtoEndpointWake,
679{
680    let (wake_sender, wake_receiver) = tokio_mpsc::channel(1);
681    endpoint.install_outbound_wake(wake_sender.clone());
682    let _ = wake_sender.try_send(());
683
684    let result = TcpEndpointTask {
685        stream,
686        endpoint: endpoint.clone(),
687        diagnostics,
688        read_mode: CarrierReadMode::new(STREAM_REF_TCP_CHUNK_SIZE, true, true),
689        decoder: FrameDecoder::default(),
690        read_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
691        pending_tail: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
692        write_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
693        encode_buffer: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
694        pending_diagnostics: VecDeque::new(),
695        outbound_closed: false,
696        write_shutdown: false,
697        wake_receiver,
698    }
699    .run()
700    .await;
701
702    endpoint.clear_outbound_wake();
703    if let Err(error) = &result {
704        endpoint.fail_connection(error.clone());
705    }
706    result
707}
708
709struct TcpEndpointTask<E>
710where
711    E: StreamRefProtoEndpointWake,
712{
713    stream: TcpStream,
714    endpoint: E,
715    diagnostics: Option<StreamRefProtocolDiagnostics>,
716    read_mode: CarrierReadMode,
717    decoder: FrameDecoder,
718    read_buffer: BytesMut,
719    pending_tail: Vec<u8>,
720    write_buffer: BytesMut,
721    encode_buffer: Vec<u8>,
722    pending_diagnostics: VecDeque<PendingDiagnostic>,
723    outbound_closed: bool,
724    write_shutdown: bool,
725    wake_receiver: tokio_mpsc::Receiver<()>,
726}
727
728impl<E> TcpEndpointTask<E>
729where
730    E: StreamRefProtoEndpointWake,
731{
732    async fn run(mut self) -> StreamResult<NotUsed> {
733        self.stream.set_nodelay(true).map_err(io_error)?;
734        loop {
735            self.drain_outbound()?;
736            if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) {
737                self.flush_write_buffer().await?;
738            }
739
740            tokio::select! {
741                biased;
742                wake = self.wake_receiver.recv() => {
743                    if wake.is_none() && !self.outbound_closed {
744                        self.drain_outbound()?;
745                    }
746                }
747                ready = self.stream.readable() => {
748                    ready.map_err(io_error)?;
749                    if self.read_available()? {
750                        return Ok(NotUsed);
751                    }
752                }
753                ready = self.stream.writable(), if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) => {
754                    ready.map_err(io_error)?;
755                    self.flush_ready_write_buffer()?;
756                }
757            }
758        }
759    }
760
761    fn drain_outbound(&mut self) -> StreamResult<()> {
762        while !self.outbound_closed && self.write_buffer.len() < MAX_STREAM_REF_FRAME_BYTES {
763            match self
764                .endpoint
765                .try_next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
766            {
767                StreamRefOutboundPoll::Ready(Ok(outbound)) => {
768                    encode_carrier_outbound_into(&outbound, &mut self.encode_buffer)?;
769                    let encoded_len = self.encode_buffer.len();
770                    if encoded_len == 0 {
771                        continue;
772                    }
773                    if self.diagnostics.is_some() {
774                        self.pending_diagnostics.push_back(PendingDiagnostic {
775                            remaining: encoded_len,
776                            counts: outbound_counts(&outbound),
777                        });
778                    }
779                    self.write_buffer.extend_from_slice(&self.encode_buffer);
780                }
781                StreamRefOutboundPoll::Ready(Err(error)) => return Err(error),
782                StreamRefOutboundPoll::Pending => break,
783                StreamRefOutboundPoll::Closed => {
784                    self.outbound_closed = true;
785                    break;
786                }
787            }
788        }
789        Ok(())
790    }
791
792    async fn flush_write_buffer(&mut self) -> StreamResult<()> {
793        if !self.write_buffer.is_empty() {
794            self.stream.writable().await.map_err(io_error)?;
795            self.flush_ready_write_buffer()?;
796        }
797        if self.outbound_closed && self.write_buffer.is_empty() && !self.write_shutdown {
798            self.stream.shutdown().await.map_err(io_error)?;
799            self.write_shutdown = true;
800        }
801        Ok(())
802    }
803
804    fn flush_ready_write_buffer(&mut self) -> StreamResult<()> {
805        while !self.write_buffer.is_empty() {
806            match self.stream.try_write(&self.write_buffer) {
807                Ok(0) => {
808                    return Err(StreamError::Failed(
809                        "StreamRefs TCP socket accepted zero write bytes".to_owned(),
810                    ));
811                }
812                Ok(written) => {
813                    self.write_buffer.advance(written);
814                    self.record_written_bytes(written);
815                }
816                Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
817                Err(error) => return Err(io_error(error)),
818            }
819        }
820
821        Ok(())
822    }
823
824    fn record_written_bytes(&mut self, mut written: usize) {
825        let Some(diagnostics) = &self.diagnostics else {
826            return;
827        };
828        while written > 0 {
829            let Some(front) = self.pending_diagnostics.front_mut() else {
830                return;
831            };
832            if written < front.remaining {
833                front.remaining -= written;
834                return;
835            }
836            written -= front.remaining;
837            let counts = front.counts;
838            self.pending_diagnostics.pop_front();
839            diagnostics.record_counts(counts);
840        }
841    }
842
843    fn read_available(&mut self) -> StreamResult<bool> {
844        loop {
845            self.read_buffer.reserve(self.read_mode.chunk_size);
846            match self.stream.try_read_buf(&mut self.read_buffer) {
847                Ok(0) => return self.handle_eof(),
848                Ok(_) => {
849                    feed_read_bytes(
850                        &mut self.decoder,
851                        &self.endpoint,
852                        self.read_mode,
853                        &mut self.pending_tail,
854                        &self.read_buffer,
855                    )?;
856                    self.read_buffer.clear();
857                    self.drain_outbound()?;
858                }
859                Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(false),
860                Err(error) => return Err(io_error(error)),
861            }
862        }
863    }
864
865    fn handle_eof(&mut self) -> StreamResult<bool> {
866        if !self.pending_tail.is_empty() {
867            feed_inbound_chunk(&mut self.decoder, &self.endpoint, &self.pending_tail)?;
868            self.pending_tail.clear();
869        }
870        if self.read_mode.fail_on_eof {
871            self.endpoint
872                .fail_connection(StreamError::AbruptTermination);
873        }
874        Ok(true)
875    }
876}
877
878fn run_outbound_frames<W, E>(
879    mut writer: W,
880    endpoint: E,
881    handle: Handle,
882    diagnostics: Option<StreamRefProtocolDiagnostics>,
883) -> StreamResult<NotUsed>
884where
885    W: AsyncWrite + Unpin + Send + 'static,
886    E: StreamRefProtoEndpoint,
887{
888    let mut bytes = Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE);
889    loop {
890        let Some(outbound) =
891            endpoint.next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
892        else {
893            handle.block_on(async {
894                writer.flush().await.map_err(io_error)?;
895                writer.shutdown().await.map_err(io_error)
896            })?;
897            return Ok(NotUsed);
898        };
899        let outbound = outbound?;
900        encode_carrier_outbound_into(&outbound, &mut bytes)?;
901        handle.block_on(async {
902            writer.write_all(&bytes).await.map_err(io_error)?;
903            writer.flush().await.map_err(io_error)
904        })?;
905        if let Some(diagnostics) = &diagnostics {
906            diagnostics.record_written_outbound(&outbound);
907        }
908    }
909}
910
911fn run_inbound_frames<R, E>(
912    mut reader: R,
913    endpoint: E,
914    handle: Handle,
915    read_mode: CarrierReadMode,
916) -> StreamResult<NotUsed>
917where
918    R: AsyncRead + Unpin + Send + 'static,
919    E: StreamRefProtoEndpoint,
920{
921    handle.block_on(async move {
922        let mut buffer = vec![0_u8; read_mode.chunk_size];
923        let mut pending_tail = Vec::with_capacity(read_mode.chunk_size);
924        let mut decoder = FrameDecoder::default();
925
926        loop {
927            let read = reader.read(&mut buffer).await.map_err(io_error)?;
928            if read == 0 {
929                if !pending_tail.is_empty() {
930                    feed_inbound_chunk(&mut decoder, &endpoint, &pending_tail)?;
931                    pending_tail.clear();
932                }
933                if read_mode.fail_on_eof {
934                    endpoint.fail_connection(StreamError::AbruptTermination);
935                }
936                return Ok(NotUsed);
937            }
938            feed_read_bytes(
939                &mut decoder,
940                &endpoint,
941                read_mode,
942                &mut pending_tail,
943                &buffer[..read],
944            )?;
945        }
946    })
947}
948
949fn feed_read_bytes<E>(
950    decoder: &mut FrameDecoder,
951    endpoint: &E,
952    read_mode: CarrierReadMode,
953    pending_tail: &mut Vec<u8>,
954    read_buffer: &[u8],
955) -> StreamResult<()>
956where
957    E: StreamRefProtoEndpoint,
958{
959    if read_mode.emit_available {
960        if !pending_tail.is_empty() {
961            pending_tail.extend_from_slice(read_buffer);
962            feed_inbound_chunk(decoder, endpoint, pending_tail)?;
963            pending_tail.clear();
964            return Ok(());
965        }
966        return feed_inbound_chunk(decoder, endpoint, read_buffer);
967    }
968
969    let mut offset = 0;
970    if !pending_tail.is_empty() {
971        let needed = read_mode.chunk_size - pending_tail.len();
972        let take = needed.min(read_buffer.len());
973        pending_tail.extend_from_slice(&read_buffer[..take]);
974        offset += take;
975        if pending_tail.len() == read_mode.chunk_size {
976            feed_inbound_chunk(decoder, endpoint, pending_tail)?;
977            pending_tail.clear();
978        }
979    }
980
981    while offset + read_mode.chunk_size <= read_buffer.len() {
982        let next = offset + read_mode.chunk_size;
983        feed_inbound_chunk(decoder, endpoint, &read_buffer[offset..next])?;
984        offset = next;
985    }
986
987    if offset < read_buffer.len() {
988        pending_tail.extend_from_slice(&read_buffer[offset..]);
989    }
990    Ok(())
991}
992
993fn feed_inbound_chunk<E>(decoder: &mut FrameDecoder, endpoint: &E, chunk: &[u8]) -> StreamResult<()>
994where
995    E: StreamRefProtoEndpoint,
996{
997    decoder.push_chunk(chunk, endpoint)
998}
999
1000fn bind_tcp_listener<A>(addr: A) -> StreamResult<(TcpListener, StreamRefTcpBinding, Handle)>
1001where
1002    A: ToSocketAddrs + Send + 'static,
1003{
1004    let runtime = stream_ref_tcp_runtime()?;
1005    let listener = runtime
1006        .block_on(async { TcpListener::bind(addr).await })
1007        .map_err(io_error)?;
1008    let local_addr = listener.local_addr().map_err(io_error)?;
1009    Ok((
1010        listener,
1011        StreamRefTcpBinding { local_addr },
1012        runtime.handle().clone(),
1013    ))
1014}
1015
1016fn connect_tcp_stream<A>(addr: A) -> StreamResult<(TcpStream, Handle)>
1017where
1018    A: ToSocketAddrs + Send + 'static,
1019{
1020    let runtime = stream_ref_tcp_runtime()?;
1021    let stream = runtime
1022        .block_on(async { TcpStream::connect(addr).await })
1023        .map_err(io_error)?;
1024    stream.set_nodelay(true).map_err(io_error)?;
1025    Ok((stream, runtime.handle().clone()))
1026}
1027
1028fn stream_ref_tcp_runtime() -> StreamResult<&'static Runtime> {
1029    static RUNTIME: OnceLock<Result<Runtime, String>> = OnceLock::new();
1030    match RUNTIME.get_or_init(|| {
1031        tokio::runtime::Builder::new_multi_thread()
1032            .thread_name("datum-streamref-tcp")
1033            .enable_all()
1034            .build()
1035            .map_err(|error| error.to_string())
1036    }) {
1037        Ok(runtime) => Ok(runtime),
1038        Err(error) => Err(StreamError::Failed(format!(
1039            "failed to start StreamRefs TCP runtime: {error}"
1040        ))),
1041    }
1042}
1043
1044fn current_tokio_handle() -> StreamResult<Handle> {
1045    Handle::try_current().map_err(|error| {
1046        StreamError::Failed(format!(
1047            "StreamRefs TCP stream helper requires a current Tokio runtime: {error}"
1048        ))
1049    })
1050}
1051
1052fn io_error(error: std::io::Error) -> StreamError {
1053    StreamError::Failed(error.to_string())
1054}
1055
1056fn is_quic_teardown_loss(error: &StreamError) -> bool {
1057    matches!(error, StreamError::Failed(message) if message == "connection lost")
1058}
1059
1060fn spawn_carrier_thread<F>(run: F) -> mpsc::Receiver<StreamResult<NotUsed>>
1061where
1062    F: FnOnce() -> StreamResult<NotUsed> + Send + 'static,
1063{
1064    let (sender, receiver) = mpsc::channel();
1065    thread::spawn(move || {
1066        let result = run();
1067        let _ = sender.send(result);
1068    });
1069    receiver
1070}
1071
1072fn spawn_tcp_endpoint_task<F>(handle: &Handle, run: F) -> mpsc::Receiver<StreamResult<NotUsed>>
1073where
1074    F: Future<Output = StreamResult<NotUsed>> + Send + 'static,
1075{
1076    let (sender, receiver) = mpsc::channel();
1077    handle.spawn(async move {
1078        let result = run.await;
1079        let _ = sender.send(result);
1080    });
1081    receiver
1082}
1083
1084fn encode_carrier_outbound_into(
1085    outbound: &StreamRefOutbound,
1086    bytes: &mut Vec<u8>,
1087) -> StreamResult<()> {
1088    bytes.clear();
1089    match outbound {
1090        StreamRefOutbound::Frame(frame) => append_protobuf_carrier_frame(frame, bytes)?,
1091        StreamRefOutbound::SequencedBatch(batch) => {
1092            append_compact_payload_batch(batch, bytes)?;
1093        }
1094    }
1095    Ok(())
1096}
1097
1098#[cfg(test)]
1099fn encode_carrier_frames(frames: &[StreamRefFrame]) -> StreamResult<Vec<u8>> {
1100    let mut bytes = Vec::new();
1101    let mut index = 0;
1102    while index < frames.len() {
1103        if sequenced_on_next(&frames[index]).is_some() {
1104            let end = sequenced_run_end(frames, index);
1105            append_compact_sequenced_batches(&frames[index..end], &mut bytes)?;
1106            index = end;
1107        } else {
1108            append_protobuf_carrier_frame(&frames[index], &mut bytes)?;
1109            index += 1;
1110        }
1111    }
1112    Ok(bytes)
1113}
1114
1115fn append_compact_payload_batch(
1116    batch: &StreamRefPayloadBatch,
1117    bytes: &mut Vec<u8>,
1118) -> StreamResult<()> {
1119    let mut start = 0;
1120    while start < batch.count() {
1121        let mut end = start;
1122        let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
1123        while end < batch.count() {
1124            let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
1125                .checked_add(batch.payload_len(end))
1126                .ok_or(StreamError::LimitExceeded {
1127                    max: MAX_STREAM_REF_FRAME_BYTES as u64,
1128                })?;
1129            let next_payload_len =
1130                payload_len
1131                    .checked_add(element_len)
1132                    .ok_or(StreamError::LimitExceeded {
1133                        max: MAX_STREAM_REF_FRAME_BYTES as u64,
1134                    })?;
1135            if end > start
1136                && (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
1137                    || end - start >= u16::MAX as usize)
1138            {
1139                break;
1140            }
1141            if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
1142                return Err(StreamError::LimitExceeded {
1143                    max: MAX_STREAM_REF_FRAME_BYTES as u64,
1144                });
1145            }
1146            payload_len = next_payload_len;
1147            end += 1;
1148        }
1149        append_compact_payload_batch_slice(batch, start, end, payload_len, bytes)?;
1150        start = end;
1151    }
1152    Ok(())
1153}
1154
1155fn append_compact_payload_batch_slice(
1156    batch: &StreamRefPayloadBatch,
1157    start: usize,
1158    end: usize,
1159    payload_len: usize,
1160    bytes: &mut Vec<u8>,
1161) -> StreamResult<()> {
1162    let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
1163        max: MAX_STREAM_REF_FRAME_BYTES as u64,
1164    })?;
1165    let count = u16::try_from(end - start).map_err(|_| StreamError::LimitExceeded {
1166        max: u16::MAX as u64,
1167    })?;
1168    let first_seq = batch
1169        .first_seq_nr()
1170        .checked_add(start as u64)
1171        .ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
1172    bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
1173    bytes.push(COMPACT_FRAME_VERSION);
1174    bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
1175    bytes.extend(batch.stream_ref_id().to_bytes());
1176    bytes.extend(first_seq.to_be_bytes());
1177    bytes.extend(count.to_be_bytes());
1178    for index in start..end {
1179        let payload = batch.payload(index);
1180        let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1181            max: u32::MAX as u64,
1182        })?;
1183        bytes.extend(payload_len.to_be_bytes());
1184        bytes.extend(payload);
1185    }
1186    Ok(())
1187}
1188
1189fn append_protobuf_carrier_frame(frame: &StreamRefFrame, bytes: &mut Vec<u8>) -> StreamResult<()> {
1190    let payload = frame.encode_to_vec();
1191    let len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1192        max: COMPACT_FRAME_LEN_MASK as u64,
1193    })?;
1194    if payload.len() > MAX_STREAM_REF_FRAME_BYTES || len > COMPACT_FRAME_LEN_MASK {
1195        return Err(StreamError::LimitExceeded {
1196            max: MAX_STREAM_REF_FRAME_BYTES as u64,
1197        });
1198    }
1199    bytes.extend(len.to_be_bytes());
1200    bytes.extend(payload);
1201    Ok(())
1202}
1203
1204#[cfg(test)]
1205fn append_compact_sequenced_batches(
1206    frames: &[StreamRefFrame],
1207    bytes: &mut Vec<u8>,
1208) -> StreamResult<()> {
1209    let mut start = 0;
1210    while start < frames.len() {
1211        let mut end = start;
1212        let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
1213        while end < frames.len() {
1214            let (_, payload) = sequenced_on_next(&frames[end]).expect("sequenced frame");
1215            let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
1216                .checked_add(payload.len())
1217                .ok_or(StreamError::LimitExceeded {
1218                    max: MAX_STREAM_REF_FRAME_BYTES as u64,
1219                })?;
1220            let next_payload_len =
1221                payload_len
1222                    .checked_add(element_len)
1223                    .ok_or(StreamError::LimitExceeded {
1224                        max: MAX_STREAM_REF_FRAME_BYTES as u64,
1225                    })?;
1226            if end > start
1227                && (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
1228                    || end - start >= u16::MAX as usize)
1229            {
1230                break;
1231            }
1232            if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
1233                return Err(StreamError::LimitExceeded {
1234                    max: MAX_STREAM_REF_FRAME_BYTES as u64,
1235                });
1236            }
1237            payload_len = next_payload_len;
1238            end += 1;
1239        }
1240        append_compact_sequenced_batch(&frames[start..end], payload_len, bytes)?;
1241        start = end;
1242    }
1243    Ok(())
1244}
1245
1246#[cfg(test)]
1247fn append_compact_sequenced_batch(
1248    frames: &[StreamRefFrame],
1249    payload_len: usize,
1250    bytes: &mut Vec<u8>,
1251) -> StreamResult<()> {
1252    let (first_seq, _) = sequenced_on_next(&frames[0]).expect("sequenced frame");
1253    let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
1254        max: MAX_STREAM_REF_FRAME_BYTES as u64,
1255    })?;
1256    let count = u16::try_from(frames.len()).map_err(|_| StreamError::LimitExceeded {
1257        max: u16::MAX as u64,
1258    })?;
1259    bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
1260    bytes.push(COMPACT_FRAME_VERSION);
1261    bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
1262    bytes.extend(frames[0].stream_ref_id.to_bytes());
1263    bytes.extend(first_seq.to_be_bytes());
1264    bytes.extend(count.to_be_bytes());
1265    for frame in frames {
1266        let (_, payload) = sequenced_on_next(frame).expect("sequenced frame");
1267        let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1268            max: u32::MAX as u64,
1269        })?;
1270        bytes.extend(payload_len.to_be_bytes());
1271        bytes.extend(payload);
1272    }
1273    Ok(())
1274}
1275
1276#[cfg(test)]
1277fn sequenced_run_end(frames: &[StreamRefFrame], start: usize) -> usize {
1278    let mut end = start + 1;
1279    while end < frames.len() {
1280        let Some((previous_seq, _)) = sequenced_on_next(&frames[end - 1]) else {
1281            break;
1282        };
1283        let Some((next_seq, _)) = sequenced_on_next(&frames[end]) else {
1284            break;
1285        };
1286        if frames[end].stream_ref_id != frames[start].stream_ref_id
1287            || next_seq != previous_seq.saturating_add(1)
1288        {
1289            break;
1290        }
1291        end += 1;
1292    }
1293    end
1294}
1295
1296#[cfg(test)]
1297fn sequenced_on_next(frame: &StreamRefFrame) -> Option<(u64, &[u8])> {
1298    match &frame.message {
1299        StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
1300            Some((*seq_nr, payload.bytes.as_slice()))
1301        }
1302        _ => None,
1303    }
1304}
1305
1306#[derive(Default)]
1307struct FrameDecoder {
1308    buffer: BytesMut,
1309    offset: usize,
1310}
1311
1312impl FrameDecoder {
1313    fn push_chunk<E>(&mut self, chunk: &[u8], endpoint: &E) -> StreamResult<()>
1314    where
1315        E: StreamRefProtoEndpoint,
1316    {
1317        self.buffer.extend_from_slice(chunk);
1318        while let Some(header) = self.peek_header()? {
1319            if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES + header.len {
1320                break;
1321            }
1322            let payload_start = self.offset + FRAME_LEN_BYTES;
1323            let payload_end = payload_start + header.len;
1324            let payload = &self.buffer[payload_start..payload_end];
1325            match header.kind {
1326                CarrierFrameKind::Protobuf => {
1327                    endpoint.handle_frame(StreamRefFrame::decode(payload)?)?;
1328                }
1329                CarrierFrameKind::Compact => {
1330                    decode_compact_carrier_frame(payload, endpoint)?;
1331                }
1332            }
1333            self.offset = payload_end;
1334        }
1335        if self.offset > 0 && (self.offset == self.buffer.len() || self.offset >= 64 * 1024) {
1336            self.buffer.advance(self.offset);
1337            self.offset = 0;
1338        }
1339        Ok(())
1340    }
1341
1342    fn peek_header(&self) -> StreamResult<Option<CarrierFrameHeader>> {
1343        if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES {
1344            return Ok(None);
1345        }
1346        let len = self.buffer[self.offset..self.offset + FRAME_LEN_BYTES]
1347            .try_into()
1348            .expect("frame header length");
1349        let raw_len = u32::from_be_bytes(len);
1350        let kind = if raw_len & COMPACT_FRAME_FLAG == 0 {
1351            CarrierFrameKind::Protobuf
1352        } else {
1353            CarrierFrameKind::Compact
1354        };
1355        let len = (raw_len & COMPACT_FRAME_LEN_MASK) as usize;
1356        if len > MAX_STREAM_REF_FRAME_BYTES {
1357            return Err(StreamError::LimitExceeded {
1358                max: MAX_STREAM_REF_FRAME_BYTES as u64,
1359            });
1360        }
1361        Ok(Some(CarrierFrameHeader { kind, len }))
1362    }
1363}
1364
1365#[derive(Clone, Copy)]
1366struct CarrierFrameHeader {
1367    kind: CarrierFrameKind,
1368    len: usize,
1369}
1370
1371#[derive(Clone, Copy)]
1372enum CarrierFrameKind {
1373    Protobuf,
1374    Compact,
1375}
1376
1377fn decode_compact_carrier_frame<E>(payload: &[u8], endpoint: &E) -> StreamResult<()>
1378where
1379    E: StreamRefProtoEndpoint,
1380{
1381    if payload.len() < COMPACT_BATCH_HEADER_BYTES {
1382        return Err(StreamError::Failed(
1383            "compact StreamRefs carrier frame too short".to_owned(),
1384        ));
1385    }
1386    let version = payload[0];
1387    if version != COMPACT_FRAME_VERSION {
1388        return Err(StreamError::Failed(format!(
1389            "unsupported compact StreamRefs carrier frame version: {version}"
1390        )));
1391    }
1392    let kind = payload[1];
1393    if kind != COMPACT_SEQUENCED_ON_NEXT_BATCH {
1394        return Err(StreamError::Failed(format!(
1395            "unsupported compact StreamRefs carrier frame kind: {kind}"
1396        )));
1397    }
1398    let stream_ref_id = StreamRefId::from_bytes(&payload[2..18])?;
1399    let first_seq = u64::from_be_bytes(payload[18..26].try_into().expect("seq len"));
1400    let count = u16::from_be_bytes(payload[26..28].try_into().expect("count len")) as usize;
1401    if count == 0 {
1402        return Err(StreamError::Failed(
1403            "compact StreamRefs carrier batch is empty".to_owned(),
1404        ));
1405    }
1406
1407    let mut offset = COMPACT_BATCH_HEADER_BYTES;
1408    let mut payloads = Vec::with_capacity(count);
1409    for index in 0..count {
1410        if payload.len().saturating_sub(offset) < COMPACT_BATCH_ELEMENT_LEN_BYTES {
1411            return Err(StreamError::Failed(
1412                "compact StreamRefs carrier batch has truncated payload length".to_owned(),
1413            ));
1414        }
1415        let payload_len = u32::from_be_bytes(
1416            payload[offset..offset + COMPACT_BATCH_ELEMENT_LEN_BYTES]
1417                .try_into()
1418                .expect("payload len"),
1419        ) as usize;
1420        offset += COMPACT_BATCH_ELEMENT_LEN_BYTES;
1421        if payload.len().saturating_sub(offset) < payload_len {
1422            return Err(StreamError::Failed(
1423                "compact StreamRefs carrier batch has truncated payload".to_owned(),
1424            ));
1425        }
1426        first_seq
1427            .checked_add(index as u64)
1428            .ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
1429        payloads.push(&payload[offset..offset + payload_len]);
1430        offset += payload_len;
1431    }
1432    if offset != payload.len() {
1433        return Err(StreamError::Failed(
1434            "compact StreamRefs carrier batch has trailing bytes".to_owned(),
1435        ));
1436    }
1437    endpoint.handle_sequenced_on_next_batch(stream_ref_id, first_seq, &payloads)
1438}
1439
1440fn join_carrier_thread(handle: thread::JoinHandle<StreamResult<NotUsed>>) -> StreamResult<()> {
1441    match handle.join() {
1442        Ok(Ok(NotUsed)) => Ok(()),
1443        Ok(Err(error)) => Err(error),
1444        Err(_) => Err(StreamError::Failed(
1445            "StreamRefs carrier thread panicked".to_owned(),
1446        )),
1447    }
1448}
1449
1450#[cfg(test)]
1451mod tests {
1452    use super::*;
1453    use std::sync::{Arc, Mutex};
1454
1455    #[derive(Clone)]
1456    struct RecordingEndpoint {
1457        stream_ref_id: StreamRefId,
1458        frames: Arc<Mutex<Vec<StreamRefFrame>>>,
1459    }
1460
1461    impl RecordingEndpoint {
1462        fn new(stream_ref_id: StreamRefId) -> Self {
1463            Self {
1464                stream_ref_id,
1465                frames: Arc::new(Mutex::new(Vec::new())),
1466            }
1467        }
1468
1469        fn frames(&self) -> Vec<StreamRefFrame> {
1470            self.frames.lock().expect("recording endpoint").clone()
1471        }
1472    }
1473
1474    impl StreamRefProtoEndpoint for RecordingEndpoint {
1475        fn stream_ref_id(&self) -> StreamRefId {
1476            self.stream_ref_id
1477        }
1478
1479        fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1480            None
1481        }
1482
1483        fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1484            self.frames.lock().expect("recording endpoint").push(frame);
1485            Ok(())
1486        }
1487
1488        fn fail_connection(&self, _error: StreamError) {}
1489    }
1490
1491    #[test]
1492    fn carrier_frame_decoder_reassembles_split_frames() {
1493        let frame = StreamRefFrame::new(
1494            StreamRefId::from_u128(1),
1495            datum::StreamRefMessage::CumulativeDemand { seq_nr: 32 },
1496        );
1497        let bytes = encode_carrier_frames(std::slice::from_ref(&frame)).unwrap();
1498        let split = bytes.len() / 2;
1499        let mut decoder = FrameDecoder::default();
1500        let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(1));
1501
1502        decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
1503        assert!(endpoint.frames().is_empty());
1504        decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
1505        assert_eq!(endpoint.frames(), vec![frame]);
1506    }
1507
1508    #[test]
1509    fn compact_carrier_batch_round_trips_sequenced_frames() {
1510        let frames = (0_u64..3)
1511            .map(|seq_nr| {
1512                StreamRefFrame::new(
1513                    StreamRefId::from_u128(7),
1514                    datum::StreamRefMessage::SequencedOnNext {
1515                        seq_nr,
1516                        payload: datum::StreamRefPayloadBytes {
1517                            bytes: seq_nr.to_be_bytes().to_vec(),
1518                        },
1519                    },
1520                )
1521            })
1522            .collect::<Vec<_>>();
1523        let bytes = encode_carrier_frames(&frames).unwrap();
1524        let header = u32::from_be_bytes(bytes[..FRAME_LEN_BYTES].try_into().unwrap());
1525        assert_ne!(header & COMPACT_FRAME_FLAG, 0);
1526
1527        let mut decoder = FrameDecoder::default();
1528        let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(7));
1529        decoder.push_chunk(&bytes, &endpoint).unwrap();
1530        assert_eq!(endpoint.frames(), frames);
1531    }
1532
1533    #[test]
1534    fn compact_carrier_batch_reassembles_split_frames() {
1535        let frames = (4_u64..8)
1536            .map(|seq_nr| {
1537                StreamRefFrame::new(
1538                    StreamRefId::from_u128(8),
1539                    datum::StreamRefMessage::SequencedOnNext {
1540                        seq_nr,
1541                        payload: datum::StreamRefPayloadBytes {
1542                            bytes: vec![seq_nr as u8],
1543                        },
1544                    },
1545                )
1546            })
1547            .collect::<Vec<_>>();
1548        let bytes = encode_carrier_frames(&frames).unwrap();
1549        let split = FRAME_LEN_BYTES + 5;
1550        let mut decoder = FrameDecoder::default();
1551        let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(8));
1552
1553        decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
1554        assert!(endpoint.frames().is_empty());
1555        decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
1556        assert_eq!(endpoint.frames(), frames);
1557    }
1558}