Skip to main content

datum/actor/
stream_ref_proto.rs

1//! Protobuf StreamRefs protocol and transport-agnostic endpoint drivers.
2//!
3//! This module mirrors Akka's StreamRefs protocol without embedding an actor
4//! transport. A carrier feeds inbound protobuf frames into one endpoint and
5//! drains outbound frames from the same endpoint. `datum-net` uses this seam to
6//! carry StreamRefs over a reliable, ordered QUIC bidirectional stream.
7
8use std::{
9    any::Any,
10    collections::VecDeque,
11    fmt,
12    sync::{
13        Arc, Condvar, Mutex, MutexGuard,
14        atomic::{AtomicU64, Ordering},
15    },
16    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
17};
18
19use crate::stream::{
20    BoxStream, Materializer, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult,
21};
22use futures::channel::oneshot;
23use prost::Message as ProstMessage;
24
25use super::{SourceRef, StreamRefSettings};
26
27static STREAM_REF_PROTO_ID: AtomicU64 = AtomicU64::new(1);
28
29/// Element payload codec used by the protobuf StreamRefs transport seam.
30///
31/// The built-in impls use the same big-endian primitive encodings as Ractor's
32/// `BytesConvertable`, but the trait is owned by `datum-core` so the protobuf
33/// seam does not depend on the `cluster` feature.
34pub trait StreamRefPayload: Send + 'static {
35    fn encode_stream_ref_payload(self) -> Vec<u8>;
36
37    fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self>
38    where
39        Self: Sized;
40}
41
42macro_rules! impl_stream_ref_payload_numeric {
43    ($($ty:ty),* $(,)?) => {
44        $(
45            impl StreamRefPayload for $ty {
46                fn encode_stream_ref_payload(self) -> Vec<u8> {
47                    self.to_be_bytes().to_vec()
48                }
49
50                fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
51                    let data: [u8; std::mem::size_of::<Self>()] =
52                        bytes.as_slice().try_into().map_err(|_| {
53                            StreamError::Failed(format!(
54                                "invalid {} stream ref payload length: {}",
55                                stringify!($ty),
56                                bytes.len()
57                            ))
58                        })?;
59                    Ok(Self::from_be_bytes(data))
60                }
61            }
62        )*
63    };
64}
65
66impl_stream_ref_payload_numeric!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, f32, f64);
67
68impl StreamRefPayload for bool {
69    fn encode_stream_ref_payload(self) -> Vec<u8> {
70        vec![u8::from(self)]
71    }
72
73    fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
74        match bytes.as_slice() {
75            [0] => Ok(false),
76            [1] => Ok(true),
77            _ => Err(StreamError::Failed(
78                "invalid bool stream ref payload".to_owned(),
79            )),
80        }
81    }
82}
83
84impl StreamRefPayload for String {
85    fn encode_stream_ref_payload(self) -> Vec<u8> {
86        self.into_bytes()
87    }
88
89    fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
90        String::from_utf8(bytes)
91            .map_err(|error| StreamError::Failed(format!("invalid UTF-8 payload: {error}")))
92    }
93}
94
95impl StreamRefPayload for Vec<u8> {
96    fn encode_stream_ref_payload(self) -> Vec<u8> {
97        self
98    }
99
100    fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
101        Ok(bytes)
102    }
103}
104
105/// Stream-ref identifier scoped to one transport connection.
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
107pub struct StreamRefId(u128);
108
109impl StreamRefId {
110    /// Generates a process-local id suitable for a new connection-scoped ref.
111    #[must_use]
112    pub fn new() -> Self {
113        let sequence = STREAM_REF_PROTO_ID.fetch_add(1, Ordering::Relaxed) as u128;
114        let timestamp = SystemTime::now()
115            .duration_since(UNIX_EPOCH)
116            .map(|duration| duration.as_nanos())
117            .unwrap_or_default();
118        let pid = std::process::id() as u128;
119        Self(timestamp ^ (pid << 32) ^ sequence)
120    }
121
122    /// Constructs an id from a stable numeric value, primarily for tests and
123    /// single-stream carriers that reserve a well-known id.
124    #[must_use]
125    pub const fn from_u128(value: u128) -> Self {
126        Self(value)
127    }
128
129    #[must_use]
130    pub const fn as_u128(self) -> u128 {
131        self.0
132    }
133
134    #[must_use]
135    pub fn to_bytes(self) -> [u8; 16] {
136        self.0.to_be_bytes()
137    }
138
139    pub fn from_bytes(bytes: &[u8]) -> StreamResult<Self> {
140        let value: [u8; 16] = bytes.try_into().map_err(|_| {
141            StreamError::Failed("stream ref id must be exactly 16 bytes".to_owned())
142        })?;
143        Ok(Self(u128::from_be_bytes(value)))
144    }
145}
146
147impl Default for StreamRefId {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153impl fmt::Display for StreamRefId {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        write!(f, "{:032x}", self.0)
156    }
157}
158
159/// Protobuf payload wrapper matching Akka's `Payload`, without Akka serializer
160/// ids or manifests.
161#[derive(Debug, Clone, PartialEq, Eq)]
162pub struct StreamRefPayloadBytes {
163    pub bytes: Vec<u8>,
164}
165
166/// Transport-agnostic StreamRefs protocol messages.
167#[derive(Debug, Clone, PartialEq, Eq)]
168pub enum StreamRefMessage {
169    OnSubscribeHandshake,
170    CumulativeDemand {
171        seq_nr: u64,
172    },
173    SequencedOnNext {
174        seq_nr: u64,
175        payload: StreamRefPayloadBytes,
176    },
177    RemoteStreamCompleted {
178        seq_nr: u64,
179    },
180    RemoteStreamFailure {
181        cause: Vec<u8>,
182    },
183    Ack,
184}
185
186impl StreamRefMessage {
187    #[must_use]
188    pub fn failure_text(&self) -> Option<String> {
189        match self {
190            Self::RemoteStreamFailure { cause } => {
191                Some(String::from_utf8_lossy(cause).into_owned())
192            }
193            _ => None,
194        }
195    }
196
197    fn is_ack(&self) -> bool {
198        matches!(self, Self::Ack)
199    }
200}
201
202/// One protobuf frame tagged with a connection-scoped stream-ref id.
203#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct StreamRefFrame {
205    pub stream_ref_id: StreamRefId,
206    pub message: StreamRefMessage,
207}
208
209impl StreamRefFrame {
210    #[must_use]
211    pub fn new(stream_ref_id: StreamRefId, message: StreamRefMessage) -> Self {
212        Self {
213            stream_ref_id,
214            message,
215        }
216    }
217
218    #[must_use]
219    pub fn encode_to_vec(&self) -> Vec<u8> {
220        self.to_wire().encode_to_vec()
221    }
222
223    pub fn decode(bytes: &[u8]) -> StreamResult<Self> {
224        Self::from_wire(WireStreamRefFrame::decode(bytes).map_err(|error| {
225            StreamError::Failed(format!("invalid stream ref protobuf frame: {error}"))
226        })?)
227    }
228
229    fn to_wire(&self) -> WireStreamRefFrame {
230        WireStreamRefFrame {
231            stream_ref_id: self.stream_ref_id.to_bytes().to_vec(),
232            message: Some(match &self.message {
233                StreamRefMessage::OnSubscribeHandshake => {
234                    wire_stream_ref_frame::Message::OnSubscribeHandshake(
235                        WireOnSubscribeHandshake {},
236                    )
237                }
238                StreamRefMessage::CumulativeDemand { seq_nr } => {
239                    wire_stream_ref_frame::Message::CumulativeDemand(WireCumulativeDemand {
240                        seq_nr: *seq_nr,
241                    })
242                }
243                StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
244                    wire_stream_ref_frame::Message::SequencedOnNext(WireSequencedOnNext {
245                        seq_nr: *seq_nr,
246                        payload: Some(WirePayload {
247                            enclosed_message: payload.bytes.clone(),
248                        }),
249                    })
250                }
251                StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
252                    wire_stream_ref_frame::Message::RemoteStreamCompleted(
253                        WireRemoteStreamCompleted { seq_nr: *seq_nr },
254                    )
255                }
256                StreamRefMessage::RemoteStreamFailure { cause } => {
257                    wire_stream_ref_frame::Message::RemoteStreamFailure(WireRemoteStreamFailure {
258                        cause: cause.clone(),
259                    })
260                }
261                StreamRefMessage::Ack => wire_stream_ref_frame::Message::Ack(WireAck {}),
262            }),
263        }
264    }
265
266    fn from_wire(wire: WireStreamRefFrame) -> StreamResult<Self> {
267        let stream_ref_id = StreamRefId::from_bytes(&wire.stream_ref_id)?;
268        let message = match wire.message.ok_or_else(|| {
269            StreamError::Failed("stream ref protobuf frame has no message".to_owned())
270        })? {
271            wire_stream_ref_frame::Message::OnSubscribeHandshake(_) => {
272                StreamRefMessage::OnSubscribeHandshake
273            }
274            wire_stream_ref_frame::Message::CumulativeDemand(message) => {
275                StreamRefMessage::CumulativeDemand {
276                    seq_nr: message.seq_nr,
277                }
278            }
279            wire_stream_ref_frame::Message::SequencedOnNext(message) => {
280                let payload = message.payload.ok_or_else(|| {
281                    StreamError::Failed("SequencedOnNext missing payload".to_owned())
282                })?;
283                StreamRefMessage::SequencedOnNext {
284                    seq_nr: message.seq_nr,
285                    payload: StreamRefPayloadBytes {
286                        bytes: payload.enclosed_message,
287                    },
288                }
289            }
290            wire_stream_ref_frame::Message::RemoteStreamCompleted(message) => {
291                StreamRefMessage::RemoteStreamCompleted {
292                    seq_nr: message.seq_nr,
293                }
294            }
295            wire_stream_ref_frame::Message::RemoteStreamFailure(message) => {
296                StreamRefMessage::RemoteStreamFailure {
297                    cause: message.cause,
298                }
299            }
300            wire_stream_ref_frame::Message::Ack(_) => StreamRefMessage::Ack,
301        };
302        Ok(Self {
303            stream_ref_id,
304            message,
305        })
306    }
307}
308
309#[derive(Clone, PartialEq, ProstMessage)]
310struct WireStreamRefFrame {
311    #[prost(bytes = "vec", tag = "1")]
312    stream_ref_id: Vec<u8>,
313    #[prost(oneof = "wire_stream_ref_frame::Message", tags = "2, 3, 4, 5, 6, 7")]
314    message: Option<wire_stream_ref_frame::Message>,
315}
316
317mod wire_stream_ref_frame {
318    #[derive(Clone, PartialEq, prost::Oneof)]
319    pub enum Message {
320        #[prost(message, tag = "2")]
321        OnSubscribeHandshake(super::WireOnSubscribeHandshake),
322        #[prost(message, tag = "3")]
323        CumulativeDemand(super::WireCumulativeDemand),
324        #[prost(message, tag = "4")]
325        SequencedOnNext(super::WireSequencedOnNext),
326        #[prost(message, tag = "5")]
327        RemoteStreamCompleted(super::WireRemoteStreamCompleted),
328        #[prost(message, tag = "6")]
329        RemoteStreamFailure(super::WireRemoteStreamFailure),
330        #[prost(message, tag = "7")]
331        Ack(super::WireAck),
332    }
333}
334
335#[derive(Clone, PartialEq, ProstMessage)]
336struct WirePayload {
337    #[prost(bytes = "vec", tag = "1")]
338    enclosed_message: Vec<u8>,
339}
340
341#[derive(Clone, PartialEq, ProstMessage)]
342struct WireOnSubscribeHandshake {}
343
344#[derive(Clone, PartialEq, ProstMessage)]
345struct WireCumulativeDemand {
346    #[prost(uint64, tag = "1")]
347    seq_nr: u64,
348}
349
350#[derive(Clone, PartialEq, ProstMessage)]
351struct WireSequencedOnNext {
352    #[prost(uint64, tag = "1")]
353    seq_nr: u64,
354    #[prost(message, optional, tag = "2")]
355    payload: Option<WirePayload>,
356}
357
358#[derive(Clone, PartialEq, ProstMessage)]
359struct WireRemoteStreamFailure {
360    #[prost(bytes = "vec", tag = "1")]
361    cause: Vec<u8>,
362}
363
364#[derive(Clone, PartialEq, ProstMessage)]
365struct WireRemoteStreamCompleted {
366    #[prost(uint64, tag = "1")]
367    seq_nr: u64,
368}
369
370#[derive(Clone, PartialEq, ProstMessage)]
371struct WireAck {}
372
373/// Common interface for external carriers that pump protobuf frames.
374pub trait StreamRefProtoEndpoint: Clone + Send + Sync + 'static {
375    fn stream_ref_id(&self) -> StreamRefId;
376    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>>;
377    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()>;
378    fn fail_connection(&self, error: StreamError);
379}
380
381/// Producer-side endpoint for a local `SourceRef` or `Source`.
382///
383/// Feed inbound `OnSubscribeHandshake`, `CumulativeDemand`, cancellation, and
384/// Ack frames with [`StreamRefProtoEndpoint::handle_frame`]. Drain outbound
385/// `SequencedOnNext`, completion, failure, and Ack frames with
386/// [`StreamRefProtoEndpoint::next_frame`].
387pub struct StreamRefProtoProducer<T>
388where
389    T: StreamRefPayload,
390{
391    shared: Arc<ProducerShared<T>>,
392}
393
394impl<T> Clone for StreamRefProtoProducer<T>
395where
396    T: StreamRefPayload,
397{
398    fn clone(&self) -> Self {
399        Self {
400            shared: Arc::clone(&self.shared),
401        }
402    }
403}
404
405impl<T> StreamRefProtoProducer<T>
406where
407    T: StreamRefPayload,
408{
409    pub fn from_source_ref(
410        source_ref: SourceRef<T>,
411        stream_ref_id: StreamRefId,
412        settings: StreamRefSettings,
413    ) -> StreamResult<Self> {
414        Self::from_source(
415            super::stream_ref::proto_source(&source_ref),
416            stream_ref_id,
417            settings,
418        )
419    }
420
421    pub fn from_source<Mat>(
422        source: Source<T, Mat>,
423        stream_ref_id: StreamRefId,
424        settings: StreamRefSettings,
425    ) -> StreamResult<Self>
426    where
427        Mat: Send + 'static,
428    {
429        let materializer = Materializer::new();
430        let (input, materialized) = Arc::clone(&source.factory).create(&materializer)?;
431        Ok(Self {
432            shared: Arc::new(ProducerShared {
433                stream_ref_id,
434                settings,
435                input: Mutex::new(Some(input)),
436                state: Mutex::new(ProducerState {
437                    partner_seen: false,
438                    cumulative_demand: 0,
439                    sent: 0,
440                    terminal_sent: false,
441                    waiting_for_ack: false,
442                    ack_deadline: None,
443                    stopped: None,
444                    ack_queued: false,
445                    done: false,
446                    input_attached: true,
447                    terminal_result: None,
448                }),
449                changed: Condvar::new(),
450                completion: Mutex::new(None),
451                _materializer: materializer,
452                _materialized: Mutex::new(Some(Box::new(materialized))),
453            }),
454        })
455    }
456
457    /// Creates a producer with no attached input, for the SinkRef sender side.
458    ///
459    /// The input stream is attached later by materializing the [`Sink`]
460    /// returned from [`StreamRefProtoProducer::sink`]. Until then, the
461    /// producer's `next_frame` waits on the condvar instead of spinning, so an
462    /// idle lazy producer does not busy-loop while the remote consumer has not
463    /// yet subscribed or attached demand.
464    #[must_use]
465    pub fn new_lazy(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
466        Self {
467            shared: Arc::new(ProducerShared {
468                stream_ref_id,
469                settings,
470                input: Mutex::new(None),
471                state: Mutex::new(ProducerState {
472                    partner_seen: false,
473                    cumulative_demand: 0,
474                    sent: 0,
475                    terminal_sent: false,
476                    waiting_for_ack: false,
477                    ack_deadline: None,
478                    stopped: None,
479                    ack_queued: false,
480                    done: false,
481                    input_attached: false,
482                    terminal_result: None,
483                }),
484                changed: Condvar::new(),
485                completion: Mutex::new(None),
486                _materializer: Materializer::new(),
487                _materialized: Mutex::new(None),
488            }),
489        }
490    }
491
492    /// Returns a [`Sink`] whose incoming elements are framed and sent as this
493    /// producer's outbound `SequencedOnNext` frames.
494    ///
495    /// Materializing the sink attaches the input stream to the lazy producer
496    /// and returns a [`StreamCompletion`] that resolves when the producer
497    /// reaches its terminal state (all elements sent and acknowledged, the
498    /// remote cancelled/failed, or the carrier failed).
499    #[must_use]
500    pub fn sink(&self) -> Sink<T, StreamCompletion<NotUsed>> {
501        let shared = Arc::clone(&self.shared);
502        Sink::from_runner(move |input, _materializer| {
503            let (sender, receiver) = oneshot::channel();
504            *shared
505                .completion
506                .lock()
507                .unwrap_or_else(|poison| poison.into_inner()) = Some(sender);
508            shared.attach_input(input);
509            Ok(StreamCompletion::from_receiver(receiver, None))
510        })
511    }
512}
513
514impl<T> StreamRefProtoEndpoint for StreamRefProtoProducer<T>
515where
516    T: StreamRefPayload,
517{
518    fn stream_ref_id(&self) -> StreamRefId {
519        self.shared.stream_ref_id
520    }
521
522    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
523        self.shared.next_frame()
524    }
525
526    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
527        self.shared.handle_frame(frame)
528    }
529
530    fn fail_connection(&self, error: StreamError) {
531        self.shared.fail_connection(error);
532    }
533}
534
535struct ProducerShared<T>
536where
537    T: StreamRefPayload,
538{
539    stream_ref_id: StreamRefId,
540    settings: StreamRefSettings,
541    input: Mutex<Option<BoxStream<T>>>,
542    state: Mutex<ProducerState>,
543    changed: Condvar,
544    completion: Mutex<Option<oneshot::Sender<StreamResult<NotUsed>>>>,
545    _materializer: Materializer,
546    _materialized: Mutex<Option<Box<dyn Any + Send>>>,
547}
548
549struct ProducerState {
550    partner_seen: bool,
551    cumulative_demand: u64,
552    sent: u64,
553    terminal_sent: bool,
554    waiting_for_ack: bool,
555    ack_deadline: Option<Instant>,
556    stopped: Option<StreamError>,
557    ack_queued: bool,
558    done: bool,
559    input_attached: bool,
560    terminal_result: Option<StreamResult<NotUsed>>,
561}
562
563impl<T> ProducerShared<T>
564where
565    T: StreamRefPayload,
566{
567    fn lock_state(&self) -> MutexGuard<'_, ProducerState> {
568        self.state
569            .lock()
570            .unwrap_or_else(|poison| poison.into_inner())
571    }
572
573    fn lock_input(&self) -> MutexGuard<'_, Option<BoxStream<T>>> {
574        self.input
575            .lock()
576            .unwrap_or_else(|poison| poison.into_inner())
577    }
578
579    fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
580        StreamRefFrame::new(self.stream_ref_id, message)
581    }
582
583    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
584        let subscription_deadline = deadline_from_now(self.settings.subscription_timeout());
585        loop {
586            let mut state = self.lock_state();
587            if state.done {
588                return None;
589            }
590
591            if state.ack_queued {
592                state.ack_queued = false;
593                state.done = true;
594                state.terminal_result = Some(match state.stopped.clone() {
595                    Some(error) => Err(error),
596                    None => Ok(NotUsed),
597                });
598                self.changed.notify_all();
599                drop(state);
600                self.drop_input();
601                self.settle();
602                return Some(Ok(self.frame(StreamRefMessage::Ack)));
603            }
604
605            if state.waiting_for_ack {
606                if state
607                    .ack_deadline
608                    .is_some_and(|deadline| Instant::now() >= deadline)
609                {
610                    let timeout_error =
611                        subscription_timeout_error("stream ref producer terminal ack");
612                    state.done = true;
613                    state.terminal_result = Some(Err(timeout_error.clone()));
614                    self.changed.notify_all();
615                    drop(state);
616                    self.drop_input();
617                    self.settle();
618                    return Some(Err(timeout_error));
619                }
620                if let Some(remaining) = state
621                    .ack_deadline
622                    .and_then(|deadline| deadline.checked_duration_since(Instant::now()))
623                {
624                    let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
625                    drop(next);
626                } else {
627                    drop(state);
628                }
629                continue;
630            }
631
632            if let Some(error) = state.stopped.clone() {
633                state.done = true;
634                state.terminal_result = Some(Err(error.clone()));
635                self.changed.notify_all();
636                drop(state);
637                self.drop_input();
638                self.settle();
639                return Some(Err(error));
640            }
641
642            if state.cumulative_demand > 0 && state.sent < state.cumulative_demand {
643                drop(state);
644                if let Some(frame) = self.pull_next_frame() {
645                    return Some(frame);
646                }
647                continue;
648            }
649
650            if state.cumulative_demand == 0 && Instant::now() >= subscription_deadline {
651                let timeout_error = subscription_timeout_error("stream ref producer first demand");
652                state.done = true;
653                state.terminal_result = Some(Err(timeout_error.clone()));
654                self.changed.notify_all();
655                drop(state);
656                self.drop_input();
657                self.settle();
658                return Some(Err(timeout_error));
659            }
660
661            if state.cumulative_demand == 0 {
662                let remaining = subscription_deadline.saturating_duration_since(Instant::now());
663                if remaining.is_zero() {
664                    drop(state);
665                    continue;
666                }
667                let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
668                drop(next);
669            } else {
670                let next = wait_unpoison(&self.changed, state);
671                drop(next);
672            }
673        }
674    }
675
676    fn pull_next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
677        let item = {
678            let mut input_guard = self.lock_input();
679            if input_guard.is_none() {
680                drop(input_guard);
681                // Lazy producer: no input attached yet. Block on the condvar
682                // (notify-driven, no fixed poll) until `attach_input` or a
683                // terminal state wakes us, then return None so `next_frame`
684                // re-enters its loop and re-checks demand/terminal predicates.
685                // The `input_attached` flag is checked under the state lock to
686                // avoid a missed-wakeup race with `attach_input`.
687                let mut state = self.lock_state();
688                while !state.input_attached
689                    && !state.done
690                    && state.stopped.is_none()
691                    && !state.terminal_sent
692                {
693                    state = wait_unpoison(&self.changed, state);
694                }
695                drop(state);
696                return None;
697            }
698            input_guard.as_mut().expect("input attached").next()
699        };
700
701        match item {
702            Some(Ok(item)) => {
703                let mut state = self.lock_state();
704                if state.done || state.stopped.is_some() || state.waiting_for_ack {
705                    return None;
706                }
707                let seq_nr = state.sent;
708                state.sent = state.sent.saturating_add(1);
709                Some(Ok(self.frame(StreamRefMessage::SequencedOnNext {
710                    seq_nr,
711                    payload: StreamRefPayloadBytes {
712                        bytes: item.encode_stream_ref_payload(),
713                    },
714                })))
715            }
716            Some(Err(error)) => {
717                self.drop_input();
718                let mut state = self.lock_state();
719                if state.done || state.terminal_sent {
720                    return None;
721                }
722                state.terminal_sent = true;
723                state.waiting_for_ack = true;
724                state.terminal_result = Some(Err(error.clone()));
725                state.ack_deadline = Some(deadline_from_now(self.settings.subscription_timeout()));
726                self.changed.notify_all();
727                drop(state);
728                Some(Ok(self.frame(StreamRefMessage::RemoteStreamFailure {
729                    cause: failure_cause(&error),
730                })))
731            }
732            None => {
733                self.drop_input();
734                let mut state = self.lock_state();
735                if state.done || state.terminal_sent {
736                    return None;
737                }
738                let seq_nr = state.sent;
739                state.terminal_sent = true;
740                state.waiting_for_ack = true;
741                state.terminal_result = Some(Ok(NotUsed));
742                state.ack_deadline = Some(deadline_from_now(self.settings.subscription_timeout()));
743                self.changed.notify_all();
744                drop(state);
745                Some(Ok(
746                    self.frame(StreamRefMessage::RemoteStreamCompleted { seq_nr })
747                ))
748            }
749        }
750    }
751
752    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
753        self.validate_frame_id(frame.stream_ref_id)?;
754        match frame.message {
755            StreamRefMessage::OnSubscribeHandshake => {
756                let mut state = self.lock_state();
757                state.partner_seen = true;
758                self.changed.notify_all();
759                drop(state);
760                Ok(())
761            }
762            StreamRefMessage::CumulativeDemand { seq_nr } => {
763                if seq_nr == 0 {
764                    return Err(StreamError::Failed(
765                        "CumulativeDemand seq_nr must be positive".to_owned(),
766                    ));
767                }
768                let mut state = self.lock_state();
769                state.partner_seen = true;
770                if seq_nr > state.cumulative_demand {
771                    state.cumulative_demand = seq_nr;
772                }
773                self.changed.notify_all();
774                drop(state);
775                Ok(())
776            }
777            StreamRefMessage::RemoteStreamCompleted { .. } => {
778                self.stop_from_consumer(StreamError::Cancelled);
779                Ok(())
780            }
781            StreamRefMessage::RemoteStreamFailure { cause } => {
782                self.stop_from_consumer(StreamError::Failed(
783                    String::from_utf8_lossy(&cause).into_owned(),
784                ));
785                Ok(())
786            }
787            StreamRefMessage::Ack => {
788                let mut state = self.lock_state();
789                if state.waiting_for_ack {
790                    state.waiting_for_ack = false;
791                    state.done = true;
792                    if state.terminal_result.is_none() {
793                        state.terminal_result = Some(Ok(NotUsed));
794                    }
795                    self.changed.notify_all();
796                    drop(state);
797                    self.drop_input();
798                    self.settle();
799                } else {
800                    drop(state);
801                }
802                Ok(())
803            }
804            StreamRefMessage::SequencedOnNext { .. } => Err(StreamError::Failed(
805                "producer endpoint cannot receive SequencedOnNext".to_owned(),
806            )),
807        }
808    }
809
810    fn stop_from_consumer(&self, error: StreamError) {
811        let mut state = self.lock_state();
812        if !state.done {
813            state.stopped = Some(error.clone());
814            state.ack_queued = true;
815            state.terminal_result = Some(Err(error));
816        }
817        self.changed.notify_all();
818        drop(state);
819        self.drop_input();
820    }
821
822    fn fail_connection(&self, error: StreamError) {
823        let mut state = self.lock_state();
824        if !state.done {
825            state.stopped = Some(error.clone());
826            state.done = true;
827            state.terminal_result = Some(Err(error));
828        }
829        self.changed.notify_all();
830        drop(state);
831        self.drop_input();
832        self.settle();
833    }
834
835    fn attach_input(&self, input: BoxStream<T>) {
836        *self.lock_input() = Some(input);
837        let mut state = self.lock_state();
838        state.input_attached = true;
839        self.changed.notify_all();
840        drop(state);
841    }
842
843    fn settle(&self) {
844        let result = self.lock_state().terminal_result.clone();
845        let sender = self
846            .completion
847            .lock()
848            .unwrap_or_else(|poison| poison.into_inner())
849            .take();
850        if let (Some(sender), Some(result)) = (sender, result) {
851            let _ = sender.send(result);
852        }
853    }
854
855    fn drop_input(&self) {
856        let input = self.lock_input().take();
857        drop(input);
858    }
859
860    fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
861        if stream_ref_id == self.stream_ref_id {
862            Ok(())
863        } else {
864            Err(StreamError::Failed(format!(
865                "stream ref id mismatch: expected {}, got {}",
866                self.stream_ref_id, stream_ref_id
867            )))
868        }
869    }
870}
871
872/// Consumer-side endpoint that exposes inbound remote elements as a local
873/// [`Source`].
874///
875/// The source sends `OnSubscribeHandshake` and a cumulative demand ceiling when
876/// it is materialized. It then refills demand at half the configured buffer.
877pub struct StreamRefProtoConsumer<T>
878where
879    T: StreamRefPayload,
880{
881    shared: Arc<ConsumerShared<T>>,
882}
883
884impl<T> Clone for StreamRefProtoConsumer<T>
885where
886    T: StreamRefPayload,
887{
888    fn clone(&self) -> Self {
889        Self {
890            shared: Arc::clone(&self.shared),
891        }
892    }
893}
894
895impl<T> StreamRefProtoConsumer<T>
896where
897    T: StreamRefPayload,
898{
899    #[must_use]
900    pub fn new(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
901        Self {
902            shared: Arc::new(ConsumerShared {
903                stream_ref_id,
904                settings,
905                state: Mutex::new(ConsumerState {
906                    source_taken: false,
907                    subscribed: false,
908                    queue: VecDeque::new(),
909                    terminal: None,
910                    expected_seq: 0,
911                    delivered: 0,
912                    cumulative_demand: 0,
913                    outbound: VecDeque::new(),
914                    finish_after_outbound_ack: false,
915                    waiting_cancel_ack: false,
916                    done: false,
917                }),
918                changed: Condvar::new(),
919            }),
920        }
921    }
922
923    #[must_use]
924    pub fn source(&self) -> Source<T, NotUsed> {
925        let shared = Arc::clone(&self.shared);
926        Source::unfold_resource(
927            move || shared.start_stream(),
928            |stream| stream.next_item(),
929            |mut stream| {
930                stream.close();
931                Ok(())
932            },
933        )
934    }
935}
936
937impl<T> StreamRefProtoEndpoint for StreamRefProtoConsumer<T>
938where
939    T: StreamRefPayload,
940{
941    fn stream_ref_id(&self) -> StreamRefId {
942        self.shared.stream_ref_id
943    }
944
945    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
946        self.shared.next_frame()
947    }
948
949    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
950        self.shared.handle_frame(frame)
951    }
952
953    fn fail_connection(&self, error: StreamError) {
954        self.shared.fail_connection(error);
955    }
956}
957
958struct ConsumerShared<T>
959where
960    T: StreamRefPayload,
961{
962    stream_ref_id: StreamRefId,
963    settings: StreamRefSettings,
964    state: Mutex<ConsumerState<T>>,
965    changed: Condvar,
966}
967
968struct ConsumerState<T> {
969    source_taken: bool,
970    subscribed: bool,
971    queue: VecDeque<T>,
972    terminal: Option<ConsumerTerminal>,
973    expected_seq: u64,
974    delivered: u64,
975    cumulative_demand: u64,
976    outbound: VecDeque<StreamRefMessage>,
977    finish_after_outbound_ack: bool,
978    waiting_cancel_ack: bool,
979    done: bool,
980}
981
982#[derive(Clone)]
983enum ConsumerTerminal {
984    Complete,
985    Error(StreamError),
986}
987
988impl<T> ConsumerShared<T>
989where
990    T: StreamRefPayload,
991{
992    fn lock_state(&self) -> MutexGuard<'_, ConsumerState<T>> {
993        self.state
994            .lock()
995            .unwrap_or_else(|poison| poison.into_inner())
996    }
997
998    fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
999        StreamRefFrame::new(self.stream_ref_id, message)
1000    }
1001
1002    fn start_stream(self: &Arc<Self>) -> StreamResult<ConsumerStream<T>> {
1003        {
1004            let mut state = self.lock_state();
1005            if state.source_taken {
1006                return Err(StreamError::Failed(
1007                    "stream ref source has already been materialized".to_owned(),
1008                ));
1009            }
1010            state.source_taken = true;
1011            if !state.subscribed {
1012                state.subscribed = true;
1013                state
1014                    .outbound
1015                    .push_back(StreamRefMessage::OnSubscribeHandshake);
1016                if let Some(demand) = next_demand(&mut state, self.settings) {
1017                    state
1018                        .outbound
1019                        .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
1020                }
1021            }
1022            self.changed.notify_all();
1023        }
1024        Ok(ConsumerStream {
1025            shared: Arc::clone(self),
1026            terminated: false,
1027        })
1028    }
1029
1030    fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1031        loop {
1032            let mut state = self.lock_state();
1033            if let Some(message) = state.outbound.pop_front() {
1034                let finish_after_ack = message.is_ack() && state.finish_after_outbound_ack;
1035                if finish_after_ack {
1036                    state.done = true;
1037                }
1038                drop(state);
1039                return Some(Ok(self.frame(message)));
1040            }
1041            if state.done {
1042                return None;
1043            }
1044            let next = wait_unpoison(&self.changed, state);
1045            drop(next);
1046        }
1047    }
1048
1049    fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1050        self.validate_frame_id(frame.stream_ref_id)?;
1051        match frame.message {
1052            StreamRefMessage::OnSubscribeHandshake => Ok(()),
1053            StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
1054                let item = T::decode_stream_ref_payload(payload.bytes)?;
1055                let mut state = self.lock_state();
1056                if state.terminal.is_some() || state.done {
1057                    return Ok(());
1058                }
1059                if seq_nr != state.expected_seq {
1060                    let error =
1061                        invalid_sequence_error(state.expected_seq, seq_nr, "stream ref element");
1062                    state.queue.clear();
1063                    state.terminal = Some(ConsumerTerminal::Error(error.clone()));
1064                    state
1065                        .outbound
1066                        .push_back(StreamRefMessage::RemoteStreamFailure {
1067                            cause: failure_cause(&error),
1068                        });
1069                    state.waiting_cancel_ack = true;
1070                } else if state.queue.len() >= self.settings.buffer_capacity() {
1071                    let error = StreamError::Failed(
1072                        "stream ref receive buffer overflowed demand window".to_owned(),
1073                    );
1074                    state.queue.clear();
1075                    state.terminal = Some(ConsumerTerminal::Error(error.clone()));
1076                    state
1077                        .outbound
1078                        .push_back(StreamRefMessage::RemoteStreamFailure {
1079                            cause: failure_cause(&error),
1080                        });
1081                    state.waiting_cancel_ack = true;
1082                } else {
1083                    state.expected_seq = state.expected_seq.saturating_add(1);
1084                    state.queue.push_back(item);
1085                }
1086                self.changed.notify_all();
1087                drop(state);
1088                Ok(())
1089            }
1090            StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
1091                let mut state = self.lock_state();
1092                if state.terminal.is_none() && !state.done {
1093                    if seq_nr != state.expected_seq {
1094                        state.queue.clear();
1095                        state.terminal = Some(ConsumerTerminal::Error(invalid_sequence_error(
1096                            state.expected_seq,
1097                            seq_nr,
1098                            "stream ref completion",
1099                        )));
1100                    } else {
1101                        state.terminal = Some(ConsumerTerminal::Complete);
1102                    }
1103                    state.outbound.push_back(StreamRefMessage::Ack);
1104                    state.finish_after_outbound_ack = true;
1105                }
1106                self.changed.notify_all();
1107                drop(state);
1108                Ok(())
1109            }
1110            StreamRefMessage::RemoteStreamFailure { cause } => {
1111                let mut state = self.lock_state();
1112                if state.terminal.is_none() && !state.done {
1113                    state.queue.clear();
1114                    state.terminal = Some(ConsumerTerminal::Error(StreamError::Failed(
1115                        String::from_utf8_lossy(&cause).into_owned(),
1116                    )));
1117                    state.outbound.push_back(StreamRefMessage::Ack);
1118                    state.finish_after_outbound_ack = true;
1119                }
1120                self.changed.notify_all();
1121                drop(state);
1122                Ok(())
1123            }
1124            StreamRefMessage::Ack => {
1125                let mut state = self.lock_state();
1126                if state.waiting_cancel_ack {
1127                    state.waiting_cancel_ack = false;
1128                    state.done = true;
1129                }
1130                self.changed.notify_all();
1131                drop(state);
1132                Ok(())
1133            }
1134            StreamRefMessage::CumulativeDemand { .. } => Err(StreamError::Failed(
1135                "consumer endpoint cannot receive CumulativeDemand".to_owned(),
1136            )),
1137        }
1138    }
1139
1140    fn cancel_from_downstream(&self) {
1141        let mut state = self.lock_state();
1142        if state.terminal.is_none() && !state.done {
1143            let seq_nr = state.expected_seq;
1144            state.terminal = Some(ConsumerTerminal::Error(StreamError::Cancelled));
1145            state
1146                .outbound
1147                .push_back(StreamRefMessage::RemoteStreamCompleted { seq_nr });
1148            state.waiting_cancel_ack = true;
1149        }
1150        self.changed.notify_all();
1151        drop(state);
1152    }
1153
1154    fn fail_connection(&self, error: StreamError) {
1155        let mut state = self.lock_state();
1156        if state.terminal.is_none() {
1157            state.queue.clear();
1158            state.terminal = Some(ConsumerTerminal::Error(error));
1159        }
1160        state.done = true;
1161        self.changed.notify_all();
1162        drop(state);
1163    }
1164
1165    fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
1166        if stream_ref_id == self.stream_ref_id {
1167            Ok(())
1168        } else {
1169            Err(StreamError::Failed(format!(
1170                "stream ref id mismatch: expected {}, got {}",
1171                self.stream_ref_id, stream_ref_id
1172            )))
1173        }
1174    }
1175}
1176
1177struct ConsumerStream<T>
1178where
1179    T: StreamRefPayload,
1180{
1181    shared: Arc<ConsumerShared<T>>,
1182    terminated: bool,
1183}
1184
1185impl<T> ConsumerStream<T>
1186where
1187    T: StreamRefPayload,
1188{
1189    fn next_item(&mut self) -> StreamResult<Option<T>> {
1190        if self.terminated {
1191            return Ok(None);
1192        }
1193        loop {
1194            let mut state = self.shared.lock_state();
1195            if let Some(item) = state.queue.pop_front() {
1196                state.delivered = state.delivered.saturating_add(1);
1197                if let Some(demand) = next_demand(&mut state, self.shared.settings) {
1198                    state
1199                        .outbound
1200                        .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
1201                    self.shared.changed.notify_all();
1202                }
1203                return Ok(Some(item));
1204            }
1205
1206            if let Some(terminal) = state.terminal.clone() {
1207                self.terminated = true;
1208                return match terminal {
1209                    ConsumerTerminal::Complete => Ok(None),
1210                    ConsumerTerminal::Error(error) => Err(error),
1211                };
1212            }
1213
1214            let next = wait_unpoison(&self.shared.changed, state);
1215            drop(next);
1216        }
1217    }
1218
1219    fn close(&mut self) {
1220        if !self.terminated {
1221            self.shared.cancel_from_downstream();
1222            self.terminated = true;
1223        }
1224    }
1225}
1226
1227fn next_demand<T>(state: &mut ConsumerState<T>, settings: StreamRefSettings) -> Option<u64> {
1228    // Akka redelivers CumulativeDemand because some remoting carriers are
1229    // lossy. This seam is driven by reliable ordered carriers such as one QUIC
1230    // bidirectional stream, so a larger cumulative ceiling is sent once.
1231    if state.terminal.is_some() {
1232        return None;
1233    }
1234    let remaining_credit = state.cumulative_demand.saturating_sub(state.delivered);
1235    if state.cumulative_demand != 0 && remaining_credit > demand_replenish_threshold(settings) {
1236        return None;
1237    }
1238    let target = state
1239        .delivered
1240        .saturating_add(settings.buffer_capacity() as u64);
1241    if state.cumulative_demand >= target {
1242        return None;
1243    }
1244    state.cumulative_demand = target;
1245    Some(target)
1246}
1247
1248fn demand_replenish_threshold(settings: StreamRefSettings) -> u64 {
1249    (settings.buffer_capacity() as u64) / 2
1250}
1251
1252fn failure_cause(error: &StreamError) -> Vec<u8> {
1253    match error {
1254        StreamError::Failed(message) => message.clone().into_bytes(),
1255        other => other.to_string().into_bytes(),
1256    }
1257}
1258
1259fn subscription_timeout_error(side: &str) -> StreamError {
1260    StreamError::Failed(format!(
1261        "{side} remote side did not subscribe within subscription timeout"
1262    ))
1263}
1264
1265fn invalid_sequence_error(expected: u64, got: u64, context: &str) -> StreamError {
1266    StreamError::Failed(format!(
1267        "{context} sequence gap: expected sequence {expected}, got {got}"
1268    ))
1269}
1270
1271fn deadline_from_now(timeout: Duration) -> Instant {
1272    Instant::now()
1273        .checked_add(timeout)
1274        .unwrap_or_else(far_future)
1275}
1276
1277fn far_future() -> Instant {
1278    Instant::now() + Duration::from_secs(60 * 60 * 24 * 365)
1279}
1280
1281fn wait_timeout_unpoison<'a, T>(
1282    condvar: &Condvar,
1283    guard: MutexGuard<'a, T>,
1284    timeout: Duration,
1285) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
1286    condvar
1287        .wait_timeout(guard, timeout)
1288        .unwrap_or_else(|poison| poison.into_inner())
1289}
1290
1291fn wait_unpoison<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
1292    condvar
1293        .wait(guard)
1294        .unwrap_or_else(|poison| poison.into_inner())
1295}
1296
1297#[cfg(test)]
1298mod tests {
1299    use std::time::Duration;
1300
1301    use super::*;
1302    use crate::{Source, StreamRefs};
1303
1304    fn short_settings() -> StreamRefSettings {
1305        StreamRefSettings::default()
1306            .with_buffer_capacity(4)
1307            .with_subscription_timeout(Duration::from_millis(50))
1308    }
1309
1310    #[test]
1311    fn protobuf_frame_round_trip() {
1312        let frame = StreamRefFrame::new(
1313            StreamRefId::from_u128(42),
1314            StreamRefMessage::SequencedOnNext {
1315                seq_nr: 7,
1316                payload: StreamRefPayloadBytes {
1317                    bytes: 99_u64.encode_stream_ref_payload(),
1318                },
1319            },
1320        );
1321
1322        let decoded = StreamRefFrame::decode(&frame.encode_to_vec()).unwrap();
1323        assert_eq!(decoded, frame);
1324    }
1325
1326    #[test]
1327    fn producer_consumer_seam_streams_with_low_watermark_demand() {
1328        let id = StreamRefId::from_u128(1);
1329        let settings = short_settings();
1330        let source_ref = Source::from_iter(0_u64..10)
1331            .run_with(StreamRefs::source_ref_with_settings(settings))
1332            .unwrap();
1333        let producer = StreamRefProtoProducer::from_source_ref(source_ref, id, settings).unwrap();
1334        let consumer = StreamRefProtoConsumer::<u64>::new(id, settings);
1335        let consumer_source = consumer.source();
1336
1337        let producer_thread = std::thread::spawn({
1338            let producer = producer.clone();
1339            let consumer = consumer.clone();
1340            move || {
1341                while let Some(frame) = producer.next_frame() {
1342                    consumer.handle_frame(frame?)?;
1343                }
1344                Ok::<_, StreamError>(())
1345            }
1346        });
1347        let consumer_thread = std::thread::spawn({
1348            let producer = producer.clone();
1349            let consumer = consumer.clone();
1350            move || {
1351                while let Some(frame) = consumer.next_frame() {
1352                    producer.handle_frame(frame?)?;
1353                }
1354                Ok::<_, StreamError>(())
1355            }
1356        });
1357
1358        assert_eq!(
1359            consumer_source.run_collect().unwrap(),
1360            (0_u64..10).collect::<Vec<_>>()
1361        );
1362        producer_thread.join().unwrap().unwrap();
1363        consumer_thread.join().unwrap().unwrap();
1364    }
1365
1366    #[test]
1367    fn strict_sequence_gap_fails_consumer_and_sends_failure() {
1368        let id = StreamRefId::from_u128(2);
1369        let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
1370        let source = consumer
1371            .source()
1372            .run_with(crate::testkit::TestSink::probe())
1373            .unwrap();
1374        source.request(1);
1375        consumer.next_frame().unwrap().unwrap();
1376        consumer.next_frame().unwrap().unwrap();
1377
1378        consumer
1379            .handle_frame(StreamRefFrame::new(
1380                id,
1381                StreamRefMessage::SequencedOnNext {
1382                    seq_nr: 1,
1383                    payload: StreamRefPayloadBytes {
1384                        bytes: 1_u64.encode_stream_ref_payload(),
1385                    },
1386                },
1387            ))
1388            .unwrap();
1389
1390        let outbound = consumer.next_frame().unwrap().unwrap();
1391        assert!(matches!(
1392            outbound.message,
1393            StreamRefMessage::RemoteStreamFailure { .. }
1394        ));
1395        assert!(matches!(source.expect_error(), StreamError::Failed(_)));
1396    }
1397
1398    #[test]
1399    fn producer_times_out_without_first_demand() {
1400        let producer = StreamRefProtoProducer::from_source(
1401            Source::repeat(1_u64),
1402            StreamRefId::from_u128(3),
1403            short_settings(),
1404        )
1405        .unwrap();
1406
1407        let error = producer.next_frame().unwrap().unwrap_err();
1408        assert!(matches!(error, StreamError::Failed(message) if message.contains("first demand")));
1409    }
1410
1411    #[test]
1412    fn demand_redelivery_is_not_required_by_reliable_carriers() {
1413        // Akka redelivers CumulativeDemand because Artery/Aeron may lose messages.
1414        // The protobuf seam is intended for reliable ordered carriers such as a
1415        // single QUIC bidirectional stream, so each cumulative ceiling is sent
1416        // once and remains valid until a larger ceiling replaces it.
1417        assert_eq!(
1418            StreamRefSettings::default().demand_redelivery_interval(),
1419            Duration::from_secs(1)
1420        );
1421    }
1422}