1use std::{
8 collections::VecDeque,
9 future::Future,
10 net::SocketAddr,
11 sync::{Arc, Mutex, OnceLock, mpsc},
12 time::Duration,
13};
14
15use bytes::{Buf, BytesMut};
16use datum::{
17 NotUsed, Sink, Source, SourceRef, StreamCompletion, StreamError, StreamRefFrame, StreamRefId,
18 StreamRefMessage, StreamRefOutbound, StreamRefPayload, StreamRefPayloadBatch,
19 StreamRefProtoConsumer, StreamRefProtoEndpoint, StreamRefProtoProducer, StreamRefSettings,
20 StreamResult,
21 actor::stream_ref_proto::{StreamRefOutboundPoll, StreamRefProtoEndpointWake},
22};
23use tokio::{
24 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
25 net::{TcpListener, TcpStream, ToSocketAddrs},
26 runtime::{Handle, Runtime},
27 sync::mpsc as tokio_mpsc,
28};
29
30use crate::QuicBidirectionalStream;
31
32const FRAME_LEN_BYTES: usize = 4;
33const MAX_STREAM_REF_FRAME_BYTES: usize = 16 * 1024 * 1024;
34const STREAM_REF_TCP_CHUNK_SIZE: usize = 8192;
35const STREAM_REF_QUIC_READ_BUFFER_BYTES: usize = 2048;
36const STREAM_REF_OUTBOUND_BATCH_FRAMES: usize = 64;
37const STREAM_REF_OUTBOUND_RECHECK_INTERVAL: Duration = Duration::from_millis(5);
38
39const COMPACT_FRAME_FLAG: u32 = 0x8000_0000;
43const COMPACT_FRAME_LEN_MASK: u32 = 0x7fff_ffff;
44const COMPACT_FRAME_VERSION: u8 = 1;
45const COMPACT_SEQUENCED_ON_NEXT_BATCH: u8 = 1;
46const COMPACT_BATCH_HEADER_BYTES: usize = 1 + 1 + 16 + 8 + 2;
47const COMPACT_BATCH_ELEMENT_LEN_BYTES: usize = 4;
48
49#[derive(Clone, Copy)]
50struct CarrierReadMode {
51 chunk_size: usize,
52 emit_available: bool,
53 fail_on_eof: bool,
54}
55
56impl CarrierReadMode {
57 fn new(chunk_size: usize, emit_available: bool, fail_on_eof: bool) -> Self {
58 assert!(chunk_size > 0, "chunk size must be greater than zero");
59 Self {
60 chunk_size,
61 emit_available,
62 fail_on_eof,
63 }
64 }
65}
66
67#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
70pub struct StreamRefProtocolMessageCounts {
71 pub cumulative_demand: u64,
72 pub sequenced_on_next: u64,
73 pub ack: u64,
74}
75
76#[derive(Clone, Default)]
82pub struct StreamRefProtocolDiagnostics {
83 counts: Arc<Mutex<StreamRefProtocolMessageCounts>>,
84}
85
86impl StreamRefProtocolDiagnostics {
87 #[must_use]
88 pub fn new() -> Self {
89 Self::default()
90 }
91
92 #[must_use]
93 pub fn snapshot(&self) -> StreamRefProtocolMessageCounts {
94 *self
95 .counts
96 .lock()
97 .expect("stream ref protocol diagnostics poisoned")
98 }
99
100 fn record_counts(&self, delta: StreamRefProtocolMessageCounts) {
101 if delta == StreamRefProtocolMessageCounts::default() {
102 return;
103 }
104 let mut counts = self
105 .counts
106 .lock()
107 .expect("stream ref protocol diagnostics poisoned");
108 counts.cumulative_demand = counts
109 .cumulative_demand
110 .saturating_add(delta.cumulative_demand);
111 counts.sequenced_on_next = counts
112 .sequenced_on_next
113 .saturating_add(delta.sequenced_on_next);
114 counts.ack = counts.ack.saturating_add(delta.ack);
115 }
116}
117
118fn outbound_counts(outbound: &StreamRefOutbound) -> StreamRefProtocolMessageCounts {
119 let mut counts = StreamRefProtocolMessageCounts::default();
120 match outbound {
121 StreamRefOutbound::Frame(frame) => match &frame.message {
122 StreamRefMessage::CumulativeDemand { .. } => {
123 counts.cumulative_demand = 1;
124 }
125 StreamRefMessage::SequencedOnNext { .. } => {
126 counts.sequenced_on_next = 1;
127 }
128 StreamRefMessage::Ack => {
129 counts.ack = 1;
130 }
131 StreamRefMessage::OnSubscribeHandshake
132 | StreamRefMessage::RemoteStreamCompleted { .. }
133 | StreamRefMessage::RemoteStreamFailure { .. } => {}
134 },
135 StreamRefOutbound::SequencedBatch(batch) => {
136 counts.sequenced_on_next = batch.count() as u64;
137 }
138 }
139 counts
140}
141
142#[derive(Clone, Copy)]
143struct PendingDiagnostic {
144 remaining: usize,
145 counts: StreamRefProtocolMessageCounts,
146}
147
148#[must_use = "wait for the QUIC StreamRefs carrier to observe completion or failure"]
150pub struct StreamRefQuicHandle {
151 receiver: mpsc::Receiver<StreamResult<NotUsed>>,
152}
153
154impl StreamRefQuicHandle {
155 pub fn wait(self) -> StreamResult<NotUsed> {
156 self.receiver
157 .recv()
158 .unwrap_or(Err(StreamError::AbruptTermination))
159 }
160
161 #[must_use]
162 pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
163 self.receiver.try_recv().ok()
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub struct StreamRefTcpBinding {
170 local_addr: SocketAddr,
171}
172
173impl StreamRefTcpBinding {
174 #[must_use]
175 pub fn local_addr(&self) -> SocketAddr {
176 self.local_addr
177 }
178}
179
180#[must_use = "wait for the TCP StreamRefs carrier to observe completion or failure"]
182pub struct StreamRefTcpHandle {
183 receiver: mpsc::Receiver<StreamResult<NotUsed>>,
184}
185
186impl StreamRefTcpHandle {
187 pub fn wait(self) -> StreamResult<NotUsed> {
188 self.receiver
189 .recv()
190 .unwrap_or(Err(StreamError::AbruptTermination))
191 }
192
193 #[must_use]
194 pub fn try_wait(&self) -> Option<StreamResult<NotUsed>> {
195 self.receiver.try_recv().ok()
196 }
197}
198
199pub fn serve_source_ref_over_quic<T>(
201 stream: QuicBidirectionalStream,
202 source_ref: SourceRef<T>,
203 stream_ref_id: StreamRefId,
204 settings: StreamRefSettings,
205) -> StreamResult<StreamRefQuicHandle>
206where
207 T: StreamRefPayload,
208{
209 let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
210 Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
211}
212
213pub fn serve_source_over_quic<T, Mat>(
215 stream: QuicBidirectionalStream,
216 source: Source<T, Mat>,
217 stream_ref_id: StreamRefId,
218 settings: StreamRefSettings,
219) -> StreamResult<StreamRefQuicHandle>
220where
221 T: StreamRefPayload,
222 Mat: Send + 'static,
223{
224 let producer = StreamRefProtoProducer::from_source(source, stream_ref_id, settings)?;
225 Ok(drive_stream_ref_endpoint_over_quic(stream, producer, None))
226}
227
228pub fn source_ref_over_quic<T>(
230 stream: QuicBidirectionalStream,
231 stream_ref_id: StreamRefId,
232 settings: StreamRefSettings,
233) -> (Source<T, NotUsed>, StreamRefQuicHandle)
234where
235 T: StreamRefPayload,
236{
237 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
238 let source = consumer.source();
239 let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
240 (source, handle)
241}
242
243pub fn serve_sink_ref_over_quic<T>(
251 stream: QuicBidirectionalStream,
252 stream_ref_id: StreamRefId,
253 settings: StreamRefSettings,
254) -> (Source<T, NotUsed>, StreamRefQuicHandle)
255where
256 T: StreamRefPayload,
257{
258 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
259 let source = consumer.source();
260 let handle = drive_stream_ref_endpoint_over_quic(stream, consumer, None);
261 (source, handle)
262}
263
264pub fn sink_ref_over_quic<T>(
267 stream: QuicBidirectionalStream,
268 stream_ref_id: StreamRefId,
269 settings: StreamRefSettings,
270) -> (Sink<T, StreamCompletion<NotUsed>>, StreamRefQuicHandle)
271where
272 T: StreamRefPayload,
273{
274 let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
275 let sink = producer.sink();
276 let handle = drive_stream_ref_endpoint_over_quic(stream, producer, None);
277 (sink, handle)
278}
279
280pub fn serve_source_ref_over_tcp<T, A>(
287 addr: A,
288 source_ref: SourceRef<T>,
289 stream_ref_id: StreamRefId,
290 settings: StreamRefSettings,
291) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
292where
293 T: StreamRefPayload,
294 A: ToSocketAddrs + Send + 'static,
295{
296 serve_source_ref_over_tcp_with_diagnostics(addr, source_ref, stream_ref_id, settings, None)
297}
298
299pub fn serve_source_ref_over_tcp_with_diagnostics<T, A>(
300 addr: A,
301 source_ref: SourceRef<T>,
302 stream_ref_id: StreamRefId,
303 settings: StreamRefSettings,
304 diagnostics: Option<StreamRefProtocolDiagnostics>,
305) -> StreamResult<(StreamRefTcpBinding, StreamRefTcpHandle)>
306where
307 T: StreamRefPayload,
308 A: ToSocketAddrs + Send + 'static,
309{
310 let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
311 let (listener, binding, handle) = bind_tcp_listener(addr)?;
312 Ok((
313 binding,
314 drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics),
315 ))
316}
317
318pub fn serve_source_ref_over_tcp_stream<T>(
324 stream: TcpStream,
325 source_ref: SourceRef<T>,
326 stream_ref_id: StreamRefId,
327 settings: StreamRefSettings,
328) -> StreamResult<StreamRefTcpHandle>
329where
330 T: StreamRefPayload,
331{
332 serve_source_ref_over_tcp_stream_with_diagnostics(
333 stream,
334 source_ref,
335 stream_ref_id,
336 settings,
337 None,
338 )
339}
340
341pub fn serve_source_ref_over_tcp_stream_with_diagnostics<T>(
342 stream: TcpStream,
343 source_ref: SourceRef<T>,
344 stream_ref_id: StreamRefId,
345 settings: StreamRefSettings,
346 diagnostics: Option<StreamRefProtocolDiagnostics>,
347) -> StreamResult<StreamRefTcpHandle>
348where
349 T: StreamRefPayload,
350{
351 let producer = StreamRefProtoProducer::from_source_ref(source_ref, stream_ref_id, settings)?;
352 let handle = current_tokio_handle()?;
353 Ok(drive_stream_ref_endpoint_over_tcp_stream(
354 stream,
355 handle,
356 producer,
357 diagnostics,
358 ))
359}
360
361pub fn source_ref_over_tcp<T, A>(
366 addr: A,
367 stream_ref_id: StreamRefId,
368 settings: StreamRefSettings,
369) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
370where
371 T: StreamRefPayload,
372 A: ToSocketAddrs + Send + 'static,
373{
374 source_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
375}
376
377pub fn source_ref_over_tcp_with_diagnostics<T, A>(
378 addr: A,
379 stream_ref_id: StreamRefId,
380 settings: StreamRefSettings,
381 diagnostics: Option<StreamRefProtocolDiagnostics>,
382) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
383where
384 T: StreamRefPayload,
385 A: ToSocketAddrs + Send + 'static,
386{
387 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
388 let source = consumer.source();
389 let (stream, handle) = connect_tcp_stream(addr)?;
390 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
391 Ok((source, handle))
392}
393
394pub fn source_ref_over_tcp_stream<T>(
397 stream: TcpStream,
398 stream_ref_id: StreamRefId,
399 settings: StreamRefSettings,
400) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
401where
402 T: StreamRefPayload,
403{
404 source_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
405}
406
407pub fn source_ref_over_tcp_stream_with_diagnostics<T>(
408 stream: TcpStream,
409 stream_ref_id: StreamRefId,
410 settings: StreamRefSettings,
411 diagnostics: Option<StreamRefProtocolDiagnostics>,
412) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
413where
414 T: StreamRefPayload,
415{
416 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
417 let source = consumer.source();
418 let handle = current_tokio_handle()?;
419 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
420 Ok((source, handle))
421}
422
423pub fn serve_sink_ref_over_tcp<T, A>(
429 addr: A,
430 stream_ref_id: StreamRefId,
431 settings: StreamRefSettings,
432) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
433where
434 T: StreamRefPayload,
435 A: ToSocketAddrs + Send + 'static,
436{
437 serve_sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
438}
439
440pub fn serve_sink_ref_over_tcp_with_diagnostics<T, A>(
441 addr: A,
442 stream_ref_id: StreamRefId,
443 settings: StreamRefSettings,
444 diagnostics: Option<StreamRefProtocolDiagnostics>,
445) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
446where
447 T: StreamRefPayload,
448 A: ToSocketAddrs + Send + 'static,
449{
450 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
451 let source = consumer.source();
452 let (stream, handle) = connect_tcp_stream(addr)?;
453 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
454 Ok((source, handle))
455}
456
457pub fn serve_sink_ref_over_tcp_stream<T>(
460 stream: TcpStream,
461 stream_ref_id: StreamRefId,
462 settings: StreamRefSettings,
463) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
464where
465 T: StreamRefPayload,
466{
467 serve_sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
468}
469
470pub fn serve_sink_ref_over_tcp_stream_with_diagnostics<T>(
471 stream: TcpStream,
472 stream_ref_id: StreamRefId,
473 settings: StreamRefSettings,
474 diagnostics: Option<StreamRefProtocolDiagnostics>,
475) -> StreamResult<(Source<T, NotUsed>, StreamRefTcpHandle)>
476where
477 T: StreamRefPayload,
478{
479 let consumer = StreamRefProtoConsumer::new(stream_ref_id, settings);
480 let source = consumer.source();
481 let handle = current_tokio_handle()?;
482 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, consumer, diagnostics);
483 Ok((source, handle))
484}
485
486pub fn sink_ref_over_tcp<T, A>(
494 addr: A,
495 stream_ref_id: StreamRefId,
496 settings: StreamRefSettings,
497) -> StreamResult<(
498 Sink<T, StreamCompletion<NotUsed>>,
499 StreamRefTcpBinding,
500 StreamRefTcpHandle,
501)>
502where
503 T: StreamRefPayload,
504 A: ToSocketAddrs + Send + 'static,
505{
506 sink_ref_over_tcp_with_diagnostics(addr, stream_ref_id, settings, None)
507}
508
509pub fn sink_ref_over_tcp_with_diagnostics<T, A>(
510 addr: A,
511 stream_ref_id: StreamRefId,
512 settings: StreamRefSettings,
513 diagnostics: Option<StreamRefProtocolDiagnostics>,
514) -> StreamResult<(
515 Sink<T, StreamCompletion<NotUsed>>,
516 StreamRefTcpBinding,
517 StreamRefTcpHandle,
518)>
519where
520 T: StreamRefPayload,
521 A: ToSocketAddrs + Send + 'static,
522{
523 let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
524 let sink = producer.sink();
525 let (listener, binding, handle) = bind_tcp_listener(addr)?;
526 let handle =
527 drive_stream_ref_endpoint_over_tcp_listener(listener, handle, producer, diagnostics);
528 Ok((sink, binding, handle))
529}
530
531pub fn sink_ref_over_tcp_stream<T>(
534 stream: TcpStream,
535 stream_ref_id: StreamRefId,
536 settings: StreamRefSettings,
537) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
538where
539 T: StreamRefPayload,
540{
541 sink_ref_over_tcp_stream_with_diagnostics(stream, stream_ref_id, settings, None)
542}
543
544pub fn sink_ref_over_tcp_stream_with_diagnostics<T>(
545 stream: TcpStream,
546 stream_ref_id: StreamRefId,
547 settings: StreamRefSettings,
548 diagnostics: Option<StreamRefProtocolDiagnostics>,
549) -> StreamResult<(Sink<T, StreamCompletion<NotUsed>>, StreamRefTcpHandle)>
550where
551 T: StreamRefPayload,
552{
553 let producer = StreamRefProtoProducer::new_lazy(stream_ref_id, settings);
554 let sink = producer.sink();
555 let handle = current_tokio_handle()?;
556 let handle = drive_stream_ref_endpoint_over_tcp_stream(stream, handle, producer, diagnostics);
557 Ok((sink, handle))
558}
559
560fn drive_stream_ref_endpoint_over_quic<E>(
561 stream: QuicBidirectionalStream,
562 endpoint: E,
563 diagnostics: Option<StreamRefProtocolDiagnostics>,
564) -> StreamRefQuicHandle
565where
566 E: StreamRefProtoEndpointWake,
567{
568 let (reader, writer, handle, chunk_size, emit_available) = stream.into_stream_ref_parts();
569 let read_mode = CarrierReadMode::new(chunk_size, emit_available, false);
570 StreamRefQuicHandle {
571 receiver: spawn_endpoint_task(&handle, async move {
572 run_stream_ref_endpoint_quic_task(reader, writer, endpoint, read_mode, diagnostics)
573 .await
574 }),
575 }
576}
577
578fn drive_stream_ref_endpoint_over_tcp_listener<E>(
579 listener: TcpListener,
580 handle: Handle,
581 endpoint: E,
582 diagnostics: Option<StreamRefProtocolDiagnostics>,
583) -> StreamRefTcpHandle
584where
585 E: StreamRefProtoEndpointWake,
586{
587 StreamRefTcpHandle {
588 receiver: spawn_endpoint_task(&handle, async move {
589 let (stream, _) = listener.accept().await.map_err(io_error)?;
590 run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
591 }),
592 }
593}
594
595fn drive_stream_ref_endpoint_over_tcp_stream<E>(
596 stream: TcpStream,
597 handle: Handle,
598 endpoint: E,
599 diagnostics: Option<StreamRefProtocolDiagnostics>,
600) -> StreamRefTcpHandle
601where
602 E: StreamRefProtoEndpointWake,
603{
604 StreamRefTcpHandle {
605 receiver: spawn_endpoint_task(&handle, async move {
606 run_stream_ref_endpoint_tcp_task(stream, endpoint, diagnostics).await
607 }),
608 }
609}
610
611async fn run_stream_ref_endpoint_quic_task<R, W, E>(
612 reader: R,
613 writer: W,
614 endpoint: E,
615 read_mode: CarrierReadMode,
616 diagnostics: Option<StreamRefProtocolDiagnostics>,
617) -> StreamResult<NotUsed>
618where
619 R: AsyncRead + Unpin + Send + 'static,
620 W: AsyncWrite + Unpin + Send + 'static,
621 E: StreamRefProtoEndpointWake,
622{
623 let (wake_sender, wake_receiver) = tokio_mpsc::channel(1);
624 endpoint.install_outbound_wake(wake_sender.clone());
625 let _ = wake_sender.try_send(());
626
627 let result = QuicEndpointTask {
628 reader,
629 writer,
630 endpoint: endpoint.clone(),
631 diagnostics,
632 read_mode,
633 decoder: FrameDecoder::default(),
634 read_buffer: vec![
635 0_u8;
636 read_mode
637 .chunk_size
638 .clamp(1, STREAM_REF_QUIC_READ_BUFFER_BYTES)
639 ],
640 pending_tail: Vec::new(),
641 write_buffer: BytesMut::new(),
642 encode_buffer: Vec::new(),
643 pending_diagnostics: VecDeque::new(),
644 read_closed: false,
645 inbound_seen: false,
646 outbound_written: false,
647 recheck_outbound: false,
648 outbound_closed: false,
649 write_shutdown: false,
650 wake_receiver,
651 }
652 .run()
653 .await;
654
655 endpoint.clear_outbound_wake();
656 if let Err(error) = &result {
657 endpoint.fail_connection(error.clone());
658 }
659 result
660}
661
662struct QuicEndpointTask<R, W, E>
663where
664 E: StreamRefProtoEndpointWake,
665{
666 reader: R,
667 writer: W,
668 endpoint: E,
669 diagnostics: Option<StreamRefProtocolDiagnostics>,
670 read_mode: CarrierReadMode,
671 decoder: FrameDecoder,
672 read_buffer: Vec<u8>,
673 pending_tail: Vec<u8>,
674 write_buffer: BytesMut,
675 encode_buffer: Vec<u8>,
676 pending_diagnostics: VecDeque<PendingDiagnostic>,
677 read_closed: bool,
678 inbound_seen: bool,
679 outbound_written: bool,
680 recheck_outbound: bool,
681 outbound_closed: bool,
682 write_shutdown: bool,
683 wake_receiver: tokio_mpsc::Receiver<()>,
684}
685
686impl<R, W, E> QuicEndpointTask<R, W, E>
687where
688 R: AsyncRead + Unpin,
689 W: AsyncWrite + Unpin,
690 E: StreamRefProtoEndpointWake,
691{
692 async fn run(mut self) -> StreamResult<NotUsed> {
693 loop {
694 if self.read_closed && self.write_shutdown {
695 return Ok(NotUsed);
696 }
697
698 self.drain_outbound()?;
699 if self.outbound_closed && self.write_buffer.is_empty() && !self.write_shutdown {
700 self.writer.shutdown().await.map_err(io_error)?;
701 self.write_shutdown = true;
702 continue;
703 }
704 tokio::select! {
705 biased;
706 wake = self.wake_receiver.recv(), if !self.outbound_closed => {
707 if wake.is_none() {
708 self.drain_outbound()?;
709 }
710 }
711 _ = tokio::time::sleep(STREAM_REF_OUTBOUND_RECHECK_INTERVAL), if self.recheck_outbound && !self.outbound_closed && self.write_buffer.is_empty() => {
712 self.drain_outbound()?;
713 }
714 written = self.writer.write(&self.write_buffer), if !self.write_buffer.is_empty() => {
715 self.handle_written(written)?;
716 }
717 read = self.reader.read(&mut self.read_buffer), if !self.read_closed => {
718 match read {
719 Ok(0) => self.handle_eof()?,
720 Ok(read) => {
721 self.feed_read_buffer(read)?;
722 self.drain_outbound()?;
723 }
724 Err(error) => {
725 let error = io_error(error);
726 if self.write_shutdown && is_quic_teardown_loss(&error) {
727 return Ok(NotUsed);
728 }
729 return Err(error);
730 }
731 }
732 }
733 }
734 }
735 }
736
737 fn drain_outbound(&mut self) -> StreamResult<()> {
738 self.recheck_outbound = false;
739 while !self.outbound_closed && self.write_buffer.len() < MAX_STREAM_REF_FRAME_BYTES {
740 match self
741 .endpoint
742 .try_next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
743 {
744 StreamRefOutboundPoll::Ready(Ok(outbound)) => {
745 encode_carrier_outbound_into(&outbound, &mut self.encode_buffer)?;
746 let encoded_len = self.encode_buffer.len();
747 if encoded_len == 0 {
748 continue;
749 }
750 if self.diagnostics.is_some() {
751 self.pending_diagnostics.push_back(PendingDiagnostic {
752 remaining: encoded_len,
753 counts: outbound_counts(&outbound),
754 });
755 }
756 self.outbound_written = true;
757 self.write_buffer.extend_from_slice(&self.encode_buffer);
758 }
759 StreamRefOutboundPoll::Ready(Err(error)) => return Err(error),
760 StreamRefOutboundPoll::Pending => {
761 self.recheck_outbound =
762 !self.outbound_written || self.inbound_seen || self.read_closed;
763 break;
764 }
765 StreamRefOutboundPoll::Closed => {
766 self.outbound_closed = true;
767 break;
768 }
769 }
770 }
771 Ok(())
772 }
773
774 fn handle_written(&mut self, written: Result<usize, std::io::Error>) -> StreamResult<()> {
775 match written {
776 Ok(0) => Err(StreamError::Failed(
777 "StreamRefs QUIC stream accepted zero write bytes".to_owned(),
778 )),
779 Ok(written) => {
780 self.write_buffer.advance(written);
781 self.record_written_bytes(written);
782 Ok(())
783 }
784 Err(error) => Err(io_error(error)),
785 }
786 }
787
788 fn record_written_bytes(&mut self, mut written: usize) {
789 let Some(diagnostics) = &self.diagnostics else {
790 return;
791 };
792 while written > 0 {
793 let Some(front) = self.pending_diagnostics.front_mut() else {
794 return;
795 };
796 if written < front.remaining {
797 front.remaining -= written;
798 return;
799 }
800 written -= front.remaining;
801 let counts = front.counts;
802 self.pending_diagnostics.pop_front();
803 diagnostics.record_counts(counts);
804 }
805 }
806
807 fn feed_read_buffer(&mut self, read: usize) -> StreamResult<()> {
808 feed_read_bytes(
809 &mut self.decoder,
810 &self.endpoint,
811 self.read_mode,
812 &mut self.pending_tail,
813 &self.read_buffer[..read],
814 )?;
815 self.inbound_seen = true;
816 Ok(())
817 }
818
819 fn handle_eof(&mut self) -> StreamResult<()> {
820 if !self.pending_tail.is_empty() {
821 feed_inbound_chunk(&mut self.decoder, &self.endpoint, &self.pending_tail)?;
822 self.pending_tail.clear();
823 }
824 if self.read_mode.fail_on_eof {
825 self.endpoint
826 .fail_connection(StreamError::AbruptTermination);
827 }
828 self.read_closed = true;
829 self.recheck_outbound = true;
830 Ok(())
831 }
832}
833
834async fn run_stream_ref_endpoint_tcp_task<E>(
835 stream: TcpStream,
836 endpoint: E,
837 diagnostics: Option<StreamRefProtocolDiagnostics>,
838) -> StreamResult<NotUsed>
839where
840 E: StreamRefProtoEndpointWake,
841{
842 let (wake_sender, wake_receiver) = tokio_mpsc::channel(1);
843 endpoint.install_outbound_wake(wake_sender.clone());
844 let _ = wake_sender.try_send(());
845
846 let result = TcpEndpointTask {
847 stream,
848 endpoint: endpoint.clone(),
849 diagnostics,
850 read_mode: CarrierReadMode::new(STREAM_REF_TCP_CHUNK_SIZE, true, true),
851 decoder: FrameDecoder::default(),
852 read_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
853 pending_tail: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
854 write_buffer: BytesMut::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
855 encode_buffer: Vec::with_capacity(STREAM_REF_TCP_CHUNK_SIZE),
856 pending_diagnostics: VecDeque::new(),
857 outbound_closed: false,
858 write_shutdown: false,
859 wake_receiver,
860 }
861 .run()
862 .await;
863
864 endpoint.clear_outbound_wake();
865 if let Err(error) = &result {
866 endpoint.fail_connection(error.clone());
867 }
868 result
869}
870
871struct TcpEndpointTask<E>
872where
873 E: StreamRefProtoEndpointWake,
874{
875 stream: TcpStream,
876 endpoint: E,
877 diagnostics: Option<StreamRefProtocolDiagnostics>,
878 read_mode: CarrierReadMode,
879 decoder: FrameDecoder,
880 read_buffer: BytesMut,
881 pending_tail: Vec<u8>,
882 write_buffer: BytesMut,
883 encode_buffer: Vec<u8>,
884 pending_diagnostics: VecDeque<PendingDiagnostic>,
885 outbound_closed: bool,
886 write_shutdown: bool,
887 wake_receiver: tokio_mpsc::Receiver<()>,
888}
889
890impl<E> TcpEndpointTask<E>
891where
892 E: StreamRefProtoEndpointWake,
893{
894 async fn run(mut self) -> StreamResult<NotUsed> {
895 self.stream.set_nodelay(true).map_err(io_error)?;
896 loop {
897 self.drain_outbound()?;
898 if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) {
899 self.flush_write_buffer().await?;
900 }
901
902 tokio::select! {
903 biased;
904 wake = self.wake_receiver.recv() => {
905 if wake.is_none() && !self.outbound_closed {
906 self.drain_outbound()?;
907 }
908 }
909 ready = self.stream.readable() => {
910 ready.map_err(io_error)?;
911 if self.read_available()? {
912 return Ok(NotUsed);
913 }
914 }
915 ready = self.stream.writable(), if !self.write_buffer.is_empty() || (self.outbound_closed && !self.write_shutdown) => {
916 ready.map_err(io_error)?;
917 self.flush_ready_write_buffer()?;
918 }
919 }
920 }
921 }
922
923 fn drain_outbound(&mut self) -> StreamResult<()> {
924 while !self.outbound_closed && self.write_buffer.len() < MAX_STREAM_REF_FRAME_BYTES {
925 match self
926 .endpoint
927 .try_next_outbound(STREAM_REF_OUTBOUND_BATCH_FRAMES, MAX_STREAM_REF_FRAME_BYTES)
928 {
929 StreamRefOutboundPoll::Ready(Ok(outbound)) => {
930 encode_carrier_outbound_into(&outbound, &mut self.encode_buffer)?;
931 let encoded_len = self.encode_buffer.len();
932 if encoded_len == 0 {
933 continue;
934 }
935 if self.diagnostics.is_some() {
936 self.pending_diagnostics.push_back(PendingDiagnostic {
937 remaining: encoded_len,
938 counts: outbound_counts(&outbound),
939 });
940 }
941 self.write_buffer.extend_from_slice(&self.encode_buffer);
942 }
943 StreamRefOutboundPoll::Ready(Err(error)) => return Err(error),
944 StreamRefOutboundPoll::Pending => break,
945 StreamRefOutboundPoll::Closed => {
946 self.outbound_closed = true;
947 break;
948 }
949 }
950 }
951 Ok(())
952 }
953
954 async fn flush_write_buffer(&mut self) -> StreamResult<()> {
955 if !self.write_buffer.is_empty() {
956 self.stream.writable().await.map_err(io_error)?;
957 self.flush_ready_write_buffer()?;
958 }
959 if self.outbound_closed && self.write_buffer.is_empty() && !self.write_shutdown {
960 self.stream.shutdown().await.map_err(io_error)?;
961 self.write_shutdown = true;
962 }
963 Ok(())
964 }
965
966 fn flush_ready_write_buffer(&mut self) -> StreamResult<()> {
967 while !self.write_buffer.is_empty() {
968 match self.stream.try_write(&self.write_buffer) {
969 Ok(0) => {
970 return Err(StreamError::Failed(
971 "StreamRefs TCP socket accepted zero write bytes".to_owned(),
972 ));
973 }
974 Ok(written) => {
975 self.write_buffer.advance(written);
976 self.record_written_bytes(written);
977 }
978 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
979 Err(error) => return Err(io_error(error)),
980 }
981 }
982
983 Ok(())
984 }
985
986 fn record_written_bytes(&mut self, mut written: usize) {
987 let Some(diagnostics) = &self.diagnostics else {
988 return;
989 };
990 while written > 0 {
991 let Some(front) = self.pending_diagnostics.front_mut() else {
992 return;
993 };
994 if written < front.remaining {
995 front.remaining -= written;
996 return;
997 }
998 written -= front.remaining;
999 let counts = front.counts;
1000 self.pending_diagnostics.pop_front();
1001 diagnostics.record_counts(counts);
1002 }
1003 }
1004
1005 fn read_available(&mut self) -> StreamResult<bool> {
1006 loop {
1007 self.read_buffer.reserve(self.read_mode.chunk_size);
1008 match self.stream.try_read_buf(&mut self.read_buffer) {
1009 Ok(0) => return self.handle_eof(),
1010 Ok(_) => {
1011 feed_read_bytes(
1012 &mut self.decoder,
1013 &self.endpoint,
1014 self.read_mode,
1015 &mut self.pending_tail,
1016 &self.read_buffer,
1017 )?;
1018 self.read_buffer.clear();
1019 self.drain_outbound()?;
1020 }
1021 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(false),
1022 Err(error) => return Err(io_error(error)),
1023 }
1024 }
1025 }
1026
1027 fn handle_eof(&mut self) -> StreamResult<bool> {
1028 if !self.pending_tail.is_empty() {
1029 feed_inbound_chunk(&mut self.decoder, &self.endpoint, &self.pending_tail)?;
1030 self.pending_tail.clear();
1031 }
1032 if self.read_mode.fail_on_eof {
1033 self.endpoint
1034 .fail_connection(StreamError::AbruptTermination);
1035 }
1036 Ok(true)
1037 }
1038}
1039
1040fn feed_read_bytes<E>(
1041 decoder: &mut FrameDecoder,
1042 endpoint: &E,
1043 read_mode: CarrierReadMode,
1044 pending_tail: &mut Vec<u8>,
1045 read_buffer: &[u8],
1046) -> StreamResult<()>
1047where
1048 E: StreamRefProtoEndpoint,
1049{
1050 if read_mode.emit_available {
1051 if !pending_tail.is_empty() {
1052 pending_tail.extend_from_slice(read_buffer);
1053 feed_inbound_chunk(decoder, endpoint, pending_tail)?;
1054 pending_tail.clear();
1055 return Ok(());
1056 }
1057 return feed_inbound_chunk(decoder, endpoint, read_buffer);
1058 }
1059
1060 let mut offset = 0;
1061 if !pending_tail.is_empty() {
1062 let needed = read_mode.chunk_size - pending_tail.len();
1063 let take = needed.min(read_buffer.len());
1064 pending_tail.extend_from_slice(&read_buffer[..take]);
1065 offset += take;
1066 if pending_tail.len() == read_mode.chunk_size {
1067 feed_inbound_chunk(decoder, endpoint, pending_tail)?;
1068 pending_tail.clear();
1069 }
1070 }
1071
1072 while offset + read_mode.chunk_size <= read_buffer.len() {
1073 let next = offset + read_mode.chunk_size;
1074 feed_inbound_chunk(decoder, endpoint, &read_buffer[offset..next])?;
1075 offset = next;
1076 }
1077
1078 if offset < read_buffer.len() {
1079 pending_tail.extend_from_slice(&read_buffer[offset..]);
1080 }
1081 Ok(())
1082}
1083
1084fn feed_inbound_chunk<E>(decoder: &mut FrameDecoder, endpoint: &E, chunk: &[u8]) -> StreamResult<()>
1085where
1086 E: StreamRefProtoEndpoint,
1087{
1088 decoder.push_chunk(chunk, endpoint)
1089}
1090
1091fn bind_tcp_listener<A>(addr: A) -> StreamResult<(TcpListener, StreamRefTcpBinding, Handle)>
1092where
1093 A: ToSocketAddrs + Send + 'static,
1094{
1095 let runtime = stream_ref_tcp_runtime()?;
1096 let listener = runtime
1097 .block_on(async { TcpListener::bind(addr).await })
1098 .map_err(io_error)?;
1099 let local_addr = listener.local_addr().map_err(io_error)?;
1100 Ok((
1101 listener,
1102 StreamRefTcpBinding { local_addr },
1103 runtime.handle().clone(),
1104 ))
1105}
1106
1107fn connect_tcp_stream<A>(addr: A) -> StreamResult<(TcpStream, Handle)>
1108where
1109 A: ToSocketAddrs + Send + 'static,
1110{
1111 let runtime = stream_ref_tcp_runtime()?;
1112 let stream = runtime
1113 .block_on(async { TcpStream::connect(addr).await })
1114 .map_err(io_error)?;
1115 stream.set_nodelay(true).map_err(io_error)?;
1116 Ok((stream, runtime.handle().clone()))
1117}
1118
1119fn stream_ref_tcp_runtime() -> StreamResult<&'static Runtime> {
1120 static RUNTIME: OnceLock<Result<Runtime, String>> = OnceLock::new();
1121 match RUNTIME.get_or_init(|| {
1122 tokio::runtime::Builder::new_multi_thread()
1123 .thread_name("datum-streamref-tcp")
1124 .enable_all()
1125 .build()
1126 .map_err(|error| error.to_string())
1127 }) {
1128 Ok(runtime) => Ok(runtime),
1129 Err(error) => Err(StreamError::Failed(format!(
1130 "failed to start StreamRefs TCP runtime: {error}"
1131 ))),
1132 }
1133}
1134
1135fn current_tokio_handle() -> StreamResult<Handle> {
1136 Handle::try_current().map_err(|error| {
1137 StreamError::Failed(format!(
1138 "StreamRefs TCP stream helper requires a current Tokio runtime: {error}"
1139 ))
1140 })
1141}
1142
1143fn io_error(error: std::io::Error) -> StreamError {
1144 StreamError::Failed(error.to_string())
1145}
1146
1147fn is_quic_teardown_loss(error: &StreamError) -> bool {
1148 matches!(error, StreamError::Failed(message) if message == "connection lost")
1149}
1150
1151fn spawn_endpoint_task<F>(handle: &Handle, run: F) -> mpsc::Receiver<StreamResult<NotUsed>>
1152where
1153 F: Future<Output = StreamResult<NotUsed>> + Send + 'static,
1154{
1155 let (sender, receiver) = mpsc::channel();
1156 handle.spawn(async move {
1157 let result = run.await;
1158 let _ = sender.send(result);
1159 });
1160 receiver
1161}
1162
1163fn encode_carrier_outbound_into(
1164 outbound: &StreamRefOutbound,
1165 bytes: &mut Vec<u8>,
1166) -> StreamResult<()> {
1167 bytes.clear();
1168 match outbound {
1169 StreamRefOutbound::Frame(frame) => append_protobuf_carrier_frame(frame, bytes)?,
1170 StreamRefOutbound::SequencedBatch(batch) => {
1171 append_compact_payload_batch(batch, bytes)?;
1172 }
1173 }
1174 Ok(())
1175}
1176
1177#[cfg(test)]
1178fn encode_carrier_frames(frames: &[StreamRefFrame]) -> StreamResult<Vec<u8>> {
1179 let mut bytes = Vec::new();
1180 let mut index = 0;
1181 while index < frames.len() {
1182 if sequenced_on_next(&frames[index]).is_some() {
1183 let end = sequenced_run_end(frames, index);
1184 append_compact_sequenced_batches(&frames[index..end], &mut bytes)?;
1185 index = end;
1186 } else {
1187 append_protobuf_carrier_frame(&frames[index], &mut bytes)?;
1188 index += 1;
1189 }
1190 }
1191 Ok(bytes)
1192}
1193
1194fn append_compact_payload_batch(
1195 batch: &StreamRefPayloadBatch,
1196 bytes: &mut Vec<u8>,
1197) -> StreamResult<()> {
1198 let mut start = 0;
1199 while start < batch.count() {
1200 let mut end = start;
1201 let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
1202 while end < batch.count() {
1203 let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
1204 .checked_add(batch.payload_len(end))
1205 .ok_or(StreamError::LimitExceeded {
1206 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1207 })?;
1208 let next_payload_len =
1209 payload_len
1210 .checked_add(element_len)
1211 .ok_or(StreamError::LimitExceeded {
1212 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1213 })?;
1214 if end > start
1215 && (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
1216 || end - start >= u16::MAX as usize)
1217 {
1218 break;
1219 }
1220 if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
1221 return Err(StreamError::LimitExceeded {
1222 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1223 });
1224 }
1225 payload_len = next_payload_len;
1226 end += 1;
1227 }
1228 append_compact_payload_batch_slice(batch, start, end, payload_len, bytes)?;
1229 start = end;
1230 }
1231 Ok(())
1232}
1233
1234fn append_compact_payload_batch_slice(
1235 batch: &StreamRefPayloadBatch,
1236 start: usize,
1237 end: usize,
1238 payload_len: usize,
1239 bytes: &mut Vec<u8>,
1240) -> StreamResult<()> {
1241 let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
1242 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1243 })?;
1244 let count = u16::try_from(end - start).map_err(|_| StreamError::LimitExceeded {
1245 max: u16::MAX as u64,
1246 })?;
1247 let first_seq = batch
1248 .first_seq_nr()
1249 .checked_add(start as u64)
1250 .ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
1251 bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
1252 bytes.push(COMPACT_FRAME_VERSION);
1253 bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
1254 bytes.extend(batch.stream_ref_id().to_bytes());
1255 bytes.extend(first_seq.to_be_bytes());
1256 bytes.extend(count.to_be_bytes());
1257 for index in start..end {
1258 let payload = batch.payload(index);
1259 let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1260 max: u32::MAX as u64,
1261 })?;
1262 bytes.extend(payload_len.to_be_bytes());
1263 bytes.extend(payload);
1264 }
1265 Ok(())
1266}
1267
1268fn append_protobuf_carrier_frame(frame: &StreamRefFrame, bytes: &mut Vec<u8>) -> StreamResult<()> {
1269 let payload = frame.encode_to_vec();
1270 let len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1271 max: COMPACT_FRAME_LEN_MASK as u64,
1272 })?;
1273 if payload.len() > MAX_STREAM_REF_FRAME_BYTES || len > COMPACT_FRAME_LEN_MASK {
1274 return Err(StreamError::LimitExceeded {
1275 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1276 });
1277 }
1278 bytes.extend(len.to_be_bytes());
1279 bytes.extend(payload);
1280 Ok(())
1281}
1282
1283#[cfg(test)]
1284fn append_compact_sequenced_batches(
1285 frames: &[StreamRefFrame],
1286 bytes: &mut Vec<u8>,
1287) -> StreamResult<()> {
1288 let mut start = 0;
1289 while start < frames.len() {
1290 let mut end = start;
1291 let mut payload_len = COMPACT_BATCH_HEADER_BYTES;
1292 while end < frames.len() {
1293 let (_, payload) = sequenced_on_next(&frames[end]).expect("sequenced frame");
1294 let element_len = COMPACT_BATCH_ELEMENT_LEN_BYTES
1295 .checked_add(payload.len())
1296 .ok_or(StreamError::LimitExceeded {
1297 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1298 })?;
1299 let next_payload_len =
1300 payload_len
1301 .checked_add(element_len)
1302 .ok_or(StreamError::LimitExceeded {
1303 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1304 })?;
1305 if end > start
1306 && (next_payload_len > MAX_STREAM_REF_FRAME_BYTES
1307 || end - start >= u16::MAX as usize)
1308 {
1309 break;
1310 }
1311 if next_payload_len > MAX_STREAM_REF_FRAME_BYTES {
1312 return Err(StreamError::LimitExceeded {
1313 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1314 });
1315 }
1316 payload_len = next_payload_len;
1317 end += 1;
1318 }
1319 append_compact_sequenced_batch(&frames[start..end], payload_len, bytes)?;
1320 start = end;
1321 }
1322 Ok(())
1323}
1324
1325#[cfg(test)]
1326fn append_compact_sequenced_batch(
1327 frames: &[StreamRefFrame],
1328 payload_len: usize,
1329 bytes: &mut Vec<u8>,
1330) -> StreamResult<()> {
1331 let (first_seq, _) = sequenced_on_next(&frames[0]).expect("sequenced frame");
1332 let payload_len = u32::try_from(payload_len).map_err(|_| StreamError::LimitExceeded {
1333 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1334 })?;
1335 let count = u16::try_from(frames.len()).map_err(|_| StreamError::LimitExceeded {
1336 max: u16::MAX as u64,
1337 })?;
1338 bytes.extend((COMPACT_FRAME_FLAG | payload_len).to_be_bytes());
1339 bytes.push(COMPACT_FRAME_VERSION);
1340 bytes.push(COMPACT_SEQUENCED_ON_NEXT_BATCH);
1341 bytes.extend(frames[0].stream_ref_id.to_bytes());
1342 bytes.extend(first_seq.to_be_bytes());
1343 bytes.extend(count.to_be_bytes());
1344 for frame in frames {
1345 let (_, payload) = sequenced_on_next(frame).expect("sequenced frame");
1346 let payload_len = u32::try_from(payload.len()).map_err(|_| StreamError::LimitExceeded {
1347 max: u32::MAX as u64,
1348 })?;
1349 bytes.extend(payload_len.to_be_bytes());
1350 bytes.extend(payload);
1351 }
1352 Ok(())
1353}
1354
1355#[cfg(test)]
1356fn sequenced_run_end(frames: &[StreamRefFrame], start: usize) -> usize {
1357 let mut end = start + 1;
1358 while end < frames.len() {
1359 let Some((previous_seq, _)) = sequenced_on_next(&frames[end - 1]) else {
1360 break;
1361 };
1362 let Some((next_seq, _)) = sequenced_on_next(&frames[end]) else {
1363 break;
1364 };
1365 if frames[end].stream_ref_id != frames[start].stream_ref_id
1366 || next_seq != previous_seq.saturating_add(1)
1367 {
1368 break;
1369 }
1370 end += 1;
1371 }
1372 end
1373}
1374
1375#[cfg(test)]
1376fn sequenced_on_next(frame: &StreamRefFrame) -> Option<(u64, &[u8])> {
1377 match &frame.message {
1378 StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
1379 Some((*seq_nr, payload.bytes.as_slice()))
1380 }
1381 _ => None,
1382 }
1383}
1384
1385#[derive(Default)]
1386struct FrameDecoder {
1387 buffer: BytesMut,
1388 offset: usize,
1389}
1390
1391impl FrameDecoder {
1392 fn push_chunk<E>(&mut self, chunk: &[u8], endpoint: &E) -> StreamResult<()>
1393 where
1394 E: StreamRefProtoEndpoint,
1395 {
1396 self.buffer.extend_from_slice(chunk);
1397 while let Some(header) = self.peek_header()? {
1398 if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES + header.len {
1399 break;
1400 }
1401 let payload_start = self.offset + FRAME_LEN_BYTES;
1402 let payload_end = payload_start + header.len;
1403 let payload = &self.buffer[payload_start..payload_end];
1404 match header.kind {
1405 CarrierFrameKind::Protobuf => {
1406 endpoint.handle_frame(StreamRefFrame::decode(payload)?)?;
1407 }
1408 CarrierFrameKind::Compact => {
1409 decode_compact_carrier_frame(payload, endpoint)?;
1410 }
1411 }
1412 self.offset = payload_end;
1413 }
1414 if self.offset > 0 && (self.offset == self.buffer.len() || self.offset >= 64 * 1024) {
1415 self.buffer.advance(self.offset);
1416 self.offset = 0;
1417 }
1418 Ok(())
1419 }
1420
1421 fn peek_header(&self) -> StreamResult<Option<CarrierFrameHeader>> {
1422 if self.buffer.len().saturating_sub(self.offset) < FRAME_LEN_BYTES {
1423 return Ok(None);
1424 }
1425 let len = self.buffer[self.offset..self.offset + FRAME_LEN_BYTES]
1426 .try_into()
1427 .expect("frame header length");
1428 let raw_len = u32::from_be_bytes(len);
1429 let kind = if raw_len & COMPACT_FRAME_FLAG == 0 {
1430 CarrierFrameKind::Protobuf
1431 } else {
1432 CarrierFrameKind::Compact
1433 };
1434 let len = (raw_len & COMPACT_FRAME_LEN_MASK) as usize;
1435 if len > MAX_STREAM_REF_FRAME_BYTES {
1436 return Err(StreamError::LimitExceeded {
1437 max: MAX_STREAM_REF_FRAME_BYTES as u64,
1438 });
1439 }
1440 Ok(Some(CarrierFrameHeader { kind, len }))
1441 }
1442}
1443
1444#[derive(Clone, Copy)]
1445struct CarrierFrameHeader {
1446 kind: CarrierFrameKind,
1447 len: usize,
1448}
1449
1450#[derive(Clone, Copy)]
1451enum CarrierFrameKind {
1452 Protobuf,
1453 Compact,
1454}
1455
1456fn decode_compact_carrier_frame<E>(payload: &[u8], endpoint: &E) -> StreamResult<()>
1457where
1458 E: StreamRefProtoEndpoint,
1459{
1460 if payload.len() < COMPACT_BATCH_HEADER_BYTES {
1461 return Err(StreamError::Failed(
1462 "compact StreamRefs carrier frame too short".to_owned(),
1463 ));
1464 }
1465 let version = payload[0];
1466 if version != COMPACT_FRAME_VERSION {
1467 return Err(StreamError::Failed(format!(
1468 "unsupported compact StreamRefs carrier frame version: {version}"
1469 )));
1470 }
1471 let kind = payload[1];
1472 if kind != COMPACT_SEQUENCED_ON_NEXT_BATCH {
1473 return Err(StreamError::Failed(format!(
1474 "unsupported compact StreamRefs carrier frame kind: {kind}"
1475 )));
1476 }
1477 let stream_ref_id = StreamRefId::from_bytes(&payload[2..18])?;
1478 let first_seq = u64::from_be_bytes(payload[18..26].try_into().expect("seq len"));
1479 let count = u16::from_be_bytes(payload[26..28].try_into().expect("count len")) as usize;
1480 if count == 0 {
1481 return Err(StreamError::Failed(
1482 "compact StreamRefs carrier batch is empty".to_owned(),
1483 ));
1484 }
1485
1486 let mut offset = COMPACT_BATCH_HEADER_BYTES;
1487 let mut payloads = Vec::with_capacity(count);
1488 for index in 0..count {
1489 if payload.len().saturating_sub(offset) < COMPACT_BATCH_ELEMENT_LEN_BYTES {
1490 return Err(StreamError::Failed(
1491 "compact StreamRefs carrier batch has truncated payload length".to_owned(),
1492 ));
1493 }
1494 let payload_len = u32::from_be_bytes(
1495 payload[offset..offset + COMPACT_BATCH_ELEMENT_LEN_BYTES]
1496 .try_into()
1497 .expect("payload len"),
1498 ) as usize;
1499 offset += COMPACT_BATCH_ELEMENT_LEN_BYTES;
1500 if payload.len().saturating_sub(offset) < payload_len {
1501 return Err(StreamError::Failed(
1502 "compact StreamRefs carrier batch has truncated payload".to_owned(),
1503 ));
1504 }
1505 first_seq
1506 .checked_add(index as u64)
1507 .ok_or_else(|| StreamError::Failed("compact StreamRefs seq_nr overflow".to_owned()))?;
1508 payloads.push(&payload[offset..offset + payload_len]);
1509 offset += payload_len;
1510 }
1511 if offset != payload.len() {
1512 return Err(StreamError::Failed(
1513 "compact StreamRefs carrier batch has trailing bytes".to_owned(),
1514 ));
1515 }
1516 endpoint.handle_sequenced_on_next_batch(stream_ref_id, first_seq, &payloads)
1517}
1518
1519#[cfg(test)]
1520mod tests {
1521 use super::*;
1522 use std::sync::{Arc, Mutex};
1523
1524 #[derive(Clone)]
1525 struct RecordingEndpoint {
1526 stream_ref_id: StreamRefId,
1527 frames: Arc<Mutex<Vec<StreamRefFrame>>>,
1528 }
1529
1530 impl RecordingEndpoint {
1531 fn new(stream_ref_id: StreamRefId) -> Self {
1532 Self {
1533 stream_ref_id,
1534 frames: Arc::new(Mutex::new(Vec::new())),
1535 }
1536 }
1537
1538 fn frames(&self) -> Vec<StreamRefFrame> {
1539 self.frames.lock().expect("recording endpoint").clone()
1540 }
1541 }
1542
1543 impl StreamRefProtoEndpoint for RecordingEndpoint {
1544 fn stream_ref_id(&self) -> StreamRefId {
1545 self.stream_ref_id
1546 }
1547
1548 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1549 None
1550 }
1551
1552 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1553 self.frames.lock().expect("recording endpoint").push(frame);
1554 Ok(())
1555 }
1556
1557 fn fail_connection(&self, _error: StreamError) {}
1558 }
1559
1560 #[test]
1561 fn carrier_frame_decoder_reassembles_split_frames() {
1562 let frame = StreamRefFrame::new(
1563 StreamRefId::from_u128(1),
1564 datum::StreamRefMessage::CumulativeDemand { seq_nr: 32 },
1565 );
1566 let bytes = encode_carrier_frames(std::slice::from_ref(&frame)).unwrap();
1567 let split = bytes.len() / 2;
1568 let mut decoder = FrameDecoder::default();
1569 let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(1));
1570
1571 decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
1572 assert!(endpoint.frames().is_empty());
1573 decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
1574 assert_eq!(endpoint.frames(), vec![frame]);
1575 }
1576
1577 #[test]
1578 fn compact_carrier_batch_round_trips_sequenced_frames() {
1579 let frames = (0_u64..3)
1580 .map(|seq_nr| {
1581 StreamRefFrame::new(
1582 StreamRefId::from_u128(7),
1583 datum::StreamRefMessage::SequencedOnNext {
1584 seq_nr,
1585 payload: datum::StreamRefPayloadBytes {
1586 bytes: seq_nr.to_be_bytes().to_vec(),
1587 },
1588 },
1589 )
1590 })
1591 .collect::<Vec<_>>();
1592 let bytes = encode_carrier_frames(&frames).unwrap();
1593 let header = u32::from_be_bytes(bytes[..FRAME_LEN_BYTES].try_into().unwrap());
1594 assert_ne!(header & COMPACT_FRAME_FLAG, 0);
1595
1596 let mut decoder = FrameDecoder::default();
1597 let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(7));
1598 decoder.push_chunk(&bytes, &endpoint).unwrap();
1599 assert_eq!(endpoint.frames(), frames);
1600 }
1601
1602 #[test]
1603 fn compact_carrier_batch_reassembles_split_frames() {
1604 let frames = (4_u64..8)
1605 .map(|seq_nr| {
1606 StreamRefFrame::new(
1607 StreamRefId::from_u128(8),
1608 datum::StreamRefMessage::SequencedOnNext {
1609 seq_nr,
1610 payload: datum::StreamRefPayloadBytes {
1611 bytes: vec![seq_nr as u8],
1612 },
1613 },
1614 )
1615 })
1616 .collect::<Vec<_>>();
1617 let bytes = encode_carrier_frames(&frames).unwrap();
1618 let split = FRAME_LEN_BYTES + 5;
1619 let mut decoder = FrameDecoder::default();
1620 let endpoint = RecordingEndpoint::new(StreamRefId::from_u128(8));
1621
1622 decoder.push_chunk(&bytes[..split], &endpoint).unwrap();
1623 assert!(endpoint.frames().is_empty());
1624 decoder.push_chunk(&bytes[split..], &endpoint).unwrap();
1625 assert_eq!(endpoint.frames(), frames);
1626 }
1627}