1use std::{
9 any::Any,
10 collections::VecDeque,
11 fmt,
12 sync::{
13 Arc, Condvar, Mutex, MutexGuard,
14 atomic::{AtomicU64, Ordering},
15 },
16 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
17};
18
19use crate::stream::{
20 BoxStream, Materializer, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult,
21 TerminalSinkConsumerDyn, TerminalSourceHookDyn, TerminalSourceStatus,
22};
23use futures::channel::oneshot;
24use prost::Message as ProstMessage;
25use tokio::sync::mpsc as tokio_mpsc;
26
27use super::{SourceRef, StreamRefSettings};
28
29static STREAM_REF_PROTO_ID: AtomicU64 = AtomicU64::new(1);
30
31pub trait StreamRefPayload: Send + 'static {
37 fn encode_stream_ref_payload(self) -> Vec<u8>;
38
39 fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>)
40 where
41 Self: Sized,
42 {
43 bytes.extend(self.encode_stream_ref_payload());
44 }
45
46 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self>
47 where
48 Self: Sized;
49
50 fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self>
51 where
52 Self: Sized,
53 {
54 Self::decode_stream_ref_payload(bytes.to_vec())
55 }
56}
57
58macro_rules! impl_stream_ref_payload_numeric {
59 ($($ty:ty),* $(,)?) => {
60 $(
61 impl StreamRefPayload for $ty {
62 fn encode_stream_ref_payload(self) -> Vec<u8> {
63 self.to_be_bytes().to_vec()
64 }
65
66 fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>) {
67 bytes.extend(self.to_be_bytes());
68 }
69
70 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
71 let data: [u8; std::mem::size_of::<Self>()] =
72 bytes.as_slice().try_into().map_err(|_| {
73 StreamError::Failed(format!(
74 "invalid {} stream ref payload length: {}",
75 stringify!($ty),
76 bytes.len()
77 ))
78 })?;
79 Ok(Self::from_be_bytes(data))
80 }
81
82 fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self> {
83 let data: [u8; std::mem::size_of::<Self>()] =
84 bytes.try_into().map_err(|_| {
85 StreamError::Failed(format!(
86 "invalid {} stream ref payload length: {}",
87 stringify!($ty),
88 bytes.len()
89 ))
90 })?;
91 Ok(Self::from_be_bytes(data))
92 }
93 }
94 )*
95 };
96}
97
98impl_stream_ref_payload_numeric!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, f32, f64);
99
100impl StreamRefPayload for bool {
101 fn encode_stream_ref_payload(self) -> Vec<u8> {
102 vec![u8::from(self)]
103 }
104
105 fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>) {
106 bytes.push(u8::from(self));
107 }
108
109 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
110 match bytes.as_slice() {
111 [0] => Ok(false),
112 [1] => Ok(true),
113 _ => Err(StreamError::Failed(
114 "invalid bool stream ref payload".to_owned(),
115 )),
116 }
117 }
118
119 fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self> {
120 match bytes {
121 [0] => Ok(false),
122 [1] => Ok(true),
123 _ => Err(StreamError::Failed(
124 "invalid bool stream ref payload".to_owned(),
125 )),
126 }
127 }
128}
129
130impl StreamRefPayload for String {
131 fn encode_stream_ref_payload(self) -> Vec<u8> {
132 self.into_bytes()
133 }
134
135 fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>) {
136 bytes.extend(self.into_bytes());
137 }
138
139 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
140 String::from_utf8(bytes)
141 .map_err(|error| StreamError::Failed(format!("invalid UTF-8 payload: {error}")))
142 }
143
144 fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self> {
145 String::from_utf8(bytes.to_vec())
146 .map_err(|error| StreamError::Failed(format!("invalid UTF-8 payload: {error}")))
147 }
148}
149
150impl StreamRefPayload for Vec<u8> {
151 fn encode_stream_ref_payload(self) -> Vec<u8> {
152 self
153 }
154
155 fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>) {
156 bytes.extend(self);
157 }
158
159 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
160 Ok(bytes)
161 }
162
163 fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self> {
164 Ok(bytes.to_vec())
165 }
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
170pub struct StreamRefId(u128);
171
172impl StreamRefId {
173 #[must_use]
175 pub fn new() -> Self {
176 let sequence = STREAM_REF_PROTO_ID.fetch_add(1, Ordering::Relaxed) as u128;
177 let timestamp = SystemTime::now()
178 .duration_since(UNIX_EPOCH)
179 .map(|duration| duration.as_nanos())
180 .unwrap_or_default();
181 let pid = std::process::id() as u128;
182 Self(timestamp ^ (pid << 32) ^ sequence)
183 }
184
185 #[must_use]
188 pub const fn from_u128(value: u128) -> Self {
189 Self(value)
190 }
191
192 #[must_use]
193 pub const fn as_u128(self) -> u128 {
194 self.0
195 }
196
197 #[must_use]
198 pub fn to_bytes(self) -> [u8; 16] {
199 self.0.to_be_bytes()
200 }
201
202 pub fn from_bytes(bytes: &[u8]) -> StreamResult<Self> {
203 let value: [u8; 16] = bytes.try_into().map_err(|_| {
204 StreamError::Failed("stream ref id must be exactly 16 bytes".to_owned())
205 })?;
206 Ok(Self(u128::from_be_bytes(value)))
207 }
208}
209
210impl Default for StreamRefId {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216impl fmt::Display for StreamRefId {
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 write!(f, "{:032x}", self.0)
219 }
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct StreamRefPayloadBytes {
226 pub bytes: Vec<u8>,
227}
228
229#[derive(Debug, Clone, PartialEq, Eq)]
231pub enum StreamRefMessage {
232 OnSubscribeHandshake,
233 CumulativeDemand {
234 seq_nr: u64,
235 },
236 SequencedOnNext {
237 seq_nr: u64,
238 payload: StreamRefPayloadBytes,
239 },
240 RemoteStreamCompleted {
241 seq_nr: u64,
242 },
243 RemoteStreamFailure {
244 cause: Vec<u8>,
245 },
246 Ack,
247}
248
249impl StreamRefMessage {
250 #[must_use]
251 pub fn failure_text(&self) -> Option<String> {
252 match self {
253 Self::RemoteStreamFailure { cause } => {
254 Some(String::from_utf8_lossy(cause).into_owned())
255 }
256 _ => None,
257 }
258 }
259
260 fn is_ack(&self) -> bool {
261 matches!(self, Self::Ack)
262 }
263}
264
265#[derive(Debug, Clone, PartialEq, Eq)]
267pub struct StreamRefFrame {
268 pub stream_ref_id: StreamRefId,
269 pub message: StreamRefMessage,
270}
271
272impl StreamRefFrame {
273 #[must_use]
274 pub fn new(stream_ref_id: StreamRefId, message: StreamRefMessage) -> Self {
275 Self {
276 stream_ref_id,
277 message,
278 }
279 }
280
281 #[must_use]
282 pub fn encode_to_vec(&self) -> Vec<u8> {
283 self.to_wire().encode_to_vec()
284 }
285
286 pub fn decode(bytes: &[u8]) -> StreamResult<Self> {
287 Self::from_wire(WireStreamRefFrame::decode(bytes).map_err(|error| {
288 StreamError::Failed(format!("invalid stream ref protobuf frame: {error}"))
289 })?)
290 }
291
292 fn to_wire(&self) -> WireStreamRefFrame {
293 WireStreamRefFrame {
294 stream_ref_id: self.stream_ref_id.to_bytes().to_vec(),
295 message: Some(match &self.message {
296 StreamRefMessage::OnSubscribeHandshake => {
297 wire_stream_ref_frame::Message::OnSubscribeHandshake(
298 WireOnSubscribeHandshake {},
299 )
300 }
301 StreamRefMessage::CumulativeDemand { seq_nr } => {
302 wire_stream_ref_frame::Message::CumulativeDemand(WireCumulativeDemand {
303 seq_nr: *seq_nr,
304 })
305 }
306 StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
307 wire_stream_ref_frame::Message::SequencedOnNext(WireSequencedOnNext {
308 seq_nr: *seq_nr,
309 payload: Some(WirePayload {
310 enclosed_message: payload.bytes.clone(),
311 }),
312 })
313 }
314 StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
315 wire_stream_ref_frame::Message::RemoteStreamCompleted(
316 WireRemoteStreamCompleted { seq_nr: *seq_nr },
317 )
318 }
319 StreamRefMessage::RemoteStreamFailure { cause } => {
320 wire_stream_ref_frame::Message::RemoteStreamFailure(WireRemoteStreamFailure {
321 cause: cause.clone(),
322 })
323 }
324 StreamRefMessage::Ack => wire_stream_ref_frame::Message::Ack(WireAck {}),
325 }),
326 }
327 }
328
329 fn from_wire(wire: WireStreamRefFrame) -> StreamResult<Self> {
330 let stream_ref_id = StreamRefId::from_bytes(&wire.stream_ref_id)?;
331 let message = match wire.message.ok_or_else(|| {
332 StreamError::Failed("stream ref protobuf frame has no message".to_owned())
333 })? {
334 wire_stream_ref_frame::Message::OnSubscribeHandshake(_) => {
335 StreamRefMessage::OnSubscribeHandshake
336 }
337 wire_stream_ref_frame::Message::CumulativeDemand(message) => {
338 StreamRefMessage::CumulativeDemand {
339 seq_nr: message.seq_nr,
340 }
341 }
342 wire_stream_ref_frame::Message::SequencedOnNext(message) => {
343 let payload = message.payload.ok_or_else(|| {
344 StreamError::Failed("SequencedOnNext missing payload".to_owned())
345 })?;
346 StreamRefMessage::SequencedOnNext {
347 seq_nr: message.seq_nr,
348 payload: StreamRefPayloadBytes {
349 bytes: payload.enclosed_message,
350 },
351 }
352 }
353 wire_stream_ref_frame::Message::RemoteStreamCompleted(message) => {
354 StreamRefMessage::RemoteStreamCompleted {
355 seq_nr: message.seq_nr,
356 }
357 }
358 wire_stream_ref_frame::Message::RemoteStreamFailure(message) => {
359 StreamRefMessage::RemoteStreamFailure {
360 cause: message.cause,
361 }
362 }
363 wire_stream_ref_frame::Message::Ack(_) => StreamRefMessage::Ack,
364 };
365 Ok(Self {
366 stream_ref_id,
367 message,
368 })
369 }
370}
371
372#[derive(Clone, PartialEq, ProstMessage)]
373struct WireStreamRefFrame {
374 #[prost(bytes = "vec", tag = "1")]
375 stream_ref_id: Vec<u8>,
376 #[prost(oneof = "wire_stream_ref_frame::Message", tags = "2, 3, 4, 5, 6, 7")]
377 message: Option<wire_stream_ref_frame::Message>,
378}
379
380mod wire_stream_ref_frame {
381 #[derive(Clone, PartialEq, prost::Oneof)]
382 pub enum Message {
383 #[prost(message, tag = "2")]
384 OnSubscribeHandshake(super::WireOnSubscribeHandshake),
385 #[prost(message, tag = "3")]
386 CumulativeDemand(super::WireCumulativeDemand),
387 #[prost(message, tag = "4")]
388 SequencedOnNext(super::WireSequencedOnNext),
389 #[prost(message, tag = "5")]
390 RemoteStreamCompleted(super::WireRemoteStreamCompleted),
391 #[prost(message, tag = "6")]
392 RemoteStreamFailure(super::WireRemoteStreamFailure),
393 #[prost(message, tag = "7")]
394 Ack(super::WireAck),
395 }
396}
397
398#[derive(Clone, PartialEq, ProstMessage)]
399struct WirePayload {
400 #[prost(bytes = "vec", tag = "1")]
401 enclosed_message: Vec<u8>,
402}
403
404#[derive(Clone, PartialEq, ProstMessage)]
405struct WireOnSubscribeHandshake {}
406
407#[derive(Clone, PartialEq, ProstMessage)]
408struct WireCumulativeDemand {
409 #[prost(uint64, tag = "1")]
410 seq_nr: u64,
411}
412
413#[derive(Clone, PartialEq, ProstMessage)]
414struct WireSequencedOnNext {
415 #[prost(uint64, tag = "1")]
416 seq_nr: u64,
417 #[prost(message, optional, tag = "2")]
418 payload: Option<WirePayload>,
419}
420
421#[derive(Clone, PartialEq, ProstMessage)]
422struct WireRemoteStreamFailure {
423 #[prost(bytes = "vec", tag = "1")]
424 cause: Vec<u8>,
425}
426
427#[derive(Clone, PartialEq, ProstMessage)]
428struct WireRemoteStreamCompleted {
429 #[prost(uint64, tag = "1")]
430 seq_nr: u64,
431}
432
433#[derive(Clone, PartialEq, ProstMessage)]
434struct WireAck {}
435
436#[derive(Debug, Clone)]
437struct StreamRefPayloadSegment {
438 offset: usize,
439 len: usize,
440}
441
442#[derive(Debug, Clone)]
444pub struct StreamRefPayloadBatch {
445 stream_ref_id: StreamRefId,
446 first_seq_nr: u64,
447 payloads: Vec<u8>,
448 segments: Vec<StreamRefPayloadSegment>,
449}
450
451impl StreamRefPayloadBatch {
452 #[must_use]
453 pub fn new(stream_ref_id: StreamRefId, first_seq_nr: u64) -> Self {
454 Self {
455 stream_ref_id,
456 first_seq_nr,
457 payloads: Vec::new(),
458 segments: Vec::new(),
459 }
460 }
461
462 #[must_use]
463 pub fn stream_ref_id(&self) -> StreamRefId {
464 self.stream_ref_id
465 }
466
467 #[must_use]
468 pub fn first_seq_nr(&self) -> u64 {
469 self.first_seq_nr
470 }
471
472 #[must_use]
473 pub fn count(&self) -> usize {
474 self.segments.len()
475 }
476
477 #[must_use]
478 pub fn is_empty(&self) -> bool {
479 self.segments.is_empty()
480 }
481
482 #[must_use]
483 pub fn payload_len(&self, index: usize) -> usize {
484 self.segments[index].len
485 }
486
487 #[must_use]
488 pub fn payload(&self, index: usize) -> &[u8] {
489 let segment = &self.segments[index];
490 &self.payloads[segment.offset..segment.offset + segment.len]
491 }
492
493 pub fn push_payload<T>(&mut self, item: T) -> StreamResult<()>
494 where
495 T: StreamRefPayload,
496 {
497 let offset = self.payloads.len();
498 item.encode_stream_ref_payload_into(&mut self.payloads);
499 let len = self.payloads.len().saturating_sub(offset);
500 if len > u32::MAX as usize {
501 return Err(StreamError::LimitExceeded {
502 max: u32::MAX as u64,
503 });
504 }
505 self.segments.push(StreamRefPayloadSegment { offset, len });
506 Ok(())
507 }
508
509 fn into_single_frame(self) -> StreamResult<StreamRefFrame> {
510 if self.count() != 1 {
511 return Err(StreamError::Failed(
512 "stream ref batch cannot be converted into a single frame".to_owned(),
513 ));
514 }
515 Ok(StreamRefFrame::new(
516 self.stream_ref_id,
517 StreamRefMessage::SequencedOnNext {
518 seq_nr: self.first_seq_nr,
519 payload: StreamRefPayloadBytes {
520 bytes: self.payload(0).to_vec(),
521 },
522 },
523 ))
524 }
525}
526
527#[derive(Debug, Clone)]
530pub enum StreamRefOutbound {
531 Frame(StreamRefFrame),
532 SequencedBatch(StreamRefPayloadBatch),
533}
534
535impl StreamRefOutbound {
536 fn into_single_frame(self) -> StreamResult<StreamRefFrame> {
537 match self {
538 Self::Frame(frame) => Ok(frame),
539 Self::SequencedBatch(batch) => batch.into_single_frame(),
540 }
541 }
542}
543
544pub trait StreamRefProtoEndpoint: Clone + Send + Sync + 'static {
546 fn stream_ref_id(&self) -> StreamRefId;
547 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>>;
548 fn next_outbound(
549 &self,
550 max_data_elements: usize,
551 _max_data_bytes: usize,
552 ) -> Option<StreamResult<StreamRefOutbound>> {
553 let _ = max_data_elements;
554 self.next_frame()
555 .map(|frame| frame.map(StreamRefOutbound::Frame))
556 }
557 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()>;
558 fn handle_sequenced_on_next_batch(
559 &self,
560 stream_ref_id: StreamRefId,
561 first_seq_nr: u64,
562 payloads: &[&[u8]],
563 ) -> StreamResult<()> {
564 for (index, payload) in payloads.iter().enumerate() {
565 let seq_nr = first_seq_nr.checked_add(index as u64).ok_or_else(|| {
566 StreamError::Failed("stream ref batch sequence overflow".to_owned())
567 })?;
568 self.handle_frame(StreamRefFrame::new(
569 stream_ref_id,
570 StreamRefMessage::SequencedOnNext {
571 seq_nr,
572 payload: StreamRefPayloadBytes {
573 bytes: payload.to_vec(),
574 },
575 },
576 ))?;
577 }
578 Ok(())
579 }
580 fn fail_connection(&self, error: StreamError);
581}
582
583#[doc(hidden)]
585pub enum StreamRefOutboundPoll {
586 Ready(StreamResult<StreamRefOutbound>),
587 Pending,
588 Closed,
589}
590
591#[doc(hidden)]
593pub trait StreamRefProtoEndpointWake: StreamRefProtoEndpoint {
594 fn install_outbound_wake(&self, sender: tokio_mpsc::Sender<()>);
595 fn clear_outbound_wake(&self);
596 fn try_next_outbound(
597 &self,
598 max_data_elements: usize,
599 max_data_bytes: usize,
600 ) -> StreamRefOutboundPoll;
601}
602
603#[derive(Default)]
604struct OutboundWake {
605 sender: Mutex<Option<tokio_mpsc::Sender<()>>>,
606}
607
608impl OutboundWake {
609 fn install(&self, sender: tokio_mpsc::Sender<()>) {
610 *self
611 .sender
612 .lock()
613 .unwrap_or_else(|poison| poison.into_inner()) = Some(sender);
614 }
615
616 fn clear(&self) {
617 *self
618 .sender
619 .lock()
620 .unwrap_or_else(|poison| poison.into_inner()) = None;
621 }
622
623 fn wake(&self) {
624 let sender = self
625 .sender
626 .lock()
627 .unwrap_or_else(|poison| poison.into_inner())
628 .clone();
629 if let Some(sender) = sender {
630 let _ = sender.try_send(());
631 }
632 }
633}
634
635pub struct StreamRefProtoProducer<T>
642where
643 T: StreamRefPayload,
644{
645 shared: Arc<ProducerShared<T>>,
646}
647
648impl<T> Clone for StreamRefProtoProducer<T>
649where
650 T: StreamRefPayload,
651{
652 fn clone(&self) -> Self {
653 Self {
654 shared: Arc::clone(&self.shared),
655 }
656 }
657}
658
659impl<T> StreamRefProtoProducer<T>
660where
661 T: StreamRefPayload,
662{
663 pub fn from_source_ref(
664 source_ref: SourceRef<T>,
665 stream_ref_id: StreamRefId,
666 settings: StreamRefSettings,
667 ) -> StreamResult<Self> {
668 Self::from_source(
669 super::stream_ref::proto_source(&source_ref),
670 stream_ref_id,
671 settings,
672 )
673 }
674
675 pub fn from_source<Mat>(
676 source: Source<T, Mat>,
677 stream_ref_id: StreamRefId,
678 settings: StreamRefSettings,
679 ) -> StreamResult<Self>
680 where
681 Mat: Send + 'static,
682 {
683 let materializer = Materializer::new();
684 let (input, materialized) = Arc::clone(&source.factory).create(&materializer)?;
685 Ok(Self {
686 shared: Arc::new(ProducerShared {
687 stream_ref_id,
688 settings,
689 input: Mutex::new(Some(input)),
690 state: Mutex::new(ProducerState {
691 partner_seen: false,
692 cumulative_demand: 0,
693 first_demand_deadline: None,
694 sent: 0,
695 terminal_sent: false,
696 waiting_for_ack: false,
697 ack_deadline: None,
698 stopped: None,
699 ack_queued: false,
700 pending_terminal: None,
701 done: false,
702 input_attached: true,
703 terminal_result: None,
704 }),
705 changed: Condvar::new(),
706 outbound_wake: OutboundWake::default(),
707 completion: Mutex::new(None),
708 _materializer: materializer,
709 _materialized: Mutex::new(Some(Box::new(materialized))),
710 }),
711 })
712 }
713
714 #[must_use]
722 pub fn new_lazy(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
723 Self {
724 shared: Arc::new(ProducerShared {
725 stream_ref_id,
726 settings,
727 input: Mutex::new(None),
728 state: Mutex::new(ProducerState {
729 partner_seen: false,
730 cumulative_demand: 0,
731 first_demand_deadline: None,
732 sent: 0,
733 terminal_sent: false,
734 waiting_for_ack: false,
735 ack_deadline: None,
736 stopped: None,
737 ack_queued: false,
738 pending_terminal: None,
739 done: false,
740 input_attached: false,
741 terminal_result: None,
742 }),
743 changed: Condvar::new(),
744 outbound_wake: OutboundWake::default(),
745 completion: Mutex::new(None),
746 _materializer: Materializer::new(),
747 _materialized: Mutex::new(None),
748 }),
749 }
750 }
751
752 #[must_use]
760 pub fn sink(&self) -> Sink<T, StreamCompletion<NotUsed>> {
761 let shared = Arc::clone(&self.shared);
762 Sink::from_runner(move |input, _materializer| {
763 let (sender, receiver) = oneshot::channel();
764 *shared
765 .completion
766 .lock()
767 .unwrap_or_else(|poison| poison.into_inner()) = Some(sender);
768 shared.attach_input(input);
769 Ok(StreamCompletion::from_receiver(receiver, None))
770 })
771 }
772}
773
774impl<T> StreamRefProtoEndpoint for StreamRefProtoProducer<T>
775where
776 T: StreamRefPayload,
777{
778 fn stream_ref_id(&self) -> StreamRefId {
779 self.shared.stream_ref_id
780 }
781
782 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
783 self.shared.next_frame()
784 }
785
786 fn next_outbound(
787 &self,
788 max_data_elements: usize,
789 max_data_bytes: usize,
790 ) -> Option<StreamResult<StreamRefOutbound>> {
791 self.shared.next_outbound(max_data_elements, max_data_bytes)
792 }
793
794 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
795 self.shared.handle_frame(frame)
796 }
797
798 fn fail_connection(&self, error: StreamError) {
799 self.shared.fail_connection(error);
800 }
801}
802
803impl<T> StreamRefProtoEndpointWake for StreamRefProtoProducer<T>
804where
805 T: StreamRefPayload,
806{
807 fn install_outbound_wake(&self, sender: tokio_mpsc::Sender<()>) {
808 self.shared.outbound_wake.install(sender);
809 }
810
811 fn clear_outbound_wake(&self) {
812 self.shared.outbound_wake.clear();
813 }
814
815 fn try_next_outbound(
816 &self,
817 max_data_elements: usize,
818 max_data_bytes: usize,
819 ) -> StreamRefOutboundPoll {
820 self.shared
821 .try_next_outbound(max_data_elements, max_data_bytes)
822 }
823}
824
825struct ProducerShared<T>
826where
827 T: StreamRefPayload,
828{
829 stream_ref_id: StreamRefId,
830 settings: StreamRefSettings,
831 input: Mutex<Option<BoxStream<T>>>,
832 state: Mutex<ProducerState>,
833 changed: Condvar,
834 outbound_wake: OutboundWake,
835 completion: Mutex<Option<oneshot::Sender<StreamResult<NotUsed>>>>,
836 _materializer: Materializer,
837 _materialized: Mutex<Option<Box<dyn Any + Send>>>,
838}
839
840struct ProducerState {
841 partner_seen: bool,
842 cumulative_demand: u64,
843 first_demand_deadline: Option<Instant>,
844 sent: u64,
845 terminal_sent: bool,
846 waiting_for_ack: bool,
847 ack_deadline: Option<Instant>,
848 stopped: Option<StreamError>,
849 ack_queued: bool,
850 pending_terminal: Option<StreamRefMessage>,
851 done: bool,
852 input_attached: bool,
853 terminal_result: Option<StreamResult<NotUsed>>,
854}
855
856enum ProducerBatchPoll {
857 Ready(StreamResult<StreamRefOutbound>),
858 Pending,
859 StateChanged,
860}
861
862enum InputItemPoll<T> {
863 Ready(StreamResult<T>),
864 Pending,
865 TerminalQueued,
866}
867
868impl<T> ProducerShared<T>
869where
870 T: StreamRefPayload,
871{
872 fn lock_state(&self) -> MutexGuard<'_, ProducerState> {
873 self.state
874 .lock()
875 .unwrap_or_else(|poison| poison.into_inner())
876 }
877
878 fn lock_input(&self) -> MutexGuard<'_, Option<BoxStream<T>>> {
879 self.input
880 .lock()
881 .unwrap_or_else(|poison| poison.into_inner())
882 }
883
884 fn notify_changed(&self) {
885 self.changed.notify_all();
886 self.outbound_wake.wake();
887 }
888
889 fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
890 StreamRefFrame::new(self.stream_ref_id, message)
891 }
892
893 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
894 self.next_outbound(1, usize::MAX)
895 .map(|outbound| outbound.and_then(StreamRefOutbound::into_single_frame))
896 }
897
898 fn next_outbound(
899 &self,
900 max_data_elements: usize,
901 max_data_bytes: usize,
902 ) -> Option<StreamResult<StreamRefOutbound>> {
903 let subscription_deadline = deadline_from_now(self.settings.subscription_timeout());
904 loop {
905 let mut state = self.lock_state();
906 if state.done {
907 return None;
908 }
909
910 if state.ack_queued {
911 state.ack_queued = false;
912 state.done = true;
913 state.terminal_result = Some(match state.stopped.clone() {
914 Some(error) => Err(error),
915 None => Ok(NotUsed),
916 });
917 self.notify_changed();
918 drop(state);
919 self.drop_input();
920 self.settle();
921 return Some(Ok(StreamRefOutbound::Frame(
922 self.frame(StreamRefMessage::Ack),
923 )));
924 }
925
926 if let Some(message) = state.pending_terminal.take() {
927 drop(state);
928 return Some(Ok(StreamRefOutbound::Frame(self.frame(message))));
929 }
930
931 if state.waiting_for_ack {
932 if state
933 .ack_deadline
934 .is_some_and(|deadline| Instant::now() >= deadline)
935 {
936 let timeout_error =
937 subscription_timeout_error("stream ref producer terminal ack");
938 state.done = true;
939 state.terminal_result = Some(Err(timeout_error.clone()));
940 self.notify_changed();
941 drop(state);
942 self.drop_input();
943 self.settle();
944 return Some(Err(timeout_error));
945 }
946 if let Some(remaining) = state
947 .ack_deadline
948 .and_then(|deadline| deadline.checked_duration_since(Instant::now()))
949 {
950 let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
951 drop(next);
952 } else {
953 drop(state);
954 }
955 continue;
956 }
957
958 if let Some(error) = state.stopped.clone() {
959 state.done = true;
960 state.terminal_result = Some(Err(error.clone()));
961 self.notify_changed();
962 drop(state);
963 self.drop_input();
964 self.settle();
965 return Some(Err(error));
966 }
967
968 if state.cumulative_demand > 0 && state.sent < state.cumulative_demand {
969 drop(state);
970 if let Some(outbound) =
971 self.pull_next_outbound_batch(max_data_elements.max(1), max_data_bytes.max(1))
972 {
973 return Some(outbound);
974 }
975 continue;
976 }
977
978 if state.cumulative_demand == 0 && Instant::now() >= subscription_deadline {
979 let timeout_error = subscription_timeout_error("stream ref producer first demand");
980 state.done = true;
981 state.terminal_result = Some(Err(timeout_error.clone()));
982 self.notify_changed();
983 drop(state);
984 self.drop_input();
985 self.settle();
986 return Some(Err(timeout_error));
987 }
988
989 if state.cumulative_demand == 0 {
990 let remaining = subscription_deadline.saturating_duration_since(Instant::now());
991 if remaining.is_zero() {
992 drop(state);
993 continue;
994 }
995 let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
996 drop(next);
997 } else {
998 let next = wait_unpoison(&self.changed, state);
999 drop(next);
1000 }
1001 }
1002 }
1003
1004 fn try_next_outbound(
1005 &self,
1006 max_data_elements: usize,
1007 max_data_bytes: usize,
1008 ) -> StreamRefOutboundPoll {
1009 loop {
1010 let mut state = self.lock_state();
1011 if state.done {
1012 return StreamRefOutboundPoll::Closed;
1013 }
1014
1015 if state.ack_queued {
1016 state.ack_queued = false;
1017 state.done = true;
1018 state.terminal_result = Some(match state.stopped.clone() {
1019 Some(error) => Err(error),
1020 None => Ok(NotUsed),
1021 });
1022 self.notify_changed();
1023 drop(state);
1024 self.drop_input();
1025 self.settle();
1026 return StreamRefOutboundPoll::Ready(Ok(StreamRefOutbound::Frame(
1027 self.frame(StreamRefMessage::Ack),
1028 )));
1029 }
1030
1031 if let Some(message) = state.pending_terminal.take() {
1032 drop(state);
1033 return StreamRefOutboundPoll::Ready(Ok(StreamRefOutbound::Frame(
1034 self.frame(message),
1035 )));
1036 }
1037
1038 if state.waiting_for_ack {
1039 if state
1040 .ack_deadline
1041 .is_some_and(|deadline| Instant::now() >= deadline)
1042 {
1043 let timeout_error =
1044 subscription_timeout_error("stream ref producer terminal ack");
1045 state.done = true;
1046 state.terminal_result = Some(Err(timeout_error.clone()));
1047 self.notify_changed();
1048 drop(state);
1049 self.drop_input();
1050 self.settle();
1051 return StreamRefOutboundPoll::Ready(Err(timeout_error));
1052 }
1053 return StreamRefOutboundPoll::Pending;
1054 }
1055
1056 if let Some(error) = state.stopped.clone() {
1057 state.done = true;
1058 state.terminal_result = Some(Err(error.clone()));
1059 self.notify_changed();
1060 drop(state);
1061 self.drop_input();
1062 self.settle();
1063 return StreamRefOutboundPoll::Ready(Err(error));
1064 }
1065
1066 if state.cumulative_demand > 0 && state.sent < state.cumulative_demand {
1067 drop(state);
1068 match self
1069 .try_pull_next_outbound_batch(max_data_elements.max(1), max_data_bytes.max(1))
1070 {
1071 ProducerBatchPoll::Ready(outbound) => {
1072 return StreamRefOutboundPoll::Ready(outbound);
1073 }
1074 ProducerBatchPoll::StateChanged => continue,
1075 ProducerBatchPoll::Pending => return StreamRefOutboundPoll::Pending,
1076 }
1077 }
1078
1079 if state.cumulative_demand == 0 {
1080 let deadline = *state
1081 .first_demand_deadline
1082 .get_or_insert_with(|| deadline_from_now(self.settings.subscription_timeout()));
1083 if Instant::now() >= deadline {
1084 let timeout_error =
1085 subscription_timeout_error("stream ref producer first demand");
1086 state.done = true;
1087 state.terminal_result = Some(Err(timeout_error.clone()));
1088 self.notify_changed();
1089 drop(state);
1090 self.drop_input();
1091 self.settle();
1092 return StreamRefOutboundPoll::Ready(Err(timeout_error));
1093 }
1094 }
1095
1096 return StreamRefOutboundPoll::Pending;
1097 }
1098 }
1099
1100 fn pull_next_outbound_batch(
1101 &self,
1102 max_data_elements: usize,
1103 max_data_bytes: usize,
1104 ) -> Option<StreamResult<StreamRefOutbound>> {
1105 let mut batch: Option<StreamRefPayloadBatch> = None;
1106 while batch
1107 .as_ref()
1108 .is_none_or(|batch| batch.count() < max_data_elements)
1109 {
1110 if batch
1111 .as_ref()
1112 .is_some_and(|batch| batch.payloads.len() >= max_data_bytes)
1113 {
1114 break;
1115 }
1116
1117 let seq_nr = {
1118 let state = self.lock_state();
1119 if state.done
1120 || state.stopped.is_some()
1121 || state.waiting_for_ack
1122 || state.sent >= state.cumulative_demand
1123 {
1124 break;
1125 }
1126 state.sent
1127 };
1128
1129 let item = match self.next_input_item() {
1130 Some(item) => item,
1131 None => break,
1132 };
1133
1134 match item {
1135 Ok(item) => {
1136 let mut state = self.lock_state();
1137 if state.done
1138 || state.stopped.is_some()
1139 || state.waiting_for_ack
1140 || state.sent != seq_nr
1141 {
1142 break;
1143 }
1144 state.sent = state.sent.saturating_add(1);
1145 drop(state);
1146
1147 let batch = batch.get_or_insert_with(|| {
1148 StreamRefPayloadBatch::new(self.stream_ref_id, seq_nr)
1149 });
1150 if let Err(error) = batch.push_payload(item) {
1151 return Some(Err(error));
1152 }
1153 }
1154 Err(error) => {
1155 let terminal = StreamRefMessage::RemoteStreamFailure {
1156 cause: failure_cause(&error),
1157 };
1158 self.note_terminal(Err(error), terminal.clone());
1159 return match batch {
1160 Some(batch) if !batch.is_empty() => {
1161 Some(Ok(StreamRefOutbound::SequencedBatch(batch)))
1162 }
1163 _ => None,
1164 };
1165 }
1166 }
1167 }
1168
1169 batch
1170 .filter(|batch| !batch.is_empty())
1171 .map(|batch| Ok(StreamRefOutbound::SequencedBatch(batch)))
1172 }
1173
1174 fn try_pull_next_outbound_batch(
1175 &self,
1176 max_data_elements: usize,
1177 max_data_bytes: usize,
1178 ) -> ProducerBatchPoll {
1179 let mut batch: Option<StreamRefPayloadBatch> = None;
1180 while batch
1181 .as_ref()
1182 .is_none_or(|batch| batch.count() < max_data_elements)
1183 {
1184 if batch
1185 .as_ref()
1186 .is_some_and(|batch| batch.payloads.len() >= max_data_bytes)
1187 {
1188 break;
1189 }
1190
1191 let seq_nr = {
1192 let state = self.lock_state();
1193 if state.done
1194 || state.stopped.is_some()
1195 || state.waiting_for_ack
1196 || state.sent >= state.cumulative_demand
1197 {
1198 break;
1199 }
1200 if !state.input_attached {
1201 return match batch {
1202 Some(batch) if !batch.is_empty() => {
1203 ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1204 }
1205 _ => ProducerBatchPoll::Pending,
1206 };
1207 }
1208 state.sent
1209 };
1210
1211 let item = match self.try_next_input_item() {
1212 InputItemPoll::Ready(item) => item,
1213 InputItemPoll::Pending => {
1214 return match batch {
1215 Some(batch) if !batch.is_empty() => {
1216 ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1217 }
1218 _ => ProducerBatchPoll::Pending,
1219 };
1220 }
1221 InputItemPoll::TerminalQueued => {
1222 return match batch {
1223 Some(batch) if !batch.is_empty() => {
1224 ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1225 }
1226 _ => ProducerBatchPoll::StateChanged,
1227 };
1228 }
1229 };
1230
1231 match item {
1232 Ok(item) => {
1233 let mut state = self.lock_state();
1234 if state.done
1235 || state.stopped.is_some()
1236 || state.waiting_for_ack
1237 || state.sent != seq_nr
1238 {
1239 break;
1240 }
1241 state.sent = state.sent.saturating_add(1);
1242 drop(state);
1243
1244 let batch = batch.get_or_insert_with(|| {
1245 StreamRefPayloadBatch::new(self.stream_ref_id, seq_nr)
1246 });
1247 if let Err(error) = batch.push_payload(item) {
1248 return ProducerBatchPoll::Ready(Err(error));
1249 }
1250 }
1251 Err(error) => {
1252 let terminal = StreamRefMessage::RemoteStreamFailure {
1253 cause: failure_cause(&error),
1254 };
1255 self.note_terminal(Err(error), terminal);
1256 return match batch {
1257 Some(batch) if !batch.is_empty() => {
1258 ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1259 }
1260 _ => ProducerBatchPoll::StateChanged,
1261 };
1262 }
1263 }
1264 }
1265
1266 match batch {
1267 Some(batch) if !batch.is_empty() => {
1268 ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1269 }
1270 _ => ProducerBatchPoll::Pending,
1271 }
1272 }
1273
1274 fn next_input_item(&self) -> Option<StreamResult<T>> {
1275 let mut input_guard = self.lock_input();
1276 if input_guard.is_none() {
1277 drop(input_guard);
1278 let mut state = self.lock_state();
1283 while !state.input_attached
1284 && !state.done
1285 && state.stopped.is_none()
1286 && !state.terminal_sent
1287 {
1288 state = wait_unpoison(&self.changed, state);
1289 }
1290 drop(state);
1291 return None;
1292 }
1293
1294 match input_guard.as_mut().expect("input attached").next() {
1295 Some(item) => Some(item),
1296 None => {
1297 drop(input_guard);
1298 let seq_nr = self.lock_state().sent;
1299 self.note_terminal(
1300 Ok(NotUsed),
1301 StreamRefMessage::RemoteStreamCompleted { seq_nr },
1302 );
1303 None
1304 }
1305 }
1306 }
1307
1308 fn try_next_input_item(&self) -> InputItemPoll<T> {
1309 let mut input_guard = self.lock_input();
1310 let Some(input) = input_guard.as_mut() else {
1311 return InputItemPoll::Pending;
1312 };
1313
1314 match input.next() {
1315 Some(item) => InputItemPoll::Ready(item),
1316 None => {
1317 drop(input_guard);
1318 let seq_nr = self.lock_state().sent;
1319 self.note_terminal(
1320 Ok(NotUsed),
1321 StreamRefMessage::RemoteStreamCompleted { seq_nr },
1322 );
1323 InputItemPoll::TerminalQueued
1324 }
1325 }
1326 }
1327
1328 fn note_terminal(&self, result: StreamResult<NotUsed>, terminal_message: StreamRefMessage) {
1329 self.drop_input();
1330 let mut state = self.lock_state();
1331 if state.done || state.terminal_sent {
1332 return;
1333 }
1334 state.terminal_sent = true;
1335 state.waiting_for_ack = true;
1336 state.terminal_result = Some(result);
1337 state.ack_deadline = Some(deadline_from_now(self.settings.subscription_timeout()));
1338 state.pending_terminal = Some(terminal_message);
1339 self.notify_changed();
1340 drop(state);
1341 }
1342
1343 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1344 self.validate_frame_id(frame.stream_ref_id)?;
1345 match frame.message {
1346 StreamRefMessage::OnSubscribeHandshake => {
1347 let mut state = self.lock_state();
1348 state.partner_seen = true;
1349 self.notify_changed();
1350 drop(state);
1351 Ok(())
1352 }
1353 StreamRefMessage::CumulativeDemand { seq_nr } => {
1354 if seq_nr == 0 {
1355 return Err(StreamError::Failed(
1356 "CumulativeDemand seq_nr must be positive".to_owned(),
1357 ));
1358 }
1359 let mut state = self.lock_state();
1360 state.partner_seen = true;
1361 if seq_nr > state.cumulative_demand {
1362 state.cumulative_demand = seq_nr;
1363 }
1364 self.notify_changed();
1365 drop(state);
1366 Ok(())
1367 }
1368 StreamRefMessage::RemoteStreamCompleted { .. } => {
1369 self.stop_from_consumer(StreamError::Cancelled);
1370 Ok(())
1371 }
1372 StreamRefMessage::RemoteStreamFailure { cause } => {
1373 self.stop_from_consumer(StreamError::Failed(
1374 String::from_utf8_lossy(&cause).into_owned(),
1375 ));
1376 Ok(())
1377 }
1378 StreamRefMessage::Ack => {
1379 let mut state = self.lock_state();
1380 if state.waiting_for_ack {
1381 state.waiting_for_ack = false;
1382 state.done = true;
1383 if state.terminal_result.is_none() {
1384 state.terminal_result = Some(Ok(NotUsed));
1385 }
1386 self.notify_changed();
1387 drop(state);
1388 self.drop_input();
1389 self.settle();
1390 } else {
1391 drop(state);
1392 }
1393 Ok(())
1394 }
1395 StreamRefMessage::SequencedOnNext { .. } => Err(StreamError::Failed(
1396 "producer endpoint cannot receive SequencedOnNext".to_owned(),
1397 )),
1398 }
1399 }
1400
1401 fn stop_from_consumer(&self, error: StreamError) {
1402 let mut state = self.lock_state();
1403 if !state.done {
1404 state.stopped = Some(error.clone());
1405 state.ack_queued = true;
1406 state.terminal_result = Some(Err(error));
1407 }
1408 self.notify_changed();
1409 drop(state);
1410 self.drop_input();
1411 }
1412
1413 fn fail_connection(&self, error: StreamError) {
1414 let mut state = self.lock_state();
1415 if !state.done {
1416 state.stopped = Some(error.clone());
1417 state.done = true;
1418 state.terminal_result = Some(Err(error));
1419 }
1420 self.notify_changed();
1421 drop(state);
1422 self.drop_input();
1423 self.settle();
1424 }
1425
1426 fn attach_input(&self, input: BoxStream<T>) {
1427 *self.lock_input() = Some(input);
1428 let mut state = self.lock_state();
1429 state.input_attached = true;
1430 self.notify_changed();
1431 drop(state);
1432 }
1433
1434 fn settle(&self) {
1435 let result = self.lock_state().terminal_result.clone();
1436 let sender = self
1437 .completion
1438 .lock()
1439 .unwrap_or_else(|poison| poison.into_inner())
1440 .take();
1441 if let (Some(sender), Some(result)) = (sender, result) {
1442 let _ = sender.send(result);
1443 }
1444 }
1445
1446 fn drop_input(&self) {
1447 let input = self.lock_input().take();
1448 drop(input);
1449 }
1450
1451 fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
1452 if stream_ref_id == self.stream_ref_id {
1453 Ok(())
1454 } else {
1455 Err(StreamError::Failed(format!(
1456 "stream ref id mismatch: expected {}, got {}",
1457 self.stream_ref_id, stream_ref_id
1458 )))
1459 }
1460 }
1461}
1462
1463pub struct StreamRefProtoConsumer<T>
1469where
1470 T: StreamRefPayload,
1471{
1472 shared: Arc<ConsumerShared<T>>,
1473}
1474
1475impl<T> Clone for StreamRefProtoConsumer<T>
1476where
1477 T: StreamRefPayload,
1478{
1479 fn clone(&self) -> Self {
1480 Self {
1481 shared: Arc::clone(&self.shared),
1482 }
1483 }
1484}
1485
1486impl<T> StreamRefProtoConsumer<T>
1487where
1488 T: StreamRefPayload,
1489{
1490 #[must_use]
1491 pub fn new(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
1492 Self {
1493 shared: Arc::new(ConsumerShared {
1494 stream_ref_id,
1495 settings,
1496 state: Mutex::new(ConsumerState {
1497 source_taken: false,
1498 subscribed: false,
1499 queue: VecDeque::new(),
1500 direct_terminal: false,
1501 direct_consumer: None,
1502 direct_cancelled: None,
1503 terminal: None,
1504 expected_seq: 0,
1505 delivered: 0,
1506 cumulative_demand: 0,
1507 outbound: VecDeque::new(),
1508 finish_after_outbound_ack: false,
1509 waiting_cancel_ack: false,
1510 done: false,
1511 }),
1512 changed: Condvar::new(),
1513 outbound_wake: OutboundWake::default(),
1514 }),
1515 }
1516 }
1517
1518 #[must_use]
1519 pub fn source(&self) -> Source<T, NotUsed> {
1520 let shared_for_stream = Arc::clone(&self.shared);
1521 let shared_for_terminal = Arc::clone(&self.shared);
1522 Source::from_terminal_direct_materialized_factory(
1523 move |_materializer| {
1524 shared_for_stream
1525 .start_stream()
1526 .map(|stream| (Box::new(stream) as BoxStream<T>, NotUsed))
1527 },
1528 move |_materializer| {
1529 Ok((
1530 Arc::new(StreamRefProtoTerminalHook {
1531 shared: Arc::clone(&shared_for_terminal),
1532 stream: Mutex::new(None),
1533 }) as Arc<dyn TerminalSourceHookDyn<T>>,
1534 NotUsed,
1535 ))
1536 },
1537 )
1538 }
1539}
1540
1541impl<T> StreamRefProtoEndpoint for StreamRefProtoConsumer<T>
1542where
1543 T: StreamRefPayload,
1544{
1545 fn stream_ref_id(&self) -> StreamRefId {
1546 self.shared.stream_ref_id
1547 }
1548
1549 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1550 self.shared.next_frame()
1551 }
1552
1553 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1554 self.shared.handle_frame(frame)
1555 }
1556
1557 fn handle_sequenced_on_next_batch(
1558 &self,
1559 stream_ref_id: StreamRefId,
1560 first_seq_nr: u64,
1561 payloads: &[&[u8]],
1562 ) -> StreamResult<()> {
1563 self.shared
1564 .handle_sequenced_on_next_batch(stream_ref_id, first_seq_nr, payloads)
1565 }
1566
1567 fn fail_connection(&self, error: StreamError) {
1568 self.shared.fail_connection(error);
1569 }
1570}
1571
1572impl<T> StreamRefProtoEndpointWake for StreamRefProtoConsumer<T>
1573where
1574 T: StreamRefPayload,
1575{
1576 fn install_outbound_wake(&self, sender: tokio_mpsc::Sender<()>) {
1577 self.shared.outbound_wake.install(sender);
1578 }
1579
1580 fn clear_outbound_wake(&self) {
1581 self.shared.outbound_wake.clear();
1582 }
1583
1584 fn try_next_outbound(
1585 &self,
1586 _max_data_elements: usize,
1587 _max_data_bytes: usize,
1588 ) -> StreamRefOutboundPoll {
1589 self.shared.try_next_frame()
1590 }
1591}
1592
1593struct ConsumerShared<T>
1594where
1595 T: StreamRefPayload,
1596{
1597 stream_ref_id: StreamRefId,
1598 settings: StreamRefSettings,
1599 state: Mutex<ConsumerState<T>>,
1600 changed: Condvar,
1601 outbound_wake: OutboundWake,
1602}
1603
1604struct ConsumerState<T> {
1605 source_taken: bool,
1606 subscribed: bool,
1607 queue: VecDeque<T>,
1608 direct_terminal: bool,
1609 direct_consumer: Option<Box<dyn TerminalSinkConsumerDyn<T>>>,
1610 direct_cancelled: Option<Arc<std::sync::atomic::AtomicBool>>,
1611 terminal: Option<ConsumerTerminal>,
1612 expected_seq: u64,
1613 delivered: u64,
1614 cumulative_demand: u64,
1615 outbound: VecDeque<StreamRefMessage>,
1616 finish_after_outbound_ack: bool,
1617 waiting_cancel_ack: bool,
1618 done: bool,
1619}
1620
1621#[derive(Clone)]
1622enum ConsumerTerminal {
1623 Complete,
1624 Error(StreamError),
1625}
1626
1627struct StreamRefProtoTerminalHook<T>
1628where
1629 T: StreamRefPayload,
1630{
1631 shared: Arc<ConsumerShared<T>>,
1632 stream: Mutex<Option<ConsumerStream<T>>>,
1633}
1634
1635impl<T> TerminalSourceHookDyn<T> for StreamRefProtoTerminalHook<T>
1636where
1637 T: StreamRefPayload,
1638{
1639 fn drain_terminal_batch(
1640 &self,
1641 materializer: &Materializer,
1642 cancelled: &Arc<std::sync::atomic::AtomicBool>,
1643 batch: &mut Vec<T>,
1644 ) -> StreamResult<TerminalSourceStatus> {
1645 batch.clear();
1646 if materializer.is_shutdown() {
1647 self.shared.cancel_from_downstream();
1648 return Err(StreamError::AbruptTermination);
1649 }
1650 if cancelled.load(Ordering::SeqCst) {
1651 self.shared.cancel_from_downstream();
1652 return Err(StreamError::Cancelled);
1653 }
1654
1655 let mut stream = self
1656 .stream
1657 .lock()
1658 .unwrap_or_else(|poison| poison.into_inner());
1659 if stream.is_none() {
1660 *stream = Some(self.shared.start_stream()?);
1661 }
1662 let stream = stream.as_mut().expect("terminal stream present");
1663 for _ in 0..self.shared.settings.buffer_capacity().max(1) {
1664 match stream.next_item()? {
1665 Some(item) => batch.push(item),
1666 None => return Ok(TerminalSourceStatus::Completed),
1667 }
1668 if batch.len() >= 64 {
1669 break;
1670 }
1671 }
1672 Ok(TerminalSourceStatus::Active)
1673 }
1674
1675 fn supports_direct_terminal(&self) -> bool {
1676 true
1677 }
1678
1679 fn try_register_direct_terminal(
1680 &self,
1681 consumer: Box<dyn TerminalSinkConsumerDyn<T>>,
1682 cancelled: Arc<std::sync::atomic::AtomicBool>,
1683 ) -> Option<StreamResult<()>> {
1684 Some(self.shared.start_direct_terminal(consumer, cancelled))
1685 }
1686
1687 fn cancel_terminal(&self) {
1688 self.shared.cancel_from_downstream();
1689 }
1690}
1691
1692impl<T> ConsumerShared<T>
1693where
1694 T: StreamRefPayload,
1695{
1696 fn lock_state(&self) -> MutexGuard<'_, ConsumerState<T>> {
1697 self.state
1698 .lock()
1699 .unwrap_or_else(|poison| poison.into_inner())
1700 }
1701
1702 fn notify_changed(&self) {
1703 self.changed.notify_all();
1704 self.outbound_wake.wake();
1705 }
1706
1707 fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
1708 StreamRefFrame::new(self.stream_ref_id, message)
1709 }
1710
1711 fn start_stream(self: &Arc<Self>) -> StreamResult<ConsumerStream<T>> {
1712 self.start_subscription()?;
1713 Ok(ConsumerStream {
1714 shared: Arc::clone(self),
1715 terminated: false,
1716 })
1717 }
1718
1719 fn start_subscription(self: &Arc<Self>) -> StreamResult<()> {
1720 {
1721 let mut state = self.lock_state();
1722 if state.source_taken {
1723 return Err(StreamError::Failed(
1724 "stream ref source has already been materialized".to_owned(),
1725 ));
1726 }
1727 state.source_taken = true;
1728 if !state.subscribed {
1729 state.subscribed = true;
1730 state
1731 .outbound
1732 .push_back(StreamRefMessage::OnSubscribeHandshake);
1733 if let Some(demand) = next_demand(&mut state, self.settings) {
1734 state
1735 .outbound
1736 .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
1737 }
1738 }
1739 self.notify_changed();
1740 }
1741 Ok(())
1742 }
1743
1744 fn start_direct_terminal(
1745 self: &Arc<Self>,
1746 consumer: Box<dyn TerminalSinkConsumerDyn<T>>,
1747 cancelled: Arc<std::sync::atomic::AtomicBool>,
1748 ) -> StreamResult<()> {
1749 let mut finish = None;
1750 {
1751 let mut state = self.lock_state();
1752 if state.source_taken {
1753 return Err(StreamError::Failed(
1754 "stream ref source has already been materialized".to_owned(),
1755 ));
1756 }
1757 state.source_taken = true;
1758 state.direct_terminal = true;
1759 state.direct_cancelled = Some(cancelled);
1760 state.direct_consumer = Some(consumer);
1761 if let Some(terminal) = state.terminal.clone() {
1762 finish = state
1763 .direct_consumer
1764 .take()
1765 .map(|consumer| (consumer, terminal_result(terminal)));
1766 } else if state.done {
1767 finish = state
1768 .direct_consumer
1769 .take()
1770 .map(|consumer| (consumer, Err(StreamError::AbruptTermination)));
1771 } else if !state.subscribed {
1772 state.subscribed = true;
1773 state
1774 .outbound
1775 .push_back(StreamRefMessage::OnSubscribeHandshake);
1776 if let Some(demand) = next_demand(&mut state, self.settings) {
1777 state
1778 .outbound
1779 .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
1780 }
1781 }
1782 self.notify_changed();
1783 }
1784 if let Some((consumer, result)) = finish {
1785 consumer.finish(result);
1786 }
1787 Ok(())
1788 }
1789
1790 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1791 loop {
1792 let mut state = self.lock_state();
1793 if let Some(message) = state.outbound.pop_front() {
1794 let finish_after_ack = message.is_ack() && state.finish_after_outbound_ack;
1795 if finish_after_ack {
1796 state.done = true;
1797 }
1798 drop(state);
1799 return Some(Ok(self.frame(message)));
1800 }
1801 if state.done {
1802 return None;
1803 }
1804 let next = wait_unpoison(&self.changed, state);
1805 drop(next);
1806 }
1807 }
1808
1809 fn try_next_frame(&self) -> StreamRefOutboundPoll {
1810 let mut state = self.lock_state();
1811 if let Some(message) = state.outbound.pop_front() {
1812 let finish_after_ack = message.is_ack() && state.finish_after_outbound_ack;
1813 if finish_after_ack {
1814 state.done = true;
1815 }
1816 drop(state);
1817 StreamRefOutboundPoll::Ready(Ok(StreamRefOutbound::Frame(self.frame(message))))
1818 } else if state.done {
1819 StreamRefOutboundPoll::Closed
1820 } else {
1821 StreamRefOutboundPoll::Pending
1822 }
1823 }
1824
1825 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1826 self.validate_frame_id(frame.stream_ref_id)?;
1827 match frame.message {
1828 StreamRefMessage::OnSubscribeHandshake => Ok(()),
1829 StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
1830 let item = T::decode_stream_ref_payload(payload.bytes)?;
1831 let demand_ceiling = self.lock_state().cumulative_demand;
1832 self.handle_decoded_on_next(seq_nr, item, demand_ceiling)
1833 }
1834 StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
1835 let finish = self.handle_remote_completed(seq_nr);
1836 if let Some((consumer, result)) = finish {
1837 consumer.finish(result);
1838 }
1839 Ok(())
1840 }
1841 StreamRefMessage::RemoteStreamFailure { cause } => {
1842 let error = StreamError::Failed(String::from_utf8_lossy(&cause).into_owned());
1843 let finish = self.handle_remote_failure(error);
1844 if let Some((consumer, result)) = finish {
1845 consumer.finish(result);
1846 }
1847 Ok(())
1848 }
1849 StreamRefMessage::Ack => {
1850 let mut state = self.lock_state();
1851 if state.waiting_cancel_ack {
1852 state.waiting_cancel_ack = false;
1853 state.done = true;
1854 }
1855 self.notify_changed();
1856 drop(state);
1857 Ok(())
1858 }
1859 StreamRefMessage::CumulativeDemand { .. } => Err(StreamError::Failed(
1860 "consumer endpoint cannot receive CumulativeDemand".to_owned(),
1861 )),
1862 }
1863 }
1864
1865 fn handle_sequenced_on_next_batch(
1866 &self,
1867 stream_ref_id: StreamRefId,
1868 first_seq_nr: u64,
1869 payloads: &[&[u8]],
1870 ) -> StreamResult<()> {
1871 self.validate_frame_id(stream_ref_id)?;
1872 if self.lock_state().direct_terminal {
1873 return self.handle_direct_sequenced_on_next_batch(first_seq_nr, payloads);
1874 }
1875 let mut state = self.lock_state();
1876 for (index, payload) in payloads.iter().enumerate() {
1877 if state.terminal.is_some() || state.done {
1878 break;
1879 }
1880 let seq_nr = first_seq_nr.checked_add(index as u64).ok_or_else(|| {
1881 StreamError::Failed("stream ref batch sequence overflow".to_owned())
1882 })?;
1883 let item = T::decode_stream_ref_payload_slice(payload)?;
1884 self.handle_decoded_on_next_locked(&mut state, seq_nr, item);
1885 }
1886 self.notify_changed();
1887 drop(state);
1888 Ok(())
1889 }
1890
1891 fn handle_direct_sequenced_on_next_batch(
1892 &self,
1893 first_seq_nr: u64,
1894 payloads: &[&[u8]],
1895 ) -> StreamResult<()> {
1896 if payloads.is_empty() {
1897 return Ok(());
1898 }
1899 let count = payloads.len() as u64;
1900 let last_seq_nr = first_seq_nr
1901 .checked_add(count - 1)
1902 .ok_or_else(|| StreamError::Failed("stream ref batch sequence overflow".to_owned()))?;
1903
1904 let mut consumer = {
1905 let mut state = self.lock_state();
1906 if state.terminal.is_some() || state.done {
1907 return Ok(());
1908 }
1909 if state
1910 .direct_cancelled
1911 .as_ref()
1912 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
1913 {
1914 let error = StreamError::Cancelled;
1915 self.fail_consumer_locked(&mut state, error.clone());
1916 let consumer = state.direct_consumer.take();
1917 self.notify_changed();
1918 drop(state);
1919 if let Some(consumer) = consumer {
1920 consumer.finish(Err(error));
1921 }
1922 return Ok(());
1923 }
1924
1925 let error = if first_seq_nr != state.expected_seq {
1926 Some(invalid_sequence_error(
1927 state.expected_seq,
1928 first_seq_nr,
1929 "stream ref element",
1930 ))
1931 } else if last_seq_nr >= state.cumulative_demand {
1932 Some(StreamError::Failed(
1933 "stream ref receive buffer overflowed demand window".to_owned(),
1934 ))
1935 } else {
1936 None
1937 };
1938 if let Some(error) = error {
1939 self.fail_consumer_locked(&mut state, error.clone());
1940 let consumer = state.direct_consumer.take();
1941 self.notify_changed();
1942 drop(state);
1943 if let Some(consumer) = consumer {
1944 consumer.finish(Err(error));
1945 }
1946 return Ok(());
1947 }
1948
1949 state.expected_seq = state.expected_seq.saturating_add(count);
1950 match state.direct_consumer.take() {
1951 Some(consumer) => consumer,
1952 None => return Ok(()),
1953 }
1954 };
1955
1956 let mut consumed = 0_u64;
1957 let mut consume_result = Ok(());
1958 for payload in payloads {
1959 match T::decode_stream_ref_payload_slice(payload)
1960 .and_then(|item| consumer.on_item(item))
1961 {
1962 Ok(()) => consumed = consumed.saturating_add(1),
1963 Err(error) => {
1964 consume_result = Err(error);
1965 break;
1966 }
1967 }
1968 }
1969
1970 let finish = {
1971 let mut state = self.lock_state();
1972 if let Some(terminal) = state.terminal.clone() {
1973 Some((consumer, terminal_result(terminal)))
1974 } else if state.done {
1975 Some((consumer, Err(StreamError::AbruptTermination)))
1976 } else {
1977 match consume_result {
1978 Ok(()) => {
1979 debug_assert_eq!(consumed, count);
1980 for _ in 0..count {
1981 state.delivered = state.delivered.saturating_add(1);
1982 if let Some(demand) = next_demand(&mut state, self.settings) {
1983 state
1984 .outbound
1985 .push_back(StreamRefMessage::CumulativeDemand {
1986 seq_nr: demand,
1987 });
1988 }
1989 }
1990 state.direct_consumer = Some(consumer);
1991 None
1992 }
1993 Err(error) => {
1994 self.fail_consumer_locked(&mut state, error.clone());
1995 Some((consumer, Err(error)))
1996 }
1997 }
1998 }
1999 };
2000 self.notify_changed();
2001 if let Some((consumer, result)) = finish {
2002 consumer.finish(result);
2003 }
2004 Ok(())
2005 }
2006
2007 fn handle_decoded_on_next(
2008 &self,
2009 seq_nr: u64,
2010 item: T,
2011 demand_ceiling: u64,
2012 ) -> StreamResult<()> {
2013 if self.lock_state().direct_terminal {
2014 self.handle_direct_decoded_on_next(seq_nr, item, demand_ceiling);
2015 } else {
2016 let mut state = self.lock_state();
2017 self.handle_decoded_on_next_locked(&mut state, seq_nr, item);
2018 self.notify_changed();
2019 drop(state);
2020 }
2021 Ok(())
2022 }
2023
2024 fn handle_decoded_on_next_locked(&self, state: &mut ConsumerState<T>, seq_nr: u64, item: T) {
2025 if state.terminal.is_some() || state.done {
2026 return;
2027 }
2028 if seq_nr != state.expected_seq {
2029 let error = invalid_sequence_error(state.expected_seq, seq_nr, "stream ref element");
2030 self.fail_consumer_locked(state, error);
2031 } else if state.queue.len() >= self.settings.buffer_capacity() {
2032 self.fail_consumer_locked(
2033 state,
2034 StreamError::Failed(
2035 "stream ref receive buffer overflowed demand window".to_owned(),
2036 ),
2037 );
2038 } else {
2039 state.expected_seq = state.expected_seq.saturating_add(1);
2040 state.queue.push_back(item);
2041 }
2042 }
2043
2044 fn handle_direct_decoded_on_next(&self, seq_nr: u64, item: T, demand_ceiling: u64) {
2045 let mut consumer = {
2046 let mut state = self.lock_state();
2047 if state.terminal.is_some() || state.done {
2048 return;
2049 }
2050 if state
2051 .direct_cancelled
2052 .as_ref()
2053 .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
2054 {
2055 let error = StreamError::Cancelled;
2056 self.fail_consumer_locked(&mut state, error.clone());
2057 let consumer = state.direct_consumer.take();
2058 self.notify_changed();
2059 drop(state);
2060 if let Some(consumer) = consumer {
2061 consumer.finish(Err(error));
2062 }
2063 return;
2064 }
2065 let error = if seq_nr != state.expected_seq {
2066 Some(invalid_sequence_error(
2067 state.expected_seq,
2068 seq_nr,
2069 "stream ref element",
2070 ))
2071 } else if seq_nr >= demand_ceiling {
2072 Some(StreamError::Failed(
2073 "stream ref receive buffer overflowed demand window".to_owned(),
2074 ))
2075 } else {
2076 None
2077 };
2078 if let Some(error) = error {
2079 self.fail_consumer_locked(&mut state, error.clone());
2080 let consumer = state.direct_consumer.take();
2081 self.notify_changed();
2082 drop(state);
2083 if let Some(consumer) = consumer {
2084 consumer.finish(Err(error));
2085 }
2086 return;
2087 }
2088 state.expected_seq = state.expected_seq.saturating_add(1);
2089 match state.direct_consumer.take() {
2090 Some(consumer) => consumer,
2091 None => return,
2092 }
2093 };
2094
2095 let consume_result = consumer.on_item(item);
2096 let finish = {
2097 let mut state = self.lock_state();
2098 if let Some(terminal) = state.terminal.clone() {
2099 Some((consumer, terminal_result(terminal)))
2100 } else if state.done {
2101 Some((consumer, Err(StreamError::AbruptTermination)))
2102 } else {
2103 match consume_result {
2104 Ok(()) => {
2105 state.delivered = state.delivered.saturating_add(1);
2106 if let Some(demand) = next_demand(&mut state, self.settings) {
2107 state
2108 .outbound
2109 .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
2110 }
2111 state.direct_consumer = Some(consumer);
2112 None
2113 }
2114 Err(error) => {
2115 self.fail_consumer_locked(&mut state, error.clone());
2116 Some((consumer, Err(error)))
2117 }
2118 }
2119 }
2120 };
2121 self.notify_changed();
2122 if let Some((consumer, result)) = finish {
2123 consumer.finish(result);
2124 }
2125 }
2126
2127 fn handle_remote_completed(
2128 &self,
2129 seq_nr: u64,
2130 ) -> Option<(Box<dyn TerminalSinkConsumerDyn<T>>, StreamResult<()>)> {
2131 let mut finish = None;
2132 let mut state = self.lock_state();
2133 if state.terminal.is_none() && !state.done {
2134 if seq_nr != state.expected_seq {
2135 state.queue.clear();
2136 let error =
2137 invalid_sequence_error(state.expected_seq, seq_nr, "stream ref completion");
2138 state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2139 finish = state
2140 .direct_consumer
2141 .take()
2142 .map(|consumer| (consumer, Err(error)));
2143 } else {
2144 state.terminal = Some(ConsumerTerminal::Complete);
2145 finish = state
2146 .direct_consumer
2147 .take()
2148 .map(|consumer| (consumer, Ok(())));
2149 }
2150 state.outbound.push_back(StreamRefMessage::Ack);
2151 state.finish_after_outbound_ack = true;
2152 }
2153 self.notify_changed();
2154 drop(state);
2155 finish
2156 }
2157
2158 fn handle_remote_failure(
2159 &self,
2160 error: StreamError,
2161 ) -> Option<(Box<dyn TerminalSinkConsumerDyn<T>>, StreamResult<()>)> {
2162 let mut state = self.lock_state();
2163 let mut finish = None;
2164 if state.terminal.is_none() && !state.done {
2165 state.queue.clear();
2166 state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2167 finish = state
2168 .direct_consumer
2169 .take()
2170 .map(|consumer| (consumer, Err(error)));
2171 state.outbound.push_back(StreamRefMessage::Ack);
2172 state.finish_after_outbound_ack = true;
2173 }
2174 self.notify_changed();
2175 drop(state);
2176 finish
2177 }
2178
2179 fn fail_consumer_locked(&self, state: &mut ConsumerState<T>, error: StreamError) {
2180 state.queue.clear();
2181 state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2182 state
2183 .outbound
2184 .push_back(StreamRefMessage::RemoteStreamFailure {
2185 cause: failure_cause(&error),
2186 });
2187 state.waiting_cancel_ack = true;
2188 }
2189
2190 fn cancel_from_downstream(&self) {
2191 let mut finish = None;
2192 let mut state = self.lock_state();
2193 if state.terminal.is_none() && !state.done {
2194 let seq_nr = state.expected_seq;
2195 let error = StreamError::Cancelled;
2196 state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2197 finish = state
2198 .direct_consumer
2199 .take()
2200 .map(|consumer| (consumer, Err(error)));
2201 state
2202 .outbound
2203 .push_back(StreamRefMessage::RemoteStreamCompleted { seq_nr });
2204 state.waiting_cancel_ack = true;
2205 }
2206 self.notify_changed();
2207 drop(state);
2208 if let Some((consumer, result)) = finish {
2209 consumer.finish(result);
2210 }
2211 }
2212
2213 fn fail_connection(&self, error: StreamError) {
2214 let mut finish = None;
2215 let mut state = self.lock_state();
2216 if state.terminal.is_none() {
2217 state.queue.clear();
2218 state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2219 finish = state
2220 .direct_consumer
2221 .take()
2222 .map(|consumer| (consumer, Err(error.clone())));
2223 }
2224 state.done = true;
2225 self.notify_changed();
2226 drop(state);
2227 if let Some((consumer, result)) = finish {
2228 consumer.finish(result);
2229 }
2230 }
2231
2232 fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
2233 if stream_ref_id == self.stream_ref_id {
2234 Ok(())
2235 } else {
2236 Err(StreamError::Failed(format!(
2237 "stream ref id mismatch: expected {}, got {}",
2238 self.stream_ref_id, stream_ref_id
2239 )))
2240 }
2241 }
2242}
2243
2244struct ConsumerStream<T>
2245where
2246 T: StreamRefPayload,
2247{
2248 shared: Arc<ConsumerShared<T>>,
2249 terminated: bool,
2250}
2251
2252impl<T> ConsumerStream<T>
2253where
2254 T: StreamRefPayload,
2255{
2256 fn next_item(&mut self) -> StreamResult<Option<T>> {
2257 if self.terminated {
2258 return Ok(None);
2259 }
2260 loop {
2261 let mut state = self.shared.lock_state();
2262 if let Some(item) = state.queue.pop_front() {
2263 state.delivered = state.delivered.saturating_add(1);
2264 if let Some(demand) = next_demand(&mut state, self.shared.settings) {
2265 state
2266 .outbound
2267 .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
2268 self.shared.notify_changed();
2269 }
2270 return Ok(Some(item));
2271 }
2272
2273 if let Some(terminal) = state.terminal.clone() {
2274 self.terminated = true;
2275 return match terminal {
2276 ConsumerTerminal::Complete => Ok(None),
2277 ConsumerTerminal::Error(error) => Err(error),
2278 };
2279 }
2280
2281 let next = wait_unpoison(&self.shared.changed, state);
2282 drop(next);
2283 }
2284 }
2285
2286 fn close(&mut self) {
2287 if !self.terminated {
2288 self.shared.cancel_from_downstream();
2289 self.terminated = true;
2290 }
2291 }
2292}
2293
2294impl<T> Iterator for ConsumerStream<T>
2295where
2296 T: StreamRefPayload,
2297{
2298 type Item = StreamResult<T>;
2299
2300 fn next(&mut self) -> Option<Self::Item> {
2301 match self.next_item() {
2302 Ok(Some(item)) => Some(Ok(item)),
2303 Ok(None) => None,
2304 Err(error) => Some(Err(error)),
2305 }
2306 }
2307}
2308
2309impl<T> Drop for ConsumerStream<T>
2310where
2311 T: StreamRefPayload,
2312{
2313 fn drop(&mut self) {
2314 self.close();
2315 }
2316}
2317
2318fn next_demand<T>(state: &mut ConsumerState<T>, settings: StreamRefSettings) -> Option<u64> {
2319 if state.terminal.is_some() {
2323 return None;
2324 }
2325 let remaining_credit = state.cumulative_demand.saturating_sub(state.delivered);
2326 if state.cumulative_demand != 0 && remaining_credit > demand_replenish_threshold(settings) {
2327 return None;
2328 }
2329 let target = state
2330 .delivered
2331 .saturating_add(settings.buffer_capacity() as u64);
2332 if state.cumulative_demand >= target {
2333 return None;
2334 }
2335 state.cumulative_demand = target;
2336 Some(target)
2337}
2338
2339fn demand_replenish_threshold(settings: StreamRefSettings) -> u64 {
2340 (settings.buffer_capacity() as u64) / 2
2341}
2342
2343fn terminal_result(terminal: ConsumerTerminal) -> StreamResult<()> {
2344 match terminal {
2345 ConsumerTerminal::Complete => Ok(()),
2346 ConsumerTerminal::Error(error) => Err(error),
2347 }
2348}
2349
2350fn failure_cause(error: &StreamError) -> Vec<u8> {
2351 match error {
2352 StreamError::Failed(message) => message.clone().into_bytes(),
2353 other => other.to_string().into_bytes(),
2354 }
2355}
2356
2357fn subscription_timeout_error(side: &str) -> StreamError {
2358 StreamError::Failed(format!(
2359 "{side} remote side did not subscribe within subscription timeout"
2360 ))
2361}
2362
2363fn invalid_sequence_error(expected: u64, got: u64, context: &str) -> StreamError {
2364 StreamError::Failed(format!(
2365 "{context} sequence gap: expected sequence {expected}, got {got}"
2366 ))
2367}
2368
2369fn deadline_from_now(timeout: Duration) -> Instant {
2370 Instant::now()
2371 .checked_add(timeout)
2372 .unwrap_or_else(far_future)
2373}
2374
2375fn far_future() -> Instant {
2376 Instant::now() + Duration::from_secs(60 * 60 * 24 * 365)
2377}
2378
2379fn wait_timeout_unpoison<'a, T>(
2380 condvar: &Condvar,
2381 guard: MutexGuard<'a, T>,
2382 timeout: Duration,
2383) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
2384 condvar
2385 .wait_timeout(guard, timeout)
2386 .unwrap_or_else(|poison| poison.into_inner())
2387}
2388
2389fn wait_unpoison<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
2390 condvar
2391 .wait(guard)
2392 .unwrap_or_else(|poison| poison.into_inner())
2393}
2394
2395#[cfg(test)]
2396mod tests {
2397 use std::time::Duration;
2398
2399 use super::*;
2400 use crate::{Sink, Source, StreamRefs};
2401
2402 fn short_settings() -> StreamRefSettings {
2403 StreamRefSettings::default()
2404 .with_buffer_capacity(4)
2405 .with_subscription_timeout(Duration::from_millis(50))
2406 }
2407
2408 #[test]
2409 fn protobuf_frame_round_trip() {
2410 let frame = StreamRefFrame::new(
2411 StreamRefId::from_u128(42),
2412 StreamRefMessage::SequencedOnNext {
2413 seq_nr: 7,
2414 payload: StreamRefPayloadBytes {
2415 bytes: 99_u64.encode_stream_ref_payload(),
2416 },
2417 },
2418 );
2419
2420 let decoded = StreamRefFrame::decode(&frame.encode_to_vec()).unwrap();
2421 assert_eq!(decoded, frame);
2422 }
2423
2424 #[test]
2425 fn producer_consumer_seam_streams_with_low_watermark_demand() {
2426 let id = StreamRefId::from_u128(1);
2427 let settings = short_settings();
2428 let source_ref = Source::from_iter(0_u64..10)
2429 .run_with(StreamRefs::source_ref_with_settings(settings))
2430 .unwrap();
2431 let producer = StreamRefProtoProducer::from_source_ref(source_ref, id, settings).unwrap();
2432 let consumer = StreamRefProtoConsumer::<u64>::new(id, settings);
2433 let consumer_source = consumer.source();
2434
2435 let producer_thread = std::thread::spawn({
2436 let producer = producer.clone();
2437 let consumer = consumer.clone();
2438 move || {
2439 while let Some(frame) = producer.next_frame() {
2440 consumer.handle_frame(frame?)?;
2441 }
2442 Ok::<_, StreamError>(())
2443 }
2444 });
2445 let consumer_thread = std::thread::spawn({
2446 let producer = producer.clone();
2447 let consumer = consumer.clone();
2448 move || {
2449 while let Some(frame) = consumer.next_frame() {
2450 producer.handle_frame(frame?)?;
2451 }
2452 Ok::<_, StreamError>(())
2453 }
2454 });
2455
2456 assert_eq!(
2457 consumer_source.run_collect().unwrap(),
2458 (0_u64..10).collect::<Vec<_>>()
2459 );
2460 producer_thread.join().unwrap().unwrap();
2461 consumer_thread.join().unwrap().unwrap();
2462 }
2463
2464 #[test]
2465 fn strict_sequence_gap_fails_consumer_and_sends_failure() {
2466 let id = StreamRefId::from_u128(2);
2467 let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
2468 let source = consumer
2469 .source()
2470 .run_with(crate::testkit::TestSink::probe())
2471 .unwrap();
2472 source.request(1);
2473 consumer.next_frame().unwrap().unwrap();
2474 consumer.next_frame().unwrap().unwrap();
2475
2476 consumer
2477 .handle_frame(StreamRefFrame::new(
2478 id,
2479 StreamRefMessage::SequencedOnNext {
2480 seq_nr: 1,
2481 payload: StreamRefPayloadBytes {
2482 bytes: 1_u64.encode_stream_ref_payload(),
2483 },
2484 },
2485 ))
2486 .unwrap();
2487
2488 let outbound = consumer.next_frame().unwrap().unwrap();
2489 assert!(matches!(
2490 outbound.message,
2491 StreamRefMessage::RemoteStreamFailure { .. }
2492 ));
2493 assert!(matches!(source.expect_error(), StreamError::Failed(_)));
2494 }
2495
2496 #[test]
2497 fn direct_terminal_sequence_gap_uses_shared_failure_path() {
2498 let id = StreamRefId::from_u128(23);
2499 let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
2500 let completion = consumer
2501 .source()
2502 .run_with(Sink::fold(0_u64, |acc, item| acc + item))
2503 .unwrap();
2504 consumer.next_frame().unwrap().unwrap();
2505 consumer.next_frame().unwrap().unwrap();
2506
2507 consumer
2508 .handle_frame(StreamRefFrame::new(
2509 id,
2510 StreamRefMessage::SequencedOnNext {
2511 seq_nr: 1,
2512 payload: StreamRefPayloadBytes {
2513 bytes: 1_u64.encode_stream_ref_payload(),
2514 },
2515 },
2516 ))
2517 .unwrap();
2518
2519 let outbound = consumer.next_frame().unwrap().unwrap();
2520 assert!(matches!(
2521 outbound.message,
2522 StreamRefMessage::RemoteStreamFailure { .. }
2523 ));
2524 assert!(
2525 matches!(completion.wait(), Err(StreamError::Failed(message)) if message.contains("sequence gap"))
2526 );
2527 }
2528
2529 #[test]
2530 fn direct_terminal_batch_over_demand_ceiling_fails_consumer() {
2531 let id = StreamRefId::from_u128(24);
2532 let settings = short_settings();
2533 let consumer = StreamRefProtoConsumer::<u64>::new(id, settings);
2534 let completion = consumer
2535 .source()
2536 .run_with(Sink::fold(0_u64, |acc, item| acc + item))
2537 .unwrap();
2538 consumer.next_frame().unwrap().unwrap();
2539 consumer.next_frame().unwrap().unwrap();
2540
2541 let payloads = (0_u64..=settings.buffer_capacity() as u64)
2542 .map(u64::encode_stream_ref_payload)
2543 .collect::<Vec<_>>();
2544 let payload_slices = payloads.iter().map(Vec::as_slice).collect::<Vec<&[u8]>>();
2545 consumer
2546 .handle_sequenced_on_next_batch(id, 0, &payload_slices)
2547 .unwrap();
2548
2549 let mut saw_failure = false;
2550 for _ in 0..4 {
2551 let Some(outbound) = consumer.next_frame() else {
2552 break;
2553 };
2554 if matches!(
2555 outbound.unwrap().message,
2556 StreamRefMessage::RemoteStreamFailure { .. }
2557 ) {
2558 saw_failure = true;
2559 break;
2560 }
2561 }
2562 assert!(saw_failure);
2563 assert!(
2564 matches!(completion.wait(), Err(StreamError::Failed(message)) if message.contains("demand window"))
2565 );
2566 }
2567
2568 #[test]
2569 fn producer_batches_ready_elements_and_preserves_completion_order() {
2570 let id = StreamRefId::from_u128(20);
2571 let settings = StreamRefSettings::default().with_buffer_capacity(8);
2572 let producer =
2573 StreamRefProtoProducer::from_source(Source::from_iter(0_u64..6), id, settings).unwrap();
2574 producer
2575 .handle_frame(StreamRefFrame::new(
2576 id,
2577 StreamRefMessage::CumulativeDemand { seq_nr: 8 },
2578 ))
2579 .unwrap();
2580
2581 let first = producer.next_outbound(4, usize::MAX).unwrap().unwrap();
2582 let StreamRefOutbound::SequencedBatch(first) = first else {
2583 panic!("expected first data batch");
2584 };
2585 assert_eq!(first.first_seq_nr(), 0);
2586 assert_eq!(first.count(), 4);
2587 for index in 0..first.count() {
2588 assert_eq!(
2589 u64::decode_stream_ref_payload_slice(first.payload(index)).unwrap(),
2590 index as u64
2591 );
2592 }
2593
2594 let second = producer.next_outbound(4, usize::MAX).unwrap().unwrap();
2595 let StreamRefOutbound::SequencedBatch(second) = second else {
2596 panic!("expected second data batch");
2597 };
2598 assert_eq!(second.first_seq_nr(), 4);
2599 assert_eq!(second.count(), 2);
2600
2601 let completion = producer.next_outbound(4, usize::MAX).unwrap().unwrap();
2602 assert!(matches!(
2603 completion,
2604 StreamRefOutbound::Frame(StreamRefFrame {
2605 message: StreamRefMessage::RemoteStreamCompleted { seq_nr: 6 },
2606 ..
2607 })
2608 ));
2609 }
2610
2611 #[test]
2612 fn consumer_batch_ingress_preserves_order() {
2613 let id = StreamRefId::from_u128(21);
2614 let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
2615 let probe = consumer
2616 .source()
2617 .run_with(crate::testkit::TestSink::probe())
2618 .unwrap();
2619 probe.request(3);
2620 consumer.next_frame().unwrap().unwrap();
2621 consumer.next_frame().unwrap().unwrap();
2622
2623 let payloads = [10_u64, 11, 12]
2624 .into_iter()
2625 .map(u64::encode_stream_ref_payload)
2626 .collect::<Vec<_>>();
2627 let payload_slices = payloads.iter().map(Vec::as_slice).collect::<Vec<&[u8]>>();
2628 consumer
2629 .handle_sequenced_on_next_batch(id, 0, &payload_slices)
2630 .unwrap();
2631
2632 assert_eq!(probe.expect_next(), 10);
2633 assert_eq!(probe.expect_next(), 11);
2634 assert_eq!(probe.expect_next(), 12);
2635 }
2636
2637 #[test]
2638 fn consumer_batch_sequence_gap_uses_shared_failure_path() {
2639 let id = StreamRefId::from_u128(22);
2640 let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
2641 let source = consumer
2642 .source()
2643 .run_with(crate::testkit::TestSink::probe())
2644 .unwrap();
2645 source.request(1);
2646 consumer.next_frame().unwrap().unwrap();
2647 consumer.next_frame().unwrap().unwrap();
2648
2649 let payload = 1_u64.encode_stream_ref_payload();
2650 consumer
2651 .handle_sequenced_on_next_batch(id, 1, &[payload.as_slice()])
2652 .unwrap();
2653
2654 let outbound = consumer.next_frame().unwrap().unwrap();
2655 assert!(matches!(
2656 outbound.message,
2657 StreamRefMessage::RemoteStreamFailure { .. }
2658 ));
2659 assert!(matches!(source.expect_error(), StreamError::Failed(_)));
2660 }
2661
2662 #[test]
2663 fn producer_times_out_without_first_demand() {
2664 let producer = StreamRefProtoProducer::from_source(
2665 Source::repeat(1_u64),
2666 StreamRefId::from_u128(3),
2667 short_settings(),
2668 )
2669 .unwrap();
2670
2671 let error = producer.next_frame().unwrap().unwrap_err();
2672 assert!(matches!(error, StreamError::Failed(message) if message.contains("first demand")));
2673 }
2674
2675 #[test]
2676 fn demand_redelivery_is_not_required_by_reliable_carriers() {
2677 assert_eq!(
2682 StreamRefSettings::default().demand_redelivery_interval(),
2683 Duration::from_secs(1)
2684 );
2685 }
2686}