1use std::{
9 any::Any,
10 collections::VecDeque,
11 fmt,
12 sync::{
13 Arc, Condvar, Mutex, MutexGuard,
14 atomic::{AtomicU64, Ordering},
15 },
16 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
17};
18
19use crate::stream::{
20 BoxStream, Materializer, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult,
21};
22use futures::channel::oneshot;
23use prost::Message as ProstMessage;
24
25use super::{SourceRef, StreamRefSettings};
26
27static STREAM_REF_PROTO_ID: AtomicU64 = AtomicU64::new(1);
28
29pub trait StreamRefPayload: Send + 'static {
35 fn encode_stream_ref_payload(self) -> Vec<u8>;
36
37 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self>
38 where
39 Self: Sized;
40}
41
42macro_rules! impl_stream_ref_payload_numeric {
43 ($($ty:ty),* $(,)?) => {
44 $(
45 impl StreamRefPayload for $ty {
46 fn encode_stream_ref_payload(self) -> Vec<u8> {
47 self.to_be_bytes().to_vec()
48 }
49
50 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
51 let data: [u8; std::mem::size_of::<Self>()] =
52 bytes.as_slice().try_into().map_err(|_| {
53 StreamError::Failed(format!(
54 "invalid {} stream ref payload length: {}",
55 stringify!($ty),
56 bytes.len()
57 ))
58 })?;
59 Ok(Self::from_be_bytes(data))
60 }
61 }
62 )*
63 };
64}
65
66impl_stream_ref_payload_numeric!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, f32, f64);
67
68impl StreamRefPayload for bool {
69 fn encode_stream_ref_payload(self) -> Vec<u8> {
70 vec![u8::from(self)]
71 }
72
73 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
74 match bytes.as_slice() {
75 [0] => Ok(false),
76 [1] => Ok(true),
77 _ => Err(StreamError::Failed(
78 "invalid bool stream ref payload".to_owned(),
79 )),
80 }
81 }
82}
83
84impl StreamRefPayload for String {
85 fn encode_stream_ref_payload(self) -> Vec<u8> {
86 self.into_bytes()
87 }
88
89 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
90 String::from_utf8(bytes)
91 .map_err(|error| StreamError::Failed(format!("invalid UTF-8 payload: {error}")))
92 }
93}
94
95impl StreamRefPayload for Vec<u8> {
96 fn encode_stream_ref_payload(self) -> Vec<u8> {
97 self
98 }
99
100 fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
101 Ok(bytes)
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
107pub struct StreamRefId(u128);
108
109impl StreamRefId {
110 #[must_use]
112 pub fn new() -> Self {
113 let sequence = STREAM_REF_PROTO_ID.fetch_add(1, Ordering::Relaxed) as u128;
114 let timestamp = SystemTime::now()
115 .duration_since(UNIX_EPOCH)
116 .map(|duration| duration.as_nanos())
117 .unwrap_or_default();
118 let pid = std::process::id() as u128;
119 Self(timestamp ^ (pid << 32) ^ sequence)
120 }
121
122 #[must_use]
125 pub const fn from_u128(value: u128) -> Self {
126 Self(value)
127 }
128
129 #[must_use]
130 pub const fn as_u128(self) -> u128 {
131 self.0
132 }
133
134 #[must_use]
135 pub fn to_bytes(self) -> [u8; 16] {
136 self.0.to_be_bytes()
137 }
138
139 pub fn from_bytes(bytes: &[u8]) -> StreamResult<Self> {
140 let value: [u8; 16] = bytes.try_into().map_err(|_| {
141 StreamError::Failed("stream ref id must be exactly 16 bytes".to_owned())
142 })?;
143 Ok(Self(u128::from_be_bytes(value)))
144 }
145}
146
147impl Default for StreamRefId {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153impl fmt::Display for StreamRefId {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 write!(f, "{:032x}", self.0)
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
162pub struct StreamRefPayloadBytes {
163 pub bytes: Vec<u8>,
164}
165
166#[derive(Debug, Clone, PartialEq, Eq)]
168pub enum StreamRefMessage {
169 OnSubscribeHandshake,
170 CumulativeDemand {
171 seq_nr: u64,
172 },
173 SequencedOnNext {
174 seq_nr: u64,
175 payload: StreamRefPayloadBytes,
176 },
177 RemoteStreamCompleted {
178 seq_nr: u64,
179 },
180 RemoteStreamFailure {
181 cause: Vec<u8>,
182 },
183 Ack,
184}
185
186impl StreamRefMessage {
187 #[must_use]
188 pub fn failure_text(&self) -> Option<String> {
189 match self {
190 Self::RemoteStreamFailure { cause } => {
191 Some(String::from_utf8_lossy(cause).into_owned())
192 }
193 _ => None,
194 }
195 }
196
197 fn is_ack(&self) -> bool {
198 matches!(self, Self::Ack)
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct StreamRefFrame {
205 pub stream_ref_id: StreamRefId,
206 pub message: StreamRefMessage,
207}
208
209impl StreamRefFrame {
210 #[must_use]
211 pub fn new(stream_ref_id: StreamRefId, message: StreamRefMessage) -> Self {
212 Self {
213 stream_ref_id,
214 message,
215 }
216 }
217
218 #[must_use]
219 pub fn encode_to_vec(&self) -> Vec<u8> {
220 self.to_wire().encode_to_vec()
221 }
222
223 pub fn decode(bytes: &[u8]) -> StreamResult<Self> {
224 Self::from_wire(WireStreamRefFrame::decode(bytes).map_err(|error| {
225 StreamError::Failed(format!("invalid stream ref protobuf frame: {error}"))
226 })?)
227 }
228
229 fn to_wire(&self) -> WireStreamRefFrame {
230 WireStreamRefFrame {
231 stream_ref_id: self.stream_ref_id.to_bytes().to_vec(),
232 message: Some(match &self.message {
233 StreamRefMessage::OnSubscribeHandshake => {
234 wire_stream_ref_frame::Message::OnSubscribeHandshake(
235 WireOnSubscribeHandshake {},
236 )
237 }
238 StreamRefMessage::CumulativeDemand { seq_nr } => {
239 wire_stream_ref_frame::Message::CumulativeDemand(WireCumulativeDemand {
240 seq_nr: *seq_nr,
241 })
242 }
243 StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
244 wire_stream_ref_frame::Message::SequencedOnNext(WireSequencedOnNext {
245 seq_nr: *seq_nr,
246 payload: Some(WirePayload {
247 enclosed_message: payload.bytes.clone(),
248 }),
249 })
250 }
251 StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
252 wire_stream_ref_frame::Message::RemoteStreamCompleted(
253 WireRemoteStreamCompleted { seq_nr: *seq_nr },
254 )
255 }
256 StreamRefMessage::RemoteStreamFailure { cause } => {
257 wire_stream_ref_frame::Message::RemoteStreamFailure(WireRemoteStreamFailure {
258 cause: cause.clone(),
259 })
260 }
261 StreamRefMessage::Ack => wire_stream_ref_frame::Message::Ack(WireAck {}),
262 }),
263 }
264 }
265
266 fn from_wire(wire: WireStreamRefFrame) -> StreamResult<Self> {
267 let stream_ref_id = StreamRefId::from_bytes(&wire.stream_ref_id)?;
268 let message = match wire.message.ok_or_else(|| {
269 StreamError::Failed("stream ref protobuf frame has no message".to_owned())
270 })? {
271 wire_stream_ref_frame::Message::OnSubscribeHandshake(_) => {
272 StreamRefMessage::OnSubscribeHandshake
273 }
274 wire_stream_ref_frame::Message::CumulativeDemand(message) => {
275 StreamRefMessage::CumulativeDemand {
276 seq_nr: message.seq_nr,
277 }
278 }
279 wire_stream_ref_frame::Message::SequencedOnNext(message) => {
280 let payload = message.payload.ok_or_else(|| {
281 StreamError::Failed("SequencedOnNext missing payload".to_owned())
282 })?;
283 StreamRefMessage::SequencedOnNext {
284 seq_nr: message.seq_nr,
285 payload: StreamRefPayloadBytes {
286 bytes: payload.enclosed_message,
287 },
288 }
289 }
290 wire_stream_ref_frame::Message::RemoteStreamCompleted(message) => {
291 StreamRefMessage::RemoteStreamCompleted {
292 seq_nr: message.seq_nr,
293 }
294 }
295 wire_stream_ref_frame::Message::RemoteStreamFailure(message) => {
296 StreamRefMessage::RemoteStreamFailure {
297 cause: message.cause,
298 }
299 }
300 wire_stream_ref_frame::Message::Ack(_) => StreamRefMessage::Ack,
301 };
302 Ok(Self {
303 stream_ref_id,
304 message,
305 })
306 }
307}
308
309#[derive(Clone, PartialEq, ProstMessage)]
310struct WireStreamRefFrame {
311 #[prost(bytes = "vec", tag = "1")]
312 stream_ref_id: Vec<u8>,
313 #[prost(oneof = "wire_stream_ref_frame::Message", tags = "2, 3, 4, 5, 6, 7")]
314 message: Option<wire_stream_ref_frame::Message>,
315}
316
317mod wire_stream_ref_frame {
318 #[derive(Clone, PartialEq, prost::Oneof)]
319 pub enum Message {
320 #[prost(message, tag = "2")]
321 OnSubscribeHandshake(super::WireOnSubscribeHandshake),
322 #[prost(message, tag = "3")]
323 CumulativeDemand(super::WireCumulativeDemand),
324 #[prost(message, tag = "4")]
325 SequencedOnNext(super::WireSequencedOnNext),
326 #[prost(message, tag = "5")]
327 RemoteStreamCompleted(super::WireRemoteStreamCompleted),
328 #[prost(message, tag = "6")]
329 RemoteStreamFailure(super::WireRemoteStreamFailure),
330 #[prost(message, tag = "7")]
331 Ack(super::WireAck),
332 }
333}
334
335#[derive(Clone, PartialEq, ProstMessage)]
336struct WirePayload {
337 #[prost(bytes = "vec", tag = "1")]
338 enclosed_message: Vec<u8>,
339}
340
341#[derive(Clone, PartialEq, ProstMessage)]
342struct WireOnSubscribeHandshake {}
343
344#[derive(Clone, PartialEq, ProstMessage)]
345struct WireCumulativeDemand {
346 #[prost(uint64, tag = "1")]
347 seq_nr: u64,
348}
349
350#[derive(Clone, PartialEq, ProstMessage)]
351struct WireSequencedOnNext {
352 #[prost(uint64, tag = "1")]
353 seq_nr: u64,
354 #[prost(message, optional, tag = "2")]
355 payload: Option<WirePayload>,
356}
357
358#[derive(Clone, PartialEq, ProstMessage)]
359struct WireRemoteStreamFailure {
360 #[prost(bytes = "vec", tag = "1")]
361 cause: Vec<u8>,
362}
363
364#[derive(Clone, PartialEq, ProstMessage)]
365struct WireRemoteStreamCompleted {
366 #[prost(uint64, tag = "1")]
367 seq_nr: u64,
368}
369
370#[derive(Clone, PartialEq, ProstMessage)]
371struct WireAck {}
372
373pub trait StreamRefProtoEndpoint: Clone + Send + Sync + 'static {
375 fn stream_ref_id(&self) -> StreamRefId;
376 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>>;
377 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()>;
378 fn fail_connection(&self, error: StreamError);
379}
380
381pub struct StreamRefProtoProducer<T>
388where
389 T: StreamRefPayload,
390{
391 shared: Arc<ProducerShared<T>>,
392}
393
394impl<T> Clone for StreamRefProtoProducer<T>
395where
396 T: StreamRefPayload,
397{
398 fn clone(&self) -> Self {
399 Self {
400 shared: Arc::clone(&self.shared),
401 }
402 }
403}
404
405impl<T> StreamRefProtoProducer<T>
406where
407 T: StreamRefPayload,
408{
409 pub fn from_source_ref(
410 source_ref: SourceRef<T>,
411 stream_ref_id: StreamRefId,
412 settings: StreamRefSettings,
413 ) -> StreamResult<Self> {
414 Self::from_source(
415 super::stream_ref::proto_source(&source_ref),
416 stream_ref_id,
417 settings,
418 )
419 }
420
421 pub fn from_source<Mat>(
422 source: Source<T, Mat>,
423 stream_ref_id: StreamRefId,
424 settings: StreamRefSettings,
425 ) -> StreamResult<Self>
426 where
427 Mat: Send + 'static,
428 {
429 let materializer = Materializer::new();
430 let (input, materialized) = Arc::clone(&source.factory).create(&materializer)?;
431 Ok(Self {
432 shared: Arc::new(ProducerShared {
433 stream_ref_id,
434 settings,
435 input: Mutex::new(Some(input)),
436 state: Mutex::new(ProducerState {
437 partner_seen: false,
438 cumulative_demand: 0,
439 sent: 0,
440 terminal_sent: false,
441 waiting_for_ack: false,
442 ack_deadline: None,
443 stopped: None,
444 ack_queued: false,
445 done: false,
446 input_attached: true,
447 terminal_result: None,
448 }),
449 changed: Condvar::new(),
450 completion: Mutex::new(None),
451 _materializer: materializer,
452 _materialized: Mutex::new(Some(Box::new(materialized))),
453 }),
454 })
455 }
456
457 #[must_use]
465 pub fn new_lazy(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
466 Self {
467 shared: Arc::new(ProducerShared {
468 stream_ref_id,
469 settings,
470 input: Mutex::new(None),
471 state: Mutex::new(ProducerState {
472 partner_seen: false,
473 cumulative_demand: 0,
474 sent: 0,
475 terminal_sent: false,
476 waiting_for_ack: false,
477 ack_deadline: None,
478 stopped: None,
479 ack_queued: false,
480 done: false,
481 input_attached: false,
482 terminal_result: None,
483 }),
484 changed: Condvar::new(),
485 completion: Mutex::new(None),
486 _materializer: Materializer::new(),
487 _materialized: Mutex::new(None),
488 }),
489 }
490 }
491
492 #[must_use]
500 pub fn sink(&self) -> Sink<T, StreamCompletion<NotUsed>> {
501 let shared = Arc::clone(&self.shared);
502 Sink::from_runner(move |input, _materializer| {
503 let (sender, receiver) = oneshot::channel();
504 *shared
505 .completion
506 .lock()
507 .unwrap_or_else(|poison| poison.into_inner()) = Some(sender);
508 shared.attach_input(input);
509 Ok(StreamCompletion::from_receiver(receiver, None))
510 })
511 }
512}
513
514impl<T> StreamRefProtoEndpoint for StreamRefProtoProducer<T>
515where
516 T: StreamRefPayload,
517{
518 fn stream_ref_id(&self) -> StreamRefId {
519 self.shared.stream_ref_id
520 }
521
522 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
523 self.shared.next_frame()
524 }
525
526 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
527 self.shared.handle_frame(frame)
528 }
529
530 fn fail_connection(&self, error: StreamError) {
531 self.shared.fail_connection(error);
532 }
533}
534
535struct ProducerShared<T>
536where
537 T: StreamRefPayload,
538{
539 stream_ref_id: StreamRefId,
540 settings: StreamRefSettings,
541 input: Mutex<Option<BoxStream<T>>>,
542 state: Mutex<ProducerState>,
543 changed: Condvar,
544 completion: Mutex<Option<oneshot::Sender<StreamResult<NotUsed>>>>,
545 _materializer: Materializer,
546 _materialized: Mutex<Option<Box<dyn Any + Send>>>,
547}
548
549struct ProducerState {
550 partner_seen: bool,
551 cumulative_demand: u64,
552 sent: u64,
553 terminal_sent: bool,
554 waiting_for_ack: bool,
555 ack_deadline: Option<Instant>,
556 stopped: Option<StreamError>,
557 ack_queued: bool,
558 done: bool,
559 input_attached: bool,
560 terminal_result: Option<StreamResult<NotUsed>>,
561}
562
563impl<T> ProducerShared<T>
564where
565 T: StreamRefPayload,
566{
567 fn lock_state(&self) -> MutexGuard<'_, ProducerState> {
568 self.state
569 .lock()
570 .unwrap_or_else(|poison| poison.into_inner())
571 }
572
573 fn lock_input(&self) -> MutexGuard<'_, Option<BoxStream<T>>> {
574 self.input
575 .lock()
576 .unwrap_or_else(|poison| poison.into_inner())
577 }
578
579 fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
580 StreamRefFrame::new(self.stream_ref_id, message)
581 }
582
583 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
584 let subscription_deadline = deadline_from_now(self.settings.subscription_timeout());
585 loop {
586 let mut state = self.lock_state();
587 if state.done {
588 return None;
589 }
590
591 if state.ack_queued {
592 state.ack_queued = false;
593 state.done = true;
594 state.terminal_result = Some(match state.stopped.clone() {
595 Some(error) => Err(error),
596 None => Ok(NotUsed),
597 });
598 self.changed.notify_all();
599 drop(state);
600 self.drop_input();
601 self.settle();
602 return Some(Ok(self.frame(StreamRefMessage::Ack)));
603 }
604
605 if state.waiting_for_ack {
606 if state
607 .ack_deadline
608 .is_some_and(|deadline| Instant::now() >= deadline)
609 {
610 let timeout_error =
611 subscription_timeout_error("stream ref producer terminal ack");
612 state.done = true;
613 state.terminal_result = Some(Err(timeout_error.clone()));
614 self.changed.notify_all();
615 drop(state);
616 self.drop_input();
617 self.settle();
618 return Some(Err(timeout_error));
619 }
620 if let Some(remaining) = state
621 .ack_deadline
622 .and_then(|deadline| deadline.checked_duration_since(Instant::now()))
623 {
624 let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
625 drop(next);
626 } else {
627 drop(state);
628 }
629 continue;
630 }
631
632 if let Some(error) = state.stopped.clone() {
633 state.done = true;
634 state.terminal_result = Some(Err(error.clone()));
635 self.changed.notify_all();
636 drop(state);
637 self.drop_input();
638 self.settle();
639 return Some(Err(error));
640 }
641
642 if state.cumulative_demand > 0 && state.sent < state.cumulative_demand {
643 drop(state);
644 if let Some(frame) = self.pull_next_frame() {
645 return Some(frame);
646 }
647 continue;
648 }
649
650 if state.cumulative_demand == 0 && Instant::now() >= subscription_deadline {
651 let timeout_error = subscription_timeout_error("stream ref producer first demand");
652 state.done = true;
653 state.terminal_result = Some(Err(timeout_error.clone()));
654 self.changed.notify_all();
655 drop(state);
656 self.drop_input();
657 self.settle();
658 return Some(Err(timeout_error));
659 }
660
661 if state.cumulative_demand == 0 {
662 let remaining = subscription_deadline.saturating_duration_since(Instant::now());
663 if remaining.is_zero() {
664 drop(state);
665 continue;
666 }
667 let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
668 drop(next);
669 } else {
670 let next = wait_unpoison(&self.changed, state);
671 drop(next);
672 }
673 }
674 }
675
676 fn pull_next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
677 let item = {
678 let mut input_guard = self.lock_input();
679 if input_guard.is_none() {
680 drop(input_guard);
681 let mut state = self.lock_state();
688 while !state.input_attached
689 && !state.done
690 && state.stopped.is_none()
691 && !state.terminal_sent
692 {
693 state = wait_unpoison(&self.changed, state);
694 }
695 drop(state);
696 return None;
697 }
698 input_guard.as_mut().expect("input attached").next()
699 };
700
701 match item {
702 Some(Ok(item)) => {
703 let mut state = self.lock_state();
704 if state.done || state.stopped.is_some() || state.waiting_for_ack {
705 return None;
706 }
707 let seq_nr = state.sent;
708 state.sent = state.sent.saturating_add(1);
709 Some(Ok(self.frame(StreamRefMessage::SequencedOnNext {
710 seq_nr,
711 payload: StreamRefPayloadBytes {
712 bytes: item.encode_stream_ref_payload(),
713 },
714 })))
715 }
716 Some(Err(error)) => {
717 self.drop_input();
718 let mut state = self.lock_state();
719 if state.done || state.terminal_sent {
720 return None;
721 }
722 state.terminal_sent = true;
723 state.waiting_for_ack = true;
724 state.terminal_result = Some(Err(error.clone()));
725 state.ack_deadline = Some(deadline_from_now(self.settings.subscription_timeout()));
726 self.changed.notify_all();
727 drop(state);
728 Some(Ok(self.frame(StreamRefMessage::RemoteStreamFailure {
729 cause: failure_cause(&error),
730 })))
731 }
732 None => {
733 self.drop_input();
734 let mut state = self.lock_state();
735 if state.done || state.terminal_sent {
736 return None;
737 }
738 let seq_nr = state.sent;
739 state.terminal_sent = true;
740 state.waiting_for_ack = true;
741 state.terminal_result = Some(Ok(NotUsed));
742 state.ack_deadline = Some(deadline_from_now(self.settings.subscription_timeout()));
743 self.changed.notify_all();
744 drop(state);
745 Some(Ok(
746 self.frame(StreamRefMessage::RemoteStreamCompleted { seq_nr })
747 ))
748 }
749 }
750 }
751
752 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
753 self.validate_frame_id(frame.stream_ref_id)?;
754 match frame.message {
755 StreamRefMessage::OnSubscribeHandshake => {
756 let mut state = self.lock_state();
757 state.partner_seen = true;
758 self.changed.notify_all();
759 drop(state);
760 Ok(())
761 }
762 StreamRefMessage::CumulativeDemand { seq_nr } => {
763 if seq_nr == 0 {
764 return Err(StreamError::Failed(
765 "CumulativeDemand seq_nr must be positive".to_owned(),
766 ));
767 }
768 let mut state = self.lock_state();
769 state.partner_seen = true;
770 if seq_nr > state.cumulative_demand {
771 state.cumulative_demand = seq_nr;
772 }
773 self.changed.notify_all();
774 drop(state);
775 Ok(())
776 }
777 StreamRefMessage::RemoteStreamCompleted { .. } => {
778 self.stop_from_consumer(StreamError::Cancelled);
779 Ok(())
780 }
781 StreamRefMessage::RemoteStreamFailure { cause } => {
782 self.stop_from_consumer(StreamError::Failed(
783 String::from_utf8_lossy(&cause).into_owned(),
784 ));
785 Ok(())
786 }
787 StreamRefMessage::Ack => {
788 let mut state = self.lock_state();
789 if state.waiting_for_ack {
790 state.waiting_for_ack = false;
791 state.done = true;
792 if state.terminal_result.is_none() {
793 state.terminal_result = Some(Ok(NotUsed));
794 }
795 self.changed.notify_all();
796 drop(state);
797 self.drop_input();
798 self.settle();
799 } else {
800 drop(state);
801 }
802 Ok(())
803 }
804 StreamRefMessage::SequencedOnNext { .. } => Err(StreamError::Failed(
805 "producer endpoint cannot receive SequencedOnNext".to_owned(),
806 )),
807 }
808 }
809
810 fn stop_from_consumer(&self, error: StreamError) {
811 let mut state = self.lock_state();
812 if !state.done {
813 state.stopped = Some(error.clone());
814 state.ack_queued = true;
815 state.terminal_result = Some(Err(error));
816 }
817 self.changed.notify_all();
818 drop(state);
819 self.drop_input();
820 }
821
822 fn fail_connection(&self, error: StreamError) {
823 let mut state = self.lock_state();
824 if !state.done {
825 state.stopped = Some(error.clone());
826 state.done = true;
827 state.terminal_result = Some(Err(error));
828 }
829 self.changed.notify_all();
830 drop(state);
831 self.drop_input();
832 self.settle();
833 }
834
835 fn attach_input(&self, input: BoxStream<T>) {
836 *self.lock_input() = Some(input);
837 let mut state = self.lock_state();
838 state.input_attached = true;
839 self.changed.notify_all();
840 drop(state);
841 }
842
843 fn settle(&self) {
844 let result = self.lock_state().terminal_result.clone();
845 let sender = self
846 .completion
847 .lock()
848 .unwrap_or_else(|poison| poison.into_inner())
849 .take();
850 if let (Some(sender), Some(result)) = (sender, result) {
851 let _ = sender.send(result);
852 }
853 }
854
855 fn drop_input(&self) {
856 let input = self.lock_input().take();
857 drop(input);
858 }
859
860 fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
861 if stream_ref_id == self.stream_ref_id {
862 Ok(())
863 } else {
864 Err(StreamError::Failed(format!(
865 "stream ref id mismatch: expected {}, got {}",
866 self.stream_ref_id, stream_ref_id
867 )))
868 }
869 }
870}
871
872pub struct StreamRefProtoConsumer<T>
878where
879 T: StreamRefPayload,
880{
881 shared: Arc<ConsumerShared<T>>,
882}
883
884impl<T> Clone for StreamRefProtoConsumer<T>
885where
886 T: StreamRefPayload,
887{
888 fn clone(&self) -> Self {
889 Self {
890 shared: Arc::clone(&self.shared),
891 }
892 }
893}
894
895impl<T> StreamRefProtoConsumer<T>
896where
897 T: StreamRefPayload,
898{
899 #[must_use]
900 pub fn new(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
901 Self {
902 shared: Arc::new(ConsumerShared {
903 stream_ref_id,
904 settings,
905 state: Mutex::new(ConsumerState {
906 source_taken: false,
907 subscribed: false,
908 queue: VecDeque::new(),
909 terminal: None,
910 expected_seq: 0,
911 delivered: 0,
912 cumulative_demand: 0,
913 outbound: VecDeque::new(),
914 finish_after_outbound_ack: false,
915 waiting_cancel_ack: false,
916 done: false,
917 }),
918 changed: Condvar::new(),
919 }),
920 }
921 }
922
923 #[must_use]
924 pub fn source(&self) -> Source<T, NotUsed> {
925 let shared = Arc::clone(&self.shared);
926 Source::unfold_resource(
927 move || shared.start_stream(),
928 |stream| stream.next_item(),
929 |mut stream| {
930 stream.close();
931 Ok(())
932 },
933 )
934 }
935}
936
937impl<T> StreamRefProtoEndpoint for StreamRefProtoConsumer<T>
938where
939 T: StreamRefPayload,
940{
941 fn stream_ref_id(&self) -> StreamRefId {
942 self.shared.stream_ref_id
943 }
944
945 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
946 self.shared.next_frame()
947 }
948
949 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
950 self.shared.handle_frame(frame)
951 }
952
953 fn fail_connection(&self, error: StreamError) {
954 self.shared.fail_connection(error);
955 }
956}
957
958struct ConsumerShared<T>
959where
960 T: StreamRefPayload,
961{
962 stream_ref_id: StreamRefId,
963 settings: StreamRefSettings,
964 state: Mutex<ConsumerState<T>>,
965 changed: Condvar,
966}
967
968struct ConsumerState<T> {
969 source_taken: bool,
970 subscribed: bool,
971 queue: VecDeque<T>,
972 terminal: Option<ConsumerTerminal>,
973 expected_seq: u64,
974 delivered: u64,
975 cumulative_demand: u64,
976 outbound: VecDeque<StreamRefMessage>,
977 finish_after_outbound_ack: bool,
978 waiting_cancel_ack: bool,
979 done: bool,
980}
981
982#[derive(Clone)]
983enum ConsumerTerminal {
984 Complete,
985 Error(StreamError),
986}
987
988impl<T> ConsumerShared<T>
989where
990 T: StreamRefPayload,
991{
992 fn lock_state(&self) -> MutexGuard<'_, ConsumerState<T>> {
993 self.state
994 .lock()
995 .unwrap_or_else(|poison| poison.into_inner())
996 }
997
998 fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
999 StreamRefFrame::new(self.stream_ref_id, message)
1000 }
1001
1002 fn start_stream(self: &Arc<Self>) -> StreamResult<ConsumerStream<T>> {
1003 {
1004 let mut state = self.lock_state();
1005 if state.source_taken {
1006 return Err(StreamError::Failed(
1007 "stream ref source has already been materialized".to_owned(),
1008 ));
1009 }
1010 state.source_taken = true;
1011 if !state.subscribed {
1012 state.subscribed = true;
1013 state
1014 .outbound
1015 .push_back(StreamRefMessage::OnSubscribeHandshake);
1016 if let Some(demand) = next_demand(&mut state, self.settings) {
1017 state
1018 .outbound
1019 .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
1020 }
1021 }
1022 self.changed.notify_all();
1023 }
1024 Ok(ConsumerStream {
1025 shared: Arc::clone(self),
1026 terminated: false,
1027 })
1028 }
1029
1030 fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
1031 loop {
1032 let mut state = self.lock_state();
1033 if let Some(message) = state.outbound.pop_front() {
1034 let finish_after_ack = message.is_ack() && state.finish_after_outbound_ack;
1035 if finish_after_ack {
1036 state.done = true;
1037 }
1038 drop(state);
1039 return Some(Ok(self.frame(message)));
1040 }
1041 if state.done {
1042 return None;
1043 }
1044 let next = wait_unpoison(&self.changed, state);
1045 drop(next);
1046 }
1047 }
1048
1049 fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
1050 self.validate_frame_id(frame.stream_ref_id)?;
1051 match frame.message {
1052 StreamRefMessage::OnSubscribeHandshake => Ok(()),
1053 StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
1054 let item = T::decode_stream_ref_payload(payload.bytes)?;
1055 let mut state = self.lock_state();
1056 if state.terminal.is_some() || state.done {
1057 return Ok(());
1058 }
1059 if seq_nr != state.expected_seq {
1060 let error =
1061 invalid_sequence_error(state.expected_seq, seq_nr, "stream ref element");
1062 state.queue.clear();
1063 state.terminal = Some(ConsumerTerminal::Error(error.clone()));
1064 state
1065 .outbound
1066 .push_back(StreamRefMessage::RemoteStreamFailure {
1067 cause: failure_cause(&error),
1068 });
1069 state.waiting_cancel_ack = true;
1070 } else if state.queue.len() >= self.settings.buffer_capacity() {
1071 let error = StreamError::Failed(
1072 "stream ref receive buffer overflowed demand window".to_owned(),
1073 );
1074 state.queue.clear();
1075 state.terminal = Some(ConsumerTerminal::Error(error.clone()));
1076 state
1077 .outbound
1078 .push_back(StreamRefMessage::RemoteStreamFailure {
1079 cause: failure_cause(&error),
1080 });
1081 state.waiting_cancel_ack = true;
1082 } else {
1083 state.expected_seq = state.expected_seq.saturating_add(1);
1084 state.queue.push_back(item);
1085 }
1086 self.changed.notify_all();
1087 drop(state);
1088 Ok(())
1089 }
1090 StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
1091 let mut state = self.lock_state();
1092 if state.terminal.is_none() && !state.done {
1093 if seq_nr != state.expected_seq {
1094 state.queue.clear();
1095 state.terminal = Some(ConsumerTerminal::Error(invalid_sequence_error(
1096 state.expected_seq,
1097 seq_nr,
1098 "stream ref completion",
1099 )));
1100 } else {
1101 state.terminal = Some(ConsumerTerminal::Complete);
1102 }
1103 state.outbound.push_back(StreamRefMessage::Ack);
1104 state.finish_after_outbound_ack = true;
1105 }
1106 self.changed.notify_all();
1107 drop(state);
1108 Ok(())
1109 }
1110 StreamRefMessage::RemoteStreamFailure { cause } => {
1111 let mut state = self.lock_state();
1112 if state.terminal.is_none() && !state.done {
1113 state.queue.clear();
1114 state.terminal = Some(ConsumerTerminal::Error(StreamError::Failed(
1115 String::from_utf8_lossy(&cause).into_owned(),
1116 )));
1117 state.outbound.push_back(StreamRefMessage::Ack);
1118 state.finish_after_outbound_ack = true;
1119 }
1120 self.changed.notify_all();
1121 drop(state);
1122 Ok(())
1123 }
1124 StreamRefMessage::Ack => {
1125 let mut state = self.lock_state();
1126 if state.waiting_cancel_ack {
1127 state.waiting_cancel_ack = false;
1128 state.done = true;
1129 }
1130 self.changed.notify_all();
1131 drop(state);
1132 Ok(())
1133 }
1134 StreamRefMessage::CumulativeDemand { .. } => Err(StreamError::Failed(
1135 "consumer endpoint cannot receive CumulativeDemand".to_owned(),
1136 )),
1137 }
1138 }
1139
1140 fn cancel_from_downstream(&self) {
1141 let mut state = self.lock_state();
1142 if state.terminal.is_none() && !state.done {
1143 let seq_nr = state.expected_seq;
1144 state.terminal = Some(ConsumerTerminal::Error(StreamError::Cancelled));
1145 state
1146 .outbound
1147 .push_back(StreamRefMessage::RemoteStreamCompleted { seq_nr });
1148 state.waiting_cancel_ack = true;
1149 }
1150 self.changed.notify_all();
1151 drop(state);
1152 }
1153
1154 fn fail_connection(&self, error: StreamError) {
1155 let mut state = self.lock_state();
1156 if state.terminal.is_none() {
1157 state.queue.clear();
1158 state.terminal = Some(ConsumerTerminal::Error(error));
1159 }
1160 state.done = true;
1161 self.changed.notify_all();
1162 drop(state);
1163 }
1164
1165 fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
1166 if stream_ref_id == self.stream_ref_id {
1167 Ok(())
1168 } else {
1169 Err(StreamError::Failed(format!(
1170 "stream ref id mismatch: expected {}, got {}",
1171 self.stream_ref_id, stream_ref_id
1172 )))
1173 }
1174 }
1175}
1176
1177struct ConsumerStream<T>
1178where
1179 T: StreamRefPayload,
1180{
1181 shared: Arc<ConsumerShared<T>>,
1182 terminated: bool,
1183}
1184
1185impl<T> ConsumerStream<T>
1186where
1187 T: StreamRefPayload,
1188{
1189 fn next_item(&mut self) -> StreamResult<Option<T>> {
1190 if self.terminated {
1191 return Ok(None);
1192 }
1193 loop {
1194 let mut state = self.shared.lock_state();
1195 if let Some(item) = state.queue.pop_front() {
1196 state.delivered = state.delivered.saturating_add(1);
1197 if let Some(demand) = next_demand(&mut state, self.shared.settings) {
1198 state
1199 .outbound
1200 .push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
1201 self.shared.changed.notify_all();
1202 }
1203 return Ok(Some(item));
1204 }
1205
1206 if let Some(terminal) = state.terminal.clone() {
1207 self.terminated = true;
1208 return match terminal {
1209 ConsumerTerminal::Complete => Ok(None),
1210 ConsumerTerminal::Error(error) => Err(error),
1211 };
1212 }
1213
1214 let next = wait_unpoison(&self.shared.changed, state);
1215 drop(next);
1216 }
1217 }
1218
1219 fn close(&mut self) {
1220 if !self.terminated {
1221 self.shared.cancel_from_downstream();
1222 self.terminated = true;
1223 }
1224 }
1225}
1226
1227fn next_demand<T>(state: &mut ConsumerState<T>, settings: StreamRefSettings) -> Option<u64> {
1228 if state.terminal.is_some() {
1232 return None;
1233 }
1234 let remaining_credit = state.cumulative_demand.saturating_sub(state.delivered);
1235 if state.cumulative_demand != 0 && remaining_credit > demand_replenish_threshold(settings) {
1236 return None;
1237 }
1238 let target = state
1239 .delivered
1240 .saturating_add(settings.buffer_capacity() as u64);
1241 if state.cumulative_demand >= target {
1242 return None;
1243 }
1244 state.cumulative_demand = target;
1245 Some(target)
1246}
1247
1248fn demand_replenish_threshold(settings: StreamRefSettings) -> u64 {
1249 (settings.buffer_capacity() as u64) / 2
1250}
1251
1252fn failure_cause(error: &StreamError) -> Vec<u8> {
1253 match error {
1254 StreamError::Failed(message) => message.clone().into_bytes(),
1255 other => other.to_string().into_bytes(),
1256 }
1257}
1258
1259fn subscription_timeout_error(side: &str) -> StreamError {
1260 StreamError::Failed(format!(
1261 "{side} remote side did not subscribe within subscription timeout"
1262 ))
1263}
1264
1265fn invalid_sequence_error(expected: u64, got: u64, context: &str) -> StreamError {
1266 StreamError::Failed(format!(
1267 "{context} sequence gap: expected sequence {expected}, got {got}"
1268 ))
1269}
1270
1271fn deadline_from_now(timeout: Duration) -> Instant {
1272 Instant::now()
1273 .checked_add(timeout)
1274 .unwrap_or_else(far_future)
1275}
1276
1277fn far_future() -> Instant {
1278 Instant::now() + Duration::from_secs(60 * 60 * 24 * 365)
1279}
1280
1281fn wait_timeout_unpoison<'a, T>(
1282 condvar: &Condvar,
1283 guard: MutexGuard<'a, T>,
1284 timeout: Duration,
1285) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
1286 condvar
1287 .wait_timeout(guard, timeout)
1288 .unwrap_or_else(|poison| poison.into_inner())
1289}
1290
1291fn wait_unpoison<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
1292 condvar
1293 .wait(guard)
1294 .unwrap_or_else(|poison| poison.into_inner())
1295}
1296
1297#[cfg(test)]
1298mod tests {
1299 use std::time::Duration;
1300
1301 use super::*;
1302 use crate::{Source, StreamRefs};
1303
1304 fn short_settings() -> StreamRefSettings {
1305 StreamRefSettings::default()
1306 .with_buffer_capacity(4)
1307 .with_subscription_timeout(Duration::from_millis(50))
1308 }
1309
1310 #[test]
1311 fn protobuf_frame_round_trip() {
1312 let frame = StreamRefFrame::new(
1313 StreamRefId::from_u128(42),
1314 StreamRefMessage::SequencedOnNext {
1315 seq_nr: 7,
1316 payload: StreamRefPayloadBytes {
1317 bytes: 99_u64.encode_stream_ref_payload(),
1318 },
1319 },
1320 );
1321
1322 let decoded = StreamRefFrame::decode(&frame.encode_to_vec()).unwrap();
1323 assert_eq!(decoded, frame);
1324 }
1325
1326 #[test]
1327 fn producer_consumer_seam_streams_with_low_watermark_demand() {
1328 let id = StreamRefId::from_u128(1);
1329 let settings = short_settings();
1330 let source_ref = Source::from_iter(0_u64..10)
1331 .run_with(StreamRefs::source_ref_with_settings(settings))
1332 .unwrap();
1333 let producer = StreamRefProtoProducer::from_source_ref(source_ref, id, settings).unwrap();
1334 let consumer = StreamRefProtoConsumer::<u64>::new(id, settings);
1335 let consumer_source = consumer.source();
1336
1337 let producer_thread = std::thread::spawn({
1338 let producer = producer.clone();
1339 let consumer = consumer.clone();
1340 move || {
1341 while let Some(frame) = producer.next_frame() {
1342 consumer.handle_frame(frame?)?;
1343 }
1344 Ok::<_, StreamError>(())
1345 }
1346 });
1347 let consumer_thread = std::thread::spawn({
1348 let producer = producer.clone();
1349 let consumer = consumer.clone();
1350 move || {
1351 while let Some(frame) = consumer.next_frame() {
1352 producer.handle_frame(frame?)?;
1353 }
1354 Ok::<_, StreamError>(())
1355 }
1356 });
1357
1358 assert_eq!(
1359 consumer_source.run_collect().unwrap(),
1360 (0_u64..10).collect::<Vec<_>>()
1361 );
1362 producer_thread.join().unwrap().unwrap();
1363 consumer_thread.join().unwrap().unwrap();
1364 }
1365
1366 #[test]
1367 fn strict_sequence_gap_fails_consumer_and_sends_failure() {
1368 let id = StreamRefId::from_u128(2);
1369 let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
1370 let source = consumer
1371 .source()
1372 .run_with(crate::testkit::TestSink::probe())
1373 .unwrap();
1374 source.request(1);
1375 consumer.next_frame().unwrap().unwrap();
1376 consumer.next_frame().unwrap().unwrap();
1377
1378 consumer
1379 .handle_frame(StreamRefFrame::new(
1380 id,
1381 StreamRefMessage::SequencedOnNext {
1382 seq_nr: 1,
1383 payload: StreamRefPayloadBytes {
1384 bytes: 1_u64.encode_stream_ref_payload(),
1385 },
1386 },
1387 ))
1388 .unwrap();
1389
1390 let outbound = consumer.next_frame().unwrap().unwrap();
1391 assert!(matches!(
1392 outbound.message,
1393 StreamRefMessage::RemoteStreamFailure { .. }
1394 ));
1395 assert!(matches!(source.expect_error(), StreamError::Failed(_)));
1396 }
1397
1398 #[test]
1399 fn producer_times_out_without_first_demand() {
1400 let producer = StreamRefProtoProducer::from_source(
1401 Source::repeat(1_u64),
1402 StreamRefId::from_u128(3),
1403 short_settings(),
1404 )
1405 .unwrap();
1406
1407 let error = producer.next_frame().unwrap().unwrap_err();
1408 assert!(matches!(error, StreamError::Failed(message) if message.contains("first demand")));
1409 }
1410
1411 #[test]
1412 fn demand_redelivery_is_not_required_by_reliable_carriers() {
1413 assert_eq!(
1418 StreamRefSettings::default().demand_redelivery_interval(),
1419 Duration::from_secs(1)
1420 );
1421 }
1422}