use std::{
any::Any,
collections::VecDeque,
fmt,
sync::{
Arc, Condvar, Mutex, MutexGuard,
atomic::{AtomicU64, Ordering},
},
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
};
use crate::stream::{
BoxStream, Materializer, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult,
};
use futures::channel::oneshot;
use prost::Message as ProstMessage;
use super::{SourceRef, StreamRefSettings};
static STREAM_REF_PROTO_ID: AtomicU64 = AtomicU64::new(1);
pub trait StreamRefPayload: Send + 'static {
fn encode_stream_ref_payload(self) -> Vec<u8>;
fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self>
where
Self: Sized;
}
macro_rules! impl_stream_ref_payload_numeric {
($($ty:ty),* $(,)?) => {
$(
impl StreamRefPayload for $ty {
fn encode_stream_ref_payload(self) -> Vec<u8> {
self.to_be_bytes().to_vec()
}
fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
let data: [u8; std::mem::size_of::<Self>()] =
bytes.as_slice().try_into().map_err(|_| {
StreamError::Failed(format!(
"invalid {} stream ref payload length: {}",
stringify!($ty),
bytes.len()
))
})?;
Ok(Self::from_be_bytes(data))
}
}
)*
};
}
impl_stream_ref_payload_numeric!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, f32, f64);
impl StreamRefPayload for bool {
fn encode_stream_ref_payload(self) -> Vec<u8> {
vec![u8::from(self)]
}
fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
match bytes.as_slice() {
[0] => Ok(false),
[1] => Ok(true),
_ => Err(StreamError::Failed(
"invalid bool stream ref payload".to_owned(),
)),
}
}
}
impl StreamRefPayload for String {
fn encode_stream_ref_payload(self) -> Vec<u8> {
self.into_bytes()
}
fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
String::from_utf8(bytes)
.map_err(|error| StreamError::Failed(format!("invalid UTF-8 payload: {error}")))
}
}
impl StreamRefPayload for Vec<u8> {
fn encode_stream_ref_payload(self) -> Vec<u8> {
self
}
fn decode_stream_ref_payload(bytes: Vec<u8>) -> StreamResult<Self> {
Ok(bytes)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StreamRefId(u128);
impl StreamRefId {
#[must_use]
pub fn new() -> Self {
let sequence = STREAM_REF_PROTO_ID.fetch_add(1, Ordering::Relaxed) as u128;
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_nanos())
.unwrap_or_default();
let pid = std::process::id() as u128;
Self(timestamp ^ (pid << 32) ^ sequence)
}
#[must_use]
pub const fn from_u128(value: u128) -> Self {
Self(value)
}
#[must_use]
pub const fn as_u128(self) -> u128 {
self.0
}
#[must_use]
pub fn to_bytes(self) -> [u8; 16] {
self.0.to_be_bytes()
}
pub fn from_bytes(bytes: &[u8]) -> StreamResult<Self> {
let value: [u8; 16] = bytes.try_into().map_err(|_| {
StreamError::Failed("stream ref id must be exactly 16 bytes".to_owned())
})?;
Ok(Self(u128::from_be_bytes(value)))
}
}
impl Default for StreamRefId {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for StreamRefId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:032x}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StreamRefPayloadBytes {
pub bytes: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamRefMessage {
OnSubscribeHandshake,
CumulativeDemand {
seq_nr: u64,
},
SequencedOnNext {
seq_nr: u64,
payload: StreamRefPayloadBytes,
},
RemoteStreamCompleted {
seq_nr: u64,
},
RemoteStreamFailure {
cause: Vec<u8>,
},
Ack,
}
impl StreamRefMessage {
#[must_use]
pub fn failure_text(&self) -> Option<String> {
match self {
Self::RemoteStreamFailure { cause } => {
Some(String::from_utf8_lossy(cause).into_owned())
}
_ => None,
}
}
fn is_ack(&self) -> bool {
matches!(self, Self::Ack)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StreamRefFrame {
pub stream_ref_id: StreamRefId,
pub message: StreamRefMessage,
}
impl StreamRefFrame {
#[must_use]
pub fn new(stream_ref_id: StreamRefId, message: StreamRefMessage) -> Self {
Self {
stream_ref_id,
message,
}
}
#[must_use]
pub fn encode_to_vec(&self) -> Vec<u8> {
self.to_wire().encode_to_vec()
}
pub fn decode(bytes: &[u8]) -> StreamResult<Self> {
Self::from_wire(WireStreamRefFrame::decode(bytes).map_err(|error| {
StreamError::Failed(format!("invalid stream ref protobuf frame: {error}"))
})?)
}
fn to_wire(&self) -> WireStreamRefFrame {
WireStreamRefFrame {
stream_ref_id: self.stream_ref_id.to_bytes().to_vec(),
message: Some(match &self.message {
StreamRefMessage::OnSubscribeHandshake => {
wire_stream_ref_frame::Message::OnSubscribeHandshake(
WireOnSubscribeHandshake {},
)
}
StreamRefMessage::CumulativeDemand { seq_nr } => {
wire_stream_ref_frame::Message::CumulativeDemand(WireCumulativeDemand {
seq_nr: *seq_nr,
})
}
StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
wire_stream_ref_frame::Message::SequencedOnNext(WireSequencedOnNext {
seq_nr: *seq_nr,
payload: Some(WirePayload {
enclosed_message: payload.bytes.clone(),
}),
})
}
StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
wire_stream_ref_frame::Message::RemoteStreamCompleted(
WireRemoteStreamCompleted { seq_nr: *seq_nr },
)
}
StreamRefMessage::RemoteStreamFailure { cause } => {
wire_stream_ref_frame::Message::RemoteStreamFailure(WireRemoteStreamFailure {
cause: cause.clone(),
})
}
StreamRefMessage::Ack => wire_stream_ref_frame::Message::Ack(WireAck {}),
}),
}
}
fn from_wire(wire: WireStreamRefFrame) -> StreamResult<Self> {
let stream_ref_id = StreamRefId::from_bytes(&wire.stream_ref_id)?;
let message = match wire.message.ok_or_else(|| {
StreamError::Failed("stream ref protobuf frame has no message".to_owned())
})? {
wire_stream_ref_frame::Message::OnSubscribeHandshake(_) => {
StreamRefMessage::OnSubscribeHandshake
}
wire_stream_ref_frame::Message::CumulativeDemand(message) => {
StreamRefMessage::CumulativeDemand {
seq_nr: message.seq_nr,
}
}
wire_stream_ref_frame::Message::SequencedOnNext(message) => {
let payload = message.payload.ok_or_else(|| {
StreamError::Failed("SequencedOnNext missing payload".to_owned())
})?;
StreamRefMessage::SequencedOnNext {
seq_nr: message.seq_nr,
payload: StreamRefPayloadBytes {
bytes: payload.enclosed_message,
},
}
}
wire_stream_ref_frame::Message::RemoteStreamCompleted(message) => {
StreamRefMessage::RemoteStreamCompleted {
seq_nr: message.seq_nr,
}
}
wire_stream_ref_frame::Message::RemoteStreamFailure(message) => {
StreamRefMessage::RemoteStreamFailure {
cause: message.cause,
}
}
wire_stream_ref_frame::Message::Ack(_) => StreamRefMessage::Ack,
};
Ok(Self {
stream_ref_id,
message,
})
}
}
#[derive(Clone, PartialEq, ProstMessage)]
struct WireStreamRefFrame {
#[prost(bytes = "vec", tag = "1")]
stream_ref_id: Vec<u8>,
#[prost(oneof = "wire_stream_ref_frame::Message", tags = "2, 3, 4, 5, 6, 7")]
message: Option<wire_stream_ref_frame::Message>,
}
mod wire_stream_ref_frame {
#[derive(Clone, PartialEq, prost::Oneof)]
pub enum Message {
#[prost(message, tag = "2")]
OnSubscribeHandshake(super::WireOnSubscribeHandshake),
#[prost(message, tag = "3")]
CumulativeDemand(super::WireCumulativeDemand),
#[prost(message, tag = "4")]
SequencedOnNext(super::WireSequencedOnNext),
#[prost(message, tag = "5")]
RemoteStreamCompleted(super::WireRemoteStreamCompleted),
#[prost(message, tag = "6")]
RemoteStreamFailure(super::WireRemoteStreamFailure),
#[prost(message, tag = "7")]
Ack(super::WireAck),
}
}
#[derive(Clone, PartialEq, ProstMessage)]
struct WirePayload {
#[prost(bytes = "vec", tag = "1")]
enclosed_message: Vec<u8>,
}
#[derive(Clone, PartialEq, ProstMessage)]
struct WireOnSubscribeHandshake {}
#[derive(Clone, PartialEq, ProstMessage)]
struct WireCumulativeDemand {
#[prost(uint64, tag = "1")]
seq_nr: u64,
}
#[derive(Clone, PartialEq, ProstMessage)]
struct WireSequencedOnNext {
#[prost(uint64, tag = "1")]
seq_nr: u64,
#[prost(message, optional, tag = "2")]
payload: Option<WirePayload>,
}
#[derive(Clone, PartialEq, ProstMessage)]
struct WireRemoteStreamFailure {
#[prost(bytes = "vec", tag = "1")]
cause: Vec<u8>,
}
#[derive(Clone, PartialEq, ProstMessage)]
struct WireRemoteStreamCompleted {
#[prost(uint64, tag = "1")]
seq_nr: u64,
}
#[derive(Clone, PartialEq, ProstMessage)]
struct WireAck {}
pub trait StreamRefProtoEndpoint: Clone + Send + Sync + 'static {
fn stream_ref_id(&self) -> StreamRefId;
fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>>;
fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()>;
fn fail_connection(&self, error: StreamError);
}
pub struct StreamRefProtoProducer<T>
where
T: StreamRefPayload,
{
shared: Arc<ProducerShared<T>>,
}
impl<T> Clone for StreamRefProtoProducer<T>
where
T: StreamRefPayload,
{
fn clone(&self) -> Self {
Self {
shared: Arc::clone(&self.shared),
}
}
}
impl<T> StreamRefProtoProducer<T>
where
T: StreamRefPayload,
{
pub fn from_source_ref(
source_ref: SourceRef<T>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<Self> {
Self::from_source(
super::stream_ref::proto_source(&source_ref),
stream_ref_id,
settings,
)
}
pub fn from_source<Mat>(
source: Source<T, Mat>,
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
) -> StreamResult<Self>
where
Mat: Send + 'static,
{
let materializer = Materializer::new();
let (input, materialized) = Arc::clone(&source.factory).create(&materializer)?;
Ok(Self {
shared: Arc::new(ProducerShared {
stream_ref_id,
settings,
input: Mutex::new(Some(input)),
state: Mutex::new(ProducerState {
partner_seen: false,
cumulative_demand: 0,
sent: 0,
terminal_sent: false,
waiting_for_ack: false,
ack_deadline: None,
stopped: None,
ack_queued: false,
done: false,
input_attached: true,
terminal_result: None,
}),
changed: Condvar::new(),
completion: Mutex::new(None),
_materializer: materializer,
_materialized: Mutex::new(Some(Box::new(materialized))),
}),
})
}
#[must_use]
pub fn new_lazy(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
Self {
shared: Arc::new(ProducerShared {
stream_ref_id,
settings,
input: Mutex::new(None),
state: Mutex::new(ProducerState {
partner_seen: false,
cumulative_demand: 0,
sent: 0,
terminal_sent: false,
waiting_for_ack: false,
ack_deadline: None,
stopped: None,
ack_queued: false,
done: false,
input_attached: false,
terminal_result: None,
}),
changed: Condvar::new(),
completion: Mutex::new(None),
_materializer: Materializer::new(),
_materialized: Mutex::new(None),
}),
}
}
#[must_use]
pub fn sink(&self) -> Sink<T, StreamCompletion<NotUsed>> {
let shared = Arc::clone(&self.shared);
Sink::from_runner(move |input, _materializer| {
let (sender, receiver) = oneshot::channel();
*shared
.completion
.lock()
.unwrap_or_else(|poison| poison.into_inner()) = Some(sender);
shared.attach_input(input);
Ok(StreamCompletion::from_receiver(receiver, None))
})
}
}
impl<T> StreamRefProtoEndpoint for StreamRefProtoProducer<T>
where
T: StreamRefPayload,
{
fn stream_ref_id(&self) -> StreamRefId {
self.shared.stream_ref_id
}
fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
self.shared.next_frame()
}
fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
self.shared.handle_frame(frame)
}
fn fail_connection(&self, error: StreamError) {
self.shared.fail_connection(error);
}
}
struct ProducerShared<T>
where
T: StreamRefPayload,
{
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
input: Mutex<Option<BoxStream<T>>>,
state: Mutex<ProducerState>,
changed: Condvar,
completion: Mutex<Option<oneshot::Sender<StreamResult<NotUsed>>>>,
_materializer: Materializer,
_materialized: Mutex<Option<Box<dyn Any + Send>>>,
}
struct ProducerState {
partner_seen: bool,
cumulative_demand: u64,
sent: u64,
terminal_sent: bool,
waiting_for_ack: bool,
ack_deadline: Option<Instant>,
stopped: Option<StreamError>,
ack_queued: bool,
done: bool,
input_attached: bool,
terminal_result: Option<StreamResult<NotUsed>>,
}
impl<T> ProducerShared<T>
where
T: StreamRefPayload,
{
fn lock_state(&self) -> MutexGuard<'_, ProducerState> {
self.state
.lock()
.unwrap_or_else(|poison| poison.into_inner())
}
fn lock_input(&self) -> MutexGuard<'_, Option<BoxStream<T>>> {
self.input
.lock()
.unwrap_or_else(|poison| poison.into_inner())
}
fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
StreamRefFrame::new(self.stream_ref_id, message)
}
fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
let subscription_deadline = deadline_from_now(self.settings.subscription_timeout());
loop {
let mut state = self.lock_state();
if state.done {
return None;
}
if state.ack_queued {
state.ack_queued = false;
state.done = true;
state.terminal_result = Some(match state.stopped.clone() {
Some(error) => Err(error),
None => Ok(NotUsed),
});
self.changed.notify_all();
drop(state);
self.drop_input();
self.settle();
return Some(Ok(self.frame(StreamRefMessage::Ack)));
}
if state.waiting_for_ack {
if state
.ack_deadline
.is_some_and(|deadline| Instant::now() >= deadline)
{
let timeout_error =
subscription_timeout_error("stream ref producer terminal ack");
state.done = true;
state.terminal_result = Some(Err(timeout_error.clone()));
self.changed.notify_all();
drop(state);
self.drop_input();
self.settle();
return Some(Err(timeout_error));
}
if let Some(remaining) = state
.ack_deadline
.and_then(|deadline| deadline.checked_duration_since(Instant::now()))
{
let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
drop(next);
} else {
drop(state);
}
continue;
}
if let Some(error) = state.stopped.clone() {
state.done = true;
state.terminal_result = Some(Err(error.clone()));
self.changed.notify_all();
drop(state);
self.drop_input();
self.settle();
return Some(Err(error));
}
if state.cumulative_demand > 0 && state.sent < state.cumulative_demand {
drop(state);
if let Some(frame) = self.pull_next_frame() {
return Some(frame);
}
continue;
}
if state.cumulative_demand == 0 && Instant::now() >= subscription_deadline {
let timeout_error = subscription_timeout_error("stream ref producer first demand");
state.done = true;
state.terminal_result = Some(Err(timeout_error.clone()));
self.changed.notify_all();
drop(state);
self.drop_input();
self.settle();
return Some(Err(timeout_error));
}
if state.cumulative_demand == 0 {
let remaining = subscription_deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
drop(state);
continue;
}
let (next, _) = wait_timeout_unpoison(&self.changed, state, remaining);
drop(next);
} else {
let next = wait_unpoison(&self.changed, state);
drop(next);
}
}
}
fn pull_next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
let item = {
let mut input_guard = self.lock_input();
if input_guard.is_none() {
drop(input_guard);
let mut state = self.lock_state();
while !state.input_attached
&& !state.done
&& state.stopped.is_none()
&& !state.terminal_sent
{
state = wait_unpoison(&self.changed, state);
}
drop(state);
return None;
}
input_guard.as_mut().expect("input attached").next()
};
match item {
Some(Ok(item)) => {
let mut state = self.lock_state();
if state.done || state.stopped.is_some() || state.waiting_for_ack {
return None;
}
let seq_nr = state.sent;
state.sent = state.sent.saturating_add(1);
Some(Ok(self.frame(StreamRefMessage::SequencedOnNext {
seq_nr,
payload: StreamRefPayloadBytes {
bytes: item.encode_stream_ref_payload(),
},
})))
}
Some(Err(error)) => {
self.drop_input();
let mut state = self.lock_state();
if state.done || state.terminal_sent {
return None;
}
state.terminal_sent = true;
state.waiting_for_ack = true;
state.terminal_result = Some(Err(error.clone()));
state.ack_deadline = Some(deadline_from_now(self.settings.subscription_timeout()));
self.changed.notify_all();
drop(state);
Some(Ok(self.frame(StreamRefMessage::RemoteStreamFailure {
cause: failure_cause(&error),
})))
}
None => {
self.drop_input();
let mut state = self.lock_state();
if state.done || state.terminal_sent {
return None;
}
let seq_nr = state.sent;
state.terminal_sent = true;
state.waiting_for_ack = true;
state.terminal_result = Some(Ok(NotUsed));
state.ack_deadline = Some(deadline_from_now(self.settings.subscription_timeout()));
self.changed.notify_all();
drop(state);
Some(Ok(
self.frame(StreamRefMessage::RemoteStreamCompleted { seq_nr })
))
}
}
}
fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
self.validate_frame_id(frame.stream_ref_id)?;
match frame.message {
StreamRefMessage::OnSubscribeHandshake => {
let mut state = self.lock_state();
state.partner_seen = true;
self.changed.notify_all();
drop(state);
Ok(())
}
StreamRefMessage::CumulativeDemand { seq_nr } => {
if seq_nr == 0 {
return Err(StreamError::Failed(
"CumulativeDemand seq_nr must be positive".to_owned(),
));
}
let mut state = self.lock_state();
state.partner_seen = true;
if seq_nr > state.cumulative_demand {
state.cumulative_demand = seq_nr;
}
self.changed.notify_all();
drop(state);
Ok(())
}
StreamRefMessage::RemoteStreamCompleted { .. } => {
self.stop_from_consumer(StreamError::Cancelled);
Ok(())
}
StreamRefMessage::RemoteStreamFailure { cause } => {
self.stop_from_consumer(StreamError::Failed(
String::from_utf8_lossy(&cause).into_owned(),
));
Ok(())
}
StreamRefMessage::Ack => {
let mut state = self.lock_state();
if state.waiting_for_ack {
state.waiting_for_ack = false;
state.done = true;
if state.terminal_result.is_none() {
state.terminal_result = Some(Ok(NotUsed));
}
self.changed.notify_all();
drop(state);
self.drop_input();
self.settle();
} else {
drop(state);
}
Ok(())
}
StreamRefMessage::SequencedOnNext { .. } => Err(StreamError::Failed(
"producer endpoint cannot receive SequencedOnNext".to_owned(),
)),
}
}
fn stop_from_consumer(&self, error: StreamError) {
let mut state = self.lock_state();
if !state.done {
state.stopped = Some(error.clone());
state.ack_queued = true;
state.terminal_result = Some(Err(error));
}
self.changed.notify_all();
drop(state);
self.drop_input();
}
fn fail_connection(&self, error: StreamError) {
let mut state = self.lock_state();
if !state.done {
state.stopped = Some(error.clone());
state.done = true;
state.terminal_result = Some(Err(error));
}
self.changed.notify_all();
drop(state);
self.drop_input();
self.settle();
}
fn attach_input(&self, input: BoxStream<T>) {
*self.lock_input() = Some(input);
let mut state = self.lock_state();
state.input_attached = true;
self.changed.notify_all();
drop(state);
}
fn settle(&self) {
let result = self.lock_state().terminal_result.clone();
let sender = self
.completion
.lock()
.unwrap_or_else(|poison| poison.into_inner())
.take();
if let (Some(sender), Some(result)) = (sender, result) {
let _ = sender.send(result);
}
}
fn drop_input(&self) {
let input = self.lock_input().take();
drop(input);
}
fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
if stream_ref_id == self.stream_ref_id {
Ok(())
} else {
Err(StreamError::Failed(format!(
"stream ref id mismatch: expected {}, got {}",
self.stream_ref_id, stream_ref_id
)))
}
}
}
pub struct StreamRefProtoConsumer<T>
where
T: StreamRefPayload,
{
shared: Arc<ConsumerShared<T>>,
}
impl<T> Clone for StreamRefProtoConsumer<T>
where
T: StreamRefPayload,
{
fn clone(&self) -> Self {
Self {
shared: Arc::clone(&self.shared),
}
}
}
impl<T> StreamRefProtoConsumer<T>
where
T: StreamRefPayload,
{
#[must_use]
pub fn new(stream_ref_id: StreamRefId, settings: StreamRefSettings) -> Self {
Self {
shared: Arc::new(ConsumerShared {
stream_ref_id,
settings,
state: Mutex::new(ConsumerState {
source_taken: false,
subscribed: false,
queue: VecDeque::new(),
terminal: None,
expected_seq: 0,
delivered: 0,
cumulative_demand: 0,
outbound: VecDeque::new(),
finish_after_outbound_ack: false,
waiting_cancel_ack: false,
done: false,
}),
changed: Condvar::new(),
}),
}
}
#[must_use]
pub fn source(&self) -> Source<T, NotUsed> {
let shared = Arc::clone(&self.shared);
Source::unfold_resource(
move || shared.start_stream(),
|stream| stream.next_item(),
|mut stream| {
stream.close();
Ok(())
},
)
}
}
impl<T> StreamRefProtoEndpoint for StreamRefProtoConsumer<T>
where
T: StreamRefPayload,
{
fn stream_ref_id(&self) -> StreamRefId {
self.shared.stream_ref_id
}
fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
self.shared.next_frame()
}
fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
self.shared.handle_frame(frame)
}
fn fail_connection(&self, error: StreamError) {
self.shared.fail_connection(error);
}
}
struct ConsumerShared<T>
where
T: StreamRefPayload,
{
stream_ref_id: StreamRefId,
settings: StreamRefSettings,
state: Mutex<ConsumerState<T>>,
changed: Condvar,
}
struct ConsumerState<T> {
source_taken: bool,
subscribed: bool,
queue: VecDeque<T>,
terminal: Option<ConsumerTerminal>,
expected_seq: u64,
delivered: u64,
cumulative_demand: u64,
outbound: VecDeque<StreamRefMessage>,
finish_after_outbound_ack: bool,
waiting_cancel_ack: bool,
done: bool,
}
#[derive(Clone)]
enum ConsumerTerminal {
Complete,
Error(StreamError),
}
impl<T> ConsumerShared<T>
where
T: StreamRefPayload,
{
fn lock_state(&self) -> MutexGuard<'_, ConsumerState<T>> {
self.state
.lock()
.unwrap_or_else(|poison| poison.into_inner())
}
fn frame(&self, message: StreamRefMessage) -> StreamRefFrame {
StreamRefFrame::new(self.stream_ref_id, message)
}
fn start_stream(self: &Arc<Self>) -> StreamResult<ConsumerStream<T>> {
{
let mut state = self.lock_state();
if state.source_taken {
return Err(StreamError::Failed(
"stream ref source has already been materialized".to_owned(),
));
}
state.source_taken = true;
if !state.subscribed {
state.subscribed = true;
state
.outbound
.push_back(StreamRefMessage::OnSubscribeHandshake);
if let Some(demand) = next_demand(&mut state, self.settings) {
state
.outbound
.push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
}
}
self.changed.notify_all();
}
Ok(ConsumerStream {
shared: Arc::clone(self),
terminated: false,
})
}
fn next_frame(&self) -> Option<StreamResult<StreamRefFrame>> {
loop {
let mut state = self.lock_state();
if let Some(message) = state.outbound.pop_front() {
let finish_after_ack = message.is_ack() && state.finish_after_outbound_ack;
if finish_after_ack {
state.done = true;
}
drop(state);
return Some(Ok(self.frame(message)));
}
if state.done {
return None;
}
let next = wait_unpoison(&self.changed, state);
drop(next);
}
}
fn handle_frame(&self, frame: StreamRefFrame) -> StreamResult<()> {
self.validate_frame_id(frame.stream_ref_id)?;
match frame.message {
StreamRefMessage::OnSubscribeHandshake => Ok(()),
StreamRefMessage::SequencedOnNext { seq_nr, payload } => {
let item = T::decode_stream_ref_payload(payload.bytes)?;
let mut state = self.lock_state();
if state.terminal.is_some() || state.done {
return Ok(());
}
if seq_nr != state.expected_seq {
let error =
invalid_sequence_error(state.expected_seq, seq_nr, "stream ref element");
state.queue.clear();
state.terminal = Some(ConsumerTerminal::Error(error.clone()));
state
.outbound
.push_back(StreamRefMessage::RemoteStreamFailure {
cause: failure_cause(&error),
});
state.waiting_cancel_ack = true;
} else if state.queue.len() >= self.settings.buffer_capacity() {
let error = StreamError::Failed(
"stream ref receive buffer overflowed demand window".to_owned(),
);
state.queue.clear();
state.terminal = Some(ConsumerTerminal::Error(error.clone()));
state
.outbound
.push_back(StreamRefMessage::RemoteStreamFailure {
cause: failure_cause(&error),
});
state.waiting_cancel_ack = true;
} else {
state.expected_seq = state.expected_seq.saturating_add(1);
state.queue.push_back(item);
}
self.changed.notify_all();
drop(state);
Ok(())
}
StreamRefMessage::RemoteStreamCompleted { seq_nr } => {
let mut state = self.lock_state();
if state.terminal.is_none() && !state.done {
if seq_nr != state.expected_seq {
state.queue.clear();
state.terminal = Some(ConsumerTerminal::Error(invalid_sequence_error(
state.expected_seq,
seq_nr,
"stream ref completion",
)));
} else {
state.terminal = Some(ConsumerTerminal::Complete);
}
state.outbound.push_back(StreamRefMessage::Ack);
state.finish_after_outbound_ack = true;
}
self.changed.notify_all();
drop(state);
Ok(())
}
StreamRefMessage::RemoteStreamFailure { cause } => {
let mut state = self.lock_state();
if state.terminal.is_none() && !state.done {
state.queue.clear();
state.terminal = Some(ConsumerTerminal::Error(StreamError::Failed(
String::from_utf8_lossy(&cause).into_owned(),
)));
state.outbound.push_back(StreamRefMessage::Ack);
state.finish_after_outbound_ack = true;
}
self.changed.notify_all();
drop(state);
Ok(())
}
StreamRefMessage::Ack => {
let mut state = self.lock_state();
if state.waiting_cancel_ack {
state.waiting_cancel_ack = false;
state.done = true;
}
self.changed.notify_all();
drop(state);
Ok(())
}
StreamRefMessage::CumulativeDemand { .. } => Err(StreamError::Failed(
"consumer endpoint cannot receive CumulativeDemand".to_owned(),
)),
}
}
fn cancel_from_downstream(&self) {
let mut state = self.lock_state();
if state.terminal.is_none() && !state.done {
let seq_nr = state.expected_seq;
state.terminal = Some(ConsumerTerminal::Error(StreamError::Cancelled));
state
.outbound
.push_back(StreamRefMessage::RemoteStreamCompleted { seq_nr });
state.waiting_cancel_ack = true;
}
self.changed.notify_all();
drop(state);
}
fn fail_connection(&self, error: StreamError) {
let mut state = self.lock_state();
if state.terminal.is_none() {
state.queue.clear();
state.terminal = Some(ConsumerTerminal::Error(error));
}
state.done = true;
self.changed.notify_all();
drop(state);
}
fn validate_frame_id(&self, stream_ref_id: StreamRefId) -> StreamResult<()> {
if stream_ref_id == self.stream_ref_id {
Ok(())
} else {
Err(StreamError::Failed(format!(
"stream ref id mismatch: expected {}, got {}",
self.stream_ref_id, stream_ref_id
)))
}
}
}
struct ConsumerStream<T>
where
T: StreamRefPayload,
{
shared: Arc<ConsumerShared<T>>,
terminated: bool,
}
impl<T> ConsumerStream<T>
where
T: StreamRefPayload,
{
fn next_item(&mut self) -> StreamResult<Option<T>> {
if self.terminated {
return Ok(None);
}
loop {
let mut state = self.shared.lock_state();
if let Some(item) = state.queue.pop_front() {
state.delivered = state.delivered.saturating_add(1);
if let Some(demand) = next_demand(&mut state, self.shared.settings) {
state
.outbound
.push_back(StreamRefMessage::CumulativeDemand { seq_nr: demand });
self.shared.changed.notify_all();
}
return Ok(Some(item));
}
if let Some(terminal) = state.terminal.clone() {
self.terminated = true;
return match terminal {
ConsumerTerminal::Complete => Ok(None),
ConsumerTerminal::Error(error) => Err(error),
};
}
let next = wait_unpoison(&self.shared.changed, state);
drop(next);
}
}
fn close(&mut self) {
if !self.terminated {
self.shared.cancel_from_downstream();
self.terminated = true;
}
}
}
fn next_demand<T>(state: &mut ConsumerState<T>, settings: StreamRefSettings) -> Option<u64> {
if state.terminal.is_some() {
return None;
}
let remaining_credit = state.cumulative_demand.saturating_sub(state.delivered);
if state.cumulative_demand != 0 && remaining_credit > demand_replenish_threshold(settings) {
return None;
}
let target = state
.delivered
.saturating_add(settings.buffer_capacity() as u64);
if state.cumulative_demand >= target {
return None;
}
state.cumulative_demand = target;
Some(target)
}
fn demand_replenish_threshold(settings: StreamRefSettings) -> u64 {
(settings.buffer_capacity() as u64) / 2
}
fn failure_cause(error: &StreamError) -> Vec<u8> {
match error {
StreamError::Failed(message) => message.clone().into_bytes(),
other => other.to_string().into_bytes(),
}
}
fn subscription_timeout_error(side: &str) -> StreamError {
StreamError::Failed(format!(
"{side} remote side did not subscribe within subscription timeout"
))
}
fn invalid_sequence_error(expected: u64, got: u64, context: &str) -> StreamError {
StreamError::Failed(format!(
"{context} sequence gap: expected sequence {expected}, got {got}"
))
}
fn deadline_from_now(timeout: Duration) -> Instant {
Instant::now()
.checked_add(timeout)
.unwrap_or_else(far_future)
}
fn far_future() -> Instant {
Instant::now() + Duration::from_secs(60 * 60 * 24 * 365)
}
fn wait_timeout_unpoison<'a, T>(
condvar: &Condvar,
guard: MutexGuard<'a, T>,
timeout: Duration,
) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
condvar
.wait_timeout(guard, timeout)
.unwrap_or_else(|poison| poison.into_inner())
}
fn wait_unpoison<'a, T>(condvar: &Condvar, guard: MutexGuard<'a, T>) -> MutexGuard<'a, T> {
condvar
.wait(guard)
.unwrap_or_else(|poison| poison.into_inner())
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::{Source, StreamRefs};
fn short_settings() -> StreamRefSettings {
StreamRefSettings::default()
.with_buffer_capacity(4)
.with_subscription_timeout(Duration::from_millis(50))
}
#[test]
fn protobuf_frame_round_trip() {
let frame = StreamRefFrame::new(
StreamRefId::from_u128(42),
StreamRefMessage::SequencedOnNext {
seq_nr: 7,
payload: StreamRefPayloadBytes {
bytes: 99_u64.encode_stream_ref_payload(),
},
},
);
let decoded = StreamRefFrame::decode(&frame.encode_to_vec()).unwrap();
assert_eq!(decoded, frame);
}
#[test]
fn producer_consumer_seam_streams_with_low_watermark_demand() {
let id = StreamRefId::from_u128(1);
let settings = short_settings();
let source_ref = Source::from_iter(0_u64..10)
.run_with(StreamRefs::source_ref_with_settings(settings))
.unwrap();
let producer = StreamRefProtoProducer::from_source_ref(source_ref, id, settings).unwrap();
let consumer = StreamRefProtoConsumer::<u64>::new(id, settings);
let consumer_source = consumer.source();
let producer_thread = std::thread::spawn({
let producer = producer.clone();
let consumer = consumer.clone();
move || {
while let Some(frame) = producer.next_frame() {
consumer.handle_frame(frame?)?;
}
Ok::<_, StreamError>(())
}
});
let consumer_thread = std::thread::spawn({
let producer = producer.clone();
let consumer = consumer.clone();
move || {
while let Some(frame) = consumer.next_frame() {
producer.handle_frame(frame?)?;
}
Ok::<_, StreamError>(())
}
});
assert_eq!(
consumer_source.run_collect().unwrap(),
(0_u64..10).collect::<Vec<_>>()
);
producer_thread.join().unwrap().unwrap();
consumer_thread.join().unwrap().unwrap();
}
#[test]
fn strict_sequence_gap_fails_consumer_and_sends_failure() {
let id = StreamRefId::from_u128(2);
let consumer = StreamRefProtoConsumer::<u64>::new(id, short_settings());
let source = consumer
.source()
.run_with(crate::testkit::TestSink::probe())
.unwrap();
source.request(1);
consumer.next_frame().unwrap().unwrap();
consumer.next_frame().unwrap().unwrap();
consumer
.handle_frame(StreamRefFrame::new(
id,
StreamRefMessage::SequencedOnNext {
seq_nr: 1,
payload: StreamRefPayloadBytes {
bytes: 1_u64.encode_stream_ref_payload(),
},
},
))
.unwrap();
let outbound = consumer.next_frame().unwrap().unwrap();
assert!(matches!(
outbound.message,
StreamRefMessage::RemoteStreamFailure { .. }
));
assert!(matches!(source.expect_error(), StreamError::Failed(_)));
}
#[test]
fn producer_times_out_without_first_demand() {
let producer = StreamRefProtoProducer::from_source(
Source::repeat(1_u64),
StreamRefId::from_u128(3),
short_settings(),
)
.unwrap();
let error = producer.next_frame().unwrap().unwrap_err();
assert!(matches!(error, StreamError::Failed(message) if message.contains("first demand")));
}
#[test]
fn demand_redelivery_is_not_required_by_reliable_carriers() {
assert_eq!(
StreamRefSettings::default().demand_redelivery_interval(),
Duration::from_secs(1)
);
}
}