Skip to main content

datum/actor/
stream_ref_proto.rs

1//! Protobuf StreamRefs protocol and transport-agnostic endpoint drivers.
2//!
3//! This module mirrors Akka's StreamRefs protocol without embedding an actor
4//! transport. A carrier feeds inbound protobuf frames into one endpoint and
5//! drains outbound frames from the same endpoint. `datum-net` uses this seam to
6//! carry StreamRefs over a reliable, ordered QUIC bidirectional stream.
7
8use std::{
9    any::Any,
10    collections::VecDeque,
11    fmt,
12    sync::{
13        Arc, Condvar, Mutex, MutexGuard,
14        atomic::{AtomicU64, Ordering},
15    },
16    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
17};
18
19use crate::stream::{
20    BoxStream, Materializer, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult,
21    TerminalSinkConsumerDyn, TerminalSourceHookDyn, TerminalSourceStatus,
22};
23use futures::channel::oneshot;
24use prost::Message as ProstMessage;
25use tokio::sync::mpsc as tokio_mpsc;
26
27use super::{SourceRef, StreamRefSettings};
28
29static STREAM_REF_PROTO_ID: AtomicU64 = AtomicU64::new(1);
30
31/// Element payload codec used by the protobuf StreamRefs transport seam.
32///
33/// The built-in impls use the same big-endian primitive encodings as Ractor's
34/// `BytesConvertable`, but the trait is owned by `datum-core` so the protobuf
35/// seam does not depend on the `cluster` feature.
36pub trait StreamRefPayload: Send + 'static {
37    fn encode_stream_ref_payload(self) -> Vec<u8>;
38
39    fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>)
40    where
41        Self: Sized,
42    {
43        bytes.extend(self.encode_stream_ref_payload());
44    }
45
46    fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self>
47    where
48        Self: Sized;
49
50    fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self>
51    where
52        Self: Sized,
53    {
54        Self::decode_stream_ref_payload(bytes.to_vec())
55    }
56}
57
58macro_rules! impl_stream_ref_payload_numeric {
59    ($($ty:ty),* $(,)?) => {
60        $(
61            impl StreamRefPayload for $ty {
62                fn encode_stream_ref_payload(self) -> Vec<u8> {
63                    self.to_be_bytes().to_vec()
64                }
65
66                fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>) {
67                    bytes.extend(self.to_be_bytes());
68                }
69
70                fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
71                    let data: [u8; std::mem::size_of::<Self>()] =
72                        bytes.as_slice().try_into().map_err(|_| {
73                            StreamError::Failed(format!(
74                                "invalid {} stream ref payload length: {}",
75                                stringify!($ty),
76                                bytes.len()
77                            ))
78                        })?;
79                    Ok(Self::from_be_bytes(data))
80                }
81
82                fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self> {
83                    let data: [u8; std::mem::size_of::<Self>()] =
84                        bytes.try_into().map_err(|_| {
85                            StreamError::Failed(format!(
86                                "invalid {} stream ref payload length: {}",
87                                stringify!($ty),
88                                bytes.len()
89                            ))
90                        })?;
91                    Ok(Self::from_be_bytes(data))
92                }
93            }
94        )*
95    };
96}
97
98impl_stream_ref_payload_numeric!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, f32, f64);
99
100impl StreamRefPayload for bool {
101    fn encode_stream_ref_payload(self) -> Vec<u8> {
102        vec![u8::from(self)]
103    }
104
105    fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>) {
106        bytes.push(u8::from(self));
107    }
108
109    fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
110        match bytes.as_slice() {
111            [0] => Ok(false),
112            [1] => Ok(true),
113            _ => Err(StreamError::Failed(
114                "invalid bool stream ref payload".to_owned(),
115            )),
116        }
117    }
118
119    fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self> {
120        match bytes {
121            [0] => Ok(false),
122            [1] => Ok(true),
123            _ => Err(StreamError::Failed(
124                "invalid bool stream ref payload".to_owned(),
125            )),
126        }
127    }
128}
129
130impl StreamRefPayload for String {
131    fn encode_stream_ref_payload(self) -> Vec<u8> {
132        self.into_bytes()
133    }
134
135    fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>) {
136        bytes.extend(self.into_bytes());
137    }
138
139    fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
140        String::from_utf8(bytes)
141            .map_err(|error| StreamError::Failed(format!("invalid UTF-8 payload: {error}")))
142    }
143
144    fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self> {
145        String::from_utf8(bytes.to_vec())
146            .map_err(|error| StreamError::Failed(format!("invalid UTF-8 payload: {error}")))
147    }
148}
149
150impl StreamRefPayload for Vec<u8> {
151    fn encode_stream_ref_payload(self) -> Vec<u8> {
152        self
153    }
154
155    fn encode_stream_ref_payload_into(self, bytes: &mut Vec<u8>) {
156        bytes.extend(self);
157    }
158
159    fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
160        Ok(bytes)
161    }
162
163    fn decode_stream_ref_payload_slice(bytes: &[u8]) -> StreamResult<Self> {
164        Ok(bytes.to_vec())
165    }
166}
167
168/// Stream-ref identifier scoped to one transport connection.
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
170pub struct StreamRefId(u128);
171
172impl StreamRefId {
173    /// Generates a process-local id suitable for a new connection-scoped ref.
174    #[must_use]
175    pub fn new() -> Self {
176        let sequence = STREAM_REF_PROTO_ID.fetch_add(1, Ordering::Relaxed) as u128;
177        let timestamp = SystemTime::now()
178            .duration_since(UNIX_EPOCH)
179            .map(|duration| duration.as_nanos())
180            .unwrap_or_default();
181        let pid = std::process::id() as u128;
182        Self(timestamp ^ (pid << 32) ^ sequence)
183    }
184
185    /// Constructs an id from a stable numeric value, primarily for tests and
186    /// single-stream carriers that reserve a well-known id.
187    #[must_use]
188    pub const fn from_u128(value: u128) -> Self {
189        Self(value)
190    }
191
192    #[must_use]
193    pub const fn as_u128(self) -> u128 {
194        self.0
195    }
196
197    #[must_use]
198    pub fn to_bytes(self) -> [u8; 16] {
199        self.0.to_be_bytes()
200    }
201
202    pub fn from_bytes(bytes: &[u8]) -> StreamResult<Self> {
203        let value: [u8; 16] = bytes.try_into().map_err(|_| {
204            StreamError::Failed("stream ref id must be exactly 16 bytes".to_owned())
205        })?;
206        Ok(Self(u128::from_be_bytes(value)))
207    }
208}
209
210impl Default for StreamRefId {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216impl fmt::Display for StreamRefId {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        write!(f, "{:032x}", self.0)
219    }
220}
221
222/// Protobuf payload wrapper matching Akka's `Payload`, without Akka serializer
223/// ids or manifests.
224#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct StreamRefPayloadBytes {
226    pub bytes: Vec<u8>,
227}
228
229/// Transport-agnostic StreamRefs protocol messages.
230#[derive(Debug, Clone, PartialEq, Eq)]
231pub enum StreamRefMessage {
232    OnSubscribeHandshake,
233    CumulativeDemand {
234        seq_nr: u64,
235    },
236    SequencedOnNext {
237        seq_nr: u64,
238        payload: StreamRefPayloadBytes,
239    },
240    RemoteStreamCompleted {
241        seq_nr: u64,
242    },
243    RemoteStreamFailure {
244        cause: Vec<u8>,
245    },
246    Ack,
247}
248
249impl StreamRefMessage {
250    #[must_use]
251    pub fn failure_text(&self) -> Option<String> {
252        match self {
253            Self::RemoteStreamFailure { cause } => {
254                Some(String::from_utf8_lossy(cause).into_owned())
255            }
256            _ => None,
257        }
258    }
259
260    fn is_ack(&self) -> bool {
261        matches!(self, Self::Ack)
262    }
263}
264
265/// One protobuf frame tagged with a connection-scoped stream-ref id.
266#[derive(Debug, Clone, PartialEq, Eq)]
267pub struct StreamRefFrame {
268    pub stream_ref_id: StreamRefId,
269    pub message: StreamRefMessage,
270}
271
272impl StreamRefFrame {
273    #[must_use]
274    pub fn new(stream_ref_id: StreamRefId, message: StreamRefMessage) -> Self {
275        Self {
276            stream_ref_id,
277            message,
278        }
279    }
280
281    #[must_use]
282    pub fn encode_to_vec(&self) -> Vec<u8> {
283        self.to_wire().encode_to_vec()
284    }
285
286    pub fn decode(bytes: &[u8]) -> StreamResult<Self> {
287        Self::from_wire(WireStreamRefFrame::decode(bytes).map_err(|error| {
288            StreamError::Failed(format!("invalid stream ref protobuf frame: {error}"))
289        })?)
290    }
291
292    fn to_wire(&self) -> WireStreamRefFrame {
293        WireStreamRefFrame {
294            stream_ref_id: self.stream_ref_id.to_bytes().to_vec(),
295            message: Some(match &self.message {
296                StreamRefMessage::OnSubscribeHandshake => {
297                    wire_stream_ref_frame::Message::OnSubscribeHandshake(
298                        WireOnSubscribeHandshake {},
299                    )
300                }
301                StreamRefMessage::CumulativeDemand { seq_nr } => {
302                    wire_stream_ref_frame::Message::CumulativeDemand(WireCumulativeDemand {
303                        seq_nr: *seq_nr,
304                    })
305                }
306                StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
307                    wire_stream_ref_frame::Message::SequencedOnNext(WireSequencedOnNext {
308                        seq_nr: *seq_nr,
309                        payload: Some(WirePayload {
310                            enclosed_message: payload.bytes.clone(),
311                        }),
312                    })
313                }
314                StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
315                    wire_stream_ref_frame::Message::RemoteStreamCompleted(
316                        WireRemoteStreamCompleted { seq_nr: *seq_nr },
317                    )
318                }
319                StreamRefMessage::RemoteStreamFailure { cause } => {
320                    wire_stream_ref_frame::Message::RemoteStreamFailure(WireRemoteStreamFailure {
321                        cause: cause.clone(),
322                    })
323                }
324                StreamRefMessage::Ack => wire_stream_ref_frame::Message::Ack(WireAck {}),
325            }),
326        }
327    }
328
329    fn from_wire(wire: WireStreamRefFrame) -> StreamResult<Self> {
330        let stream_ref_id = StreamRefId::from_bytes(&wire.stream_ref_id)?;
331        let message = match wire.message.ok_or_else(|| {
332            StreamError::Failed("stream ref protobuf frame has no message".to_owned())
333        })? {
334            wire_stream_ref_frame::Message::OnSubscribeHandshake(_) => {
335                StreamRefMessage::OnSubscribeHandshake
336            }
337            wire_stream_ref_frame::Message::CumulativeDemand(message) => {
338                StreamRefMessage::CumulativeDemand {
339                    seq_nr: message.seq_nr,
340                }
341            }
342            wire_stream_ref_frame::Message::SequencedOnNext(message) => {
343                let payload = message.payload.ok_or_else(|| {
344                    StreamError::Failed("SequencedOnNext missing payload".to_owned())
345                })?;
346                StreamRefMessage::SequencedOnNext {
347                    seq_nr: message.seq_nr,
348                    payload: StreamRefPayloadBytes {
349                        bytes: payload.enclosed_message,
350                    },
351                }
352            }
353            wire_stream_ref_frame::Message::RemoteStreamCompleted(message) => {
354                StreamRefMessage::RemoteStreamCompleted {
355                    seq_nr: message.seq_nr,
356                }
357            }
358            wire_stream_ref_frame::Message::RemoteStreamFailure(message) => {
359                StreamRefMessage::RemoteStreamFailure {
360                    cause: message.cause,
361                }
362            }
363            wire_stream_ref_frame::Message::Ack(_) => StreamRefMessage::Ack,
364        };
365        Ok(Self {
366            stream_ref_id,
367            message,
368        })
369    }
370}
371
372#[derive(Clone, PartialEq, ProstMessage)]
373struct WireStreamRefFrame {
374    #[prost(bytes = "vec", tag = "1")]
375    stream_ref_id: Vec<u8>,
376    #[prost(oneof = "wire_stream_ref_frame::Message", tags = "2, 3, 4, 5, 6, 7")]
377    message: Option<wire_stream_ref_frame::Message>,
378}
379
380mod wire_stream_ref_frame {
381    #[derive(Clone, PartialEq, prost::Oneof)]
382    pub enum Message {
383        #[prost(message, tag = "2")]
384        OnSubscribeHandshake(super::WireOnSubscribeHandshake),
385        #[prost(message, tag = "3")]
386        CumulativeDemand(super::WireCumulativeDemand),
387        #[prost(message, tag = "4")]
388        SequencedOnNext(super::WireSequencedOnNext),
389        #[prost(message, tag = "5")]
390        RemoteStreamCompleted(super::WireRemoteStreamCompleted),
391        #[prost(message, tag = "6")]
392        RemoteStreamFailure(super::WireRemoteStreamFailure),
393        #[prost(message, tag = "7")]
394        Ack(super::WireAck),
395    }
396}
397
398#[derive(Clone, PartialEq, ProstMessage)]
399struct WirePayload {
400    #[prost(bytes = "vec", tag = "1")]
401    enclosed_message: Vec<u8>,
402}
403
404#[derive(Clone, PartialEq, ProstMessage)]
405struct WireOnSubscribeHandshake {}
406
407#[derive(Clone, PartialEq, ProstMessage)]
408struct WireCumulativeDemand {
409    #[prost(uint64, tag = "1")]
410    seq_nr: u64,
411}
412
413#[derive(Clone, PartialEq, ProstMessage)]
414struct WireSequencedOnNext {
415    #[prost(uint64, tag = "1")]
416    seq_nr: u64,
417    #[prost(message, optional, tag = "2")]
418    payload: Option<WirePayload>,
419}
420
421#[derive(Clone, PartialEq, ProstMessage)]
422struct WireRemoteStreamFailure {
423    #[prost(bytes = "vec", tag = "1")]
424    cause: Vec<u8>,
425}
426
427#[derive(Clone, PartialEq, ProstMessage)]
428struct WireRemoteStreamCompleted {
429    #[prost(uint64, tag = "1")]
430    seq_nr: u64,
431}
432
433#[derive(Clone, PartialEq, ProstMessage)]
434struct WireAck {}
435
436#[derive(Debug, Clone)]
437struct StreamRefPayloadSegment {
438    offset: usize,
439    len: usize,
440}
441
442/// A contiguous batch of already-encoded `SequencedOnNext` payloads.
443#[derive(Debug, Clone)]
444pub struct StreamRefPayloadBatch {
445    stream_ref_id: StreamRefId,
446    first_seq_nr: u64,
447    payloads: Vec<u8>,
448    segments: Vec<StreamRefPayloadSegment>,
449}
450
451impl StreamRefPayloadBatch {
452    #[must_use]
453    pub fn new(stream_ref_id: StreamRefId, first_seq_nr: u64) -> Self {
454        Self {
455            stream_ref_id,
456            first_seq_nr,
457            payloads: Vec::new(),
458            segments: Vec::new(),
459        }
460    }
461
462    #[must_use]
463    pub fn stream_ref_id(&self) -> StreamRefId {
464        self.stream_ref_id
465    }
466
467    #[must_use]
468    pub fn first_seq_nr(&self) -> u64 {
469        self.first_seq_nr
470    }
471
472    #[must_use]
473    pub fn count(&self) -> usize {
474        self.segments.len()
475    }
476
477    #[must_use]
478    pub fn is_empty(&self) -> bool {
479        self.segments.is_empty()
480    }
481
482    #[must_use]
483    pub fn payload_len(&self, index: usize) -> usize {
484        self.segments[index].len
485    }
486
487    #[must_use]
488    pub fn payload(&self, index: usize) -> &[u8] {
489        let segment = &self.segments[index];
490        &self.payloads[segment.offset..segment.offset + segment.len]
491    }
492
493    pub fn push_payload<T>(&mut self, item: T) -> StreamResult<()>
494    where
495        T: StreamRefPayload,
496    {
497        let offset = self.payloads.len();
498        item.encode_stream_ref_payload_into(&mut self.payloads);
499        let len = self.payloads.len().saturating_sub(offset);
500        if len > u32::MAX as usize {
501            return Err(StreamError::LimitExceeded {
502                max: u32::MAX as u64,
503            });
504        }
505        self.segments.push(StreamRefPayloadSegment { offset, len });
506        Ok(())
507    }
508
509    fn into_single_frame(self) -> StreamResult<StreamRefFrame> {
510        if self.count() != 1 {
511            return Err(StreamError::Failed(
512                "stream ref batch cannot be converted into a single frame".to_owned(),
513            ));
514        }
515        Ok(StreamRefFrame::new(
516            self.stream_ref_id,
517            StreamRefMessage::SequencedOnNext {
518                seq_nr: self.first_seq_nr,
519                payload: StreamRefPayloadBytes {
520                    bytes: self.payload(0).to_vec(),
521                },
522            },
523        ))
524    }
525}
526
527/// An outbound protocol item. Control remains a protobuf frame; steady data can
528/// move as one compact batch without first inflating one frame per element.
529#[derive(Debug, Clone)]
530pub enum StreamRefOutbound {
531    Frame(StreamRefFrame),
532    SequencedBatch(StreamRefPayloadBatch),
533}
534
535impl StreamRefOutbound {
536    fn into_single_frame(self) -> StreamResult<StreamRefFrame> {
537        match self {
538            Self::Frame(frame) => Ok(frame),
539            Self::SequencedBatch(batch) => batch.into_single_frame(),
540        }
541    }
542}
543
544/// Common interface for external carriers that pump protobuf frames.
545pub trait StreamRefProtoEndpoint: Clone + Send + Sync + 'static {
546    fn stream_ref_id(&self) -> StreamRefId;
547    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>>;
548    fn next_outbound(
549        &self,
550        max_data_elements: usize,
551        _max_data_bytes: usize,
552    ) -> Option<StreamResult<StreamRefOutbound>> {
553        let _ = max_data_elements;
554        self.next_frame()
555            .map(|frame| frame.map(StreamRefOutbound::Frame))
556    }
557    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()>;
558    fn handle_sequenced_on_next_batch(
559        &self,
560        stream_ref_id: StreamRefId,
561        first_seq_nr: u64,
562        payloads: &[&[u8]],
563    ) -> StreamResult<()> {
564        for (index, payload) in payloads.iter().enumerate() {
565            let seq_nr = first_seq_nr.checked_add(index as u64).ok_or_else(|| {
566                StreamError::Failed("stream ref batch sequence overflow".to_owned())
567            })?;
568            self.handle_frame(StreamRefFrame::new(
569                stream_ref_id,
570                StreamRefMessage::SequencedOnNext {
571                    seq_nr,
572                    payload: StreamRefPayloadBytes {
573                        bytes: payload.to_vec(),
574                    },
575                },
576            ))?;
577        }
578        Ok(())
579    }
580    fn fail_connection(&self, error: StreamError);
581}
582
583/// Immediate outbound poll result for non-blocking carriers.
584#[doc(hidden)]
585pub enum StreamRefOutboundPoll {
586    Ready(StreamResult<StreamRefOutbound>),
587    Pending,
588    Closed,
589}
590
591/// Internal bridge used by single-owner async carriers.
592#[doc(hidden)]
593pub trait StreamRefProtoEndpointWake: StreamRefProtoEndpoint {
594    fn install_outbound_wake(&self, sender: tokio_mpsc::Sender<()>);
595    fn clear_outbound_wake(&self);
596    fn try_next_outbound(
597        &self,
598        max_data_elements: usize,
599        max_data_bytes: usize,
600    ) -> StreamRefOutboundPoll;
601}
602
603#[derive(Default)]
604struct OutboundWake {
605    sender: Mutex<Option<tokio_mpsc::Sender<()>>>,
606}
607
608impl OutboundWake {
609    fn install(&self, sender: tokio_mpsc::Sender<()>) {
610        *self
611            .sender
612            .lock()
613            .unwrap_or_else(|poison| poison.into_inner()) = Some(sender);
614    }
615
616    fn clear(&self) {
617        *self
618            .sender
619            .lock()
620            .unwrap_or_else(|poison| poison.into_inner()) = None;
621    }
622
623    fn wake(&self) {
624        let sender = self
625            .sender
626            .lock()
627            .unwrap_or_else(|poison| poison.into_inner())
628            .clone();
629        if let Some(sender) = sender {
630            let _ = sender.try_send(());
631        }
632    }
633}
634
635/// Producer-side endpoint for a local `SourceRef` or `Source`.
636///
637/// Feed inbound `OnSubscribeHandshake`, `CumulativeDemand`, cancellation, and
638/// Ack frames with [`StreamRefProtoEndpoint::handle_frame`]. Drain outbound
639/// `SequencedOnNext`, completion, failure, and Ack frames with
640/// [`StreamRefProtoEndpoint::next_frame`].
641pub struct StreamRefProtoProducer<T>
642where
643    T: StreamRefPayload,
644{
645    shared: Arc<ProducerShared<T>>,
646}
647
648impl<T> Clone for StreamRefProtoProducer<T>
649where
650    T: StreamRefPayload,
651{
652    fn clone(&self) -> Self {
653        Self {
654            shared: Arc::clone(&self.shared),
655        }
656    }
657}
658
659impl<T> StreamRefProtoProducer<T>
660where
661    T: StreamRefPayload,
662{
663    pub fn from_source_ref(
664        source_ref: SourceRef<T>,
665        stream_ref_id: StreamRefId,
666        settings: StreamRefSettings,
667    ) -> StreamResult<Self> {
668        Self::from_source(
669            super::stream_ref::proto_source(&source_ref),
670            stream_ref_id,
671            settings,
672        )
673    }
674
675    pub fn from_source<Mat>(
676        source: Source<T, Mat>,
677        stream_ref_id: StreamRefId,
678        settings: StreamRefSettings,
679    ) -> StreamResult<Self>
680    where
681        Mat: Send + 'static,
682    {
683        let materializer = Materializer::new();
684        let (input, materialized) = Arc::clone(&source.factory).create(&materializer)?;
685        Ok(Self {
686            shared: Arc::new(ProducerShared {
687                stream_ref_id,
688                settings,
689                input: Mutex::new(Some(input)),
690                state: Mutex::new(ProducerState {
691                    partner_seen: false,
692                    cumulative_demand: 0,
693                    first_demand_deadline: None,
694                    sent: 0,
695                    terminal_sent: false,
696                    waiting_for_ack: false,
697                    ack_deadline: None,
698                    stopped: None,
699                    ack_queued: false,
700                    pending_terminal: None,
701                    done: false,
702                    input_attached: true,
703                    terminal_result: None,
704                }),
705                changed: Condvar::new(),
706                outbound_wake: OutboundWake::default(),
707                completion: Mutex::new(None),
708                _materializer: materializer,
709                _materialized: Mutex::new(Some(Box::new(materialized))),
710            }),
711        })
712    }
713
714    /// Creates a producer with no attached input, for the SinkRef sender side.
715    ///
716    /// The input stream is attached later by materializing the [`Sink`]
717    /// returned from [`StreamRefProtoProducer::sink`]. Until then, the
718    /// producer's `next_frame` waits on the condvar instead of spinning, so an
719    /// idle lazy producer does not busy-loop while the remote consumer has not
720    /// yet subscribed or attached demand.
721    #[must_use]
722    pub fn new_lazy(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
723        Self {
724            shared: Arc::new(ProducerShared {
725                stream_ref_id,
726                settings,
727                input: Mutex::new(None),
728                state: Mutex::new(ProducerState {
729                    partner_seen: false,
730                    cumulative_demand: 0,
731                    first_demand_deadline: None,
732                    sent: 0,
733                    terminal_sent: false,
734                    waiting_for_ack: false,
735                    ack_deadline: None,
736                    stopped: None,
737                    ack_queued: false,
738                    pending_terminal: None,
739                    done: false,
740                    input_attached: false,
741                    terminal_result: None,
742                }),
743                changed: Condvar::new(),
744                outbound_wake: OutboundWake::default(),
745                completion: Mutex::new(None),
746                _materializer: Materializer::new(),
747                _materialized: Mutex::new(None),
748            }),
749        }
750    }
751
752    /// Returns a [`Sink`] whose incoming elements are framed and sent as this
753    /// producer's outbound `SequencedOnNext` frames.
754    ///
755    /// Materializing the sink attaches the input stream to the lazy producer
756    /// and returns a [`StreamCompletion`] that resolves when the producer
757    /// reaches its terminal state (all elements sent and acknowledged, the
758    /// remote cancelled/failed, or the carrier failed).
759    #[must_use]
760    pub fn sink(&self) -> Sink<T, StreamCompletion<NotUsed>> {
761        let shared = Arc::clone(&self.shared);
762        Sink::from_runner(move |input, _materializer| {
763            let (sender, receiver) = oneshot::channel();
764            *shared
765                .completion
766                .lock()
767                .unwrap_or_else(|poison| poison.into_inner()) = Some(sender);
768            shared.attach_input(input);
769            Ok(StreamCompletion::from_receiver(receiver, None))
770        })
771    }
772}
773
774impl<T> StreamRefProtoEndpoint for StreamRefProtoProducer<T>
775where
776    T: StreamRefPayload,
777{
778    fn stream_ref_id(&self) -> StreamRefId {
779        self.shared.stream_ref_id
780    }
781
782    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
783        self.shared.next_frame()
784    }
785
786    fn next_outbound(
787        &self,
788        max_data_elements: usize,
789        max_data_bytes: usize,
790    ) -> Option<StreamResult<StreamRefOutbound>> {
791        self.shared.next_outbound(max_data_elements, max_data_bytes)
792    }
793
794    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
795        self.shared.handle_frame(frame)
796    }
797
798    fn fail_connection(&self, error: StreamError) {
799        self.shared.fail_connection(error);
800    }
801}
802
803impl<T> StreamRefProtoEndpointWake for StreamRefProtoProducer<T>
804where
805    T: StreamRefPayload,
806{
807    fn install_outbound_wake(&self, sender: tokio_mpsc::Sender<()>) {
808        self.shared.outbound_wake.install(sender);
809    }
810
811    fn clear_outbound_wake(&self) {
812        self.shared.outbound_wake.clear();
813    }
814
815    fn try_next_outbound(
816        &self,
817        max_data_elements: usize,
818        max_data_bytes: usize,
819    ) -> StreamRefOutboundPoll {
820        self.shared
821            .try_next_outbound(max_data_elements, max_data_bytes)
822    }
823}
824
825struct ProducerShared<T>
826where
827    T: StreamRefPayload,
828{
829    stream_ref_id: StreamRefId,
830    settings: StreamRefSettings,
831    input: Mutex<Option<BoxStream<T>>>,
832    state: Mutex<ProducerState>,
833    changed: Condvar,
834    outbound_wake: OutboundWake,
835    completion: Mutex<Option<oneshot::Sender<StreamResult<NotUsed>>>>,
836    _materializer: Materializer,
837    _materialized: Mutex<Option<Box<dyn Any + Send>>>,
838}
839
840struct ProducerState {
841    partner_seen: bool,
842    cumulative_demand: u64,
843    first_demand_deadline: Option<Instant>,
844    sent: u64,
845    terminal_sent: bool,
846    waiting_for_ack: bool,
847    ack_deadline: Option<Instant>,
848    stopped: Option<StreamError>,
849    ack_queued: bool,
850    pending_terminal: Option<StreamRefMessage>,
851    done: bool,
852    input_attached: bool,
853    terminal_result: Option<StreamResult<NotUsed>>,
854}
855
856enum ProducerBatchPoll {
857    Ready(StreamResult<StreamRefOutbound>),
858    Pending,
859    StateChanged,
860}
861
862enum InputItemPoll<T> {
863    Ready(StreamResult<T>),
864    Pending,
865    TerminalQueued,
866}
867
868impl<T> ProducerShared<T>
869where
870    T: StreamRefPayload,
871{
872    fn lock_state(&self) -> MutexGuard<'_, ProducerState> {
873        self.state
874            .lock()
875            .unwrap_or_else(|poison| poison.into_inner())
876    }
877
878    fn lock_input(&self) -> MutexGuard<'_, Option<BoxStream<T>>> {
879        self.input
880            .lock()
881            .unwrap_or_else(|poison| poison.into_inner())
882    }
883
884    fn notify_changed(&self) {
885        self.changed.notify_all();
886        self.outbound_wake.wake();
887    }
888
889    fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
890        StreamRefFrame::new(self.stream_ref_id, message)
891    }
892
893    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
894        self.next_outbound(1, usize::MAX)
895            .map(|outbound| outbound.and_then(StreamRefOutbound::into_single_frame))
896    }
897
898    fn next_outbound(
899        &self,
900        max_data_elements: usize,
901        max_data_bytes: usize,
902    ) -> Option<StreamResult<StreamRefOutbound>> {
903        let subscription_deadline = deadline_from_now(self.settings.subscription_timeout());
904        loop {
905            let mut state = self.lock_state();
906            if state.done {
907                return None;
908            }
909
910            if state.ack_queued {
911                state.ack_queued = false;
912                state.done = true;
913                state.terminal_result = Some(match state.stopped.clone() {
914                    Some(error) => Err(error),
915                    None => Ok(NotUsed),
916                });
917                self.notify_changed();
918                drop(state);
919                self.drop_input();
920                self.settle();
921                return Some(Ok(StreamRefOutbound::Frame(
922                    self.frame(StreamRefMessage::Ack),
923                )));
924            }
925
926            if let Some(message) = state.pending_terminal.take() {
927                drop(state);
928                return Some(Ok(StreamRefOutbound::Frame(self.frame(message))));
929            }
930
931            if state.waiting_for_ack {
932                if state
933                    .ack_deadline
934                    .is_some_and(|deadline| Instant::now() >= deadline)
935                {
936                    let timeout_error =
937                        subscription_timeout_error("stream ref producer terminal ack");
938                    state.done = true;
939                    state.terminal_result = Some(Err(timeout_error.clone()));
940                    self.notify_changed();
941                    drop(state);
942                    self.drop_input();
943                    self.settle();
944                    return Some(Err(timeout_error));
945                }
946                if let Some(remaining) = state
947                    .ack_deadline
948                    .and_then(|deadline| deadline.checked_duration_since(Instant::now()))
949                {
950                    let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
951                    drop(next);
952                } else {
953                    drop(state);
954                }
955                continue;
956            }
957
958            if let Some(error) = state.stopped.clone() {
959                state.done = true;
960                state.terminal_result = Some(Err(error.clone()));
961                self.notify_changed();
962                drop(state);
963                self.drop_input();
964                self.settle();
965                return Some(Err(error));
966            }
967
968            if state.cumulative_demand > 0 && state.sent < state.cumulative_demand {
969                drop(state);
970                if let Some(outbound) =
971                    self.pull_next_outbound_batch(max_data_elements.max(1), max_data_bytes.max(1))
972                {
973                    return Some(outbound);
974                }
975                continue;
976            }
977
978            if state.cumulative_demand == 0 && Instant::now() >= subscription_deadline {
979                let timeout_error = subscription_timeout_error("stream ref producer first demand");
980                state.done = true;
981                state.terminal_result = Some(Err(timeout_error.clone()));
982                self.notify_changed();
983                drop(state);
984                self.drop_input();
985                self.settle();
986                return Some(Err(timeout_error));
987            }
988
989            if state.cumulative_demand == 0 {
990                let remaining = subscription_deadline.saturating_duration_since(Instant::now());
991                if remaining.is_zero() {
992                    drop(state);
993                    continue;
994                }
995                let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
996                drop(next);
997            } else {
998                let next = wait_unpoison(&self.changed, state);
999                drop(next);
1000            }
1001        }
1002    }
1003
1004    fn try_next_outbound(
1005        &self,
1006        max_data_elements: usize,
1007        max_data_bytes: usize,
1008    ) -> StreamRefOutboundPoll {
1009        loop {
1010            let mut state = self.lock_state();
1011            if state.done {
1012                return StreamRefOutboundPoll::Closed;
1013            }
1014
1015            if state.ack_queued {
1016                state.ack_queued = false;
1017                state.done = true;
1018                state.terminal_result = Some(match state.stopped.clone() {
1019                    Some(error) => Err(error),
1020                    None => Ok(NotUsed),
1021                });
1022                self.notify_changed();
1023                drop(state);
1024                self.drop_input();
1025                self.settle();
1026                return StreamRefOutboundPoll::Ready(Ok(StreamRefOutbound::Frame(
1027                    self.frame(StreamRefMessage::Ack),
1028                )));
1029            }
1030
1031            if let Some(message) = state.pending_terminal.take() {
1032                drop(state);
1033                return StreamRefOutboundPoll::Ready(Ok(StreamRefOutbound::Frame(
1034                    self.frame(message),
1035                )));
1036            }
1037
1038            if state.waiting_for_ack {
1039                if state
1040                    .ack_deadline
1041                    .is_some_and(|deadline| Instant::now() >= deadline)
1042                {
1043                    let timeout_error =
1044                        subscription_timeout_error("stream ref producer terminal ack");
1045                    state.done = true;
1046                    state.terminal_result = Some(Err(timeout_error.clone()));
1047                    self.notify_changed();
1048                    drop(state);
1049                    self.drop_input();
1050                    self.settle();
1051                    return StreamRefOutboundPoll::Ready(Err(timeout_error));
1052                }
1053                return StreamRefOutboundPoll::Pending;
1054            }
1055
1056            if let Some(error) = state.stopped.clone() {
1057                state.done = true;
1058                state.terminal_result = Some(Err(error.clone()));
1059                self.notify_changed();
1060                drop(state);
1061                self.drop_input();
1062                self.settle();
1063                return StreamRefOutboundPoll::Ready(Err(error));
1064            }
1065
1066            if state.cumulative_demand > 0 && state.sent < state.cumulative_demand {
1067                drop(state);
1068                match self
1069                    .try_pull_next_outbound_batch(max_data_elements.max(1), max_data_bytes.max(1))
1070                {
1071                    ProducerBatchPoll::Ready(outbound) => {
1072                        return StreamRefOutboundPoll::Ready(outbound);
1073                    }
1074                    ProducerBatchPoll::StateChanged => continue,
1075                    ProducerBatchPoll::Pending => return StreamRefOutboundPoll::Pending,
1076                }
1077            }
1078
1079            if state.cumulative_demand == 0 {
1080                let deadline = *state
1081                    .first_demand_deadline
1082                    .get_or_insert_with(|| deadline_from_now(self.settings.subscription_timeout()));
1083                if Instant::now() >= deadline {
1084                    let timeout_error =
1085                        subscription_timeout_error("stream ref producer first demand");
1086                    state.done = true;
1087                    state.terminal_result = Some(Err(timeout_error.clone()));
1088                    self.notify_changed();
1089                    drop(state);
1090                    self.drop_input();
1091                    self.settle();
1092                    return StreamRefOutboundPoll::Ready(Err(timeout_error));
1093                }
1094            }
1095
1096            return StreamRefOutboundPoll::Pending;
1097        }
1098    }
1099
1100    fn pull_next_outbound_batch(
1101        &self,
1102        max_data_elements: usize,
1103        max_data_bytes: usize,
1104    ) -> Option<StreamResult<StreamRefOutbound>> {
1105        let mut batch: Option<StreamRefPayloadBatch> = None;
1106        while batch
1107            .as_ref()
1108            .is_none_or(|batch| batch.count() < max_data_elements)
1109        {
1110            if batch
1111                .as_ref()
1112                .is_some_and(|batch| batch.payloads.len() >= max_data_bytes)
1113            {
1114                break;
1115            }
1116
1117            let seq_nr = {
1118                let state = self.lock_state();
1119                if state.done
1120                    || state.stopped.is_some()
1121                    || state.waiting_for_ack
1122                    || state.sent >= state.cumulative_demand
1123                {
1124                    break;
1125                }
1126                state.sent
1127            };
1128
1129            let item = match self.next_input_item() {
1130                Some(item) => item,
1131                None => break,
1132            };
1133
1134            match item {
1135                Ok(item) => {
1136                    let mut state = self.lock_state();
1137                    if state.done
1138                        || state.stopped.is_some()
1139                        || state.waiting_for_ack
1140                        || state.sent != seq_nr
1141                    {
1142                        break;
1143                    }
1144                    state.sent = state.sent.saturating_add(1);
1145                    drop(state);
1146
1147                    let batch = batch.get_or_insert_with(|| {
1148                        StreamRefPayloadBatch::new(self.stream_ref_id, seq_nr)
1149                    });
1150                    if let Err(error) = batch.push_payload(item) {
1151                        return Some(Err(error));
1152                    }
1153                }
1154                Err(error) => {
1155                    let terminal = StreamRefMessage::RemoteStreamFailure {
1156                        cause: failure_cause(&error),
1157                    };
1158                    self.note_terminal(Err(error), terminal.clone());
1159                    return match batch {
1160                        Some(batch) if !batch.is_empty() => {
1161                            Some(Ok(StreamRefOutbound::SequencedBatch(batch)))
1162                        }
1163                        _ => None,
1164                    };
1165                }
1166            }
1167        }
1168
1169        batch
1170            .filter(|batch| !batch.is_empty())
1171            .map(|batch| Ok(StreamRefOutbound::SequencedBatch(batch)))
1172    }
1173
1174    fn try_pull_next_outbound_batch(
1175        &self,
1176        max_data_elements: usize,
1177        max_data_bytes: usize,
1178    ) -> ProducerBatchPoll {
1179        let mut batch: Option<StreamRefPayloadBatch> = None;
1180        while batch
1181            .as_ref()
1182            .is_none_or(|batch| batch.count() < max_data_elements)
1183        {
1184            if batch
1185                .as_ref()
1186                .is_some_and(|batch| batch.payloads.len() >= max_data_bytes)
1187            {
1188                break;
1189            }
1190
1191            let seq_nr = {
1192                let state = self.lock_state();
1193                if state.done
1194                    || state.stopped.is_some()
1195                    || state.waiting_for_ack
1196                    || state.sent >= state.cumulative_demand
1197                {
1198                    break;
1199                }
1200                if !state.input_attached {
1201                    return match batch {
1202                        Some(batch) if !batch.is_empty() => {
1203                            ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1204                        }
1205                        _ => ProducerBatchPoll::Pending,
1206                    };
1207                }
1208                state.sent
1209            };
1210
1211            let item = match self.try_next_input_item() {
1212                InputItemPoll::Ready(item) => item,
1213                InputItemPoll::Pending => {
1214                    return match batch {
1215                        Some(batch) if !batch.is_empty() => {
1216                            ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1217                        }
1218                        _ => ProducerBatchPoll::Pending,
1219                    };
1220                }
1221                InputItemPoll::TerminalQueued => {
1222                    return match batch {
1223                        Some(batch) if !batch.is_empty() => {
1224                            ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1225                        }
1226                        _ => ProducerBatchPoll::StateChanged,
1227                    };
1228                }
1229            };
1230
1231            match item {
1232                Ok(item) => {
1233                    let mut state = self.lock_state();
1234                    if state.done
1235                        || state.stopped.is_some()
1236                        || state.waiting_for_ack
1237                        || state.sent != seq_nr
1238                    {
1239                        break;
1240                    }
1241                    state.sent = state.sent.saturating_add(1);
1242                    drop(state);
1243
1244                    let batch = batch.get_or_insert_with(|| {
1245                        StreamRefPayloadBatch::new(self.stream_ref_id, seq_nr)
1246                    });
1247                    if let Err(error) = batch.push_payload(item) {
1248                        return ProducerBatchPoll::Ready(Err(error));
1249                    }
1250                }
1251                Err(error) => {
1252                    let terminal = StreamRefMessage::RemoteStreamFailure {
1253                        cause: failure_cause(&error),
1254                    };
1255                    self.note_terminal(Err(error), terminal);
1256                    return match batch {
1257                        Some(batch) if !batch.is_empty() => {
1258                            ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1259                        }
1260                        _ => ProducerBatchPoll::StateChanged,
1261                    };
1262                }
1263            }
1264        }
1265
1266        match batch {
1267            Some(batch) if !batch.is_empty() => {
1268                ProducerBatchPoll::Ready(Ok(StreamRefOutbound::SequencedBatch(batch)))
1269            }
1270            _ => ProducerBatchPoll::Pending,
1271        }
1272    }
1273
1274    fn next_input_item(&self) -> Option<StreamResult<T>> {
1275        let mut input_guard = self.lock_input();
1276        if input_guard.is_none() {
1277            drop(input_guard);
1278            // Lazy producer: no input attached yet. Block on the condvar
1279            // (notify-driven, no fixed poll) until `attach_input` or a
1280            // terminal state wakes us, then let the caller re-check demand and
1281            // terminal predicates.
1282            let mut state = self.lock_state();
1283            while !state.input_attached
1284                && !state.done
1285                && state.stopped.is_none()
1286                && !state.terminal_sent
1287            {
1288                state = wait_unpoison(&self.changed, state);
1289            }
1290            drop(state);
1291            return None;
1292        }
1293
1294        match input_guard.as_mut().expect("input attached").next() {
1295            Some(item) => Some(item),
1296            None => {
1297                drop(input_guard);
1298                let seq_nr = self.lock_state().sent;
1299                self.note_terminal(
1300                    Ok(NotUsed),
1301                    StreamRefMessage::RemoteStreamCompleted { seq_nr },
1302                );
1303                None
1304            }
1305        }
1306    }
1307
1308    fn try_next_input_item(&self) -> InputItemPoll<T> {
1309        let mut input_guard = self.lock_input();
1310        let Some(input) = input_guard.as_mut() else {
1311            return InputItemPoll::Pending;
1312        };
1313
1314        match input.next() {
1315            Some(item) => InputItemPoll::Ready(item),
1316            None => {
1317                drop(input_guard);
1318                let seq_nr = self.lock_state().sent;
1319                self.note_terminal(
1320                    Ok(NotUsed),
1321                    StreamRefMessage::RemoteStreamCompleted { seq_nr },
1322                );
1323                InputItemPoll::TerminalQueued
1324            }
1325        }
1326    }
1327
1328    fn note_terminal(&self, result: StreamResult<NotUsed>, terminal_message: StreamRefMessage) {
1329        self.drop_input();
1330        let mut state = self.lock_state();
1331        if state.done || state.terminal_sent {
1332            return;
1333        }
1334        state.terminal_sent = true;
1335        state.waiting_for_ack = true;
1336        state.terminal_result = Some(result);
1337        state.ack_deadline = Some(deadline_from_now(self.settings.subscription_timeout()));
1338        state.pending_terminal = Some(terminal_message);
1339        self.notify_changed();
1340        drop(state);
1341    }
1342
1343    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1344        self.validate_frame_id(frame.stream_ref_id)?;
1345        match frame.message {
1346            StreamRefMessage::OnSubscribeHandshake => {
1347                let mut state = self.lock_state();
1348                state.partner_seen = true;
1349                self.notify_changed();
1350                drop(state);
1351                Ok(())
1352            }
1353            StreamRefMessage::CumulativeDemand { seq_nr } => {
1354                if seq_nr == 0 {
1355                    return Err(StreamError::Failed(
1356                        "CumulativeDemand seq_nr must be positive".to_owned(),
1357                    ));
1358                }
1359                let mut state = self.lock_state();
1360                state.partner_seen = true;
1361                if seq_nr > state.cumulative_demand {
1362                    state.cumulative_demand = seq_nr;
1363                }
1364                self.notify_changed();
1365                drop(state);
1366                Ok(())
1367            }
1368            StreamRefMessage::RemoteStreamCompleted { .. } => {
1369                self.stop_from_consumer(StreamError::Cancelled);
1370                Ok(())
1371            }
1372            StreamRefMessage::RemoteStreamFailure { cause } => {
1373                self.stop_from_consumer(StreamError::Failed(
1374                    String::from_utf8_lossy(&cause).into_owned(),
1375                ));
1376                Ok(())
1377            }
1378            StreamRefMessage::Ack => {
1379                let mut state = self.lock_state();
1380                if state.waiting_for_ack {
1381                    state.waiting_for_ack = false;
1382                    state.done = true;
1383                    if state.terminal_result.is_none() {
1384                        state.terminal_result = Some(Ok(NotUsed));
1385                    }
1386                    self.notify_changed();
1387                    drop(state);
1388                    self.drop_input();
1389                    self.settle();
1390                } else {
1391                    drop(state);
1392                }
1393                Ok(())
1394            }
1395            StreamRefMessage::SequencedOnNext { .. } => Err(StreamError::Failed(
1396                "producer endpoint cannot receive SequencedOnNext".to_owned(),
1397            )),
1398        }
1399    }
1400
1401    fn stop_from_consumer(&self, error: StreamError) {
1402        let mut state = self.lock_state();
1403        if !state.done {
1404            state.stopped = Some(error.clone());
1405            state.ack_queued = true;
1406            state.terminal_result = Some(Err(error));
1407        }
1408        self.notify_changed();
1409        drop(state);
1410        self.drop_input();
1411    }
1412
1413    fn fail_connection(&self, error: StreamError) {
1414        let mut state = self.lock_state();
1415        if !state.done {
1416            state.stopped = Some(error.clone());
1417            state.done = true;
1418            state.terminal_result = Some(Err(error));
1419        }
1420        self.notify_changed();
1421        drop(state);
1422        self.drop_input();
1423        self.settle();
1424    }
1425
1426    fn attach_input(&self, input: BoxStream<T>) {
1427        *self.lock_input() = Some(input);
1428        let mut state = self.lock_state();
1429        state.input_attached = true;
1430        self.notify_changed();
1431        drop(state);
1432    }
1433
1434    fn settle(&self) {
1435        let result = self.lock_state().terminal_result.clone();
1436        let sender = self
1437            .completion
1438            .lock()
1439            .unwrap_or_else(|poison| poison.into_inner())
1440            .take();
1441        if let (Some(sender), Some(result)) = (sender, result) {
1442            let _ = sender.send(result);
1443        }
1444    }
1445
1446    fn drop_input(&self) {
1447        let input = self.lock_input().take();
1448        drop(input);
1449    }
1450
1451    fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
1452        if stream_ref_id == self.stream_ref_id {
1453            Ok(())
1454        } else {
1455            Err(StreamError::Failed(format!(
1456                "stream ref id mismatch: expected {}, got {}",
1457                self.stream_ref_id, stream_ref_id
1458            )))
1459        }
1460    }
1461}
1462
1463/// Consumer-side endpoint that exposes inbound remote elements as a local
1464/// [`Source`].
1465///
1466/// The source sends `OnSubscribeHandshake` and a cumulative demand ceiling when
1467/// it is materialized. It then refills demand at half the configured buffer.
1468pub struct StreamRefProtoConsumer<T>
1469where
1470    T: StreamRefPayload,
1471{
1472    shared: Arc<ConsumerShared<T>>,
1473}
1474
1475impl<T> Clone for StreamRefProtoConsumer<T>
1476where
1477    T: StreamRefPayload,
1478{
1479    fn clone(&self) -> Self {
1480        Self {
1481            shared: Arc::clone(&self.shared),
1482        }
1483    }
1484}
1485
1486impl<T> StreamRefProtoConsumer<T>
1487where
1488    T: StreamRefPayload,
1489{
1490    #[must_use]
1491    pub fn new(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
1492        Self {
1493            shared: Arc::new(ConsumerShared {
1494                stream_ref_id,
1495                settings,
1496                state: Mutex::new(ConsumerState {
1497                    source_taken: false,
1498                    subscribed: false,
1499                    queue: VecDeque::new(),
1500                    direct_terminal: false,
1501                    direct_consumer: None,
1502                    direct_cancelled: None,
1503                    terminal: None,
1504                    expected_seq: 0,
1505                    delivered: 0,
1506                    cumulative_demand: 0,
1507                    outbound: VecDeque::new(),
1508                    finish_after_outbound_ack: false,
1509                    waiting_cancel_ack: false,
1510                    done: false,
1511                }),
1512                changed: Condvar::new(),
1513                outbound_wake: OutboundWake::default(),
1514            }),
1515        }
1516    }
1517
1518    #[must_use]
1519    pub fn source(&self) -> Source<T, NotUsed> {
1520        let shared_for_stream = Arc::clone(&self.shared);
1521        let shared_for_terminal = Arc::clone(&self.shared);
1522        Source::from_terminal_direct_materialized_factory(
1523            move |_materializer| {
1524                shared_for_stream
1525                    .start_stream()
1526                    .map(|stream| (Box::new(stream) as BoxStream<T>, NotUsed))
1527            },
1528            move |_materializer| {
1529                Ok((
1530                    Arc::new(StreamRefProtoTerminalHook {
1531                        shared: Arc::clone(&shared_for_terminal),
1532                        stream: Mutex::new(None),
1533                    }) as Arc<dyn TerminalSourceHookDyn<T>>,
1534                    NotUsed,
1535                ))
1536            },
1537        )
1538    }
1539}
1540
1541impl<T> StreamRefProtoEndpoint for StreamRefProtoConsumer<T>
1542where
1543    T: StreamRefPayload,
1544{
1545    fn stream_ref_id(&self) -> StreamRefId {
1546        self.shared.stream_ref_id
1547    }
1548
1549    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1550        self.shared.next_frame()
1551    }
1552
1553    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1554        self.shared.handle_frame(frame)
1555    }
1556
1557    fn handle_sequenced_on_next_batch(
1558        &self,
1559        stream_ref_id: StreamRefId,
1560        first_seq_nr: u64,
1561        payloads: &[&[u8]],
1562    ) -> StreamResult<()> {
1563        self.shared
1564            .handle_sequenced_on_next_batch(stream_ref_id, first_seq_nr, payloads)
1565    }
1566
1567    fn fail_connection(&self, error: StreamError) {
1568        self.shared.fail_connection(error);
1569    }
1570}
1571
1572impl<T> StreamRefProtoEndpointWake for StreamRefProtoConsumer<T>
1573where
1574    T: StreamRefPayload,
1575{
1576    fn install_outbound_wake(&self, sender: tokio_mpsc::Sender<()>) {
1577        self.shared.outbound_wake.install(sender);
1578    }
1579
1580    fn clear_outbound_wake(&self) {
1581        self.shared.outbound_wake.clear();
1582    }
1583
1584    fn try_next_outbound(
1585        &self,
1586        _max_data_elements: usize,
1587        _max_data_bytes: usize,
1588    ) -> StreamRefOutboundPoll {
1589        self.shared.try_next_frame()
1590    }
1591}
1592
1593struct ConsumerShared<T>
1594where
1595    T: StreamRefPayload,
1596{
1597    stream_ref_id: StreamRefId,
1598    settings: StreamRefSettings,
1599    state: Mutex<ConsumerState<T>>,
1600    changed: Condvar,
1601    outbound_wake: OutboundWake,
1602}
1603
1604struct ConsumerState<T> {
1605    source_taken: bool,
1606    subscribed: bool,
1607    queue: VecDeque<T>,
1608    direct_terminal: bool,
1609    direct_consumer: Option<Box<dyn TerminalSinkConsumerDyn<T>>>,
1610    direct_cancelled: Option<Arc<std::sync::atomic::AtomicBool>>,
1611    terminal: Option<ConsumerTerminal>,
1612    expected_seq: u64,
1613    delivered: u64,
1614    cumulative_demand: u64,
1615    outbound: VecDeque<StreamRefMessage>,
1616    finish_after_outbound_ack: bool,
1617    waiting_cancel_ack: bool,
1618    done: bool,
1619}
1620
1621#[derive(Clone)]
1622enum ConsumerTerminal {
1623    Complete,
1624    Error(StreamError),
1625}
1626
1627struct StreamRefProtoTerminalHook<T>
1628where
1629    T: StreamRefPayload,
1630{
1631    shared: Arc<ConsumerShared<T>>,
1632    stream: Mutex<Option<ConsumerStream<T>>>,
1633}
1634
1635impl<T> TerminalSourceHookDyn<T> for StreamRefProtoTerminalHook<T>
1636where
1637    T: StreamRefPayload,
1638{
1639    fn drain_terminal_batch(
1640        &self,
1641        materializer: &Materializer,
1642        cancelled: &Arc<std::sync::atomic::AtomicBool>,
1643        batch: &mut Vec<T>,
1644    ) -> StreamResult<TerminalSourceStatus> {
1645        batch.clear();
1646        if materializer.is_shutdown() {
1647            self.shared.cancel_from_downstream();
1648            return Err(StreamError::AbruptTermination);
1649        }
1650        if cancelled.load(Ordering::SeqCst) {
1651            self.shared.cancel_from_downstream();
1652            return Err(StreamError::Cancelled);
1653        }
1654
1655        let mut stream = self
1656            .stream
1657            .lock()
1658            .unwrap_or_else(|poison| poison.into_inner());
1659        if stream.is_none() {
1660            *stream = Some(self.shared.start_stream()?);
1661        }
1662        let stream = stream.as_mut().expect("terminal stream present");
1663        for _ in 0..self.shared.settings.buffer_capacity().max(1) {
1664            match stream.next_item()? {
1665                Some(item) => batch.push(item),
1666                None => return Ok(TerminalSourceStatus::Completed),
1667            }
1668            if batch.len() >= 64 {
1669                break;
1670            }
1671        }
1672        Ok(TerminalSourceStatus::Active)
1673    }
1674
1675    fn supports_direct_terminal(&self) -> bool {
1676        true
1677    }
1678
1679    fn try_register_direct_terminal(
1680        &self,
1681        consumer: Box<dyn TerminalSinkConsumerDyn<T>>,
1682        cancelled: Arc<std::sync::atomic::AtomicBool>,
1683    ) -> Option<StreamResult<()>> {
1684        Some(self.shared.start_direct_terminal(consumer, cancelled))
1685    }
1686
1687    fn cancel_terminal(&self) {
1688        self.shared.cancel_from_downstream();
1689    }
1690}
1691
1692impl<T> ConsumerShared<T>
1693where
1694    T: StreamRefPayload,
1695{
1696    fn lock_state(&self) -> MutexGuard<'_, ConsumerState<T>> {
1697        self.state
1698            .lock()
1699            .unwrap_or_else(|poison| poison.into_inner())
1700    }
1701
1702    fn notify_changed(&self) {
1703        self.changed.notify_all();
1704        self.outbound_wake.wake();
1705    }
1706
1707    fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
1708        StreamRefFrame::new(self.stream_ref_id, message)
1709    }
1710
1711    fn start_stream(self: &Arc<Self>) -> StreamResult<ConsumerStream<T>> {
1712        self.start_subscription()?;
1713        Ok(ConsumerStream {
1714            shared: Arc::clone(self),
1715            terminated: false,
1716        })
1717    }
1718
1719    fn start_subscription(self: &Arc<Self>) -> StreamResult<()> {
1720        {
1721            let mut state = self.lock_state();
1722            if state.source_taken {
1723                return Err(StreamError::Failed(
1724                    "stream ref source has already been materialized".to_owned(),
1725                ));
1726            }
1727            state.source_taken = true;
1728            if !state.subscribed {
1729                state.subscribed = true;
1730                state
1731                    .outbound
1732                    .push_back(StreamRefMessage::OnSubscribeHandshake);
1733                if let Some(demand) = next_demand(&mut state, self.settings) {
1734                    state
1735                        .outbound
1736                        .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
1737                }
1738            }
1739            self.notify_changed();
1740        }
1741        Ok(())
1742    }
1743
1744    fn start_direct_terminal(
1745        self: &Arc<Self>,
1746        consumer: Box<dyn TerminalSinkConsumerDyn<T>>,
1747        cancelled: Arc<std::sync::atomic::AtomicBool>,
1748    ) -> StreamResult<()> {
1749        let mut finish = None;
1750        {
1751            let mut state = self.lock_state();
1752            if state.source_taken {
1753                return Err(StreamError::Failed(
1754                    "stream ref source has already been materialized".to_owned(),
1755                ));
1756            }
1757            state.source_taken = true;
1758            state.direct_terminal = true;
1759            state.direct_cancelled = Some(cancelled);
1760            state.direct_consumer = Some(consumer);
1761            if let Some(terminal) = state.terminal.clone() {
1762                finish = state
1763                    .direct_consumer
1764                    .take()
1765                    .map(|consumer| (consumer, terminal_result(terminal)));
1766            } else if state.done {
1767                finish = state
1768                    .direct_consumer
1769                    .take()
1770                    .map(|consumer| (consumer, Err(StreamError::AbruptTermination)));
1771            } else if !state.subscribed {
1772                state.subscribed = true;
1773                state
1774                    .outbound
1775                    .push_back(StreamRefMessage::OnSubscribeHandshake);
1776                if let Some(demand) = next_demand(&mut state, self.settings) {
1777                    state
1778                        .outbound
1779                        .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
1780                }
1781            }
1782            self.notify_changed();
1783        }
1784        if let Some((consumer, result)) = finish {
1785            consumer.finish(result);
1786        }
1787        Ok(())
1788    }
1789
1790    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1791        loop {
1792            let mut state = self.lock_state();
1793            if let Some(message) = state.outbound.pop_front() {
1794                let finish_after_ack = message.is_ack() && state.finish_after_outbound_ack;
1795                if finish_after_ack {
1796                    state.done = true;
1797                }
1798                drop(state);
1799                return Some(Ok(self.frame(message)));
1800            }
1801            if state.done {
1802                return None;
1803            }
1804            let next = wait_unpoison(&self.changed, state);
1805            drop(next);
1806        }
1807    }
1808
1809    fn try_next_frame(&self) -> StreamRefOutboundPoll {
1810        let mut state = self.lock_state();
1811        if let Some(message) = state.outbound.pop_front() {
1812            let finish_after_ack = message.is_ack() && state.finish_after_outbound_ack;
1813            if finish_after_ack {
1814                state.done = true;
1815            }
1816            drop(state);
1817            StreamRefOutboundPoll::Ready(Ok(StreamRefOutbound::Frame(self.frame(message))))
1818        } else if state.done {
1819            StreamRefOutboundPoll::Closed
1820        } else {
1821            StreamRefOutboundPoll::Pending
1822        }
1823    }
1824
1825    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1826        self.validate_frame_id(frame.stream_ref_id)?;
1827        match frame.message {
1828            StreamRefMessage::OnSubscribeHandshake => Ok(()),
1829            StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
1830                let item = T::decode_stream_ref_payload(payload.bytes)?;
1831                let demand_ceiling = self.lock_state().cumulative_demand;
1832                self.handle_decoded_on_next(seq_nr, item, demand_ceiling)
1833            }
1834            StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
1835                let finish = self.handle_remote_completed(seq_nr);
1836                if let Some((consumer, result)) = finish {
1837                    consumer.finish(result);
1838                }
1839                Ok(())
1840            }
1841            StreamRefMessage::RemoteStreamFailure { cause } => {
1842                let error = StreamError::Failed(String::from_utf8_lossy(&cause).into_owned());
1843                let finish = self.handle_remote_failure(error);
1844                if let Some((consumer, result)) = finish {
1845                    consumer.finish(result);
1846                }
1847                Ok(())
1848            }
1849            StreamRefMessage::Ack => {
1850                let mut state = self.lock_state();
1851                if state.waiting_cancel_ack {
1852                    state.waiting_cancel_ack = false;
1853                    state.done = true;
1854                }
1855                self.notify_changed();
1856                drop(state);
1857                Ok(())
1858            }
1859            StreamRefMessage::CumulativeDemand { .. } => Err(StreamError::Failed(
1860                "consumer endpoint cannot receive CumulativeDemand".to_owned(),
1861            )),
1862        }
1863    }
1864
1865    fn handle_sequenced_on_next_batch(
1866        &self,
1867        stream_ref_id: StreamRefId,
1868        first_seq_nr: u64,
1869        payloads: &[&[u8]],
1870    ) -> StreamResult<()> {
1871        self.validate_frame_id(stream_ref_id)?;
1872        if self.lock_state().direct_terminal {
1873            return self.handle_direct_sequenced_on_next_batch(first_seq_nr, payloads);
1874        }
1875        let mut state = self.lock_state();
1876        for (index, payload) in payloads.iter().enumerate() {
1877            if state.terminal.is_some() || state.done {
1878                break;
1879            }
1880            let seq_nr = first_seq_nr.checked_add(index as u64).ok_or_else(|| {
1881                StreamError::Failed("stream ref batch sequence overflow".to_owned())
1882            })?;
1883            let item = T::decode_stream_ref_payload_slice(payload)?;
1884            self.handle_decoded_on_next_locked(&mut state, seq_nr, item);
1885        }
1886        self.notify_changed();
1887        drop(state);
1888        Ok(())
1889    }
1890
1891    fn handle_direct_sequenced_on_next_batch(
1892        &self,
1893        first_seq_nr: u64,
1894        payloads: &[&[u8]],
1895    ) -> StreamResult<()> {
1896        if payloads.is_empty() {
1897            return Ok(());
1898        }
1899        let count = payloads.len() as u64;
1900        let last_seq_nr = first_seq_nr
1901            .checked_add(count - 1)
1902            .ok_or_else(|| StreamError::Failed("stream ref batch sequence overflow".to_owned()))?;
1903
1904        let mut consumer = {
1905            let mut state = self.lock_state();
1906            if state.terminal.is_some() || state.done {
1907                return Ok(());
1908            }
1909            if state
1910                .direct_cancelled
1911                .as_ref()
1912                .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
1913            {
1914                let error = StreamError::Cancelled;
1915                self.fail_consumer_locked(&mut state, error.clone());
1916                let consumer = state.direct_consumer.take();
1917                self.notify_changed();
1918                drop(state);
1919                if let Some(consumer) = consumer {
1920                    consumer.finish(Err(error));
1921                }
1922                return Ok(());
1923            }
1924
1925            let error = if first_seq_nr != state.expected_seq {
1926                Some(invalid_sequence_error(
1927                    state.expected_seq,
1928                    first_seq_nr,
1929                    "stream ref element",
1930                ))
1931            } else if last_seq_nr >= state.cumulative_demand {
1932                Some(StreamError::Failed(
1933                    "stream ref receive buffer overflowed demand window".to_owned(),
1934                ))
1935            } else {
1936                None
1937            };
1938            if let Some(error) = error {
1939                self.fail_consumer_locked(&mut state, error.clone());
1940                let consumer = state.direct_consumer.take();
1941                self.notify_changed();
1942                drop(state);
1943                if let Some(consumer) = consumer {
1944                    consumer.finish(Err(error));
1945                }
1946                return Ok(());
1947            }
1948
1949            state.expected_seq = state.expected_seq.saturating_add(count);
1950            match state.direct_consumer.take() {
1951                Some(consumer) => consumer,
1952                None => return Ok(()),
1953            }
1954        };
1955
1956        let mut consumed = 0_u64;
1957        let mut consume_result = Ok(());
1958        for payload in payloads {
1959            match T::decode_stream_ref_payload_slice(payload)
1960                .and_then(|item| consumer.on_item(item))
1961            {
1962                Ok(()) => consumed = consumed.saturating_add(1),
1963                Err(error) => {
1964                    consume_result = Err(error);
1965                    break;
1966                }
1967            }
1968        }
1969
1970        let finish = {
1971            let mut state = self.lock_state();
1972            if let Some(terminal) = state.terminal.clone() {
1973                Some((consumer, terminal_result(terminal)))
1974            } else if state.done {
1975                Some((consumer, Err(StreamError::AbruptTermination)))
1976            } else {
1977                match consume_result {
1978                    Ok(()) => {
1979                        debug_assert_eq!(consumed, count);
1980                        for _ in 0..count {
1981                            state.delivered = state.delivered.saturating_add(1);
1982                            if let Some(demand) = next_demand(&mut state, self.settings) {
1983                                state
1984                                    .outbound
1985                                    .push_back(StreamRefMessage::CumulativeDemand {
1986                                        seq_nr: demand,
1987                                    });
1988                            }
1989                        }
1990                        state.direct_consumer = Some(consumer);
1991                        None
1992                    }
1993                    Err(error) => {
1994                        self.fail_consumer_locked(&mut state, error.clone());
1995                        Some((consumer, Err(error)))
1996                    }
1997                }
1998            }
1999        };
2000        self.notify_changed();
2001        if let Some((consumer, result)) = finish {
2002            consumer.finish(result);
2003        }
2004        Ok(())
2005    }
2006
2007    fn handle_decoded_on_next(
2008        &self,
2009        seq_nr: u64,
2010        item: T,
2011        demand_ceiling: u64,
2012    ) -> StreamResult<()> {
2013        if self.lock_state().direct_terminal {
2014            self.handle_direct_decoded_on_next(seq_nr, item, demand_ceiling);
2015        } else {
2016            let mut state = self.lock_state();
2017            self.handle_decoded_on_next_locked(&mut state, seq_nr, item);
2018            self.notify_changed();
2019            drop(state);
2020        }
2021        Ok(())
2022    }
2023
2024    fn handle_decoded_on_next_locked(&self, state: &mut ConsumerState<T>, seq_nr: u64, item: T) {
2025        if state.terminal.is_some() || state.done {
2026            return;
2027        }
2028        if seq_nr != state.expected_seq {
2029            let error = invalid_sequence_error(state.expected_seq, seq_nr, "stream ref element");
2030            self.fail_consumer_locked(state, error);
2031        } else if state.queue.len() >= self.settings.buffer_capacity() {
2032            self.fail_consumer_locked(
2033                state,
2034                StreamError::Failed(
2035                    "stream ref receive buffer overflowed demand window".to_owned(),
2036                ),
2037            );
2038        } else {
2039            state.expected_seq = state.expected_seq.saturating_add(1);
2040            state.queue.push_back(item);
2041        }
2042    }
2043
2044    fn handle_direct_decoded_on_next(&self, seq_nr: u64, item: T, demand_ceiling: u64) {
2045        let mut consumer = {
2046            let mut state = self.lock_state();
2047            if state.terminal.is_some() || state.done {
2048                return;
2049            }
2050            if state
2051                .direct_cancelled
2052                .as_ref()
2053                .is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
2054            {
2055                let error = StreamError::Cancelled;
2056                self.fail_consumer_locked(&mut state, error.clone());
2057                let consumer = state.direct_consumer.take();
2058                self.notify_changed();
2059                drop(state);
2060                if let Some(consumer) = consumer {
2061                    consumer.finish(Err(error));
2062                }
2063                return;
2064            }
2065            let error = if seq_nr != state.expected_seq {
2066                Some(invalid_sequence_error(
2067                    state.expected_seq,
2068                    seq_nr,
2069                    "stream ref element",
2070                ))
2071            } else if seq_nr >= demand_ceiling {
2072                Some(StreamError::Failed(
2073                    "stream ref receive buffer overflowed demand window".to_owned(),
2074                ))
2075            } else {
2076                None
2077            };
2078            if let Some(error) = error {
2079                self.fail_consumer_locked(&mut state, error.clone());
2080                let consumer = state.direct_consumer.take();
2081                self.notify_changed();
2082                drop(state);
2083                if let Some(consumer) = consumer {
2084                    consumer.finish(Err(error));
2085                }
2086                return;
2087            }
2088            state.expected_seq = state.expected_seq.saturating_add(1);
2089            match state.direct_consumer.take() {
2090                Some(consumer) => consumer,
2091                None => return,
2092            }
2093        };
2094
2095        let consume_result = consumer.on_item(item);
2096        let finish = {
2097            let mut state = self.lock_state();
2098            if let Some(terminal) = state.terminal.clone() {
2099                Some((consumer, terminal_result(terminal)))
2100            } else if state.done {
2101                Some((consumer, Err(StreamError::AbruptTermination)))
2102            } else {
2103                match consume_result {
2104                    Ok(()) => {
2105                        state.delivered = state.delivered.saturating_add(1);
2106                        if let Some(demand) = next_demand(&mut state, self.settings) {
2107                            state
2108                                .outbound
2109                                .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
2110                        }
2111                        state.direct_consumer = Some(consumer);
2112                        None
2113                    }
2114                    Err(error) => {
2115                        self.fail_consumer_locked(&mut state, error.clone());
2116                        Some((consumer, Err(error)))
2117                    }
2118                }
2119            }
2120        };
2121        self.notify_changed();
2122        if let Some((consumer, result)) = finish {
2123            consumer.finish(result);
2124        }
2125    }
2126
2127    fn handle_remote_completed(
2128        &self,
2129        seq_nr: u64,
2130    ) -> Option<(Box<dyn TerminalSinkConsumerDyn<T>>, StreamResult<()>)> {
2131        let mut finish = None;
2132        let mut state = self.lock_state();
2133        if state.terminal.is_none() && !state.done {
2134            if seq_nr != state.expected_seq {
2135                state.queue.clear();
2136                let error =
2137                    invalid_sequence_error(state.expected_seq, seq_nr, "stream ref completion");
2138                state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2139                finish = state
2140                    .direct_consumer
2141                    .take()
2142                    .map(|consumer| (consumer, Err(error)));
2143            } else {
2144                state.terminal = Some(ConsumerTerminal::Complete);
2145                finish = state
2146                    .direct_consumer
2147                    .take()
2148                    .map(|consumer| (consumer, Ok(())));
2149            }
2150            state.outbound.push_back(StreamRefMessage::Ack);
2151            state.finish_after_outbound_ack = true;
2152        }
2153        self.notify_changed();
2154        drop(state);
2155        finish
2156    }
2157
2158    fn handle_remote_failure(
2159        &self,
2160        error: StreamError,
2161    ) -> Option<(Box<dyn TerminalSinkConsumerDyn<T>>, StreamResult<()>)> {
2162        let mut state = self.lock_state();
2163        let mut finish = None;
2164        if state.terminal.is_none() && !state.done {
2165            state.queue.clear();
2166            state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2167            finish = state
2168                .direct_consumer
2169                .take()
2170                .map(|consumer| (consumer, Err(error)));
2171            state.outbound.push_back(StreamRefMessage::Ack);
2172            state.finish_after_outbound_ack = true;
2173        }
2174        self.notify_changed();
2175        drop(state);
2176        finish
2177    }
2178
2179    fn fail_consumer_locked(&self, state: &mut ConsumerState<T>, error: StreamError) {
2180        state.queue.clear();
2181        state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2182        state
2183            .outbound
2184            .push_back(StreamRefMessage::RemoteStreamFailure {
2185                cause: failure_cause(&error),
2186            });
2187        state.waiting_cancel_ack = true;
2188    }
2189
2190    fn cancel_from_downstream(&self) {
2191        let mut finish = None;
2192        let mut state = self.lock_state();
2193        if state.terminal.is_none() && !state.done {
2194            let seq_nr = state.expected_seq;
2195            let error = StreamError::Cancelled;
2196            state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2197            finish = state
2198                .direct_consumer
2199                .take()
2200                .map(|consumer| (consumer, Err(error)));
2201            state
2202                .outbound
2203                .push_back(StreamRefMessage::RemoteStreamCompleted { seq_nr });
2204            state.waiting_cancel_ack = true;
2205        }
2206        self.notify_changed();
2207        drop(state);
2208        if let Some((consumer, result)) = finish {
2209            consumer.finish(result);
2210        }
2211    }
2212
2213    fn fail_connection(&self, error: StreamError) {
2214        let mut finish = None;
2215        let mut state = self.lock_state();
2216        if state.terminal.is_none() {
2217            state.queue.clear();
2218            state.terminal = Some(ConsumerTerminal::Error(error.clone()));
2219            finish = state
2220                .direct_consumer
2221                .take()
2222                .map(|consumer| (consumer, Err(error.clone())));
2223        }
2224        state.done = true;
2225        self.notify_changed();
2226        drop(state);
2227        if let Some((consumer, result)) = finish {
2228            consumer.finish(result);
2229        }
2230    }
2231
2232    fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
2233        if stream_ref_id == self.stream_ref_id {
2234            Ok(())
2235        } else {
2236            Err(StreamError::Failed(format!(
2237                "stream ref id mismatch: expected {}, got {}",
2238                self.stream_ref_id, stream_ref_id
2239            )))
2240        }
2241    }
2242}
2243
2244struct ConsumerStream<T>
2245where
2246    T: StreamRefPayload,
2247{
2248    shared: Arc<ConsumerShared<T>>,
2249    terminated: bool,
2250}
2251
2252impl<T> ConsumerStream<T>
2253where
2254    T: StreamRefPayload,
2255{
2256    fn next_item(&mut self) -> StreamResult<Option<T>> {
2257        if self.terminated {
2258            return Ok(None);
2259        }
2260        loop {
2261            let mut state = self.shared.lock_state();
2262            if let Some(item) = state.queue.pop_front() {
2263                state.delivered = state.delivered.saturating_add(1);
2264                if let Some(demand) = next_demand(&mut state, self.shared.settings) {
2265                    state
2266                        .outbound
2267                        .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
2268                    self.shared.notify_changed();
2269                }
2270                return Ok(Some(item));
2271            }
2272
2273            if let Some(terminal) = state.terminal.clone() {
2274                self.terminated = true;
2275                return match terminal {
2276                    ConsumerTerminal::Complete => Ok(None),
2277                    ConsumerTerminal::Error(error) => Err(error),
2278                };
2279            }
2280
2281            let next = wait_unpoison(&self.shared.changed, state);
2282            drop(next);
2283        }
2284    }
2285
2286    fn close(&mut self) {
2287        if !self.terminated {
2288            self.shared.cancel_from_downstream();
2289            self.terminated = true;
2290        }
2291    }
2292}
2293
2294impl<T> Iterator for ConsumerStream<T>
2295where
2296    T: StreamRefPayload,
2297{
2298    type Item = StreamResult<T>;
2299
2300    fn next(&mut self) -> Option<Self::Item> {
2301        match self.next_item() {
2302            Ok(Some(item)) => Some(Ok(item)),
2303            Ok(None) => None,
2304            Err(error) => Some(Err(error)),
2305        }
2306    }
2307}
2308
2309impl<T> Drop for ConsumerStream<T>
2310where
2311    T: StreamRefPayload,
2312{
2313    fn drop(&mut self) {
2314        self.close();
2315    }
2316}
2317
2318fn next_demand<T>(state: &mut ConsumerState<T>, settings: StreamRefSettings) -> Option<u64> {
2319    // Akka redelivers CumulativeDemand because some remoting carriers are
2320    // lossy. This seam is driven by reliable ordered carriers such as one QUIC
2321    // bidirectional stream, so a larger cumulative ceiling is sent once.
2322    if state.terminal.is_some() {
2323        return None;
2324    }
2325    let remaining_credit = state.cumulative_demand.saturating_sub(state.delivered);
2326    if state.cumulative_demand != 0 && remaining_credit > demand_replenish_threshold(settings) {
2327        return None;
2328    }
2329    let target = state
2330        .delivered
2331        .saturating_add(settings.buffer_capacity() as u64);
2332    if state.cumulative_demand >= target {
2333        return None;
2334    }
2335    state.cumulative_demand = target;
2336    Some(target)
2337}
2338
2339fn demand_replenish_threshold(settings: StreamRefSettings) -> u64 {
2340    (settings.buffer_capacity() as u64) / 2
2341}
2342
2343fn terminal_result(terminal: ConsumerTerminal) -> StreamResult<()> {
2344    match terminal {
2345        ConsumerTerminal::Complete => Ok(()),
2346        ConsumerTerminal::Error(error) => Err(error),
2347    }
2348}
2349
2350fn failure_cause(error: &StreamError) -> Vec<u8> {
2351    match error {
2352        StreamError::Failed(message) => message.clone().into_bytes(),
2353        other => other.to_string().into_bytes(),
2354    }
2355}
2356
2357fn subscription_timeout_error(side: &str) -> StreamError {
2358    StreamError::Failed(format!(
2359        "{side} remote side did not subscribe within subscription timeout"
2360    ))
2361}
2362
2363fn invalid_sequence_error(expected: u64, got: u64, context: &str) -> StreamError {
2364    StreamError::Failed(format!(
2365        "{context} sequence gap: expected sequence {expected}, got {got}"
2366    ))
2367}
2368
2369fn deadline_from_now(timeout: Duration) -> Instant {
2370    Instant::now()
2371        .checked_add(timeout)
2372        .unwrap_or_else(far_future)
2373}
2374
2375fn far_future() -> Instant {
2376    Instant::now() + Duration::from_secs(60 * 60 * 24 * 365)
2377}
2378
2379fn wait_timeout_unpoison<'a, T>(
2380    condvar: &Condvar,
2381    guard: MutexGuard<'a, T>,
2382    timeout: Duration,
2383) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
2384    condvar
2385        .wait_timeout(guard, timeout)
2386        .unwrap_or_else(|poison| poison.into_inner())
2387}
2388
2389fn wait_unpoison<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
2390    condvar
2391        .wait(guard)
2392        .unwrap_or_else(|poison| poison.into_inner())
2393}
2394
2395#[cfg(test)]
2396mod tests {
2397    use std::time::Duration;
2398
2399    use super::*;
2400    use crate::{Sink, Source, StreamRefs};
2401
2402    fn short_settings() -> StreamRefSettings {
2403        StreamRefSettings::default()
2404            .with_buffer_capacity(4)
2405            .with_subscription_timeout(Duration::from_millis(50))
2406    }
2407
2408    #[test]
2409    fn protobuf_frame_round_trip() {
2410        let frame = StreamRefFrame::new(
2411            StreamRefId::from_u128(42),
2412            StreamRefMessage::SequencedOnNext {
2413                seq_nr: 7,
2414                payload: StreamRefPayloadBytes {
2415                    bytes: 99_u64.encode_stream_ref_payload(),
2416                },
2417            },
2418        );
2419
2420        let decoded = StreamRefFrame::decode(&frame.encode_to_vec()).unwrap();
2421        assert_eq!(decoded, frame);
2422    }
2423
2424    #[test]
2425    fn producer_consumer_seam_streams_with_low_watermark_demand() {
2426        let id = StreamRefId::from_u128(1);
2427        let settings = short_settings();
2428        let source_ref = Source::from_iter(0_u64..10)
2429            .run_with(StreamRefs::source_ref_with_settings(settings))
2430            .unwrap();
2431        let producer = StreamRefProtoProducer::from_source_ref(source_ref, id, settings).unwrap();
2432        let consumer = StreamRefProtoConsumer::<u64>::new(id, settings);
2433        let consumer_source = consumer.source();
2434
2435        let producer_thread = std::thread::spawn({
2436            let producer = producer.clone();
2437            let consumer = consumer.clone();
2438            move || {
2439                while let Some(frame) = producer.next_frame() {
2440                    consumer.handle_frame(frame?)?;
2441                }
2442                Ok::<_, StreamError>(())
2443            }
2444        });
2445        let consumer_thread = std::thread::spawn({
2446            let producer = producer.clone();
2447            let consumer = consumer.clone();
2448            move || {
2449                while let Some(frame) = consumer.next_frame() {
2450                    producer.handle_frame(frame?)?;
2451                }
2452                Ok::<_, StreamError>(())
2453            }
2454        });
2455
2456        assert_eq!(
2457            consumer_source.run_collect().unwrap(),
2458            (0_u64..10).collect::<Vec<_>>()
2459        );
2460        producer_thread.join().unwrap().unwrap();
2461        consumer_thread.join().unwrap().unwrap();
2462    }
2463
2464    #[test]
2465    fn strict_sequence_gap_fails_consumer_and_sends_failure() {
2466        let id = StreamRefId::from_u128(2);
2467        let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
2468        let source = consumer
2469            .source()
2470            .run_with(crate::testkit::TestSink::probe())
2471            .unwrap();
2472        source.request(1);
2473        consumer.next_frame().unwrap().unwrap();
2474        consumer.next_frame().unwrap().unwrap();
2475
2476        consumer
2477            .handle_frame(StreamRefFrame::new(
2478                id,
2479                StreamRefMessage::SequencedOnNext {
2480                    seq_nr: 1,
2481                    payload: StreamRefPayloadBytes {
2482                        bytes: 1_u64.encode_stream_ref_payload(),
2483                    },
2484                },
2485            ))
2486            .unwrap();
2487
2488        let outbound = consumer.next_frame().unwrap().unwrap();
2489        assert!(matches!(
2490            outbound.message,
2491            StreamRefMessage::RemoteStreamFailure { .. }
2492        ));
2493        assert!(matches!(source.expect_error(), StreamError::Failed(_)));
2494    }
2495
2496    #[test]
2497    fn direct_terminal_sequence_gap_uses_shared_failure_path() {
2498        let id = StreamRefId::from_u128(23);
2499        let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
2500        let completion = consumer
2501            .source()
2502            .run_with(Sink::fold(0_u64, |acc, item| acc + item))
2503            .unwrap();
2504        consumer.next_frame().unwrap().unwrap();
2505        consumer.next_frame().unwrap().unwrap();
2506
2507        consumer
2508            .handle_frame(StreamRefFrame::new(
2509                id,
2510                StreamRefMessage::SequencedOnNext {
2511                    seq_nr: 1,
2512                    payload: StreamRefPayloadBytes {
2513                        bytes: 1_u64.encode_stream_ref_payload(),
2514                    },
2515                },
2516            ))
2517            .unwrap();
2518
2519        let outbound = consumer.next_frame().unwrap().unwrap();
2520        assert!(matches!(
2521            outbound.message,
2522            StreamRefMessage::RemoteStreamFailure { .. }
2523        ));
2524        assert!(
2525            matches!(completion.wait(), Err(StreamError::Failed(message)) if message.contains("sequence gap"))
2526        );
2527    }
2528
2529    #[test]
2530    fn direct_terminal_batch_over_demand_ceiling_fails_consumer() {
2531        let id = StreamRefId::from_u128(24);
2532        let settings = short_settings();
2533        let consumer = StreamRefProtoConsumer::<u64>::new(id, settings);
2534        let completion = consumer
2535            .source()
2536            .run_with(Sink::fold(0_u64, |acc, item| acc + item))
2537            .unwrap();
2538        consumer.next_frame().unwrap().unwrap();
2539        consumer.next_frame().unwrap().unwrap();
2540
2541        let payloads = (0_u64..=settings.buffer_capacity() as u64)
2542            .map(u64::encode_stream_ref_payload)
2543            .collect::<Vec<_>>();
2544        let payload_slices = payloads.iter().map(Vec::as_slice).collect::<Vec<&[u8]>>();
2545        consumer
2546            .handle_sequenced_on_next_batch(id, 0, &payload_slices)
2547            .unwrap();
2548
2549        let mut saw_failure = false;
2550        for _ in 0..4 {
2551            let Some(outbound) = consumer.next_frame() else {
2552                break;
2553            };
2554            if matches!(
2555                outbound.unwrap().message,
2556                StreamRefMessage::RemoteStreamFailure { .. }
2557            ) {
2558                saw_failure = true;
2559                break;
2560            }
2561        }
2562        assert!(saw_failure);
2563        assert!(
2564            matches!(completion.wait(), Err(StreamError::Failed(message)) if message.contains("demand window"))
2565        );
2566    }
2567
2568    #[test]
2569    fn producer_batches_ready_elements_and_preserves_completion_order() {
2570        let id = StreamRefId::from_u128(20);
2571        let settings = StreamRefSettings::default().with_buffer_capacity(8);
2572        let producer =
2573            StreamRefProtoProducer::from_source(Source::from_iter(0_u64..6), id, settings).unwrap();
2574        producer
2575            .handle_frame(StreamRefFrame::new(
2576                id,
2577                StreamRefMessage::CumulativeDemand { seq_nr: 8 },
2578            ))
2579            .unwrap();
2580
2581        let first = producer.next_outbound(4, usize::MAX).unwrap().unwrap();
2582        let StreamRefOutbound::SequencedBatch(first) = first else {
2583            panic!("expected first data batch");
2584        };
2585        assert_eq!(first.first_seq_nr(), 0);
2586        assert_eq!(first.count(), 4);
2587        for index in 0..first.count() {
2588            assert_eq!(
2589                u64::decode_stream_ref_payload_slice(first.payload(index)).unwrap(),
2590                index as u64
2591            );
2592        }
2593
2594        let second = producer.next_outbound(4, usize::MAX).unwrap().unwrap();
2595        let StreamRefOutbound::SequencedBatch(second) = second else {
2596            panic!("expected second data batch");
2597        };
2598        assert_eq!(second.first_seq_nr(), 4);
2599        assert_eq!(second.count(), 2);
2600
2601        let completion = producer.next_outbound(4, usize::MAX).unwrap().unwrap();
2602        assert!(matches!(
2603            completion,
2604            StreamRefOutbound::Frame(StreamRefFrame {
2605                message: StreamRefMessage::RemoteStreamCompleted { seq_nr: 6 },
2606                ..
2607            })
2608        ));
2609    }
2610
2611    #[test]
2612    fn consumer_batch_ingress_preserves_order() {
2613        let id = StreamRefId::from_u128(21);
2614        let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
2615        let probe = consumer
2616            .source()
2617            .run_with(crate::testkit::TestSink::probe())
2618            .unwrap();
2619        probe.request(3);
2620        consumer.next_frame().unwrap().unwrap();
2621        consumer.next_frame().unwrap().unwrap();
2622
2623        let payloads = [10_u64, 11, 12]
2624            .into_iter()
2625            .map(u64::encode_stream_ref_payload)
2626            .collect::<Vec<_>>();
2627        let payload_slices = payloads.iter().map(Vec::as_slice).collect::<Vec<&[u8]>>();
2628        consumer
2629            .handle_sequenced_on_next_batch(id, 0, &payload_slices)
2630            .unwrap();
2631
2632        assert_eq!(probe.expect_next(), 10);
2633        assert_eq!(probe.expect_next(), 11);
2634        assert_eq!(probe.expect_next(), 12);
2635    }
2636
2637    #[test]
2638    fn consumer_batch_sequence_gap_uses_shared_failure_path() {
2639        let id = StreamRefId::from_u128(22);
2640        let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
2641        let source = consumer
2642            .source()
2643            .run_with(crate::testkit::TestSink::probe())
2644            .unwrap();
2645        source.request(1);
2646        consumer.next_frame().unwrap().unwrap();
2647        consumer.next_frame().unwrap().unwrap();
2648
2649        let payload = 1_u64.encode_stream_ref_payload();
2650        consumer
2651            .handle_sequenced_on_next_batch(id, 1, &[payload.as_slice()])
2652            .unwrap();
2653
2654        let outbound = consumer.next_frame().unwrap().unwrap();
2655        assert!(matches!(
2656            outbound.message,
2657            StreamRefMessage::RemoteStreamFailure { .. }
2658        ));
2659        assert!(matches!(source.expect_error(), StreamError::Failed(_)));
2660    }
2661
2662    #[test]
2663    fn producer_times_out_without_first_demand() {
2664        let producer = StreamRefProtoProducer::from_source(
2665            Source::repeat(1_u64),
2666            StreamRefId::from_u128(3),
2667            short_settings(),
2668        )
2669        .unwrap();
2670
2671        let error = producer.next_frame().unwrap().unwrap_err();
2672        assert!(matches!(error, StreamError::Failed(message) if message.contains("first demand")));
2673    }
2674
2675    #[test]
2676    fn demand_redelivery_is_not_required_by_reliable_carriers() {
2677        // Akka redelivers CumulativeDemand because Artery/Aeron may lose messages.
2678        // The protobuf seam is intended for reliable ordered carriers such as a
2679        // single QUIC bidirectional stream, so each cumulative ceiling is sent
2680        // once and remains valid until a larger ceiling replaces it.
2681        assert_eq!(
2682            StreamRefSettings::default().demand_redelivery_interval(),
2683            Duration::from_secs(1)
2684        );
2685    }
2686}