1use std::{
8 collections::VecDeque,
9 future::Future,
10 net::SocketAddr,
11 sync::{
12 Arc, Mutex, OnceLock,
13 atomic::{AtomicBool, Ordering},
14 mpsc,
15 },
16 thread,
17};
18
19use bytes::{Buf, BytesMut};
20use datum::{
21 NotUsed, Sink, Source, SourceRef, StreamCompletion, StreamError, StreamRefFrame, StreamRefId,
22 StreamRefMessage, StreamRefOutbound, StreamRefPayload, StreamRefPayloadBatch,
23 StreamRefProtoConsumer, StreamRefProtoEndpoint, StreamRefProtoProducer, StreamRefSettings,
24 StreamResult,
25 actor::stream_ref_proto::{StreamRefOutboundPoll, StreamRefProtoEndpointWake},
26};
27use tokio::{
28 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
29 net::{TcpListener, TcpStream, ToSocketAddrs},
30 runtime::{Handle, Runtime},
31 sync::mpsc as tokio_mpsc,
32};
33
34use crate::QuicBidirectionalStream;
35
36const FRAME_LEN_BYTES: usize = 4;
37const MAX_STREAM_REF_FRAME_BYTES: usize = 16 * 1024 * 1024;
38const STREAM_REF_TCP_CHUNK_SIZE: usize = 8192;
39const STREAM_REF_OUTBOUND_BATCH_FRAMES: usize = 64;
40
41const COMPACT_FRAME_FLAG: u32 = 0x8000_0000;
45const COMPACT_FRAME_LEN_MASK: u32 = 0x7fff_ffff;
46const COMPACT_FRAME_VERSION: u8 = 1;
47const COMPACT_SEQUENCED_ON_NEXT_BATCH: u8 = 1;
48const COMPACT_BATCH_HEADER_BYTES: usize = 1 + 1 + 16 + 8 + 2;
49const COMPACT_BATCH_ELEMENT_LEN_BYTES: usize = 4;
50
51#[derive(Clone, Copy)]
52struct CarrierReadMode {
53 chunk_size: usize,
54 emit_available: bool,
55 fail_on_eof: bool,
56}
57
58impl CarrierReadMode {
59 fn new(chunk_size: usize, emit_available: bool, fail_on_eof: bool) -> Self {
60 assert!(chunk_size > 0, "chunk size must be greater than zero");
61 Self {
62 chunk_size,
63 emit_available,
64 fail_on_eof,
65 }
66 }
67}
68
69#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
72pub struct StreamRefProtocolMessageCounts {
73 pub cumulative_demand: u64,
74 pub sequenced_on_next: u64,
75 pub ack: u64,
76}
77
78#[derive(Clone, Default)]
84pub struct StreamRefProtocolDiagnostics {
85 counts: Arc<Mutex<StreamRefProtocolMessageCounts>>,
86}
87
88impl StreamRefProtocolDiagnostics {
89 #[must_use]
90 pub fn new() -> Self {
91 Self::default()
92 }
93
94 #[must_use]
95 pub fn snapshot(&self) -> StreamRefProtocolMessageCounts {
96 *self
97 .counts
98 .lock()
99 .expect("stream ref protocol diagnostics poisoned")
100 }
101
102 fn record_written_outbound(&self, outbound: &StreamRefOutbound) {
103 self.record_counts(outbound_counts(outbound));
104 }
105
106 fn record_counts(&self, delta: StreamRefProtocolMessageCounts) {
107 if delta == StreamRefProtocolMessageCounts::default() {
108 return;
109 }
110 let mut counts = self
111 .counts
112 .lock()
113 .expect("stream ref protocol diagnostics poisoned");
114 counts.cumulative_demand = counts
115 .cumulative_demand
116 .saturating_add(delta.cumulative_demand);
117 counts.sequenced_on_next = counts
118 .sequenced_on_next
119 .saturating_add(delta.sequenced_on_next);
120 counts.ack = counts.ack.saturating_add(delta.ack);
121 }
122}
123
124fn outbound_counts(outbound: &StreamRefOutbound) -> StreamRefProtocolMessageCounts {
125 let mut counts = StreamRefProtocolMessageCounts::default();
126 match outbound {
127 StreamRefOutbound::Frame(frame) => match &frame.message {
128 StreamRefMessage::CumulativeDemand { .. } => {
129 counts.cumulative_demand = 1;
130 }
131 StreamRefMessage::SequencedOnNext { .. } => {
132 counts.sequenced_on_next = 1;
133 }
134 StreamRefMessage::Ack => {
135 counts.ack = 1;
136 }
137 StreamRefMessage::OnSubscribeHandshake
138 | StreamRefMessage::RemoteStreamCompleted { .. }
139 | StreamRefMessage::RemoteStreamFailure { .. } => {}
140 },
141 StreamRefOutbound::SequencedBatch(batch) => {
142 counts.sequenced_on_next = batch.count() as u64;
143 }
144 }
145 counts
146}
147
148#[derive(Clone, Copy)]
149struct PendingDiagnostic {
150 remaining: usize,
151 counts: StreamRefProtocolMessageCounts,
152}
153
154#[must_use = "wait for the QUIC StreamRefs carrier to observe completion or failure"]
156pub struct StreamRefQuicHandle {
157 receiver: mpsc::Receiver<StreamResult<NotUsed>>,
158}
159
160impl StreamRefQuicHandle {
161 pub fn wait(self) -> StreamResult<NotUsed> {
162 self.receiver
163 .recv()
164 .unwrap_or(Err(StreamError::AbruptTermination))
165 }
166
167 #[must_use]
168 pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
169 self.receiver.try_recv().ok()
170 }
171}
172
173#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct StreamRefTcpBinding {
176 local_addr: SocketAddr,
177}
178
179impl StreamRefTcpBinding {
180 #[must_use]
181 pub fn local_addr(&self) -> SocketAddr {
182 self.local_addr
183 }
184}
185
186#[must_use = "wait for the TCP StreamRefs carrier to observe completion or failure"]
188pub struct StreamRefTcpHandle {
189 receiver: mpsc::Receiver<StreamResult<NotUsed>>,
190}
191
192impl StreamRefTcpHandle {
193 pub fn wait(self) -> StreamResult<NotUsed> {
194 self.receiver
195 .recv()
196 .unwrap_or(Err(StreamError::AbruptTermination))
197 }
198
199 #[must_use]
200 pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
201 self.receiver.try_recv().ok()
202 }
203}
204
205pub fn serve_source_ref_over_quic<T>(
207 stream: QuicBidirectionalStream,
208 source_ref: SourceRef<T>,
209 stream_ref_id: StreamRefId,
210 settings: StreamRefSettings,
211) -> StreamResult<StreamRefQuicHandle>
212where
213 T: StreamRefPayload,
214{
215 let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
216 Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
217}
218
219pub fn serve_source_over_quic<T, Mat>(
221 stream: QuicBidirectionalStream,
222 source: Source<T, Mat>,
223 stream_ref_id: StreamRefId,
224 settings: StreamRefSettings,
225) -> StreamResult<StreamRefQuicHandle>
226where
227 T: StreamRefPayload,
228 Mat: Send + 'static,
229{
230 let producer = StreamRefProtoProducer::from_source(source, stream_ref_id, settings)?;
231 Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
232}
233
234pub fn source_ref_over_quic<T>(
236 stream: QuicBidirectionalStream,
237 stream_ref_id: StreamRefId,
238 settings: StreamRefSettings,
239) -> (Source<T, NotUsed>, StreamRefQuicHandle)
240where
241 T: StreamRefPayload,
242{
243 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
244 let source = consumer.source();
245 let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
246 (source, handle)
247}
248
249pub fn serve_sink_ref_over_quic<T>(
257 stream: QuicBidirectionalStream,
258 stream_ref_id: StreamRefId,
259 settings: StreamRefSettings,
260) -> (Source<T, NotUsed>, StreamRefQuicHandle)
261where
262 T: StreamRefPayload,
263{
264 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
265 let source = consumer.source();
266 let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
267 (source, handle)
268}
269
270pub fn sink_ref_over_quic<T>(
273 stream: QuicBidirectionalStream,
274 stream_ref_id: StreamRefId,
275 settings: StreamRefSettings,
276) -> (Sink<T, StreamCompletion<NotUsed>>, StreamRefQuicHandle)
277where
278 T: StreamRefPayload,
279{
280 let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
281 let sink = producer.sink();
282 let handle = drive_stream_ref_endpoint_over_quic(stream, producer, None);
283 (sink, handle)
284}
285
286pub fn serve_source_ref_over_tcp<T, A>(
293 addr: A,
294 source_ref: SourceRef<T>,
295 stream_ref_id: StreamRefId,
296 settings: StreamRefSettings,
297) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
298where
299 T: StreamRefPayload,
300 A: ToSocketAddrs + Send + 'static,
301{
302 serve_source_ref_over_tcp_with_diagnostics(addr, source_ref, stream_ref_id, settings, None)
303}
304
305pub fn serve_source_ref_over_tcp_with_diagnostics<T, A>(
306 addr: A,
307 source_ref: SourceRef<T>,
308 stream_ref_id: StreamRefId,
309 settings: StreamRefSettings,
310 diagnostics: Option<StreamRefProtocolDiagnostics>,
311) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
312where
313 T: StreamRefPayload,
314 A: ToSocketAddrs + Send + 'static,
315{
316 let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
317 let (listener, binding, handle) = bind_tcp_listener(addr)?;
318 Ok((
319 binding,
320 drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics),
321 ))
322}
323
324pub fn serve_source_ref_over_tcp_stream<T>(
330 stream: TcpStream,
331 source_ref: SourceRef<T>,
332 stream_ref_id: StreamRefId,
333 settings: StreamRefSettings,
334) -> StreamResult<StreamRefTcpHandle>
335where
336 T: StreamRefPayload,
337{
338 serve_source_ref_over_tcp_stream_with_diagnostics(
339 stream,
340 source_ref,
341 stream_ref_id,
342 settings,
343 None,
344 )
345}
346
347pub fn serve_source_ref_over_tcp_stream_with_diagnostics<T>(
348 stream: TcpStream,
349 source_ref: SourceRef<T>,
350 stream_ref_id: StreamRefId,
351 settings: StreamRefSettings,
352 diagnostics: Option<StreamRefProtocolDiagnostics>,
353) -> StreamResult<StreamRefTcpHandle>
354where
355 T: StreamRefPayload,
356{
357 let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
358 let handle = current_tokio_handle()?;
359 Ok(drive_stream_ref_endpoint_over_tcp_stream(
360 stream,
361 handle,
362 producer,
363 diagnostics,
364 ))
365}
366
367pub fn source_ref_over_tcp<T, A>(
372 addr: A,
373 stream_ref_id: StreamRefId,
374 settings: StreamRefSettings,
375) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
376where
377 T: StreamRefPayload,
378 A: ToSocketAddrs + Send + 'static,
379{
380 source_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
381}
382
383pub fn source_ref_over_tcp_with_diagnostics<T, A>(
384 addr: A,
385 stream_ref_id: StreamRefId,
386 settings: StreamRefSettings,
387 diagnostics: Option<StreamRefProtocolDiagnostics>,
388) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
389where
390 T: StreamRefPayload,
391 A: ToSocketAddrs + Send + 'static,
392{
393 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
394 let source = consumer.source();
395 let (stream, handle) = connect_tcp_stream(addr)?;
396 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
397 Ok((source, handle))
398}
399
400pub fn source_ref_over_tcp_stream<T>(
403 stream: TcpStream,
404 stream_ref_id: StreamRefId,
405 settings: StreamRefSettings,
406) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
407where
408 T: StreamRefPayload,
409{
410 source_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
411}
412
413pub fn source_ref_over_tcp_stream_with_diagnostics<T>(
414 stream: TcpStream,
415 stream_ref_id: StreamRefId,
416 settings: StreamRefSettings,
417 diagnostics: Option<StreamRefProtocolDiagnostics>,
418) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
419where
420 T: StreamRefPayload,
421{
422 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
423 let source = consumer.source();
424 let handle = current_tokio_handle()?;
425 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
426 Ok((source, handle))
427}
428
429pub fn serve_sink_ref_over_tcp<T, A>(
435 addr: A,
436 stream_ref_id: StreamRefId,
437 settings: StreamRefSettings,
438) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
439where
440 T: StreamRefPayload,
441 A: ToSocketAddrs + Send + 'static,
442{
443 serve_sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
444}
445
446pub fn serve_sink_ref_over_tcp_with_diagnostics<T, A>(
447 addr: A,
448 stream_ref_id: StreamRefId,
449 settings: StreamRefSettings,
450 diagnostics: Option<StreamRefProtocolDiagnostics>,
451) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
452where
453 T: StreamRefPayload,
454 A: ToSocketAddrs + Send + 'static,
455{
456 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
457 let source = consumer.source();
458 let (stream, handle) = connect_tcp_stream(addr)?;
459 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
460 Ok((source, handle))
461}
462
463pub fn serve_sink_ref_over_tcp_stream<T>(
466 stream: TcpStream,
467 stream_ref_id: StreamRefId,
468 settings: StreamRefSettings,
469) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
470where
471 T: StreamRefPayload,
472{
473 serve_sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
474}
475
476pub fn serve_sink_ref_over_tcp_stream_with_diagnostics<T>(
477 stream: TcpStream,
478 stream_ref_id: StreamRefId,
479 settings: StreamRefSettings,
480 diagnostics: Option<StreamRefProtocolDiagnostics>,
481) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
482where
483 T: StreamRefPayload,
484{
485 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
486 let source = consumer.source();
487 let handle = current_tokio_handle()?;
488 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
489 Ok((source, handle))
490}
491
492pub fn sink_ref_over_tcp<T, A>(
500 addr: A,
501 stream_ref_id: StreamRefId,
502 settings: StreamRefSettings,
503) -> StreamResult<(
504 Sink<T, StreamCompletion<NotUsed>>,
505 StreamRefTcpBinding,
506 StreamRefTcpHandle,
507)>
508where
509 T: StreamRefPayload,
510 A: ToSocketAddrs + Send + 'static,
511{
512 sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
513}
514
515pub fn sink_ref_over_tcp_with_diagnostics<T, A>(
516 addr: A,
517 stream_ref_id: StreamRefId,
518 settings: StreamRefSettings,
519 diagnostics: Option<StreamRefProtocolDiagnostics>,
520) -> StreamResult<(
521 Sink<T, StreamCompletion<NotUsed>>,
522 StreamRefTcpBinding,
523 StreamRefTcpHandle,
524)>
525where
526 T: StreamRefPayload,
527 A: ToSocketAddrs + Send + 'static,
528{
529 let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
530 let sink = producer.sink();
531 let (listener, binding, handle) = bind_tcp_listener(addr)?;
532 let handle =
533 drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics);
534 Ok((sink, binding, handle))
535}
536
537pub fn sink_ref_over_tcp_stream<T>(
540 stream: TcpStream,
541 stream_ref_id: StreamRefId,
542 settings: StreamRefSettings,
543) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
544where
545 T: StreamRefPayload,
546{
547 sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
548}
549
550pub fn sink_ref_over_tcp_stream_with_diagnostics<T>(
551 stream: TcpStream,
552 stream_ref_id: StreamRefId,
553 settings: StreamRefSettings,
554 diagnostics: Option<StreamRefProtocolDiagnostics>,
555) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
556where
557 T: StreamRefPayload,
558{
559 let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
560 let sink = producer.sink();
561 let handle = current_tokio_handle()?;
562 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, producer, diagnostics);
563 Ok((sink, handle))
564}
565
566fn drive_stream_ref_endpoint_over_quic<E>(
567 stream: QuicBidirectionalStream,
568 endpoint: E,
569 diagnostics: Option<StreamRefProtocolDiagnostics>,
570) -> StreamRefQuicHandle
571where
572 E: StreamRefProtoEndpoint,
573{
574 let (reader, writer, handle, chunk_size, emit_available) = stream.into_stream_ref_parts();
575 let read_mode = CarrierReadMode::new(chunk_size, emit_available, false);
576 StreamRefQuicHandle {
577 receiver: spawn_carrier_thread(move || {
578 run_stream_ref_endpoint_io(reader, writer, endpoint, handle, read_mode, diagnostics)
579 }),
580 }
581}
582
583fn drive_stream_ref_endpoint_over_tcp_listener<E>(
584 listener: TcpListener,
585 handle: Handle,
586 endpoint: E,
587 diagnostics: Option<StreamRefProtocolDiagnostics>,
588) -> StreamRefTcpHandle
589where
590 E: StreamRefProtoEndpointWake,
591{
592 StreamRefTcpHandle {
593 receiver: spawn_tcp_endpoint_task(&handle, async move {
594 let (stream, _) = listener.accept().await.map_err(io_error)?;
595 run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
596 }),
597 }
598}
599
600fn drive_stream_ref_endpoint_over_tcp_stream<E>(
601 stream: TcpStream,
602 handle: Handle,
603 endpoint: E,
604 diagnostics: Option<StreamRefProtocolDiagnostics>,
605) -> StreamRefTcpHandle
606where
607 E: StreamRefProtoEndpointWake,
608{
609 StreamRefTcpHandle {
610 receiver: spawn_tcp_endpoint_task(&handle, async move {
611 run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
612 }),
613 }
614}
615
616fn run_stream_ref_endpoint_io<R, W, E>(
617 reader: R,
618 writer: W,
619 endpoint: E,
620 handle: Handle,
621 read_mode: CarrierReadMode,
622 diagnostics: Option<StreamRefProtocolDiagnostics>,
623) -> StreamResult<NotUsed>
624where
625 R: AsyncRead + Unpin + Send + 'static,
626 W: AsyncWrite + Unpin + Send + 'static,
627 E: StreamRefProtoEndpoint,
628{
629 let outbound_completed = Arc::new(AtomicBool::new(false));
632 let outbound_endpoint = endpoint.clone();
633 let outbound_handle = handle.clone();
634 let outbound_completed_for_writer = Arc::clone(&outbound_completed);
635 let outbound_thread = thread::spawn(move || {
636 let result = run_outbound_frames(
637 writer,
638 outbound_endpoint.clone(),
639 outbound_handle,
640 diagnostics,
641 );
642 match &result {
643 Ok(_) => outbound_completed_for_writer.store(true, Ordering::Release),
644 Err(error) => outbound_endpoint.fail_connection(error.clone()),
645 }
646 result
647 });
648
649 let inbound_endpoint = endpoint.clone();
650 let outbound_completed_for_reader = Arc::clone(&outbound_completed);
651 let inbound_thread = thread::spawn(move || {
652 let result = run_inbound_frames(reader, inbound_endpoint.clone(), handle, read_mode);
653 if let Err(error) = &result {
654 if outbound_completed_for_reader.load(Ordering::Acquire) && is_quic_teardown_loss(error)
655 {
656 return Ok(NotUsed);
657 }
658 inbound_endpoint.fail_connection(error.clone());
659 }
660 result
661 });
662
663 let outbound = join_carrier_thread(outbound_thread);
664 let inbound = join_carrier_thread(inbound_thread);
665 match (outbound, inbound) {
666 (Err(error), _) => Err(error),
667 (_, Err(error)) => Err(error),
668 (Ok(()), Ok(())) => Ok(NotUsed),
669 }
670}
671
672async fn run_stream_ref_endpoint_tcp_task<E>(
673 stream: TcpStream,
674 endpoint: E,
675 diagnostics: Option<StreamRefProtocolDiagnostics>,
676) -> StreamResult<NotUsed>
677where
678 E: StreamRefProtoEndpointWake,
679{
680 let (wake_sender, wake_receiver) = tokio_mpsc::channel(1);
681 endpoint.install_outbound_wake(wake_sender.clone());
682 let _ = wake_sender.try_send(());
683
684 let result = TcpEndpointTask {
685 stream,
686 endpoint: endpoint.clone(),
687 diagnostics,
688 read_mode: CarrierReadMode::new(STREAM_REF_TCP_CHUNK_SIZE, true, true),
689 decoder: FrameDecoder::default(),
690 read_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
691 pending_tail: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
692 write_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
693 encode_buffer: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
694 pending_diagnostics: VecDeque::new(),
695 outbound_closed: false,
696 write_shutdown: false,
697 wake_receiver,
698 }
699 .run()
700 .await;
701
702 endpoint.clear_outbound_wake();
703 if let Err(error) = &result {
704 endpoint.fail_connection(error.clone());
705 }
706 result
707}
708
709struct TcpEndpointTask<E>
710where
711 E: StreamRefProtoEndpointWake,
712{
713 stream: TcpStream,
714 endpoint: E,
715 diagnostics: Option<StreamRefProtocolDiagnostics>,
716 read_mode: CarrierReadMode,
717 decoder: FrameDecoder,
718 read_buffer: BytesMut,
719 pending_tail: Vec<u8>,
720 write_buffer: BytesMut,
721 encode_buffer: Vec<u8>,
722 pending_diagnostics: VecDeque<PendingDiagnostic>,
723 outbound_closed: bool,
724 write_shutdown: bool,
725 wake_receiver: tokio_mpsc::Receiver<()>,
726}
727
728impl<E> TcpEndpointTask<E>
729where
730 E: StreamRefProtoEndpointWake,
731{
732 async fn run(mut self) -> StreamResult<NotUsed> {
733 self.stream.set_nodelay(true).map_err(io_error)?;
734 loop {
735 self.drain_outbound()?;
736 if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) {
737 self.flush_write_buffer().await?;
738 }
739
740 tokio::select! {
741 biased;
742 wake = self.wake_receiver.recv() => {
743 if wake.is_none() && !self.outbound_closed {
744 self.drain_outbound()?;
745 }
746 }
747 ready = self.stream.readable() => {
748 ready.map_err(io_error)?;
749 if self.read_available()? {
750 return Ok(NotUsed);
751 }
752 }
753 ready = self.stream.writable(), if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) => {
754 ready.map_err(io_error)?;
755 self.flush_ready_write_buffer()?;
756 }
757 }
758 }
759 }
760
761 fn drain_outbound(&mut self) -> StreamResult<()> {
762 while !self.outbound_closed && self.write_buffer.len() < MAX_STREAM_REF_FRAME_BYTES {
763 match self
764 .endpoint
765 .try_next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
766 {
767 StreamRefOutboundPoll::Ready(Ok(outbound)) => {
768 encode_carrier_outbound_into(&outbound, &mut self.encode_buffer)?;
769 let encoded_len = self.encode_buffer.len();
770 if encoded_len == 0 {
771 continue;
772 }
773 if self.diagnostics.is_some() {
774 self.pending_diagnostics.push_back(PendingDiagnostic {
775 remaining: encoded_len,
776 counts: outbound_counts(&outbound),
777 });
778 }
779 self.write_buffer.extend_from_slice(&self.encode_buffer);
780 }
781 StreamRefOutboundPoll::Ready(Err(error)) => return Err(error),
782 StreamRefOutboundPoll::Pending => break,
783 StreamRefOutboundPoll::Closed => {
784 self.outbound_closed = true;
785 break;
786 }
787 }
788 }
789 Ok(())
790 }
791
792 async fn flush_write_buffer(&mut self) -> StreamResult<()> {
793 if !self.write_buffer.is_empty() {
794 self.stream.writable().await.map_err(io_error)?;
795 self.flush_ready_write_buffer()?;
796 }
797 if self.outbound_closed && self.write_buffer.is_empty() && !self.write_shutdown {
798 self.stream.shutdown().await.map_err(io_error)?;
799 self.write_shutdown = true;
800 }
801 Ok(())
802 }
803
804 fn flush_ready_write_buffer(&mut self) -> StreamResult<()> {
805 while !self.write_buffer.is_empty() {
806 match self.stream.try_write(&self.write_buffer) {
807 Ok(0) => {
808 return Err(StreamError::Failed(
809 "StreamRefs TCP socket accepted zero write bytes".to_owned(),
810 ));
811 }
812 Ok(written) => {
813 self.write_buffer.advance(written);
814 self.record_written_bytes(written);
815 }
816 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
817 Err(error) => return Err(io_error(error)),
818 }
819 }
820
821 Ok(())
822 }
823
824 fn record_written_bytes(&mut self, mut written: usize) {
825 let Some(diagnostics) = &self.diagnostics else {
826 return;
827 };
828 while written > 0 {
829 let Some(front) = self.pending_diagnostics.front_mut() else {
830 return;
831 };
832 if written < front.remaining {
833 front.remaining -= written;
834 return;
835 }
836 written -= front.remaining;
837 let counts = front.counts;
838 self.pending_diagnostics.pop_front();
839 diagnostics.record_counts(counts);
840 }
841 }
842
843 fn read_available(&mut self) -> StreamResult<bool> {
844 loop {
845 self.read_buffer.reserve(self.read_mode.chunk_size);
846 match self.stream.try_read_buf(&mut self.read_buffer) {
847 Ok(0) => return self.handle_eof(),
848 Ok(_) => {
849 feed_read_bytes(
850 &mut self.decoder,
851 &self.endpoint,
852 self.read_mode,
853 &mut self.pending_tail,
854 &self.read_buffer,
855 )?;
856 self.read_buffer.clear();
857 self.drain_outbound()?;
858 }
859 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(false),
860 Err(error) => return Err(io_error(error)),
861 }
862 }
863 }
864
865 fn handle_eof(&mut self) -> StreamResult<bool> {
866 if !self.pending_tail.is_empty() {
867 feed_inbound_chunk(&mut self.decoder, &self.endpoint, &self.pending_tail)?;
868 self.pending_tail.clear();
869 }
870 if self.read_mode.fail_on_eof {
871 self.endpoint
872 .fail_connection(StreamError::AbruptTermination);
873 }
874 Ok(true)
875 }
876}
877
878fn run_outbound_frames<W, E>(
879 mut writer: W,
880 endpoint: E,
881 handle: Handle,
882 diagnostics: Option<StreamRefProtocolDiagnostics>,
883) -> StreamResult<NotUsed>
884where
885 W: AsyncWrite + Unpin + Send + 'static,
886 E: StreamRefProtoEndpoint,
887{
888 let mut bytes = Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE);
889 loop {
890 let Some(outbound) =
891 endpoint.next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
892 else {
893 handle.block_on(async {
894 writer.flush().await.map_err(io_error)?;
895 writer.shutdown().await.map_err(io_error)
896 })?;
897 return Ok(NotUsed);
898 };
899 let outbound = outbound?;
900 encode_carrier_outbound_into(&outbound, &mut bytes)?;
901 handle.block_on(async {
902 writer.write_all(&bytes).await.map_err(io_error)?;
903 writer.flush().await.map_err(io_error)
904 })?;
905 if let Some(diagnostics) = &diagnostics {
906 diagnostics.record_written_outbound(&outbound);
907 }
908 }
909}
910
911fn run_inbound_frames<R, E>(
912 mut reader: R,
913 endpoint: E,
914 handle: Handle,
915 read_mode: CarrierReadMode,
916) -> StreamResult<NotUsed>
917where
918 R: AsyncRead + Unpin + Send + 'static,
919 E: StreamRefProtoEndpoint,
920{
921 handle.block_on(async move {
922 let mut buffer = vec![0_u8; read_mode.chunk_size];
923 let mut pending_tail = Vec::with_capacity(read_mode.chunk_size);
924 let mut decoder = FrameDecoder::default();
925
926 loop {
927 let read = reader.read(&mut buffer).await.map_err(io_error)?;
928 if read == 0 {
929 if !pending_tail.is_empty() {
930 feed_inbound_chunk(&mut decoder, &endpoint, &pending_tail)?;
931 pending_tail.clear();
932 }
933 if read_mode.fail_on_eof {
934 endpoint.fail_connection(StreamError::AbruptTermination);
935 }
936 return Ok(NotUsed);
937 }
938 feed_read_bytes(
939 &mut decoder,
940 &endpoint,
941 read_mode,
942 &mut pending_tail,
943 &buffer[..read],
944 )?;
945 }
946 })
947}
948
949fn feed_read_bytes<E>(
950 decoder: &mut FrameDecoder,
951 endpoint: &E,
952 read_mode: CarrierReadMode,
953 pending_tail: &mut Vec<u8>,
954 read_buffer: &[u8],
955) -> StreamResult<()>
956where
957 E: StreamRefProtoEndpoint,
958{
959 if read_mode.emit_available {
960 if !pending_tail.is_empty() {
961 pending_tail.extend_from_slice(read_buffer);
962 feed_inbound_chunk(decoder, endpoint, pending_tail)?;
963 pending_tail.clear();
964 return Ok(());
965 }
966 return feed_inbound_chunk(decoder, endpoint, read_buffer);
967 }
968
969 let mut offset = 0;
970 if !pending_tail.is_empty() {
971 let needed = read_mode.chunk_size - pending_tail.len();
972 let take = needed.min(read_buffer.len());
973 pending_tail.extend_from_slice(&read_buffer[..take]);
974 offset += take;
975 if pending_tail.len() == read_mode.chunk_size {
976 feed_inbound_chunk(decoder, endpoint, pending_tail)?;
977 pending_tail.clear();
978 }
979 }
980
981 while offset + read_mode.chunk_size <= read_buffer.len() {
982 let next = offset + read_mode.chunk_size;
983 feed_inbound_chunk(decoder, endpoint, &read_buffer[offset..next])?;
984 offset = next;
985 }
986
987 if offset < read_buffer.len() {
988 pending_tail.extend_from_slice(&read_buffer[offset..]);
989 }
990 Ok(())
991}
992
993fn feed_inbound_chunk<E>(decoder: &mut FrameDecoder, endpoint: &E, chunk: &[u8]) -> StreamResult<()>
994where
995 E: StreamRefProtoEndpoint,
996{
997 decoder.push_chunk(chunk, endpoint)
998}
999
1000fn bind_tcp_listener<A>(addr: A) -> StreamResult<(TcpListener, StreamRefTcpBinding, Handle)>
1001where
1002 A: ToSocketAddrs + Send + 'static,
1003{
1004 let runtime = stream_ref_tcp_runtime()?;
1005 let listener = runtime
1006 .block_on(async { TcpListener::bind(addr).await })
1007 .map_err(io_error)?;
1008 let local_addr = listener.local_addr().map_err(io_error)?;
1009 Ok((
1010 listener,
1011 StreamRefTcpBinding { local_addr },
1012 runtime.handle().clone(),
1013 ))
1014}
1015
1016fn connect_tcp_stream<A>(addr: A) -> StreamResult<(TcpStream, Handle)>
1017where
1018 A: ToSocketAddrs + Send + 'static,
1019{
1020 let runtime = stream_ref_tcp_runtime()?;
1021 let stream = runtime
1022 .block_on(async { TcpStream::connect(addr).await })
1023 .map_err(io_error)?;
1024 stream.set_nodelay(true).map_err(io_error)?;
1025 Ok((stream, runtime.handle().clone()))
1026}
1027
1028fn stream_ref_tcp_runtime() -> StreamResult<&'static Runtime> {
1029 static RUNTIME: OnceLock<Result<Runtime, String>> = OnceLock::new();
1030 match RUNTIME.get_or_init(|| {
1031 tokio::runtime::Builder::new_multi_thread()
1032 .thread_name("datum-streamref-tcp")
1033 .enable_all()
1034 .build()
1035 .map_err(|error| error.to_string())
1036 }) {
1037 Ok(runtime) => Ok(runtime),
1038 Err(error) => Err(StreamError::Failed(format!(
1039 "failed to start StreamRefs TCP runtime: {error}"
1040 ))),
1041 }
1042}
1043
1044fn current_tokio_handle() -> StreamResult<Handle> {
1045 Handle::try_current().map_err(|error| {
1046 StreamError::Failed(format!(
1047 "StreamRefs TCP stream helper requires a current Tokio runtime: {error}"
1048 ))
1049 })
1050}
1051
1052fn io_error(error: std::io::Error) -> StreamError {
1053 StreamError::Failed(error.to_string())
1054}
1055
1056fn is_quic_teardown_loss(error: &StreamError) -> bool {
1057 matches!(error, StreamError::Failed(message) if message == "connection lost")
1058}
1059
1060fn spawn_carrier_thread<F>(run: F) -> mpsc::Receiver<StreamResult<NotUsed>>
1061where
1062 F: FnOnce() -> StreamResult<NotUsed> + Send + 'static,
1063{
1064 let (sender, receiver) = mpsc::channel();
1065 thread::spawn(move || {
1066 let result = run();
1067 let _ = sender.send(result);
1068 });
1069 receiver
1070}
1071
1072fn spawn_tcp_endpoint_task<F>(handle: &Handle, run: F) -> mpsc::Receiver<StreamResult<NotUsed>>
1073where
1074 F: Future<Output = StreamResult<NotUsed>> + Send + 'static,
1075{
1076 let (sender, receiver) = mpsc::channel();
1077 handle.spawn(async move {
1078 let result = run.await;
1079 let _ = sender.send(result);
1080 });
1081 receiver
1082}
1083
1084fn encode_carrier_outbound_into(
1085 outbound: &StreamRefOutbound,
1086 bytes: &mut Vec<u8>,
1087) -> StreamResult<()> {
1088 bytes.clear();
1089 match outbound {
1090 StreamRefOutbound::Frame(frame) => append_protobuf_carrier_frame(frame, bytes)?,
1091 StreamRefOutbound::SequencedBatch(batch) => {
1092 append_compact_payload_batch(batch, bytes)?;
1093 }
1094 }
1095 Ok(())
1096}
1097
1098#[cfg(test)]
1099fn encode_carrier_frames(frames: &[StreamRefFrame]) -> StreamResult<Vec<u8>> {
1100 let mut bytes = Vec::new();
1101 let mut index = 0;
1102 while index < frames.len() {
1103 if sequenced_on_next(&frames[index]).is_some() {
1104 let end = sequenced_run_end(frames, index);
1105 append_compact_sequenced_batches(&frames[index..end], &mut bytes)?;
1106 index = end;
1107 } else {
1108 append_protobuf_carrier_frame(&frames[index], &mut bytes)?;
1109 index += 1;
1110 }
1111 }
1112 Ok(bytes)
1113}
1114
1115fn append_compact_payload_batch(
1116 batch: &StreamRefPayloadBatch,
1117 bytes: &mut Vec<u8>,
1118) -> StreamResult<()> {
1119 let mut start = 0;
1120 while start < batch.count() {
1121 let mut end = start;
1122 let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
1123 while end < batch.count() {
1124 let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
1125 .checked_add(batch.payload_len(end))
1126 .ok_or(StreamError::LimitExceeded {
1127 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1128 })?;
1129 let next_payload_len =
1130 payload_len
1131 .checked_add(element_len)
1132 .ok_or(StreamError::LimitExceeded {
1133 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1134 })?;
1135 if end > start
1136 && (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
1137 || end - start >= u16::MAX as usize)
1138 {
1139 break;
1140 }
1141 if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
1142 return Err(StreamError::LimitExceeded {
1143 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1144 });
1145 }
1146 payload_len = next_payload_len;
1147 end += 1;
1148 }
1149 append_compact_payload_batch_slice(batch, start, end, payload_len, bytes)?;
1150 start = end;
1151 }
1152 Ok(())
1153}
1154
1155fn append_compact_payload_batch_slice(
1156 batch: &StreamRefPayloadBatch,
1157 start: usize,
1158 end: usize,
1159 payload_len: usize,
1160 bytes: &mut Vec<u8>,
1161) -> StreamResult<()> {
1162 let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
1163 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1164 })?;
1165 let count = u16::try_from(end - start).map_err(|_| StreamError::LimitExceeded {
1166 max: u16::MAX as u64,
1167 })?;
1168 let first_seq = batch
1169 .first_seq_nr()
1170 .checked_add(start as u64)
1171 .ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
1172 bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
1173 bytes.push(COMPACT_FRAME_VERSION);
1174 bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
1175 bytes.extend(batch.stream_ref_id().to_bytes());
1176 bytes.extend(first_seq.to_be_bytes());
1177 bytes.extend(count.to_be_bytes());
1178 for index in start..end {
1179 let payload = batch.payload(index);
1180 let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1181 max: u32::MAX as u64,
1182 })?;
1183 bytes.extend(payload_len.to_be_bytes());
1184 bytes.extend(payload);
1185 }
1186 Ok(())
1187}
1188
1189fn append_protobuf_carrier_frame(frame: &StreamRefFrame, bytes: &mut Vec<u8>) -> StreamResult<()> {
1190 let payload = frame.encode_to_vec();
1191 let len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1192 max: COMPACT_FRAME_LEN_MASK as u64,
1193 })?;
1194 if payload.len() > MAX_STREAM_REF_FRAME_BYTES || len > COMPACT_FRAME_LEN_MASK {
1195 return Err(StreamError::LimitExceeded {
1196 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1197 });
1198 }
1199 bytes.extend(len.to_be_bytes());
1200 bytes.extend(payload);
1201 Ok(())
1202}
1203
1204#[cfg(test)]
1205fn append_compact_sequenced_batches(
1206 frames: &[StreamRefFrame],
1207 bytes: &mut Vec<u8>,
1208) -> StreamResult<()> {
1209 let mut start = 0;
1210 while start < frames.len() {
1211 let mut end = start;
1212 let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
1213 while end < frames.len() {
1214 let (_, payload) = sequenced_on_next(&frames[end]).expect("sequenced frame");
1215 let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
1216 .checked_add(payload.len())
1217 .ok_or(StreamError::LimitExceeded {
1218 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1219 })?;
1220 let next_payload_len =
1221 payload_len
1222 .checked_add(element_len)
1223 .ok_or(StreamError::LimitExceeded {
1224 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1225 })?;
1226 if end > start
1227 && (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
1228 || end - start >= u16::MAX as usize)
1229 {
1230 break;
1231 }
1232 if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
1233 return Err(StreamError::LimitExceeded {
1234 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1235 });
1236 }
1237 payload_len = next_payload_len;
1238 end += 1;
1239 }
1240 append_compact_sequenced_batch(&frames[start..end], payload_len, bytes)?;
1241 start = end;
1242 }
1243 Ok(())
1244}
1245
1246#[cfg(test)]
1247fn append_compact_sequenced_batch(
1248 frames: &[StreamRefFrame],
1249 payload_len: usize,
1250 bytes: &mut Vec<u8>,
1251) -> StreamResult<()> {
1252 let (first_seq, _) = sequenced_on_next(&frames[0]).expect("sequenced frame");
1253 let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
1254 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1255 })?;
1256 let count = u16::try_from(frames.len()).map_err(|_| StreamError::LimitExceeded {
1257 max: u16::MAX as u64,
1258 })?;
1259 bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
1260 bytes.push(COMPACT_FRAME_VERSION);
1261 bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
1262 bytes.extend(frames[0].stream_ref_id.to_bytes());
1263 bytes.extend(first_seq.to_be_bytes());
1264 bytes.extend(count.to_be_bytes());
1265 for frame in frames {
1266 let (_, payload) = sequenced_on_next(frame).expect("sequenced frame");
1267 let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1268 max: u32::MAX as u64,
1269 })?;
1270 bytes.extend(payload_len.to_be_bytes());
1271 bytes.extend(payload);
1272 }
1273 Ok(())
1274}
1275
1276#[cfg(test)]
1277fn sequenced_run_end(frames: &[StreamRefFrame], start: usize) -> usize {
1278 let mut end = start + 1;
1279 while end < frames.len() {
1280 let Some((previous_seq, _)) = sequenced_on_next(&frames[end - 1]) else {
1281 break;
1282 };
1283 let Some((next_seq, _)) = sequenced_on_next(&frames[end]) else {
1284 break;
1285 };
1286 if frames[end].stream_ref_id != frames[start].stream_ref_id
1287 || next_seq != previous_seq.saturating_add(1)
1288 {
1289 break;
1290 }
1291 end += 1;
1292 }
1293 end
1294}
1295
1296#[cfg(test)]
1297fn sequenced_on_next(frame: &StreamRefFrame) -> Option<(u64, &[u8])> {
1298 match &frame.message {
1299 StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
1300 Some((*seq_nr, payload.bytes.as_slice()))
1301 }
1302 _ => None,
1303 }
1304}
1305
1306#[derive(Default)]
1307struct FrameDecoder {
1308 buffer: BytesMut,
1309 offset: usize,
1310}
1311
1312impl FrameDecoder {
1313 fn push_chunk<E>(&mut self, chunk: &[u8], endpoint: &E) -> StreamResult<()>
1314 where
1315 E: StreamRefProtoEndpoint,
1316 {
1317 self.buffer.extend_from_slice(chunk);
1318 while let Some(header) = self.peek_header()? {
1319 if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES + header.len {
1320 break;
1321 }
1322 let payload_start = self.offset + FRAME_LEN_BYTES;
1323 let payload_end = payload_start + header.len;
1324 let payload = &self.buffer[payload_start..payload_end];
1325 match header.kind {
1326 CarrierFrameKind::Protobuf => {
1327 endpoint.handle_frame(StreamRefFrame::decode(payload)?)?;
1328 }
1329 CarrierFrameKind::Compact => {
1330 decode_compact_carrier_frame(payload, endpoint)?;
1331 }
1332 }
1333 self.offset = payload_end;
1334 }
1335 if self.offset > 0 && (self.offset == self.buffer.len() || self.offset >= 64 * 1024) {
1336 self.buffer.advance(self.offset);
1337 self.offset = 0;
1338 }
1339 Ok(())
1340 }
1341
1342 fn peek_header(&self) -> StreamResult<Option<CarrierFrameHeader>> {
1343 if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES {
1344 return Ok(None);
1345 }
1346 let len = self.buffer[self.offset..self.offset + FRAME_LEN_BYTES]
1347 .try_into()
1348 .expect("frame header length");
1349 let raw_len = u32::from_be_bytes(len);
1350 let kind = if raw_len & COMPACT_FRAME_FLAG == 0 {
1351 CarrierFrameKind::Protobuf
1352 } else {
1353 CarrierFrameKind::Compact
1354 };
1355 let len = (raw_len & COMPACT_FRAME_LEN_MASK) as usize;
1356 if len > MAX_STREAM_REF_FRAME_BYTES {
1357 return Err(StreamError::LimitExceeded {
1358 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1359 });
1360 }
1361 Ok(Some(CarrierFrameHeader { kind, len }))
1362 }
1363}
1364
1365#[derive(Clone, Copy)]
1366struct CarrierFrameHeader {
1367 kind: CarrierFrameKind,
1368 len: usize,
1369}
1370
1371#[derive(Clone, Copy)]
1372enum CarrierFrameKind {
1373 Protobuf,
1374 Compact,
1375}
1376
1377fn decode_compact_carrier_frame<E>(payload: &[u8], endpoint: &E) -> StreamResult<()>
1378where
1379 E: StreamRefProtoEndpoint,
1380{
1381 if payload.len() < COMPACT_BATCH_HEADER_BYTES {
1382 return Err(StreamError::Failed(
1383 "compact StreamRefs carrier frame too short".to_owned(),
1384 ));
1385 }
1386 let version = payload[0];
1387 if version != COMPACT_FRAME_VERSION {
1388 return Err(StreamError::Failed(format!(
1389 "unsupported compact StreamRefs carrier frame version: {version}"
1390 )));
1391 }
1392 let kind = payload[1];
1393 if kind != COMPACT_SEQUENCED_ON_NEXT_BATCH {
1394 return Err(StreamError::Failed(format!(
1395 "unsupported compact StreamRefs carrier frame kind: {kind}"
1396 )));
1397 }
1398 let stream_ref_id = StreamRefId::from_bytes(&payload[2..18])?;
1399 let first_seq = u64::from_be_bytes(payload[18..26].try_into().expect("seq len"));
1400 let count = u16::from_be_bytes(payload[26..28].try_into().expect("count len")) as usize;
1401 if count == 0 {
1402 return Err(StreamError::Failed(
1403 "compact StreamRefs carrier batch is empty".to_owned(),
1404 ));
1405 }
1406
1407 let mut offset = COMPACT_BATCH_HEADER_BYTES;
1408 let mut payloads = Vec::with_capacity(count);
1409 for index in 0..count {
1410 if payload.len().saturating_sub(offset) < COMPACT_BATCH_ELEMENT_LEN_BYTES {
1411 return Err(StreamError::Failed(
1412 "compact StreamRefs carrier batch has truncated payload length".to_owned(),
1413 ));
1414 }
1415 let payload_len = u32::from_be_bytes(
1416 payload[offset..offset + COMPACT_BATCH_ELEMENT_LEN_BYTES]
1417 .try_into()
1418 .expect("payload len"),
1419 ) as usize;
1420 offset += COMPACT_BATCH_ELEMENT_LEN_BYTES;
1421 if payload.len().saturating_sub(offset) < payload_len {
1422 return Err(StreamError::Failed(
1423 "compact StreamRefs carrier batch has truncated payload".to_owned(),
1424 ));
1425 }
1426 first_seq
1427 .checked_add(index as u64)
1428 .ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
1429 payloads.push(&payload[offset..offset + payload_len]);
1430 offset += payload_len;
1431 }
1432 if offset != payload.len() {
1433 return Err(StreamError::Failed(
1434 "compact StreamRefs carrier batch has trailing bytes".to_owned(),
1435 ));
1436 }
1437 endpoint.handle_sequenced_on_next_batch(stream_ref_id, first_seq, &payloads)
1438}
1439
1440fn join_carrier_thread(handle: thread::JoinHandle<StreamResult<NotUsed>>) -> StreamResult<()> {
1441 match handle.join() {
1442 Ok(Ok(NotUsed)) => Ok(()),
1443 Ok(Err(error)) => Err(error),
1444 Err(_) => Err(StreamError::Failed(
1445 "StreamRefs carrier thread panicked".to_owned(),
1446 )),
1447 }
1448}
1449
1450#[cfg(test)]
1451mod tests {
1452 use super::*;
1453 use std::sync::{Arc, Mutex};
1454
1455 #[derive(Clone)]
1456 struct RecordingEndpoint {
1457 stream_ref_id: StreamRefId,
1458 frames: Arc<Mutex<Vec<StreamRefFrame>>>,
1459 }
1460
1461 impl RecordingEndpoint {
1462 fn new(stream_ref_id: StreamRefId) -> Self {
1463 Self {
1464 stream_ref_id,
1465 frames: Arc::new(Mutex::new(Vec::new())),
1466 }
1467 }
1468
1469 fn frames(&self) -> Vec<StreamRefFrame> {
1470 self.frames.lock().expect("recording endpoint").clone()
1471 }
1472 }
1473
1474 impl StreamRefProtoEndpoint for RecordingEndpoint {
1475 fn stream_ref_id(&self) -> StreamRefId {
1476 self.stream_ref_id
1477 }
1478
1479 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1480 None
1481 }
1482
1483 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1484 self.frames.lock().expect("recording endpoint").push(frame);
1485 Ok(())
1486 }
1487
1488 fn fail_connection(&self, _error: StreamError) {}
1489 }
1490
1491 #[test]
1492 fn carrier_frame_decoder_reassembles_split_frames() {
1493 let frame = StreamRefFrame::new(
1494 StreamRefId::from_u128(1),
1495 datum::StreamRefMessage::CumulativeDemand { seq_nr: 32 },
1496 );
1497 let bytes = encode_carrier_frames(std::slice::from_ref(&frame)).unwrap();
1498 let split = bytes.len() / 2;
1499 let mut decoder = FrameDecoder::default();
1500 let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(1));
1501
1502 decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
1503 assert!(endpoint.frames().is_empty());
1504 decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
1505 assert_eq!(endpoint.frames(), vec![frame]);
1506 }
1507
1508 #[test]
1509 fn compact_carrier_batch_round_trips_sequenced_frames() {
1510 let frames = (0_u64..3)
1511 .map(|seq_nr| {
1512 StreamRefFrame::new(
1513 StreamRefId::from_u128(7),
1514 datum::StreamRefMessage::SequencedOnNext {
1515 seq_nr,
1516 payload: datum::StreamRefPayloadBytes {
1517 bytes: seq_nr.to_be_bytes().to_vec(),
1518 },
1519 },
1520 )
1521 })
1522 .collect::<Vec<_>>();
1523 let bytes = encode_carrier_frames(&frames).unwrap();
1524 let header = u32::from_be_bytes(bytes[..FRAME_LEN_BYTES].try_into().unwrap());
1525 assert_ne!(header & COMPACT_FRAME_FLAG, 0);
1526
1527 let mut decoder = FrameDecoder::default();
1528 let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(7));
1529 decoder.push_chunk(&bytes, &endpoint).unwrap();
1530 assert_eq!(endpoint.frames(), frames);
1531 }
1532
1533 #[test]
1534 fn compact_carrier_batch_reassembles_split_frames() {
1535 let frames = (4_u64..8)
1536 .map(|seq_nr| {
1537 StreamRefFrame::new(
1538 StreamRefId::from_u128(8),
1539 datum::StreamRefMessage::SequencedOnNext {
1540 seq_nr,
1541 payload: datum::StreamRefPayloadBytes {
1542 bytes: vec![seq_nr as u8],
1543 },
1544 },
1545 )
1546 })
1547 .collect::<Vec<_>>();
1548 let bytes = encode_carrier_frames(&frames).unwrap();
1549 let split = FRAME_LEN_BYTES + 5;
1550 let mut decoder = FrameDecoder::default();
1551 let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(8));
1552
1553 decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
1554 assert!(endpoint.frames().is_empty());
1555 decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
1556 assert_eq!(endpoint.frames(), frames);
1557 }
1558}