Skip to main content

datum_net/
stream_ref.rs

1//! Remote StreamRefs carriers.
2//!
3//! `datum-core` owns the protobuf protocol and state machine. This module is
4//! the carrier layer: it length-prefixes protobuf frames and pumps them over
5//! reliable, ordered bidirectional byte streams such as QUIC or plaintext TCP.
6
7use std::{
8    collections::VecDeque,
9    future::Future,
10    net::SocketAddr,
11    sync::{Arc, Mutex, OnceLock, mpsc},
12    time::Duration,
13};
14
15use bytes::{Buf, BytesMut};
16use datum::{
17    NotUsed, Sink, Source, SourceRef, StreamCompletion, StreamError, StreamRefFrame, StreamRefId,
18    StreamRefMessage, StreamRefOutbound, StreamRefPayload, StreamRefPayloadBatch,
19    StreamRefProtoConsumer, StreamRefProtoEndpoint, StreamRefProtoProducer, StreamRefSettings,
20    StreamResult,
21    actor::stream_ref_proto::{StreamRefOutboundPoll, StreamRefProtoEndpointWake},
22};
23use tokio::{
24    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
25    net::{TcpListener, TcpStream, ToSocketAddrs},
26    runtime::{Handle, Runtime},
27    sync::mpsc as tokio_mpsc,
28};
29
30use crate::QuicBidirectionalStream;
31
32const FRAME_LEN_BYTES: usize = 4;
33const MAX_STREAM_REF_FRAME_BYTES: usize = 16 * 1024 * 1024;
34const STREAM_REF_TCP_CHUNK_SIZE: usize = 8192;
35const STREAM_REF_QUIC_READ_BUFFER_BYTES: usize = 2048;
36const STREAM_REF_OUTBOUND_BATCH_FRAMES: usize = 64;
37const STREAM_REF_OUTBOUND_RECHECK_INTERVAL: Duration = Duration::from_millis(5);
38
39// Carrier wire v1 keeps protobuf for control frames. A high bit in the
40// length-prefix marks compact SequencedOnNext batches: version, kind,
41// stream_ref_id, first seqNr, count, then length-prefixed payloads.
42const COMPACT_FRAME_FLAG: u32 = 0x8000_0000;
43const COMPACT_FRAME_LEN_MASK: u32 = 0x7fff_ffff;
44const COMPACT_FRAME_VERSION: u8 = 1;
45const COMPACT_SEQUENCED_ON_NEXT_BATCH: u8 = 1;
46const COMPACT_BATCH_HEADER_BYTES: usize = 1 + 1 + 16 + 8 + 2;
47const COMPACT_BATCH_ELEMENT_LEN_BYTES: usize = 4;
48
49#[derive(Clone, Copy)]
50struct CarrierReadMode {
51    chunk_size: usize,
52    emit_available: bool,
53    fail_on_eof: bool,
54}
55
56impl CarrierReadMode {
57    fn new(chunk_size: usize, emit_available: bool, fail_on_eof: bool) -> Self {
58        assert!(chunk_size > 0, "chunk size must be greater than zero");
59        Self {
60            chunk_size,
61            emit_available,
62            fail_on_eof,
63        }
64    }
65}
66
67/// Counts selected StreamRefs protocol messages successfully written by a
68/// carrier endpoint.
69#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
70pub struct StreamRefProtocolMessageCounts {
71    pub cumulative_demand: u64,
72    pub sequenced_on_next: u64,
73    pub ack: u64,
74}
75
76/// Shared collector for StreamRefs protocol message-count diagnostics.
77///
78/// The carrier records counts only after the encoded frame has been written
79/// successfully. This is intentionally transport-level instrumentation; it does
80/// not change the protobuf protocol or wire format.
81#[derive(Clone, Default)]
82pub struct StreamRefProtocolDiagnostics {
83    counts: Arc<Mutex<StreamRefProtocolMessageCounts>>,
84}
85
86impl StreamRefProtocolDiagnostics {
87    #[must_use]
88    pub fn new() -> Self {
89        Self::default()
90    }
91
92    #[must_use]
93    pub fn snapshot(&self) -> StreamRefProtocolMessageCounts {
94        *self
95            .counts
96            .lock()
97            .expect("stream ref protocol diagnostics poisoned")
98    }
99
100    fn record_counts(&self, delta: StreamRefProtocolMessageCounts) {
101        if delta == StreamRefProtocolMessageCounts::default() {
102            return;
103        }
104        let mut counts = self
105            .counts
106            .lock()
107            .expect("stream ref protocol diagnostics poisoned");
108        counts.cumulative_demand = counts
109            .cumulative_demand
110            .saturating_add(delta.cumulative_demand);
111        counts.sequenced_on_next = counts
112            .sequenced_on_next
113            .saturating_add(delta.sequenced_on_next);
114        counts.ack = counts.ack.saturating_add(delta.ack);
115    }
116}
117
118fn outbound_counts(outbound: &StreamRefOutbound) -> StreamRefProtocolMessageCounts {
119    let mut counts = StreamRefProtocolMessageCounts::default();
120    match outbound {
121        StreamRefOutbound::Frame(frame) => match &frame.message {
122            StreamRefMessage::CumulativeDemand { .. } => {
123                counts.cumulative_demand = 1;
124            }
125            StreamRefMessage::SequencedOnNext { .. } => {
126                counts.sequenced_on_next = 1;
127            }
128            StreamRefMessage::Ack => {
129                counts.ack = 1;
130            }
131            StreamRefMessage::OnSubscribeHandshake
132            | StreamRefMessage::RemoteStreamCompleted { .. }
133            | StreamRefMessage::RemoteStreamFailure { .. } => {}
134        },
135        StreamRefOutbound::SequencedBatch(batch) => {
136            counts.sequenced_on_next = batch.count() as u64;
137        }
138    }
139    counts
140}
141
142#[derive(Clone, Copy)]
143struct PendingDiagnostic {
144    remaining: usize,
145    counts: StreamRefProtocolMessageCounts,
146}
147
148/// Completion handle for a StreamRefs-over-QUIC carrier.
149#[must_use = "wait for the QUIC StreamRefs carrier to observe completion or failure"]
150pub struct StreamRefQuicHandle {
151    receiver: mpsc::Receiver<StreamResult<NotUsed>>,
152}
153
154impl StreamRefQuicHandle {
155    pub fn wait(self) -> StreamResult<NotUsed> {
156        self.receiver
157            .recv()
158            .unwrap_or(Err(StreamError::AbruptTermination))
159    }
160
161    #[must_use]
162    pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
163        self.receiver.try_recv().ok()
164    }
165}
166
167/// Local TCP listener binding used by StreamRefs-over-TCP producer endpoints.
168#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub struct StreamRefTcpBinding {
170    local_addr: SocketAddr,
171}
172
173impl StreamRefTcpBinding {
174    #[must_use]
175    pub fn local_addr(&self) -> SocketAddr {
176        self.local_addr
177    }
178}
179
180/// Completion handle for a StreamRefs-over-TCP carrier.
181#[must_use = "wait for the TCP StreamRefs carrier to observe completion or failure"]
182pub struct StreamRefTcpHandle {
183    receiver: mpsc::Receiver<StreamResult<NotUsed>>,
184}
185
186impl StreamRefTcpHandle {
187    pub fn wait(self) -> StreamResult<NotUsed> {
188        self.receiver
189            .recv()
190            .unwrap_or(Err(StreamError::AbruptTermination))
191    }
192
193    #[must_use]
194    pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
195        self.receiver.try_recv().ok()
196    }
197}
198
199/// Serves a local `SourceRef` over an accepted or opened QUIC bidi stream.
200pub fn serve_source_ref_over_quic<T>(
201    stream: QuicBidirectionalStream,
202    source_ref: SourceRef<T>,
203    stream_ref_id: StreamRefId,
204    settings: StreamRefSettings,
205) -> StreamResult<StreamRefQuicHandle>
206where
207    T: StreamRefPayload,
208{
209    let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
210    Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
211}
212
213/// Serves a local `Source` over an accepted or opened QUIC bidi stream.
214pub fn serve_source_over_quic<T, Mat>(
215    stream: QuicBidirectionalStream,
216    source: Source<T, Mat>,
217    stream_ref_id: StreamRefId,
218    settings: StreamRefSettings,
219) -> StreamResult<StreamRefQuicHandle>
220where
221    T: StreamRefPayload,
222    Mat: Send + 'static,
223{
224    let producer = StreamRefProtoProducer::from_source(source, stream_ref_id, settings)?;
225    Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
226}
227
228/// Creates a local source fed by a remote QUIC StreamRef producer.
229pub fn source_ref_over_quic<T>(
230    stream: QuicBidirectionalStream,
231    stream_ref_id: StreamRefId,
232    settings: StreamRefSettings,
233) -> (Source<T, NotUsed>, StreamRefQuicHandle)
234where
235    T: StreamRefPayload,
236{
237    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
238    let source = consumer.source();
239    let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
240    (source, handle)
241}
242
243/// Serves a local `SinkRef` receiver over an accepted or opened QUIC bidi
244/// stream, returning a [`Source`] of inbound elements.
245///
246/// This is the local/receiver side of the SinkRef-over-QUIC pair: the remote
247/// sender pushes elements into a [`sink_ref_over_quic`](fn.sink_ref_over_quic)
248/// `Sink`, and this side surfaces them as a `Source`. The caller runs the
249/// returned source into a local `Sink` (for example `Sink::collect` or a fold).
250pub fn serve_sink_ref_over_quic<T>(
251    stream: QuicBidirectionalStream,
252    stream_ref_id: StreamRefId,
253    settings: StreamRefSettings,
254) -> (Source<T, NotUsed>, StreamRefQuicHandle)
255where
256    T: StreamRefPayload,
257{
258    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
259    let source = consumer.source();
260    let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
261    (source, handle)
262}
263
264/// Creates a local `Sink` that sends its incoming elements over QUIC to a
265/// remote `SinkRef` receiver.
266pub fn sink_ref_over_quic<T>(
267    stream: QuicBidirectionalStream,
268    stream_ref_id: StreamRefId,
269    settings: StreamRefSettings,
270) -> (Sink<T, StreamCompletion<NotUsed>>, StreamRefQuicHandle)
271where
272    T: StreamRefPayload,
273{
274    let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
275    let sink = producer.sink();
276    let handle = drive_stream_ref_endpoint_over_quic(stream, producer, None);
277    (sink, handle)
278}
279
280/// Serves a local `SourceRef` over a one-shot plaintext TCP listener.
281///
282/// The listener binds immediately, accepts one connection, then runs the same
283/// StreamRefs protocol used by the QUIC carrier. The remote receiver should use
284/// [`source_ref_over_tcp`] to open the TCP connection and send the initial
285/// subscribe+demand frames.
286pub fn serve_source_ref_over_tcp<T, A>(
287    addr: A,
288    source_ref: SourceRef<T>,
289    stream_ref_id: StreamRefId,
290    settings: StreamRefSettings,
291) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
292where
293    T: StreamRefPayload,
294    A: ToSocketAddrs + Send + 'static,
295{
296    serve_source_ref_over_tcp_with_diagnostics(addr, source_ref, stream_ref_id, settings, None)
297}
298
299pub fn serve_source_ref_over_tcp_with_diagnostics<T, A>(
300    addr: A,
301    source_ref: SourceRef<T>,
302    stream_ref_id: StreamRefId,
303    settings: StreamRefSettings,
304    diagnostics: Option<StreamRefProtocolDiagnostics>,
305) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
306where
307    T: StreamRefPayload,
308    A: ToSocketAddrs + Send + 'static,
309{
310    let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
311    let (listener, binding, handle) = bind_tcp_listener(addr)?;
312    Ok((
313        binding,
314        drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics),
315    ))
316}
317
318/// Serves a local `SourceRef` over an already-connected Tokio TCP stream.
319///
320/// This is the stream-shaped counterpart to [`serve_source_ref_over_tcp`],
321/// intended for callers that own listener lifecycle separately, such as a
322/// benchmark server that reuses one bound listener across operations.
323pub fn serve_source_ref_over_tcp_stream<T>(
324    stream: TcpStream,
325    source_ref: SourceRef<T>,
326    stream_ref_id: StreamRefId,
327    settings: StreamRefSettings,
328) -> StreamResult<StreamRefTcpHandle>
329where
330    T: StreamRefPayload,
331{
332    serve_source_ref_over_tcp_stream_with_diagnostics(
333        stream,
334        source_ref,
335        stream_ref_id,
336        settings,
337        None,
338    )
339}
340
341pub fn serve_source_ref_over_tcp_stream_with_diagnostics<T>(
342    stream: TcpStream,
343    source_ref: SourceRef<T>,
344    stream_ref_id: StreamRefId,
345    settings: StreamRefSettings,
346    diagnostics: Option<StreamRefProtocolDiagnostics>,
347) -> StreamResult<StreamRefTcpHandle>
348where
349    T: StreamRefPayload,
350{
351    let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
352    let handle = current_tokio_handle()?;
353    Ok(drive_stream_ref_endpoint_over_tcp_stream(
354        stream,
355        handle,
356        producer,
357        diagnostics,
358    ))
359}
360
361/// Creates a local source fed by a remote plaintext TCP StreamRef producer.
362///
363/// This receiver side opens the TCP connection and, once the returned source is
364/// materialized, sends the initial handshake and cumulative demand.
365pub fn source_ref_over_tcp<T, A>(
366    addr: A,
367    stream_ref_id: StreamRefId,
368    settings: StreamRefSettings,
369) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
370where
371    T: StreamRefPayload,
372    A: ToSocketAddrs + Send + 'static,
373{
374    source_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
375}
376
377pub fn source_ref_over_tcp_with_diagnostics<T, A>(
378    addr: A,
379    stream_ref_id: StreamRefId,
380    settings: StreamRefSettings,
381    diagnostics: Option<StreamRefProtocolDiagnostics>,
382) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
383where
384    T: StreamRefPayload,
385    A: ToSocketAddrs + Send + 'static,
386{
387    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
388    let source = consumer.source();
389    let (stream, handle) = connect_tcp_stream(addr)?;
390    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
391    Ok((source, handle))
392}
393
394/// Creates a local source fed by a remote producer over an already-connected
395/// Tokio TCP stream.
396pub fn source_ref_over_tcp_stream<T>(
397    stream: TcpStream,
398    stream_ref_id: StreamRefId,
399    settings: StreamRefSettings,
400) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
401where
402    T: StreamRefPayload,
403{
404    source_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
405}
406
407pub fn source_ref_over_tcp_stream_with_diagnostics<T>(
408    stream: TcpStream,
409    stream_ref_id: StreamRefId,
410    settings: StreamRefSettings,
411    diagnostics: Option<StreamRefProtocolDiagnostics>,
412) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
413where
414    T: StreamRefPayload,
415{
416    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
417    let source = consumer.source();
418    let handle = current_tokio_handle()?;
419    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
420    Ok((source, handle))
421}
422
423/// Serves a remote `SinkRef` receiver over plaintext TCP.
424///
425/// This receiver side opens the TCP connection to a sender created with
426/// [`sink_ref_over_tcp`] and, once the returned source is materialized, sends
427/// the initial handshake and cumulative demand.
428pub fn serve_sink_ref_over_tcp<T, A>(
429    addr: A,
430    stream_ref_id: StreamRefId,
431    settings: StreamRefSettings,
432) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
433where
434    T: StreamRefPayload,
435    A: ToSocketAddrs + Send + 'static,
436{
437    serve_sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
438}
439
440pub fn serve_sink_ref_over_tcp_with_diagnostics<T, A>(
441    addr: A,
442    stream_ref_id: StreamRefId,
443    settings: StreamRefSettings,
444    diagnostics: Option<StreamRefProtocolDiagnostics>,
445) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
446where
447    T: StreamRefPayload,
448    A: ToSocketAddrs + Send + 'static,
449{
450    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
451    let source = consumer.source();
452    let (stream, handle) = connect_tcp_stream(addr)?;
453    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
454    Ok((source, handle))
455}
456
457/// Serves a remote `SinkRef` receiver over an already-connected Tokio TCP
458/// stream.
459pub fn serve_sink_ref_over_tcp_stream<T>(
460    stream: TcpStream,
461    stream_ref_id: StreamRefId,
462    settings: StreamRefSettings,
463) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
464where
465    T: StreamRefPayload,
466{
467    serve_sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
468}
469
470pub fn serve_sink_ref_over_tcp_stream_with_diagnostics<T>(
471    stream: TcpStream,
472    stream_ref_id: StreamRefId,
473    settings: StreamRefSettings,
474    diagnostics: Option<StreamRefProtocolDiagnostics>,
475) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
476where
477    T: StreamRefPayload,
478{
479    let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
480    let source = consumer.source();
481    let handle = current_tokio_handle()?;
482    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
483    Ok((source, handle))
484}
485
486/// Creates a local `Sink` whose incoming elements are sent over a one-shot
487/// plaintext TCP listener to a remote `SinkRef` receiver.
488///
489/// The producer/listener side waits for the receiver to open the TCP
490/// connection. This mirrors the QUIC SinkRef direction where the receiver
491/// opens the bidi stream so its handshake+demand establish the stream before
492/// the producer has elements to send.
493pub fn sink_ref_over_tcp<T, A>(
494    addr: A,
495    stream_ref_id: StreamRefId,
496    settings: StreamRefSettings,
497) -> StreamResult<(
498    Sink<T, StreamCompletion<NotUsed>>,
499    StreamRefTcpBinding,
500    StreamRefTcpHandle,
501)>
502where
503    T: StreamRefPayload,
504    A: ToSocketAddrs + Send + 'static,
505{
506    sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
507}
508
509pub fn sink_ref_over_tcp_with_diagnostics<T, A>(
510    addr: A,
511    stream_ref_id: StreamRefId,
512    settings: StreamRefSettings,
513    diagnostics: Option<StreamRefProtocolDiagnostics>,
514) -> StreamResult<(
515    Sink<T, StreamCompletion<NotUsed>>,
516    StreamRefTcpBinding,
517    StreamRefTcpHandle,
518)>
519where
520    T: StreamRefPayload,
521    A: ToSocketAddrs + Send + 'static,
522{
523    let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
524    let sink = producer.sink();
525    let (listener, binding, handle) = bind_tcp_listener(addr)?;
526    let handle =
527        drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics);
528    Ok((sink, binding, handle))
529}
530
531/// Creates a local `Sink` whose incoming elements are sent over an
532/// already-connected Tokio TCP stream to a remote `SinkRef` receiver.
533pub fn sink_ref_over_tcp_stream<T>(
534    stream: TcpStream,
535    stream_ref_id: StreamRefId,
536    settings: StreamRefSettings,
537) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
538where
539    T: StreamRefPayload,
540{
541    sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
542}
543
544pub fn sink_ref_over_tcp_stream_with_diagnostics<T>(
545    stream: TcpStream,
546    stream_ref_id: StreamRefId,
547    settings: StreamRefSettings,
548    diagnostics: Option<StreamRefProtocolDiagnostics>,
549) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
550where
551    T: StreamRefPayload,
552{
553    let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
554    let sink = producer.sink();
555    let handle = current_tokio_handle()?;
556    let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, producer, diagnostics);
557    Ok((sink, handle))
558}
559
560fn drive_stream_ref_endpoint_over_quic<E>(
561    stream: QuicBidirectionalStream,
562    endpoint: E,
563    diagnostics: Option<StreamRefProtocolDiagnostics>,
564) -> StreamRefQuicHandle
565where
566    E: StreamRefProtoEndpointWake,
567{
568    let (reader, writer, handle, chunk_size, emit_available) = stream.into_stream_ref_parts();
569    let read_mode = CarrierReadMode::new(chunk_size, emit_available, false);
570    StreamRefQuicHandle {
571        receiver: spawn_endpoint_task(&handle, async move {
572            run_stream_ref_endpoint_quic_task(reader, writer, endpoint, read_mode, diagnostics)
573                .await
574        }),
575    }
576}
577
578fn drive_stream_ref_endpoint_over_tcp_listener<E>(
579    listener: TcpListener,
580    handle: Handle,
581    endpoint: E,
582    diagnostics: Option<StreamRefProtocolDiagnostics>,
583) -> StreamRefTcpHandle
584where
585    E: StreamRefProtoEndpointWake,
586{
587    StreamRefTcpHandle {
588        receiver: spawn_endpoint_task(&handle, async move {
589            let (stream, _) = listener.accept().await.map_err(io_error)?;
590            run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
591        }),
592    }
593}
594
595fn drive_stream_ref_endpoint_over_tcp_stream<E>(
596    stream: TcpStream,
597    handle: Handle,
598    endpoint: E,
599    diagnostics: Option<StreamRefProtocolDiagnostics>,
600) -> StreamRefTcpHandle
601where
602    E: StreamRefProtoEndpointWake,
603{
604    StreamRefTcpHandle {
605        receiver: spawn_endpoint_task(&handle, async move {
606            run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
607        }),
608    }
609}
610
611async fn run_stream_ref_endpoint_quic_task<R, W, E>(
612    reader: R,
613    writer: W,
614    endpoint: E,
615    read_mode: CarrierReadMode,
616    diagnostics: Option<StreamRefProtocolDiagnostics>,
617) -> StreamResult<NotUsed>
618where
619    R: AsyncRead + Unpin + Send + 'static,
620    W: AsyncWrite + Unpin + Send + 'static,
621    E: StreamRefProtoEndpointWake,
622{
623    let (wake_sender, wake_receiver) = tokio_mpsc::channel(1);
624    endpoint.install_outbound_wake(wake_sender.clone());
625    let _ = wake_sender.try_send(());
626
627    let result = QuicEndpointTask {
628        reader,
629        writer,
630        endpoint: endpoint.clone(),
631        diagnostics,
632        read_mode,
633        decoder: FrameDecoder::default(),
634        read_buffer: vec![
635            0_u8;
636            read_mode
637                .chunk_size
638                .clamp(1, STREAM_REF_QUIC_READ_BUFFER_BYTES)
639        ],
640        pending_tail: Vec::new(),
641        write_buffer: BytesMut::new(),
642        encode_buffer: Vec::new(),
643        pending_diagnostics: VecDeque::new(),
644        read_closed: false,
645        inbound_seen: false,
646        outbound_written: false,
647        recheck_outbound: false,
648        outbound_closed: false,
649        write_shutdown: false,
650        wake_receiver,
651    }
652    .run()
653    .await;
654
655    endpoint.clear_outbound_wake();
656    if let Err(error) = &result {
657        endpoint.fail_connection(error.clone());
658    }
659    result
660}
661
662struct QuicEndpointTask<R, W, E>
663where
664    E: StreamRefProtoEndpointWake,
665{
666    reader: R,
667    writer: W,
668    endpoint: E,
669    diagnostics: Option<StreamRefProtocolDiagnostics>,
670    read_mode: CarrierReadMode,
671    decoder: FrameDecoder,
672    read_buffer: Vec<u8>,
673    pending_tail: Vec<u8>,
674    write_buffer: BytesMut,
675    encode_buffer: Vec<u8>,
676    pending_diagnostics: VecDeque<PendingDiagnostic>,
677    read_closed: bool,
678    inbound_seen: bool,
679    outbound_written: bool,
680    recheck_outbound: bool,
681    outbound_closed: bool,
682    write_shutdown: bool,
683    wake_receiver: tokio_mpsc::Receiver<()>,
684}
685
686impl<R, W, E> QuicEndpointTask<R, W, E>
687where
688    R: AsyncRead + Unpin,
689    W: AsyncWrite + Unpin,
690    E: StreamRefProtoEndpointWake,
691{
692    async fn run(mut self) -> StreamResult<NotUsed> {
693        loop {
694            if self.read_closed && self.write_shutdown {
695                return Ok(NotUsed);
696            }
697
698            self.drain_outbound()?;
699            if self.outbound_closed && self.write_buffer.is_empty() && !self.write_shutdown {
700                self.writer.shutdown().await.map_err(io_error)?;
701                self.write_shutdown = true;
702                continue;
703            }
704            tokio::select! {
705                biased;
706                wake = self.wake_receiver.recv(), if !self.outbound_closed => {
707                    if wake.is_none() {
708                        self.drain_outbound()?;
709                    }
710                }
711                _ = tokio::time::sleep(STREAM_REF_OUTBOUND_RECHECK_INTERVAL), if self.recheck_outbound && !self.outbound_closed && self.write_buffer.is_empty() => {
712                    self.drain_outbound()?;
713                }
714                written = self.writer.write(&self.write_buffer), if !self.write_buffer.is_empty() => {
715                    self.handle_written(written)?;
716                }
717                read = self.reader.read(&mut self.read_buffer), if !self.read_closed => {
718                    match read {
719                        Ok(0) => self.handle_eof()?,
720                        Ok(read) => {
721                            self.feed_read_buffer(read)?;
722                            self.drain_outbound()?;
723                        }
724                        Err(error) => {
725                            let error = io_error(error);
726                            if self.write_shutdown && is_quic_teardown_loss(&error) {
727                                return Ok(NotUsed);
728                            }
729                            return Err(error);
730                        }
731                    }
732                }
733            }
734        }
735    }
736
737    fn drain_outbound(&mut self) -> StreamResult<()> {
738        self.recheck_outbound = false;
739        while !self.outbound_closed && self.write_buffer.len() < MAX_STREAM_REF_FRAME_BYTES {
740            match self
741                .endpoint
742                .try_next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
743            {
744                StreamRefOutboundPoll::Ready(Ok(outbound)) => {
745                    encode_carrier_outbound_into(&outbound, &mut self.encode_buffer)?;
746                    let encoded_len = self.encode_buffer.len();
747                    if encoded_len == 0 {
748                        continue;
749                    }
750                    if self.diagnostics.is_some() {
751                        self.pending_diagnostics.push_back(PendingDiagnostic {
752                            remaining: encoded_len,
753                            counts: outbound_counts(&outbound),
754                        });
755                    }
756                    self.outbound_written = true;
757                    self.write_buffer.extend_from_slice(&self.encode_buffer);
758                }
759                StreamRefOutboundPoll::Ready(Err(error)) => return Err(error),
760                StreamRefOutboundPoll::Pending => {
761                    self.recheck_outbound =
762                        !self.outbound_written || self.inbound_seen || self.read_closed;
763                    break;
764                }
765                StreamRefOutboundPoll::Closed => {
766                    self.outbound_closed = true;
767                    break;
768                }
769            }
770        }
771        Ok(())
772    }
773
774    fn handle_written(&mut self, written: Result<usize, std::io::Error>) -> StreamResult<()> {
775        match written {
776            Ok(0) => Err(StreamError::Failed(
777                "StreamRefs QUIC stream accepted zero write bytes".to_owned(),
778            )),
779            Ok(written) => {
780                self.write_buffer.advance(written);
781                self.record_written_bytes(written);
782                Ok(())
783            }
784            Err(error) => Err(io_error(error)),
785        }
786    }
787
788    fn record_written_bytes(&mut self, mut written: usize) {
789        let Some(diagnostics) = &self.diagnostics else {
790            return;
791        };
792        while written > 0 {
793            let Some(front) = self.pending_diagnostics.front_mut() else {
794                return;
795            };
796            if written < front.remaining {
797                front.remaining -= written;
798                return;
799            }
800            written -= front.remaining;
801            let counts = front.counts;
802            self.pending_diagnostics.pop_front();
803            diagnostics.record_counts(counts);
804        }
805    }
806
807    fn feed_read_buffer(&mut self, read: usize) -> StreamResult<()> {
808        feed_read_bytes(
809            &mut self.decoder,
810            &self.endpoint,
811            self.read_mode,
812            &mut self.pending_tail,
813            &self.read_buffer[..read],
814        )?;
815        self.inbound_seen = true;
816        Ok(())
817    }
818
819    fn handle_eof(&mut self) -> StreamResult<()> {
820        if !self.pending_tail.is_empty() {
821            feed_inbound_chunk(&mut self.decoder, &self.endpoint, &self.pending_tail)?;
822            self.pending_tail.clear();
823        }
824        if self.read_mode.fail_on_eof {
825            self.endpoint
826                .fail_connection(StreamError::AbruptTermination);
827        }
828        self.read_closed = true;
829        self.recheck_outbound = true;
830        Ok(())
831    }
832}
833
834async fn run_stream_ref_endpoint_tcp_task<E>(
835    stream: TcpStream,
836    endpoint: E,
837    diagnostics: Option<StreamRefProtocolDiagnostics>,
838) -> StreamResult<NotUsed>
839where
840    E: StreamRefProtoEndpointWake,
841{
842    let (wake_sender, wake_receiver) = tokio_mpsc::channel(1);
843    endpoint.install_outbound_wake(wake_sender.clone());
844    let _ = wake_sender.try_send(());
845
846    let result = TcpEndpointTask {
847        stream,
848        endpoint: endpoint.clone(),
849        diagnostics,
850        read_mode: CarrierReadMode::new(STREAM_REF_TCP_CHUNK_SIZE, true, true),
851        decoder: FrameDecoder::default(),
852        read_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
853        pending_tail: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
854        write_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
855        encode_buffer: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
856        pending_diagnostics: VecDeque::new(),
857        outbound_closed: false,
858        write_shutdown: false,
859        wake_receiver,
860    }
861    .run()
862    .await;
863
864    endpoint.clear_outbound_wake();
865    if let Err(error) = &result {
866        endpoint.fail_connection(error.clone());
867    }
868    result
869}
870
871struct TcpEndpointTask<E>
872where
873    E: StreamRefProtoEndpointWake,
874{
875    stream: TcpStream,
876    endpoint: E,
877    diagnostics: Option<StreamRefProtocolDiagnostics>,
878    read_mode: CarrierReadMode,
879    decoder: FrameDecoder,
880    read_buffer: BytesMut,
881    pending_tail: Vec<u8>,
882    write_buffer: BytesMut,
883    encode_buffer: Vec<u8>,
884    pending_diagnostics: VecDeque<PendingDiagnostic>,
885    outbound_closed: bool,
886    write_shutdown: bool,
887    wake_receiver: tokio_mpsc::Receiver<()>,
888}
889
890impl<E> TcpEndpointTask<E>
891where
892    E: StreamRefProtoEndpointWake,
893{
894    async fn run(mut self) -> StreamResult<NotUsed> {
895        self.stream.set_nodelay(true).map_err(io_error)?;
896        loop {
897            self.drain_outbound()?;
898            if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) {
899                self.flush_write_buffer().await?;
900            }
901
902            tokio::select! {
903                biased;
904                wake = self.wake_receiver.recv() => {
905                    if wake.is_none() && !self.outbound_closed {
906                        self.drain_outbound()?;
907                    }
908                }
909                ready = self.stream.readable() => {
910                    ready.map_err(io_error)?;
911                    if self.read_available()? {
912                        return Ok(NotUsed);
913                    }
914                }
915                ready = self.stream.writable(), if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) => {
916                    ready.map_err(io_error)?;
917                    self.flush_ready_write_buffer()?;
918                }
919            }
920        }
921    }
922
923    fn drain_outbound(&mut self) -> StreamResult<()> {
924        while !self.outbound_closed && self.write_buffer.len() < MAX_STREAM_REF_FRAME_BYTES {
925            match self
926                .endpoint
927                .try_next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
928            {
929                StreamRefOutboundPoll::Ready(Ok(outbound)) => {
930                    encode_carrier_outbound_into(&outbound, &mut self.encode_buffer)?;
931                    let encoded_len = self.encode_buffer.len();
932                    if encoded_len == 0 {
933                        continue;
934                    }
935                    if self.diagnostics.is_some() {
936                        self.pending_diagnostics.push_back(PendingDiagnostic {
937                            remaining: encoded_len,
938                            counts: outbound_counts(&outbound),
939                        });
940                    }
941                    self.write_buffer.extend_from_slice(&self.encode_buffer);
942                }
943                StreamRefOutboundPoll::Ready(Err(error)) => return Err(error),
944                StreamRefOutboundPoll::Pending => break,
945                StreamRefOutboundPoll::Closed => {
946                    self.outbound_closed = true;
947                    break;
948                }
949            }
950        }
951        Ok(())
952    }
953
954    async fn flush_write_buffer(&mut self) -> StreamResult<()> {
955        if !self.write_buffer.is_empty() {
956            self.stream.writable().await.map_err(io_error)?;
957            self.flush_ready_write_buffer()?;
958        }
959        if self.outbound_closed && self.write_buffer.is_empty() && !self.write_shutdown {
960            self.stream.shutdown().await.map_err(io_error)?;
961            self.write_shutdown = true;
962        }
963        Ok(())
964    }
965
966    fn flush_ready_write_buffer(&mut self) -> StreamResult<()> {
967        while !self.write_buffer.is_empty() {
968            match self.stream.try_write(&self.write_buffer) {
969                Ok(0) => {
970                    return Err(StreamError::Failed(
971                        "StreamRefs TCP socket accepted zero write bytes".to_owned(),
972                    ));
973                }
974                Ok(written) => {
975                    self.write_buffer.advance(written);
976                    self.record_written_bytes(written);
977                }
978                Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
979                Err(error) => return Err(io_error(error)),
980            }
981        }
982
983        Ok(())
984    }
985
986    fn record_written_bytes(&mut self, mut written: usize) {
987        let Some(diagnostics) = &self.diagnostics else {
988            return;
989        };
990        while written > 0 {
991            let Some(front) = self.pending_diagnostics.front_mut() else {
992                return;
993            };
994            if written < front.remaining {
995                front.remaining -= written;
996                return;
997            }
998            written -= front.remaining;
999            let counts = front.counts;
1000            self.pending_diagnostics.pop_front();
1001            diagnostics.record_counts(counts);
1002        }
1003    }
1004
1005    fn read_available(&mut self) -> StreamResult<bool> {
1006        loop {
1007            self.read_buffer.reserve(self.read_mode.chunk_size);
1008            match self.stream.try_read_buf(&mut self.read_buffer) {
1009                Ok(0) => return self.handle_eof(),
1010                Ok(_) => {
1011                    feed_read_bytes(
1012                        &mut self.decoder,
1013                        &self.endpoint,
1014                        self.read_mode,
1015                        &mut self.pending_tail,
1016                        &self.read_buffer,
1017                    )?;
1018                    self.read_buffer.clear();
1019                    self.drain_outbound()?;
1020                }
1021                Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(false),
1022                Err(error) => return Err(io_error(error)),
1023            }
1024        }
1025    }
1026
1027    fn handle_eof(&mut self) -> StreamResult<bool> {
1028        if !self.pending_tail.is_empty() {
1029            feed_inbound_chunk(&mut self.decoder, &self.endpoint, &self.pending_tail)?;
1030            self.pending_tail.clear();
1031        }
1032        if self.read_mode.fail_on_eof {
1033            self.endpoint
1034                .fail_connection(StreamError::AbruptTermination);
1035        }
1036        Ok(true)
1037    }
1038}
1039
1040fn feed_read_bytes<E>(
1041    decoder: &mut FrameDecoder,
1042    endpoint: &E,
1043    read_mode: CarrierReadMode,
1044    pending_tail: &mut Vec<u8>,
1045    read_buffer: &[u8],
1046) -> StreamResult<()>
1047where
1048    E: StreamRefProtoEndpoint,
1049{
1050    if read_mode.emit_available {
1051        if !pending_tail.is_empty() {
1052            pending_tail.extend_from_slice(read_buffer);
1053            feed_inbound_chunk(decoder, endpoint, pending_tail)?;
1054            pending_tail.clear();
1055            return Ok(());
1056        }
1057        return feed_inbound_chunk(decoder, endpoint, read_buffer);
1058    }
1059
1060    let mut offset = 0;
1061    if !pending_tail.is_empty() {
1062        let needed = read_mode.chunk_size - pending_tail.len();
1063        let take = needed.min(read_buffer.len());
1064        pending_tail.extend_from_slice(&read_buffer[..take]);
1065        offset += take;
1066        if pending_tail.len() == read_mode.chunk_size {
1067            feed_inbound_chunk(decoder, endpoint, pending_tail)?;
1068            pending_tail.clear();
1069        }
1070    }
1071
1072    while offset + read_mode.chunk_size <= read_buffer.len() {
1073        let next = offset + read_mode.chunk_size;
1074        feed_inbound_chunk(decoder, endpoint, &read_buffer[offset..next])?;
1075        offset = next;
1076    }
1077
1078    if offset < read_buffer.len() {
1079        pending_tail.extend_from_slice(&read_buffer[offset..]);
1080    }
1081    Ok(())
1082}
1083
1084fn feed_inbound_chunk<E>(decoder: &mut FrameDecoder, endpoint: &E, chunk: &[u8]) -> StreamResult<()>
1085where
1086    E: StreamRefProtoEndpoint,
1087{
1088    decoder.push_chunk(chunk, endpoint)
1089}
1090
1091fn bind_tcp_listener<A>(addr: A) -> StreamResult<(TcpListener, StreamRefTcpBinding, Handle)>
1092where
1093    A: ToSocketAddrs + Send + 'static,
1094{
1095    let runtime = stream_ref_tcp_runtime()?;
1096    let listener = runtime
1097        .block_on(async { TcpListener::bind(addr).await })
1098        .map_err(io_error)?;
1099    let local_addr = listener.local_addr().map_err(io_error)?;
1100    Ok((
1101        listener,
1102        StreamRefTcpBinding { local_addr },
1103        runtime.handle().clone(),
1104    ))
1105}
1106
1107fn connect_tcp_stream<A>(addr: A) -> StreamResult<(TcpStream, Handle)>
1108where
1109    A: ToSocketAddrs + Send + 'static,
1110{
1111    let runtime = stream_ref_tcp_runtime()?;
1112    let stream = runtime
1113        .block_on(async { TcpStream::connect(addr).await })
1114        .map_err(io_error)?;
1115    stream.set_nodelay(true).map_err(io_error)?;
1116    Ok((stream, runtime.handle().clone()))
1117}
1118
1119fn stream_ref_tcp_runtime() -> StreamResult<&'static Runtime> {
1120    static RUNTIME: OnceLock<Result<Runtime, String>> = OnceLock::new();
1121    match RUNTIME.get_or_init(|| {
1122        tokio::runtime::Builder::new_multi_thread()
1123            .thread_name("datum-streamref-tcp")
1124            .enable_all()
1125            .build()
1126            .map_err(|error| error.to_string())
1127    }) {
1128        Ok(runtime) => Ok(runtime),
1129        Err(error) => Err(StreamError::Failed(format!(
1130            "failed to start StreamRefs TCP runtime: {error}"
1131        ))),
1132    }
1133}
1134
1135fn current_tokio_handle() -> StreamResult<Handle> {
1136    Handle::try_current().map_err(|error| {
1137        StreamError::Failed(format!(
1138            "StreamRefs TCP stream helper requires a current Tokio runtime: {error}"
1139        ))
1140    })
1141}
1142
1143fn io_error(error: std::io::Error) -> StreamError {
1144    StreamError::Failed(error.to_string())
1145}
1146
1147fn is_quic_teardown_loss(error: &StreamError) -> bool {
1148    matches!(error, StreamError::Failed(message) if message == "connection lost")
1149}
1150
1151fn spawn_endpoint_task<F>(handle: &Handle, run: F) -> mpsc::Receiver<StreamResult<NotUsed>>
1152where
1153    F: Future<Output = StreamResult<NotUsed>> + Send + 'static,
1154{
1155    let (sender, receiver) = mpsc::channel();
1156    handle.spawn(async move {
1157        let result = run.await;
1158        let _ = sender.send(result);
1159    });
1160    receiver
1161}
1162
1163fn encode_carrier_outbound_into(
1164    outbound: &StreamRefOutbound,
1165    bytes: &mut Vec<u8>,
1166) -> StreamResult<()> {
1167    bytes.clear();
1168    match outbound {
1169        StreamRefOutbound::Frame(frame) => append_protobuf_carrier_frame(frame, bytes)?,
1170        StreamRefOutbound::SequencedBatch(batch) => {
1171            append_compact_payload_batch(batch, bytes)?;
1172        }
1173    }
1174    Ok(())
1175}
1176
1177#[cfg(test)]
1178fn encode_carrier_frames(frames: &[StreamRefFrame]) -> StreamResult<Vec<u8>> {
1179    let mut bytes = Vec::new();
1180    let mut index = 0;
1181    while index < frames.len() {
1182        if sequenced_on_next(&frames[index]).is_some() {
1183            let end = sequenced_run_end(frames, index);
1184            append_compact_sequenced_batches(&frames[index..end], &mut bytes)?;
1185            index = end;
1186        } else {
1187            append_protobuf_carrier_frame(&frames[index], &mut bytes)?;
1188            index += 1;
1189        }
1190    }
1191    Ok(bytes)
1192}
1193
1194fn append_compact_payload_batch(
1195    batch: &StreamRefPayloadBatch,
1196    bytes: &mut Vec<u8>,
1197) -> StreamResult<()> {
1198    let mut start = 0;
1199    while start < batch.count() {
1200        let mut end = start;
1201        let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
1202        while end < batch.count() {
1203            let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
1204                .checked_add(batch.payload_len(end))
1205                .ok_or(StreamError::LimitExceeded {
1206                    max: MAX_STREAM_REF_FRAME_BYTES as u64,
1207                })?;
1208            let next_payload_len =
1209                payload_len
1210                    .checked_add(element_len)
1211                    .ok_or(StreamError::LimitExceeded {
1212                        max: MAX_STREAM_REF_FRAME_BYTES as u64,
1213                    })?;
1214            if end > start
1215                && (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
1216                    || end - start >= u16::MAX as usize)
1217            {
1218                break;
1219            }
1220            if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
1221                return Err(StreamError::LimitExceeded {
1222                    max: MAX_STREAM_REF_FRAME_BYTES as u64,
1223                });
1224            }
1225            payload_len = next_payload_len;
1226            end += 1;
1227        }
1228        append_compact_payload_batch_slice(batch, start, end, payload_len, bytes)?;
1229        start = end;
1230    }
1231    Ok(())
1232}
1233
1234fn append_compact_payload_batch_slice(
1235    batch: &StreamRefPayloadBatch,
1236    start: usize,
1237    end: usize,
1238    payload_len: usize,
1239    bytes: &mut Vec<u8>,
1240) -> StreamResult<()> {
1241    let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
1242        max: MAX_STREAM_REF_FRAME_BYTES as u64,
1243    })?;
1244    let count = u16::try_from(end - start).map_err(|_| StreamError::LimitExceeded {
1245        max: u16::MAX as u64,
1246    })?;
1247    let first_seq = batch
1248        .first_seq_nr()
1249        .checked_add(start as u64)
1250        .ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
1251    bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
1252    bytes.push(COMPACT_FRAME_VERSION);
1253    bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
1254    bytes.extend(batch.stream_ref_id().to_bytes());
1255    bytes.extend(first_seq.to_be_bytes());
1256    bytes.extend(count.to_be_bytes());
1257    for index in start..end {
1258        let payload = batch.payload(index);
1259        let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1260            max: u32::MAX as u64,
1261        })?;
1262        bytes.extend(payload_len.to_be_bytes());
1263        bytes.extend(payload);
1264    }
1265    Ok(())
1266}
1267
1268fn append_protobuf_carrier_frame(frame: &StreamRefFrame, bytes: &mut Vec<u8>) -> StreamResult<()> {
1269    let payload = frame.encode_to_vec();
1270    let len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1271        max: COMPACT_FRAME_LEN_MASK as u64,
1272    })?;
1273    if payload.len() > MAX_STREAM_REF_FRAME_BYTES || len > COMPACT_FRAME_LEN_MASK {
1274        return Err(StreamError::LimitExceeded {
1275            max: MAX_STREAM_REF_FRAME_BYTES as u64,
1276        });
1277    }
1278    bytes.extend(len.to_be_bytes());
1279    bytes.extend(payload);
1280    Ok(())
1281}
1282
1283#[cfg(test)]
1284fn append_compact_sequenced_batches(
1285    frames: &[StreamRefFrame],
1286    bytes: &mut Vec<u8>,
1287) -> StreamResult<()> {
1288    let mut start = 0;
1289    while start < frames.len() {
1290        let mut end = start;
1291        let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
1292        while end < frames.len() {
1293            let (_, payload) = sequenced_on_next(&frames[end]).expect("sequenced frame");
1294            let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
1295                .checked_add(payload.len())
1296                .ok_or(StreamError::LimitExceeded {
1297                    max: MAX_STREAM_REF_FRAME_BYTES as u64,
1298                })?;
1299            let next_payload_len =
1300                payload_len
1301                    .checked_add(element_len)
1302                    .ok_or(StreamError::LimitExceeded {
1303                        max: MAX_STREAM_REF_FRAME_BYTES as u64,
1304                    })?;
1305            if end > start
1306                && (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
1307                    || end - start >= u16::MAX as usize)
1308            {
1309                break;
1310            }
1311            if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
1312                return Err(StreamError::LimitExceeded {
1313                    max: MAX_STREAM_REF_FRAME_BYTES as u64,
1314                });
1315            }
1316            payload_len = next_payload_len;
1317            end += 1;
1318        }
1319        append_compact_sequenced_batch(&frames[start..end], payload_len, bytes)?;
1320        start = end;
1321    }
1322    Ok(())
1323}
1324
1325#[cfg(test)]
1326fn append_compact_sequenced_batch(
1327    frames: &[StreamRefFrame],
1328    payload_len: usize,
1329    bytes: &mut Vec<u8>,
1330) -> StreamResult<()> {
1331    let (first_seq, _) = sequenced_on_next(&frames[0]).expect("sequenced frame");
1332    let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
1333        max: MAX_STREAM_REF_FRAME_BYTES as u64,
1334    })?;
1335    let count = u16::try_from(frames.len()).map_err(|_| StreamError::LimitExceeded {
1336        max: u16::MAX as u64,
1337    })?;
1338    bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
1339    bytes.push(COMPACT_FRAME_VERSION);
1340    bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
1341    bytes.extend(frames[0].stream_ref_id.to_bytes());
1342    bytes.extend(first_seq.to_be_bytes());
1343    bytes.extend(count.to_be_bytes());
1344    for frame in frames {
1345        let (_, payload) = sequenced_on_next(frame).expect("sequenced frame");
1346        let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1347            max: u32::MAX as u64,
1348        })?;
1349        bytes.extend(payload_len.to_be_bytes());
1350        bytes.extend(payload);
1351    }
1352    Ok(())
1353}
1354
1355#[cfg(test)]
1356fn sequenced_run_end(frames: &[StreamRefFrame], start: usize) -> usize {
1357    let mut end = start + 1;
1358    while end < frames.len() {
1359        let Some((previous_seq, _)) = sequenced_on_next(&frames[end - 1]) else {
1360            break;
1361        };
1362        let Some((next_seq, _)) = sequenced_on_next(&frames[end]) else {
1363            break;
1364        };
1365        if frames[end].stream_ref_id != frames[start].stream_ref_id
1366            || next_seq != previous_seq.saturating_add(1)
1367        {
1368            break;
1369        }
1370        end += 1;
1371    }
1372    end
1373}
1374
1375#[cfg(test)]
1376fn sequenced_on_next(frame: &StreamRefFrame) -> Option<(u64, &[u8])> {
1377    match &frame.message {
1378        StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
1379            Some((*seq_nr, payload.bytes.as_slice()))
1380        }
1381        _ => None,
1382    }
1383}
1384
1385#[derive(Default)]
1386struct FrameDecoder {
1387    buffer: BytesMut,
1388    offset: usize,
1389}
1390
1391impl FrameDecoder {
1392    fn push_chunk<E>(&mut self, chunk: &[u8], endpoint: &E) -> StreamResult<()>
1393    where
1394        E: StreamRefProtoEndpoint,
1395    {
1396        self.buffer.extend_from_slice(chunk);
1397        while let Some(header) = self.peek_header()? {
1398            if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES + header.len {
1399                break;
1400            }
1401            let payload_start = self.offset + FRAME_LEN_BYTES;
1402            let payload_end = payload_start + header.len;
1403            let payload = &self.buffer[payload_start..payload_end];
1404            match header.kind {
1405                CarrierFrameKind::Protobuf => {
1406                    endpoint.handle_frame(StreamRefFrame::decode(payload)?)?;
1407                }
1408                CarrierFrameKind::Compact => {
1409                    decode_compact_carrier_frame(payload, endpoint)?;
1410                }
1411            }
1412            self.offset = payload_end;
1413        }
1414        if self.offset > 0 && (self.offset == self.buffer.len() || self.offset >= 64 * 1024) {
1415            self.buffer.advance(self.offset);
1416            self.offset = 0;
1417        }
1418        Ok(())
1419    }
1420
1421    fn peek_header(&self) -> StreamResult<Option<CarrierFrameHeader>> {
1422        if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES {
1423            return Ok(None);
1424        }
1425        let len = self.buffer[self.offset..self.offset + FRAME_LEN_BYTES]
1426            .try_into()
1427            .expect("frame header length");
1428        let raw_len = u32::from_be_bytes(len);
1429        let kind = if raw_len & COMPACT_FRAME_FLAG == 0 {
1430            CarrierFrameKind::Protobuf
1431        } else {
1432            CarrierFrameKind::Compact
1433        };
1434        let len = (raw_len & COMPACT_FRAME_LEN_MASK) as usize;
1435        if len > MAX_STREAM_REF_FRAME_BYTES {
1436            return Err(StreamError::LimitExceeded {
1437                max: MAX_STREAM_REF_FRAME_BYTES as u64,
1438            });
1439        }
1440        Ok(Some(CarrierFrameHeader { kind, len }))
1441    }
1442}
1443
1444#[derive(Clone, Copy)]
1445struct CarrierFrameHeader {
1446    kind: CarrierFrameKind,
1447    len: usize,
1448}
1449
1450#[derive(Clone, Copy)]
1451enum CarrierFrameKind {
1452    Protobuf,
1453    Compact,
1454}
1455
1456fn decode_compact_carrier_frame<E>(payload: &[u8], endpoint: &E) -> StreamResult<()>
1457where
1458    E: StreamRefProtoEndpoint,
1459{
1460    if payload.len() < COMPACT_BATCH_HEADER_BYTES {
1461        return Err(StreamError::Failed(
1462            "compact StreamRefs carrier frame too short".to_owned(),
1463        ));
1464    }
1465    let version = payload[0];
1466    if version != COMPACT_FRAME_VERSION {
1467        return Err(StreamError::Failed(format!(
1468            "unsupported compact StreamRefs carrier frame version: {version}"
1469        )));
1470    }
1471    let kind = payload[1];
1472    if kind != COMPACT_SEQUENCED_ON_NEXT_BATCH {
1473        return Err(StreamError::Failed(format!(
1474            "unsupported compact StreamRefs carrier frame kind: {kind}"
1475        )));
1476    }
1477    let stream_ref_id = StreamRefId::from_bytes(&payload[2..18])?;
1478    let first_seq = u64::from_be_bytes(payload[18..26].try_into().expect("seq len"));
1479    let count = u16::from_be_bytes(payload[26..28].try_into().expect("count len")) as usize;
1480    if count == 0 {
1481        return Err(StreamError::Failed(
1482            "compact StreamRefs carrier batch is empty".to_owned(),
1483        ));
1484    }
1485
1486    let mut offset = COMPACT_BATCH_HEADER_BYTES;
1487    let mut payloads = Vec::with_capacity(count);
1488    for index in 0..count {
1489        if payload.len().saturating_sub(offset) < COMPACT_BATCH_ELEMENT_LEN_BYTES {
1490            return Err(StreamError::Failed(
1491                "compact StreamRefs carrier batch has truncated payload length".to_owned(),
1492            ));
1493        }
1494        let payload_len = u32::from_be_bytes(
1495            payload[offset..offset + COMPACT_BATCH_ELEMENT_LEN_BYTES]
1496                .try_into()
1497                .expect("payload len"),
1498        ) as usize;
1499        offset += COMPACT_BATCH_ELEMENT_LEN_BYTES;
1500        if payload.len().saturating_sub(offset) < payload_len {
1501            return Err(StreamError::Failed(
1502                "compact StreamRefs carrier batch has truncated payload".to_owned(),
1503            ));
1504        }
1505        first_seq
1506            .checked_add(index as u64)
1507            .ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
1508        payloads.push(&payload[offset..offset + payload_len]);
1509        offset += payload_len;
1510    }
1511    if offset != payload.len() {
1512        return Err(StreamError::Failed(
1513            "compact StreamRefs carrier batch has trailing bytes".to_owned(),
1514        ));
1515    }
1516    endpoint.handle_sequenced_on_next_batch(stream_ref_id, first_seq, &payloads)
1517}
1518
1519#[cfg(test)]
1520mod tests {
1521    use super::*;
1522    use std::sync::{Arc, Mutex};
1523
1524    #[derive(Clone)]
1525    struct RecordingEndpoint {
1526        stream_ref_id: StreamRefId,
1527        frames: Arc<Mutex<Vec<StreamRefFrame>>>,
1528    }
1529
1530    impl RecordingEndpoint {
1531        fn new(stream_ref_id: StreamRefId) -> Self {
1532            Self {
1533                stream_ref_id,
1534                frames: Arc::new(Mutex::new(Vec::new())),
1535            }
1536        }
1537
1538        fn frames(&self) -> Vec<StreamRefFrame> {
1539            self.frames.lock().expect("recording endpoint").clone()
1540        }
1541    }
1542
1543    impl StreamRefProtoEndpoint for RecordingEndpoint {
1544        fn stream_ref_id(&self) -> StreamRefId {
1545            self.stream_ref_id
1546        }
1547
1548        fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1549            None
1550        }
1551
1552        fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1553            self.frames.lock().expect("recording endpoint").push(frame);
1554            Ok(())
1555        }
1556
1557        fn fail_connection(&self, _error: StreamError) {}
1558    }
1559
1560    #[test]
1561    fn carrier_frame_decoder_reassembles_split_frames() {
1562        let frame = StreamRefFrame::new(
1563            StreamRefId::from_u128(1),
1564            datum::StreamRefMessage::CumulativeDemand { seq_nr: 32 },
1565        );
1566        let bytes = encode_carrier_frames(std::slice::from_ref(&frame)).unwrap();
1567        let split = bytes.len() / 2;
1568        let mut decoder = FrameDecoder::default();
1569        let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(1));
1570
1571        decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
1572        assert!(endpoint.frames().is_empty());
1573        decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
1574        assert_eq!(endpoint.frames(), vec![frame]);
1575    }
1576
1577    #[test]
1578    fn compact_carrier_batch_round_trips_sequenced_frames() {
1579        let frames = (0_u64..3)
1580            .map(|seq_nr| {
1581                StreamRefFrame::new(
1582                    StreamRefId::from_u128(7),
1583                    datum::StreamRefMessage::SequencedOnNext {
1584                        seq_nr,
1585                        payload: datum::StreamRefPayloadBytes {
1586                            bytes: seq_nr.to_be_bytes().to_vec(),
1587                        },
1588                    },
1589                )
1590            })
1591            .collect::<Vec<_>>();
1592        let bytes = encode_carrier_frames(&frames).unwrap();
1593        let header = u32::from_be_bytes(bytes[..FRAME_LEN_BYTES].try_into().unwrap());
1594        assert_ne!(header & COMPACT_FRAME_FLAG, 0);
1595
1596        let mut decoder = FrameDecoder::default();
1597        let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(7));
1598        decoder.push_chunk(&bytes, &endpoint).unwrap();
1599        assert_eq!(endpoint.frames(), frames);
1600    }
1601
1602    #[test]
1603    fn compact_carrier_batch_reassembles_split_frames() {
1604        let frames = (4_u64..8)
1605            .map(|seq_nr| {
1606                StreamRefFrame::new(
1607                    StreamRefId::from_u128(8),
1608                    datum::StreamRefMessage::SequencedOnNext {
1609                        seq_nr,
1610                        payload: datum::StreamRefPayloadBytes {
1611                            bytes: vec![seq_nr as u8],
1612                        },
1613                    },
1614                )
1615            })
1616            .collect::<Vec<_>>();
1617        let bytes = encode_carrier_frames(&frames).unwrap();
1618        let split = FRAME_LEN_BYTES + 5;
1619        let mut decoder = FrameDecoder::default();
1620        let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(8));
1621
1622        decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
1623        assert!(endpoint.frames().is_empty());
1624        decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
1625        assert_eq!(endpoint.frames(), frames);
1626    }
1627}