use std::{
collections::VecDeque,
future::Future,
net::SocketAddr,
sync::{Arc, Mutex, OnceLock, mpsc},
time::Duration,
};
use bytes::{Buf, BytesMut};
use datum::{
NotUsed, Sink, Source, SourceRef, StreamCompletion, StreamError, StreamRefFrame, StreamRefId,
StreamRefMessage, StreamRefOutbound, StreamRefPayload, StreamRefPayloadBatch,
StreamRefProtoConsumer, StreamRefProtoEndpoint, StreamRefProtoProducer, StreamRefSettings,
StreamResult,
actor::stream_ref_proto::{StreamRefOutboundPoll, StreamRefProtoEndpointWake},
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::{TcpListener, TcpStream, ToSocketAddrs},
runtime::{Handle, Runtime},
sync::mpsc as tokio_mpsc,
};
use crate::QuicBidirectionalStream;
const FRAME_LEN_BYTES: usize = 4;
const MAX_STREAM_REF_FRAME_BYTES: usize = 16 * 1024 * 1024;
const STREAM_REF_TCP_CHUNK_SIZE: usize = 8192;
const STREAM_REF_QUIC_READ_BUFFER_BYTES: usize = 2048;
const STREAM_REF_OUTBOUND_BATCH_FRAMES: usize = 64;
const STREAM_REF_OUTBOUND_RECHECK_INTERVAL: Duration = Duration::from_millis(5);
const COMPACT_FRAME_FLAG: u32 = 0x8000_0000;
const COMPACT_FRAME_LEN_MASK: u32 = 0x7fff_ffff;
const COMPACT_FRAME_VERSION: u8 = 1;
const COMPACT_SEQUENCED_ON_NEXT_BATCH: u8 = 1;
const COMPACT_BATCH_HEADER_BYTES: usize = 1 + 1 + 16 + 8 + 2;
const COMPACT_BATCH_ELEMENT_LEN_BYTES: usize = 4;
#[derive(Clone, Copy)]
struct CarrierReadMode {
chunk_size: usize,
emit_available: bool,
fail_on_eof: bool,
}
impl CarrierReadMode {
fn new(chunk_size: usize, emit_available: bool, fail_on_eof: bool) -> Self {
assert!(chunk_size > 0, "chunk size must be greater than zero");
Self {
chunk_size,
emit_available,
fail_on_eof,
}
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct StreamRefProtocolMessageCounts {
pub cumulative_demand: u64,
pub sequenced_on_next: u64,
pub ack: u64,
}
#[derive(Clone, Default)]
pub struct StreamRefProtocolDiagnostics {
counts: Arc<Mutex<StreamRefProtocolMessageCounts>>,
}
impl StreamRefProtocolDiagnostics {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn snapshot(&self) -> StreamRefProtocolMessageCounts {
*self
.counts
.lock()
.expect("stream ref protocol diagnostics poisoned")
}
fn record_counts(&self, delta: StreamRefProtocolMessageCounts) {
if delta == StreamRefProtocolMessageCounts::default() {
return;
}
let mut counts = self
.counts
.lock()
.expect("stream ref protocol diagnostics poisoned");
counts.cumulative_demand = counts
.cumulative_demand
.saturating_add(delta.cumulative_demand);
counts.sequenced_on_next = counts
.sequenced_on_next
.saturating_add(delta.sequenced_on_next);
counts.ack = counts.ack.saturating_add(delta.ack);
}
}
fn outbound_counts(outbound: &StreamRefOutbound) -> StreamRefProtocolMessageCounts {
let mut counts = StreamRefProtocolMessageCounts::default();
match outbound {
StreamRefOutbound::Frame(frame) => match &frame.message {
StreamRefMessage::CumulativeDemand { .. } => {
counts.cumulative_demand = 1;
}
StreamRefMessage::SequencedOnNext { .. } => {
counts.sequenced_on_next = 1;
}
StreamRefMessage::Ack => {
counts.ack = 1;
}
StreamRefMessage::OnSubscribeHandshake
| StreamRefMessage::RemoteStreamCompleted { .. }
| StreamRefMessage::RemoteStreamFailure { .. } => {}
},
StreamRefOutbound::SequencedBatch(batch) => {
counts.sequenced_on_next = batch.count() as u64;
}
}
counts
}
#[derive(Clone, Copy)]
struct PendingDiagnostic {
remaining: usize,
counts: StreamRefProtocolMessageCounts,
}
#[must_use = "wait for the QUIC StreamRefs carrier to observe completion or failure"]
pub struct StreamRefQuicHandle {
receiver: mpsc::Receiver<StreamResult<NotUsed>>,
}
impl StreamRefQuicHandle {
pub fn wait(self) -> StreamResult<NotUsed> {
self.receiver
.recv()
.unwrap_or(Err(StreamError::AbruptTermination))
}
#[must_use]
pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
self.receiver.try_recv().ok()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StreamRefTcpBinding {
local_addr: SocketAddr,
}
impl StreamRefTcpBinding {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
#[must_use = "wait for the TCP StreamRefs carrier to observe completion or failure"]
pub struct StreamRefTcpHandle {
receiver: mpsc::Receiver<StreamResult<NotUsed>>,
}
impl StreamRefTcpHandle {
pub fn wait(self) -> StreamResult<NotUsed> {
self.receiver
.recv()
.unwrap_or(Err(StreamError::AbruptTermination))
}
#[must_use]
pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
self.receiver.try_recv().ok()
}
}
pub fn serve_source_ref_over_quic<T>(
stream: QuicBidirectionalStream,
source_ref: SourceRef<T>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<StreamRefQuicHandle>
where
T: StreamRefPayload,
{
let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
}
pub fn serve_source_over_quic<T, Mat>(
stream: QuicBidirectionalStream,
source: Source<T, Mat>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<StreamRefQuicHandle>
where
T: StreamRefPayload,
Mat: Send + 'static,
{
let producer = StreamRefProtoProducer::from_source(source, stream_ref_id, settings)?;
Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
}
pub fn source_ref_over_quic<T>(
stream: QuicBidirectionalStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> (Source<T, NotUsed>, StreamRefQuicHandle)
where
T: StreamRefPayload,
{
let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
let source = consumer.source();
let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
(source, handle)
}
pub fn serve_sink_ref_over_quic<T>(
stream: QuicBidirectionalStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> (Source<T, NotUsed>, StreamRefQuicHandle)
where
T: StreamRefPayload,
{
let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
let source = consumer.source();
let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
(source, handle)
}
pub fn sink_ref_over_quic<T>(
stream: QuicBidirectionalStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> (Sink<T, StreamCompletion<NotUsed>>, StreamRefQuicHandle)
where
T: StreamRefPayload,
{
let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
let sink = producer.sink();
let handle = drive_stream_ref_endpoint_over_quic(stream, producer, None);
(sink, handle)
}
pub fn serve_source_ref_over_tcp<T, A>(
addr: A,
source_ref: SourceRef<T>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
where
T: StreamRefPayload,
A: ToSocketAddrs + Send + 'static,
{
serve_source_ref_over_tcp_with_diagnostics(addr, source_ref, stream_ref_id, settings, None)
}
pub fn serve_source_ref_over_tcp_with_diagnostics<T, A>(
addr: A,
source_ref: SourceRef<T>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
where
T: StreamRefPayload,
A: ToSocketAddrs + Send + 'static,
{
let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
let (listener, binding, handle) = bind_tcp_listener(addr)?;
Ok((
binding,
drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics),
))
}
pub fn serve_source_ref_over_tcp_stream<T>(
stream: TcpStream,
source_ref: SourceRef<T>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<StreamRefTcpHandle>
where
T: StreamRefPayload,
{
serve_source_ref_over_tcp_stream_with_diagnostics(
stream,
source_ref,
stream_ref_id,
settings,
None,
)
}
pub fn serve_source_ref_over_tcp_stream_with_diagnostics<T>(
stream: TcpStream,
source_ref: SourceRef<T>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<StreamRefTcpHandle>
where
T: StreamRefPayload,
{
let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
let handle = current_tokio_handle()?;
Ok(drive_stream_ref_endpoint_over_tcp_stream(
stream,
handle,
producer,
diagnostics,
))
}
pub fn source_ref_over_tcp<T, A>(
addr: A,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
A: ToSocketAddrs + Send + 'static,
{
source_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
}
pub fn source_ref_over_tcp_with_diagnostics<T, A>(
addr: A,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
A: ToSocketAddrs + Send + 'static,
{
let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
let source = consumer.source();
let (stream, handle) = connect_tcp_stream(addr)?;
let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
Ok((source, handle))
}
pub fn source_ref_over_tcp_stream<T>(
stream: TcpStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
{
source_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
}
pub fn source_ref_over_tcp_stream_with_diagnostics<T>(
stream: TcpStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
{
let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
let source = consumer.source();
let handle = current_tokio_handle()?;
let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
Ok((source, handle))
}
pub fn serve_sink_ref_over_tcp<T, A>(
addr: A,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
A: ToSocketAddrs + Send + 'static,
{
serve_sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
}
pub fn serve_sink_ref_over_tcp_with_diagnostics<T, A>(
addr: A,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
A: ToSocketAddrs + Send + 'static,
{
let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
let source = consumer.source();
let (stream, handle) = connect_tcp_stream(addr)?;
let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
Ok((source, handle))
}
pub fn serve_sink_ref_over_tcp_stream<T>(
stream: TcpStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
{
serve_sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
}
pub fn serve_sink_ref_over_tcp_stream_with_diagnostics<T>(
stream: TcpStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
{
let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
let source = consumer.source();
let handle = current_tokio_handle()?;
let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
Ok((source, handle))
}
pub fn sink_ref_over_tcp<T, A>(
addr: A,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<(
Sink<T, StreamCompletion<NotUsed>>,
StreamRefTcpBinding,
StreamRefTcpHandle,
)>
where
T: StreamRefPayload,
A: ToSocketAddrs + Send + 'static,
{
sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
}
pub fn sink_ref_over_tcp_with_diagnostics<T, A>(
addr: A,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<(
Sink<T, StreamCompletion<NotUsed>>,
StreamRefTcpBinding,
StreamRefTcpHandle,
)>
where
T: StreamRefPayload,
A: ToSocketAddrs + Send + 'static,
{
let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
let sink = producer.sink();
let (listener, binding, handle) = bind_tcp_listener(addr)?;
let handle =
drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics);
Ok((sink, binding, handle))
}
pub fn sink_ref_over_tcp_stream<T>(
stream: TcpStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
{
sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
}
pub fn sink_ref_over_tcp_stream_with_diagnostics<T>(
stream: TcpStream,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
where
T: StreamRefPayload,
{
let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
let sink = producer.sink();
let handle = current_tokio_handle()?;
let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, producer, diagnostics);
Ok((sink, handle))
}
fn drive_stream_ref_endpoint_over_quic<E>(
stream: QuicBidirectionalStream,
endpoint: E,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamRefQuicHandle
where
E: StreamRefProtoEndpointWake,
{
let (reader, writer, handle, chunk_size, emit_available) = stream.into_stream_ref_parts();
let read_mode = CarrierReadMode::new(chunk_size, emit_available, false);
StreamRefQuicHandle {
receiver: spawn_endpoint_task(&handle, async move {
run_stream_ref_endpoint_quic_task(reader, writer, endpoint, read_mode, diagnostics)
.await
}),
}
}
fn drive_stream_ref_endpoint_over_tcp_listener<E>(
listener: TcpListener,
handle: Handle,
endpoint: E,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamRefTcpHandle
where
E: StreamRefProtoEndpointWake,
{
StreamRefTcpHandle {
receiver: spawn_endpoint_task(&handle, async move {
let (stream, _) = listener.accept().await.map_err(io_error)?;
run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
}),
}
}
fn drive_stream_ref_endpoint_over_tcp_stream<E>(
stream: TcpStream,
handle: Handle,
endpoint: E,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamRefTcpHandle
where
E: StreamRefProtoEndpointWake,
{
StreamRefTcpHandle {
receiver: spawn_endpoint_task(&handle, async move {
run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
}),
}
}
async fn run_stream_ref_endpoint_quic_task<R, W, E>(
reader: R,
writer: W,
endpoint: E,
read_mode: CarrierReadMode,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<NotUsed>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
E: StreamRefProtoEndpointWake,
{
let (wake_sender, wake_receiver) = tokio_mpsc::channel(1);
endpoint.install_outbound_wake(wake_sender.clone());
let _ = wake_sender.try_send(());
let result = QuicEndpointTask {
reader,
writer,
endpoint: endpoint.clone(),
diagnostics,
read_mode,
decoder: FrameDecoder::default(),
read_buffer: vec![
0_u8;
read_mode
.chunk_size
.clamp(1, STREAM_REF_QUIC_READ_BUFFER_BYTES)
],
pending_tail: Vec::new(),
write_buffer: BytesMut::new(),
encode_buffer: Vec::new(),
pending_diagnostics: VecDeque::new(),
read_closed: false,
inbound_seen: false,
outbound_written: false,
recheck_outbound: false,
outbound_closed: false,
write_shutdown: false,
wake_receiver,
}
.run()
.await;
endpoint.clear_outbound_wake();
if let Err(error) = &result {
endpoint.fail_connection(error.clone());
}
result
}
struct QuicEndpointTask<R, W, E>
where
E: StreamRefProtoEndpointWake,
{
reader: R,
writer: W,
endpoint: E,
diagnostics: Option<StreamRefProtocolDiagnostics>,
read_mode: CarrierReadMode,
decoder: FrameDecoder,
read_buffer: Vec<u8>,
pending_tail: Vec<u8>,
write_buffer: BytesMut,
encode_buffer: Vec<u8>,
pending_diagnostics: VecDeque<PendingDiagnostic>,
read_closed: bool,
inbound_seen: bool,
outbound_written: bool,
recheck_outbound: bool,
outbound_closed: bool,
write_shutdown: bool,
wake_receiver: tokio_mpsc::Receiver<()>,
}
impl<R, W, E> QuicEndpointTask<R, W, E>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
E: StreamRefProtoEndpointWake,
{
async fn run(mut self) -> StreamResult<NotUsed> {
loop {
if self.read_closed && self.write_shutdown {
return Ok(NotUsed);
}
self.drain_outbound()?;
if self.outbound_closed && self.write_buffer.is_empty() && !self.write_shutdown {
self.writer.shutdown().await.map_err(io_error)?;
self.write_shutdown = true;
continue;
}
tokio::select! {
biased;
wake = self.wake_receiver.recv(), if !self.outbound_closed => {
if wake.is_none() {
self.drain_outbound()?;
}
}
_ = tokio::time::sleep(STREAM_REF_OUTBOUND_RECHECK_INTERVAL), if self.recheck_outbound && !self.outbound_closed && self.write_buffer.is_empty() => {
self.drain_outbound()?;
}
written = self.writer.write(&self.write_buffer), if !self.write_buffer.is_empty() => {
self.handle_written(written)?;
}
read = self.reader.read(&mut self.read_buffer), if !self.read_closed => {
match read {
Ok(0) => self.handle_eof()?,
Ok(read) => {
self.feed_read_buffer(read)?;
self.drain_outbound()?;
}
Err(error) => {
let error = io_error(error);
if self.write_shutdown && is_quic_teardown_loss(&error) {
return Ok(NotUsed);
}
return Err(error);
}
}
}
}
}
}
fn drain_outbound(&mut self) -> StreamResult<()> {
self.recheck_outbound = false;
while !self.outbound_closed && self.write_buffer.len() < MAX_STREAM_REF_FRAME_BYTES {
match self
.endpoint
.try_next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
{
StreamRefOutboundPoll::Ready(Ok(outbound)) => {
encode_carrier_outbound_into(&outbound, &mut self.encode_buffer)?;
let encoded_len = self.encode_buffer.len();
if encoded_len == 0 {
continue;
}
if self.diagnostics.is_some() {
self.pending_diagnostics.push_back(PendingDiagnostic {
remaining: encoded_len,
counts: outbound_counts(&outbound),
});
}
self.outbound_written = true;
self.write_buffer.extend_from_slice(&self.encode_buffer);
}
StreamRefOutboundPoll::Ready(Err(error)) => return Err(error),
StreamRefOutboundPoll::Pending => {
self.recheck_outbound =
!self.outbound_written || self.inbound_seen || self.read_closed;
break;
}
StreamRefOutboundPoll::Closed => {
self.outbound_closed = true;
break;
}
}
}
Ok(())
}
fn handle_written(&mut self, written: Result<usize, std::io::Error>) -> StreamResult<()> {
match written {
Ok(0) => Err(StreamError::Failed(
"StreamRefs QUIC stream accepted zero write bytes".to_owned(),
)),
Ok(written) => {
self.write_buffer.advance(written);
self.record_written_bytes(written);
Ok(())
}
Err(error) => Err(io_error(error)),
}
}
fn record_written_bytes(&mut self, mut written: usize) {
let Some(diagnostics) = &self.diagnostics else {
return;
};
while written > 0 {
let Some(front) = self.pending_diagnostics.front_mut() else {
return;
};
if written < front.remaining {
front.remaining -= written;
return;
}
written -= front.remaining;
let counts = front.counts;
self.pending_diagnostics.pop_front();
diagnostics.record_counts(counts);
}
}
fn feed_read_buffer(&mut self, read: usize) -> StreamResult<()> {
feed_read_bytes(
&mut self.decoder,
&self.endpoint,
self.read_mode,
&mut self.pending_tail,
&self.read_buffer[..read],
)?;
self.inbound_seen = true;
Ok(())
}
fn handle_eof(&mut self) -> StreamResult<()> {
if !self.pending_tail.is_empty() {
feed_inbound_chunk(&mut self.decoder, &self.endpoint, &self.pending_tail)?;
self.pending_tail.clear();
}
if self.read_mode.fail_on_eof {
self.endpoint
.fail_connection(StreamError::AbruptTermination);
}
self.read_closed = true;
self.recheck_outbound = true;
Ok(())
}
}
async fn run_stream_ref_endpoint_tcp_task<E>(
stream: TcpStream,
endpoint: E,
diagnostics: Option<StreamRefProtocolDiagnostics>,
) -> StreamResult<NotUsed>
where
E: StreamRefProtoEndpointWake,
{
let (wake_sender, wake_receiver) = tokio_mpsc::channel(1);
endpoint.install_outbound_wake(wake_sender.clone());
let _ = wake_sender.try_send(());
let result = TcpEndpointTask {
stream,
endpoint: endpoint.clone(),
diagnostics,
read_mode: CarrierReadMode::new(STREAM_REF_TCP_CHUNK_SIZE, true, true),
decoder: FrameDecoder::default(),
read_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
pending_tail: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
write_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
encode_buffer: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
pending_diagnostics: VecDeque::new(),
outbound_closed: false,
write_shutdown: false,
wake_receiver,
}
.run()
.await;
endpoint.clear_outbound_wake();
if let Err(error) = &result {
endpoint.fail_connection(error.clone());
}
result
}
struct TcpEndpointTask<E>
where
E: StreamRefProtoEndpointWake,
{
stream: TcpStream,
endpoint: E,
diagnostics: Option<StreamRefProtocolDiagnostics>,
read_mode: CarrierReadMode,
decoder: FrameDecoder,
read_buffer: BytesMut,
pending_tail: Vec<u8>,
write_buffer: BytesMut,
encode_buffer: Vec<u8>,
pending_diagnostics: VecDeque<PendingDiagnostic>,
outbound_closed: bool,
write_shutdown: bool,
wake_receiver: tokio_mpsc::Receiver<()>,
}
impl<E> TcpEndpointTask<E>
where
E: StreamRefProtoEndpointWake,
{
async fn run(mut self) -> StreamResult<NotUsed> {
self.stream.set_nodelay(true).map_err(io_error)?;
loop {
self.drain_outbound()?;
if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) {
self.flush_write_buffer().await?;
}
tokio::select! {
biased;
wake = self.wake_receiver.recv() => {
if wake.is_none() && !self.outbound_closed {
self.drain_outbound()?;
}
}
ready = self.stream.readable() => {
ready.map_err(io_error)?;
if self.read_available()? {
return Ok(NotUsed);
}
}
ready = self.stream.writable(), if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) => {
ready.map_err(io_error)?;
self.flush_ready_write_buffer()?;
}
}
}
}
fn drain_outbound(&mut self) -> StreamResult<()> {
while !self.outbound_closed && self.write_buffer.len() < MAX_STREAM_REF_FRAME_BYTES {
match self
.endpoint
.try_next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
{
StreamRefOutboundPoll::Ready(Ok(outbound)) => {
encode_carrier_outbound_into(&outbound, &mut self.encode_buffer)?;
let encoded_len = self.encode_buffer.len();
if encoded_len == 0 {
continue;
}
if self.diagnostics.is_some() {
self.pending_diagnostics.push_back(PendingDiagnostic {
remaining: encoded_len,
counts: outbound_counts(&outbound),
});
}
self.write_buffer.extend_from_slice(&self.encode_buffer);
}
StreamRefOutboundPoll::Ready(Err(error)) => return Err(error),
StreamRefOutboundPoll::Pending => break,
StreamRefOutboundPoll::Closed => {
self.outbound_closed = true;
break;
}
}
}
Ok(())
}
async fn flush_write_buffer(&mut self) -> StreamResult<()> {
if !self.write_buffer.is_empty() {
self.stream.writable().await.map_err(io_error)?;
self.flush_ready_write_buffer()?;
}
if self.outbound_closed && self.write_buffer.is_empty() && !self.write_shutdown {
self.stream.shutdown().await.map_err(io_error)?;
self.write_shutdown = true;
}
Ok(())
}
fn flush_ready_write_buffer(&mut self) -> StreamResult<()> {
while !self.write_buffer.is_empty() {
match self.stream.try_write(&self.write_buffer) {
Ok(0) => {
return Err(StreamError::Failed(
"StreamRefs TCP socket accepted zero write bytes".to_owned(),
));
}
Ok(written) => {
self.write_buffer.advance(written);
self.record_written_bytes(written);
}
Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
Err(error) => return Err(io_error(error)),
}
}
Ok(())
}
fn record_written_bytes(&mut self, mut written: usize) {
let Some(diagnostics) = &self.diagnostics else {
return;
};
while written > 0 {
let Some(front) = self.pending_diagnostics.front_mut() else {
return;
};
if written < front.remaining {
front.remaining -= written;
return;
}
written -= front.remaining;
let counts = front.counts;
self.pending_diagnostics.pop_front();
diagnostics.record_counts(counts);
}
}
fn read_available(&mut self) -> StreamResult<bool> {
loop {
self.read_buffer.reserve(self.read_mode.chunk_size);
match self.stream.try_read_buf(&mut self.read_buffer) {
Ok(0) => return self.handle_eof(),
Ok(_) => {
feed_read_bytes(
&mut self.decoder,
&self.endpoint,
self.read_mode,
&mut self.pending_tail,
&self.read_buffer,
)?;
self.read_buffer.clear();
self.drain_outbound()?;
}
Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(false),
Err(error) => return Err(io_error(error)),
}
}
}
fn handle_eof(&mut self) -> StreamResult<bool> {
if !self.pending_tail.is_empty() {
feed_inbound_chunk(&mut self.decoder, &self.endpoint, &self.pending_tail)?;
self.pending_tail.clear();
}
if self.read_mode.fail_on_eof {
self.endpoint
.fail_connection(StreamError::AbruptTermination);
}
Ok(true)
}
}
fn feed_read_bytes<E>(
decoder: &mut FrameDecoder,
endpoint: &E,
read_mode: CarrierReadMode,
pending_tail: &mut Vec<u8>,
read_buffer: &[u8],
) -> StreamResult<()>
where
E: StreamRefProtoEndpoint,
{
if read_mode.emit_available {
if !pending_tail.is_empty() {
pending_tail.extend_from_slice(read_buffer);
feed_inbound_chunk(decoder, endpoint, pending_tail)?;
pending_tail.clear();
return Ok(());
}
return feed_inbound_chunk(decoder, endpoint, read_buffer);
}
let mut offset = 0;
if !pending_tail.is_empty() {
let needed = read_mode.chunk_size - pending_tail.len();
let take = needed.min(read_buffer.len());
pending_tail.extend_from_slice(&read_buffer[..take]);
offset += take;
if pending_tail.len() == read_mode.chunk_size {
feed_inbound_chunk(decoder, endpoint, pending_tail)?;
pending_tail.clear();
}
}
while offset + read_mode.chunk_size <= read_buffer.len() {
let next = offset + read_mode.chunk_size;
feed_inbound_chunk(decoder, endpoint, &read_buffer[offset..next])?;
offset = next;
}
if offset < read_buffer.len() {
pending_tail.extend_from_slice(&read_buffer[offset..]);
}
Ok(())
}
fn feed_inbound_chunk<E>(decoder: &mut FrameDecoder, endpoint: &E, chunk: &[u8]) -> StreamResult<()>
where
E: StreamRefProtoEndpoint,
{
decoder.push_chunk(chunk, endpoint)
}
fn bind_tcp_listener<A>(addr: A) -> StreamResult<(TcpListener, StreamRefTcpBinding, Handle)>
where
A: ToSocketAddrs + Send + 'static,
{
let runtime = stream_ref_tcp_runtime()?;
let listener = runtime
.block_on(async { TcpListener::bind(addr).await })
.map_err(io_error)?;
let local_addr = listener.local_addr().map_err(io_error)?;
Ok((
listener,
StreamRefTcpBinding { local_addr },
runtime.handle().clone(),
))
}
fn connect_tcp_stream<A>(addr: A) -> StreamResult<(TcpStream, Handle)>
where
A: ToSocketAddrs + Send + 'static,
{
let runtime = stream_ref_tcp_runtime()?;
let stream = runtime
.block_on(async { TcpStream::connect(addr).await })
.map_err(io_error)?;
stream.set_nodelay(true).map_err(io_error)?;
Ok((stream, runtime.handle().clone()))
}
fn stream_ref_tcp_runtime() -> StreamResult<&'static Runtime> {
static RUNTIME: OnceLock<Result<Runtime, String>> = OnceLock::new();
match RUNTIME.get_or_init(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("datum-streamref-tcp")
.enable_all()
.build()
.map_err(|error| error.to_string())
}) {
Ok(runtime) => Ok(runtime),
Err(error) => Err(StreamError::Failed(format!(
"failed to start StreamRefs TCP runtime: {error}"
))),
}
}
fn current_tokio_handle() -> StreamResult<Handle> {
Handle::try_current().map_err(|error| {
StreamError::Failed(format!(
"StreamRefs TCP stream helper requires a current Tokio runtime: {error}"
))
})
}
fn io_error(error: std::io::Error) -> StreamError {
StreamError::Failed(error.to_string())
}
fn is_quic_teardown_loss(error: &StreamError) -> bool {
matches!(error, StreamError::Failed(message) if message == "connection lost")
}
fn spawn_endpoint_task<F>(handle: &Handle, run: F) -> mpsc::Receiver<StreamResult<NotUsed>>
where
F: Future<Output = StreamResult<NotUsed>> + Send + 'static,
{
let (sender, receiver) = mpsc::channel();
handle.spawn(async move {
let result = run.await;
let _ = sender.send(result);
});
receiver
}
fn encode_carrier_outbound_into(
outbound: &StreamRefOutbound,
bytes: &mut Vec<u8>,
) -> StreamResult<()> {
bytes.clear();
match outbound {
StreamRefOutbound::Frame(frame) => append_protobuf_carrier_frame(frame, bytes)?,
StreamRefOutbound::SequencedBatch(batch) => {
append_compact_payload_batch(batch, bytes)?;
}
}
Ok(())
}
#[cfg(test)]
fn encode_carrier_frames(frames: &[StreamRefFrame]) -> StreamResult<Vec<u8>> {
let mut bytes = Vec::new();
let mut index = 0;
while index < frames.len() {
if sequenced_on_next(&frames[index]).is_some() {
let end = sequenced_run_end(frames, index);
append_compact_sequenced_batches(&frames[index..end], &mut bytes)?;
index = end;
} else {
append_protobuf_carrier_frame(&frames[index], &mut bytes)?;
index += 1;
}
}
Ok(bytes)
}
fn append_compact_payload_batch(
batch: &StreamRefPayloadBatch,
bytes: &mut Vec<u8>,
) -> StreamResult<()> {
let mut start = 0;
while start < batch.count() {
let mut end = start;
let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
while end < batch.count() {
let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
.checked_add(batch.payload_len(end))
.ok_or(StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
})?;
let next_payload_len =
payload_len
.checked_add(element_len)
.ok_or(StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
})?;
if end > start
&& (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
|| end - start >= u16::MAX as usize)
{
break;
}
if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
return Err(StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
});
}
payload_len = next_payload_len;
end += 1;
}
append_compact_payload_batch_slice(batch, start, end, payload_len, bytes)?;
start = end;
}
Ok(())
}
fn append_compact_payload_batch_slice(
batch: &StreamRefPayloadBatch,
start: usize,
end: usize,
payload_len: usize,
bytes: &mut Vec<u8>,
) -> StreamResult<()> {
let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
})?;
let count = u16::try_from(end - start).map_err(|_| StreamError::LimitExceeded {
max: u16::MAX as u64,
})?;
let first_seq = batch
.first_seq_nr()
.checked_add(start as u64)
.ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
bytes.push(COMPACT_FRAME_VERSION);
bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
bytes.extend(batch.stream_ref_id().to_bytes());
bytes.extend(first_seq.to_be_bytes());
bytes.extend(count.to_be_bytes());
for index in start..end {
let payload = batch.payload(index);
let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
max: u32::MAX as u64,
})?;
bytes.extend(payload_len.to_be_bytes());
bytes.extend(payload);
}
Ok(())
}
fn append_protobuf_carrier_frame(frame: &StreamRefFrame, bytes: &mut Vec<u8>) -> StreamResult<()> {
let payload = frame.encode_to_vec();
let len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
max: COMPACT_FRAME_LEN_MASK as u64,
})?;
if payload.len() > MAX_STREAM_REF_FRAME_BYTES || len > COMPACT_FRAME_LEN_MASK {
return Err(StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
});
}
bytes.extend(len.to_be_bytes());
bytes.extend(payload);
Ok(())
}
#[cfg(test)]
fn append_compact_sequenced_batches(
frames: &[StreamRefFrame],
bytes: &mut Vec<u8>,
) -> StreamResult<()> {
let mut start = 0;
while start < frames.len() {
let mut end = start;
let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
while end < frames.len() {
let (_, payload) = sequenced_on_next(&frames[end]).expect("sequenced frame");
let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
.checked_add(payload.len())
.ok_or(StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
})?;
let next_payload_len =
payload_len
.checked_add(element_len)
.ok_or(StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
})?;
if end > start
&& (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
|| end - start >= u16::MAX as usize)
{
break;
}
if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
return Err(StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
});
}
payload_len = next_payload_len;
end += 1;
}
append_compact_sequenced_batch(&frames[start..end], payload_len, bytes)?;
start = end;
}
Ok(())
}
#[cfg(test)]
fn append_compact_sequenced_batch(
frames: &[StreamRefFrame],
payload_len: usize,
bytes: &mut Vec<u8>,
) -> StreamResult<()> {
let (first_seq, _) = sequenced_on_next(&frames[0]).expect("sequenced frame");
let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
})?;
let count = u16::try_from(frames.len()).map_err(|_| StreamError::LimitExceeded {
max: u16::MAX as u64,
})?;
bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
bytes.push(COMPACT_FRAME_VERSION);
bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
bytes.extend(frames[0].stream_ref_id.to_bytes());
bytes.extend(first_seq.to_be_bytes());
bytes.extend(count.to_be_bytes());
for frame in frames {
let (_, payload) = sequenced_on_next(frame).expect("sequenced frame");
let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
max: u32::MAX as u64,
})?;
bytes.extend(payload_len.to_be_bytes());
bytes.extend(payload);
}
Ok(())
}
#[cfg(test)]
fn sequenced_run_end(frames: &[StreamRefFrame], start: usize) -> usize {
let mut end = start + 1;
while end < frames.len() {
let Some((previous_seq, _)) = sequenced_on_next(&frames[end - 1]) else {
break;
};
let Some((next_seq, _)) = sequenced_on_next(&frames[end]) else {
break;
};
if frames[end].stream_ref_id != frames[start].stream_ref_id
|| next_seq != previous_seq.saturating_add(1)
{
break;
}
end += 1;
}
end
}
#[cfg(test)]
fn sequenced_on_next(frame: &StreamRefFrame) -> Option<(u64, &[u8])> {
match &frame.message {
StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
Some((*seq_nr, payload.bytes.as_slice()))
}
_ => None,
}
}
#[derive(Default)]
struct FrameDecoder {
buffer: BytesMut,
offset: usize,
}
impl FrameDecoder {
fn push_chunk<E>(&mut self, chunk: &[u8], endpoint: &E) -> StreamResult<()>
where
E: StreamRefProtoEndpoint,
{
self.buffer.extend_from_slice(chunk);
while let Some(header) = self.peek_header()? {
if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES + header.len {
break;
}
let payload_start = self.offset + FRAME_LEN_BYTES;
let payload_end = payload_start + header.len;
let payload = &self.buffer[payload_start..payload_end];
match header.kind {
CarrierFrameKind::Protobuf => {
endpoint.handle_frame(StreamRefFrame::decode(payload)?)?;
}
CarrierFrameKind::Compact => {
decode_compact_carrier_frame(payload, endpoint)?;
}
}
self.offset = payload_end;
}
if self.offset > 0 && (self.offset == self.buffer.len() || self.offset >= 64 * 1024) {
self.buffer.advance(self.offset);
self.offset = 0;
}
Ok(())
}
fn peek_header(&self) -> StreamResult<Option<CarrierFrameHeader>> {
if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES {
return Ok(None);
}
let len = self.buffer[self.offset..self.offset + FRAME_LEN_BYTES]
.try_into()
.expect("frame header length");
let raw_len = u32::from_be_bytes(len);
let kind = if raw_len & COMPACT_FRAME_FLAG == 0 {
CarrierFrameKind::Protobuf
} else {
CarrierFrameKind::Compact
};
let len = (raw_len & COMPACT_FRAME_LEN_MASK) as usize;
if len > MAX_STREAM_REF_FRAME_BYTES {
return Err(StreamError::LimitExceeded {
max: MAX_STREAM_REF_FRAME_BYTES as u64,
});
}
Ok(Some(CarrierFrameHeader { kind, len }))
}
}
#[derive(Clone, Copy)]
struct CarrierFrameHeader {
kind: CarrierFrameKind,
len: usize,
}
#[derive(Clone, Copy)]
enum CarrierFrameKind {
Protobuf,
Compact,
}
fn decode_compact_carrier_frame<E>(payload: &[u8], endpoint: &E) -> StreamResult<()>
where
E: StreamRefProtoEndpoint,
{
if payload.len() < COMPACT_BATCH_HEADER_BYTES {
return Err(StreamError::Failed(
"compact StreamRefs carrier frame too short".to_owned(),
));
}
let version = payload[0];
if version != COMPACT_FRAME_VERSION {
return Err(StreamError::Failed(format!(
"unsupported compact StreamRefs carrier frame version: {version}"
)));
}
let kind = payload[1];
if kind != COMPACT_SEQUENCED_ON_NEXT_BATCH {
return Err(StreamError::Failed(format!(
"unsupported compact StreamRefs carrier frame kind: {kind}"
)));
}
let stream_ref_id = StreamRefId::from_bytes(&payload[2..18])?;
let first_seq = u64::from_be_bytes(payload[18..26].try_into().expect("seq len"));
let count = u16::from_be_bytes(payload[26..28].try_into().expect("count len")) as usize;
if count == 0 {
return Err(StreamError::Failed(
"compact StreamRefs carrier batch is empty".to_owned(),
));
}
let mut offset = COMPACT_BATCH_HEADER_BYTES;
let mut payloads = Vec::with_capacity(count);
for index in 0..count {
if payload.len().saturating_sub(offset) < COMPACT_BATCH_ELEMENT_LEN_BYTES {
return Err(StreamError::Failed(
"compact StreamRefs carrier batch has truncated payload length".to_owned(),
));
}
let payload_len = u32::from_be_bytes(
payload[offset..offset + COMPACT_BATCH_ELEMENT_LEN_BYTES]
.try_into()
.expect("payload len"),
) as usize;
offset += COMPACT_BATCH_ELEMENT_LEN_BYTES;
if payload.len().saturating_sub(offset) < payload_len {
return Err(StreamError::Failed(
"compact StreamRefs carrier batch has truncated payload".to_owned(),
));
}
first_seq
.checked_add(index as u64)
.ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
payloads.push(&payload[offset..offset + payload_len]);
offset += payload_len;
}
if offset != payload.len() {
return Err(StreamError::Failed(
"compact StreamRefs carrier batch has trailing bytes".to_owned(),
));
}
endpoint.handle_sequenced_on_next_batch(stream_ref_id, first_seq, &payloads)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct RecordingEndpoint {
stream_ref_id: StreamRefId,
frames: Arc<Mutex<Vec<StreamRefFrame>>>,
}
impl RecordingEndpoint {
fn new(stream_ref_id: StreamRefId) -> Self {
Self {
stream_ref_id,
frames: Arc::new(Mutex::new(Vec::new())),
}
}
fn frames(&self) -> Vec<StreamRefFrame> {
self.frames.lock().expect("recording endpoint").clone()
}
}
impl StreamRefProtoEndpoint for RecordingEndpoint {
fn stream_ref_id(&self) -> StreamRefId {
self.stream_ref_id
}
fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
None
}
fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
self.frames.lock().expect("recording endpoint").push(frame);
Ok(())
}
fn fail_connection(&self, _error: StreamError) {}
}
#[test]
fn carrier_frame_decoder_reassembles_split_frames() {
let frame = StreamRefFrame::new(
StreamRefId::from_u128(1),
datum::StreamRefMessage::CumulativeDemand { seq_nr: 32 },
);
let bytes = encode_carrier_frames(std::slice::from_ref(&frame)).unwrap();
let split = bytes.len() / 2;
let mut decoder = FrameDecoder::default();
let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(1));
decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
assert!(endpoint.frames().is_empty());
decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
assert_eq!(endpoint.frames(), vec![frame]);
}
#[test]
fn compact_carrier_batch_round_trips_sequenced_frames() {
let frames = (0_u64..3)
.map(|seq_nr| {
StreamRefFrame::new(
StreamRefId::from_u128(7),
datum::StreamRefMessage::SequencedOnNext {
seq_nr,
payload: datum::StreamRefPayloadBytes {
bytes: seq_nr.to_be_bytes().to_vec(),
},
},
)
})
.collect::<Vec<_>>();
let bytes = encode_carrier_frames(&frames).unwrap();
let header = u32::from_be_bytes(bytes[..FRAME_LEN_BYTES].try_into().unwrap());
assert_ne!(header & COMPACT_FRAME_FLAG, 0);
let mut decoder = FrameDecoder::default();
let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(7));
decoder.push_chunk(&bytes, &endpoint).unwrap();
assert_eq!(endpoint.frames(), frames);
}
#[test]
fn compact_carrier_batch_reassembles_split_frames() {
let frames = (4_u64..8)
.map(|seq_nr| {
StreamRefFrame::new(
StreamRefId::from_u128(8),
datum::StreamRefMessage::SequencedOnNext {
seq_nr,
payload: datum::StreamRefPayloadBytes {
bytes: vec![seq_nr as u8],
},
},
)
})
.collect::<Vec<_>>();
let bytes = encode_carrier_frames(&frames).unwrap();
let split = FRAME_LEN_BYTES + 5;
let mut decoder = FrameDecoder::default();
let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(8));
decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
assert!(endpoint.frames().is_empty());
decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
assert_eq!(endpoint.frames(), frames);
}
}