use std::{
collections::VecDeque,
fmt,
sync::{
Arc, Condvar, Mutex, MutexGuard,
atomic::{AtomicBool, Ordering},
},
time::{Duration, Instant},
};
use crate::stream::{
BoxStream, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult,
};
use super::{Actor, ActorProcessingErr, ActorRef, ActorResult, Message, block_on_ractor_runtime};
const DEFAULT_STREAM_REF_BUFFER_CAPACITY: usize = 32;
const DEFAULT_STREAM_REF_SUBSCRIPTION_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_STREAM_REF_DEMAND_REDELIVERY: Duration = Duration::from_secs(1);
const STREAM_REF_WAIT_POLL: Duration = Duration::from_millis(1);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StreamRefSettings {
buffer_capacity: usize,
subscription_timeout: Duration,
demand_redelivery_interval: Duration,
}
impl Default for StreamRefSettings {
fn default() -> Self {
Self {
buffer_capacity: DEFAULT_STREAM_REF_BUFFER_CAPACITY,
subscription_timeout: DEFAULT_STREAM_REF_SUBSCRIPTION_TIMEOUT,
demand_redelivery_interval: DEFAULT_STREAM_REF_DEMAND_REDELIVERY,
}
}
}
impl StreamRefSettings {
#[must_use]
pub fn buffer_capacity(&self) -> usize {
self.buffer_capacity
}
#[must_use]
pub fn subscription_timeout(&self) -> Duration {
self.subscription_timeout
}
#[must_use]
pub fn demand_redelivery_interval(&self) -> Duration {
self.demand_redelivery_interval
}
#[must_use]
pub fn with_buffer_capacity(mut self, capacity: usize) -> Self {
assert!(
capacity > 0,
"StreamRef buffer capacity must be greater than zero"
);
self.buffer_capacity = capacity;
self
}
#[must_use]
pub fn with_subscription_timeout(mut self, timeout: Duration) -> Self {
self.subscription_timeout = timeout;
self
}
#[must_use]
pub fn with_demand_redelivery_interval(mut self, interval: Duration) -> Self {
self.demand_redelivery_interval = interval;
self
}
}
pub struct StreamRefs;
impl StreamRefs {
#[must_use]
pub fn source_ref<T>() -> Sink<T, SourceRef<T>>
where
T: Send + 'static,
{
Self::source_ref_with_settings(StreamRefSettings::default())
}
#[must_use]
pub fn source_ref_with_settings<T>(settings: StreamRefSettings) -> Sink<T, SourceRef<T>>
where
T: Send + 'static,
{
stream_ref_source_sink(settings)
}
#[must_use]
pub fn sink_ref<T>() -> Source<T, SinkRef<T>>
where
T: Send + 'static,
{
Self::sink_ref_with_settings(StreamRefSettings::default())
}
#[must_use]
pub fn sink_ref_with_settings<T>(settings: StreamRefSettings) -> Source<T, SinkRef<T>>
where
T: Send + 'static,
{
stream_ref_sink_source(settings)
}
}
pub struct SourceRef<T> {
inner: Arc<SourceRefInner<T>>,
}
struct SourceRefInner<T> {
producer: ActorRef<ProducerCommand<T>>,
settings: StreamRefSettings,
subscribed: AtomicBool,
_keep_alive: Mutex<Option<StreamCompletion<NotUsed>>>,
}
impl<T> Clone for SourceRef<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<T> fmt::Debug for SourceRef<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SourceRef").finish_non_exhaustive()
}
}
impl<T> SourceRef<T>
where
T: Send + 'static,
{
#[must_use]
pub fn source(&self) -> Source<T, NotUsed> {
let inner = Arc::clone(&self.inner);
Source::from_materialized_factory(move |_materializer| {
if inner.subscribed.swap(true, Ordering::SeqCst) {
return Ok((
failed_once("source ref has already been materialized"),
NotUsed,
));
}
let shared = Arc::new(ConsumerShared::new(inner.settings));
let (consumer_ref, _handle) =
spawn_consumer_actor(Some(inner.producer.clone()), Arc::clone(&shared))?;
Ok((
Box::new(ConsumerStream {
shared,
actor_ref: Some(consumer_ref),
settings: inner.settings,
terminated: false,
source_ref_keep_alive: Some(Arc::clone(&inner)),
}) as BoxStream<T>,
NotUsed,
))
})
}
}
pub struct SinkRef<T> {
inner: Arc<SinkRefInner<T>>,
}
struct SinkRefInner<T> {
consumer: ActorRef<ConsumerCommand<T>>,
settings: StreamRefSettings,
subscribed: AtomicBool,
}
impl<T> Clone for SinkRef<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<T> fmt::Debug for SinkRef<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SinkRef").finish_non_exhaustive()
}
}
impl<T> SinkRef<T>
where
T: Send + 'static,
{
#[must_use]
pub fn sink(&self) -> Sink<T, StreamCompletion<NotUsed>> {
let inner = Arc::clone(&self.inner);
Sink::from_runner(move |input, materializer| {
if inner.subscribed.swap(true, Ordering::SeqCst) {
return Ok(StreamCompletion::ready(Err(StreamError::Failed(
"sink ref has already been materialized".to_owned(),
))));
}
let shared = Arc::new(ProducerShared::new());
let (producer_ref, _handle) =
spawn_producer_actor(Some(inner.consumer.clone()), Arc::clone(&shared))?;
let settings = inner.settings;
Ok(materializer.spawn_stream(move |cancelled| {
run_producer_endpoint(input, shared, producer_ref, settings, cancelled)
}))
})
}
}
enum ProducerCommand<T> {
Subscribe {
consumer: ActorRef<ConsumerCommand<T>>,
},
Demand {
consumer: ActorRef<ConsumerCommand<T>>,
cumulative: u64,
},
Cancel {
consumer: ActorRef<ConsumerCommand<T>>,
},
RemoteFailure {
consumer: ActorRef<ConsumerCommand<T>>,
error: StreamError,
},
Ack,
}
#[cfg(feature = "cluster")]
impl<T: Send + 'static> Message for ProducerCommand<T> {}
enum ConsumerCommand<T> {
OnSubscribe {
producer: ActorRef<ProducerCommand<T>>,
},
Element {
producer: ActorRef<ProducerCommand<T>>,
seq: u64,
item: T,
},
Complete {
producer: ActorRef<ProducerCommand<T>>,
seq: u64,
},
Failure {
producer: ActorRef<ProducerCommand<T>>,
error: StreamError,
},
}
#[cfg(feature = "cluster")]
impl<T: Send + 'static> Message for ConsumerCommand<T> {}
struct ProducerShared<T> {
inner: Mutex<ProducerInner<T>>,
changed: Condvar,
}
struct ProducerInner<T> {
consumer: Option<ActorRef<ConsumerCommand<T>>>,
cumulative_demand: u64,
sent: u64,
stopped: Option<StreamError>,
}
impl<T> ProducerShared<T>
where
T: Send + 'static,
{
fn new() -> Self {
Self {
inner: Mutex::new(ProducerInner {
consumer: None,
cumulative_demand: 0,
sent: 0,
stopped: None,
}),
changed: Condvar::new(),
}
}
fn lock(&self) -> MutexGuard<'_, ProducerInner<T>> {
self.inner
.lock()
.unwrap_or_else(|poison| poison.into_inner())
}
fn set_consumer(
&self,
consumer: ActorRef<ConsumerCommand<T>>,
) -> Result<bool, ActorRef<ConsumerCommand<T>>> {
let mut inner = self.lock();
match &inner.consumer {
Some(existing) if !same_actor(existing, &consumer) => Err(consumer),
Some(_) => Ok(false),
None => {
inner.consumer = Some(consumer);
drop(inner);
self.changed.notify_all();
Ok(true)
}
}
}
fn update_demand(&self, consumer: &ActorRef<ConsumerCommand<T>>, cumulative: u64) {
let mut inner = self.lock();
if inner
.consumer
.as_ref()
.is_some_and(|existing| same_actor(existing, consumer))
&& cumulative > inner.cumulative_demand
{
inner.cumulative_demand = cumulative;
drop(inner);
self.changed.notify_all();
}
}
fn stop_from_consumer(&self, consumer: &ActorRef<ConsumerCommand<T>>, error: StreamError) {
let mut inner = self.lock();
if inner
.consumer
.as_ref()
.is_none_or(|existing| same_actor(existing, consumer))
&& inner.stopped.is_none()
{
inner.stopped = Some(error);
drop(inner);
self.changed.notify_all();
}
}
fn stop_unless_finished(&self, error: StreamError) {
let mut inner = self.lock();
if inner.stopped.is_none() {
inner.stopped = Some(error);
drop(inner);
self.changed.notify_all();
}
}
}
struct ConsumerShared<T> {
inner: Mutex<ConsumerInner<T>>,
changed: Condvar,
}
struct ConsumerInner<T> {
producer: Option<ActorRef<ProducerCommand<T>>>,
queue: VecDeque<T>,
terminal: Option<ConsumerTerminal>,
expected_seq: u64,
delivered: u64,
cumulative_demand: u64,
}
#[derive(Clone)]
enum ConsumerTerminal {
Complete,
Error(StreamError),
}
impl<T> ConsumerShared<T>
where
T: Send + 'static,
{
fn new(_settings: StreamRefSettings) -> Self {
Self {
inner: Mutex::new(ConsumerInner {
producer: None,
queue: VecDeque::new(),
terminal: None,
expected_seq: 0,
delivered: 0,
cumulative_demand: 0,
}),
changed: Condvar::new(),
}
}
fn lock(&self) -> MutexGuard<'_, ConsumerInner<T>> {
self.inner
.lock()
.unwrap_or_else(|poison| poison.into_inner())
}
fn set_producer(&self, producer: ActorRef<ProducerCommand<T>>) -> bool {
let mut inner = self.lock();
match &inner.producer {
Some(existing) => same_actor(existing, &producer),
None => {
inner.producer = Some(producer);
drop(inner);
self.changed.notify_all();
true
}
}
}
fn push(&self, producer: &ActorRef<ProducerCommand<T>>, seq: u64, item: T) {
let mut inner = self.lock();
if inner.terminal.is_some() || !producer_matches(&inner, producer) {
return;
}
if seq != inner.expected_seq {
inner.queue.clear();
inner.terminal = Some(ConsumerTerminal::Error(invalid_sequence_error(
inner.expected_seq,
seq,
"stream ref element sequence gap",
)));
} else {
inner.expected_seq += 1;
inner.queue.push_back(item);
}
drop(inner);
self.changed.notify_all();
}
fn complete(&self, producer: &ActorRef<ProducerCommand<T>>, seq: u64) {
let mut inner = self.lock();
if inner.terminal.is_some() || !producer_matches(&inner, producer) {
return;
}
if seq != inner.expected_seq {
inner.queue.clear();
inner.terminal = Some(ConsumerTerminal::Error(invalid_sequence_error(
inner.expected_seq,
seq,
"stream ref completion sequence gap",
)));
} else {
inner.terminal = Some(ConsumerTerminal::Complete);
}
drop(inner);
self.changed.notify_all();
}
fn fail(&self, producer: &ActorRef<ProducerCommand<T>>, error: StreamError) {
let mut inner = self.lock();
if inner.terminal.is_some() || !producer_matches(&inner, producer) {
return;
}
inner.queue.clear();
inner.terminal = Some(ConsumerTerminal::Error(error));
drop(inner);
self.changed.notify_all();
}
fn fail_local(&self, error: StreamError) {
let mut inner = self.lock();
if inner.terminal.is_none() {
inner.queue.clear();
inner.terminal = Some(ConsumerTerminal::Error(error));
drop(inner);
self.changed.notify_all();
}
}
}
fn producer_matches<T>(inner: &ConsumerInner<T>, producer: &ActorRef<ProducerCommand<T>>) -> bool
where
T: Send + 'static,
{
inner
.producer
.as_ref()
.is_some_and(|existing| same_actor(existing, producer))
}
struct ProducerActor<T> {
shared: Arc<ProducerShared<T>>,
initial_consumer: Option<ActorRef<ConsumerCommand<T>>>,
}
impl<T> Actor for ProducerActor<T>
where
T: Send + 'static,
{
type Msg = ProducerCommand<T>;
type State = ();
type Arguments = ();
async fn pre_start(
&self,
myself: ActorRef<Self::Msg>,
_args: Self::Arguments,
) -> Result<Self::State, ActorProcessingErr> {
if let Some(consumer) = &self.initial_consumer {
register_producer_consumer(&self.shared, myself, consumer.clone());
}
Ok(())
}
async fn handle(
&self,
myself: ActorRef<Self::Msg>,
message: Self::Msg,
_state: &mut Self::State,
) -> ActorResult {
match message {
ProducerCommand::Subscribe { consumer } => {
register_producer_consumer(&self.shared, myself, consumer);
}
ProducerCommand::Demand {
consumer,
cumulative,
} => self.shared.update_demand(&consumer, cumulative),
ProducerCommand::Cancel { consumer } => {
self.shared
.stop_from_consumer(&consumer, StreamError::Cancelled);
}
ProducerCommand::RemoteFailure { consumer, error } => {
self.shared.stop_from_consumer(&consumer, error);
}
ProducerCommand::Ack => myself.stop(None),
}
Ok(())
}
async fn post_stop(
&self,
_myself: ActorRef<Self::Msg>,
_state: &mut Self::State,
) -> ActorResult {
self.shared
.stop_unless_finished(StreamError::ActorTerminated);
Ok(())
}
}
fn register_producer_consumer<T>(
shared: &Arc<ProducerShared<T>>,
producer: ActorRef<ProducerCommand<T>>,
consumer: ActorRef<ConsumerCommand<T>>,
) where
T: Send + 'static,
{
match shared.set_consumer(consumer.clone()) {
Ok(true) | Ok(false) => {
let _ = cast_actor(
&consumer,
ConsumerCommand::OnSubscribe {
producer: producer.clone(),
},
);
}
Err(duplicate) => {
let _ = cast_actor(
&duplicate,
ConsumerCommand::Failure {
producer,
error: StreamError::Failed(
"stream ref was already subscribed by another endpoint".to_owned(),
),
},
);
}
}
}
struct ConsumerActor<T> {
shared: Arc<ConsumerShared<T>>,
initial_producer: Option<ActorRef<ProducerCommand<T>>>,
}
impl<T> Actor for ConsumerActor<T>
where
T: Send + 'static,
{
type Msg = ConsumerCommand<T>;
type State = ();
type Arguments = ();
async fn pre_start(
&self,
myself: ActorRef<Self::Msg>,
_args: Self::Arguments,
) -> Result<Self::State, ActorProcessingErr> {
if let Some(producer) = &self.initial_producer {
self.shared.set_producer(producer.clone());
if let Err(error) = cast_actor(
producer,
ProducerCommand::Subscribe {
consumer: myself.clone(),
},
) {
self.shared.fail_local(error);
myself.stop(None);
}
}
Ok(())
}
async fn handle(
&self,
_myself: ActorRef<Self::Msg>,
message: Self::Msg,
_state: &mut Self::State,
) -> ActorResult {
match message {
ConsumerCommand::OnSubscribe { producer } => {
if !self.shared.set_producer(producer.clone()) {
let _ = cast_actor(
&producer,
ProducerCommand::RemoteFailure {
consumer: _myself.clone(),
error: StreamError::Failed(
"stream ref was already subscribed by another endpoint".to_owned(),
),
},
);
}
}
ConsumerCommand::Element {
producer,
seq,
item,
} => self.shared.push(&producer, seq, item),
ConsumerCommand::Complete { producer, seq } => {
self.shared.complete(&producer, seq);
let _ = cast_actor(&producer, ProducerCommand::Ack);
}
ConsumerCommand::Failure { producer, error } => {
self.shared.fail(&producer, error);
let _ = cast_actor(&producer, ProducerCommand::Ack);
}
}
Ok(())
}
async fn post_stop(
&self,
_myself: ActorRef<Self::Msg>,
_state: &mut Self::State,
) -> ActorResult {
self.shared.fail_local(StreamError::ActorTerminated);
Ok(())
}
}
struct ConsumerStream<T>
where
T: Send + 'static,
{
shared: Arc<ConsumerShared<T>>,
actor_ref: Option<ActorRef<ConsumerCommand<T>>>,
settings: StreamRefSettings,
terminated: bool,
source_ref_keep_alive: Option<Arc<SourceRefInner<T>>>,
}
impl<T> Iterator for ConsumerStream<T>
where
T: Send + 'static,
{
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.terminated {
return None;
}
if let Err(error) = self.wait_for_subscription() {
self.terminated = true;
return Some(Err(error));
}
if let Err(error) = self.redeliver_or_extend_demand() {
self.terminated = true;
return Some(Err(error));
}
let mut next_redelivery = next_redelivery_deadline(self.settings);
loop {
let demand_after_pop = {
let mut inner = self.shared.lock();
if let Some(item) = inner.queue.pop_front() {
inner.delivered = inner.delivered.saturating_add(1);
let demand = next_demand(&mut inner, self.settings);
drop(inner);
if let Some((producer, cumulative)) = demand
&& let Err(error) = send_demand(&self.actor_ref, &producer, cumulative)
{
self.terminated = true;
return Some(Err(error));
}
return Some(Ok(item));
}
if let Some(terminal) = inner.terminal.clone() {
drop(inner);
match terminal {
ConsumerTerminal::Complete => {
self.terminated = true;
self.stop_actor();
return None;
}
ConsumerTerminal::Error(error) => {
self.terminated = true;
self.stop_actor();
return Some(Err(error));
}
}
}
None::<(ActorRef<ProducerCommand<T>>, u64)>
};
debug_assert!(demand_after_pop.is_none());
let now = Instant::now();
let timeout = next_redelivery.saturating_duration_since(now);
let mut inner = self.shared.lock();
if !inner.queue.is_empty() || inner.terminal.is_some() {
continue;
}
let (next_inner, result) = wait_timeout_unpoison(
&self.shared.changed,
inner,
timeout.min(STREAM_REF_WAIT_POLL),
);
inner = next_inner;
drop(inner);
if result.timed_out() && Instant::now() >= next_redelivery {
if let Err(error) = self.redeliver_demand() {
self.terminated = true;
return Some(Err(error));
}
next_redelivery = next_redelivery_deadline(self.settings);
}
}
}
}
impl<T> ConsumerStream<T>
where
T: Send + 'static,
{
fn wait_for_subscription(&self) -> StreamResult<()> {
let deadline = Instant::now()
.checked_add(self.settings.subscription_timeout)
.unwrap_or_else(far_future);
let mut inner = self.shared.lock();
loop {
if inner.producer.is_some() {
return Ok(());
}
if let Some(terminal) = inner.terminal.clone() {
return match terminal {
ConsumerTerminal::Complete => Ok(()),
ConsumerTerminal::Error(error) => Err(error),
};
}
let now = Instant::now();
if now >= deadline {
drop(inner);
let error = subscription_timeout_error("stream ref source");
self.shared.fail_local(error.clone());
self.stop_actor_ref();
return Err(error);
}
let remaining = deadline.saturating_duration_since(now);
let (next, _) = wait_timeout_unpoison(
&self.shared.changed,
inner,
remaining.min(STREAM_REF_WAIT_POLL),
);
inner = next;
}
}
fn redeliver_or_extend_demand(&self) -> StreamResult<()> {
let demand = {
let mut inner = self.shared.lock();
next_demand(&mut inner, self.settings)
};
if let Some((producer, cumulative)) = demand {
send_demand(&self.actor_ref, &producer, cumulative)?;
}
Ok(())
}
fn redeliver_demand(&self) -> StreamResult<()> {
let (producer, cumulative) = {
let inner = self.shared.lock();
let Some(producer) = inner.producer.clone() else {
return Ok(());
};
if inner.cumulative_demand == 0 {
return Ok(());
}
(producer, inner.cumulative_demand)
};
send_demand(&self.actor_ref, &producer, cumulative)
}
fn stop_actor(&mut self) {
if let Some(actor_ref) = self.actor_ref.take() {
actor_ref.stop(None);
}
}
fn stop_actor_ref(&self) {
if let Some(actor_ref) = &self.actor_ref {
actor_ref.stop(None);
}
}
}
impl<T> Drop for ConsumerStream<T>
where
T: Send + 'static,
{
fn drop(&mut self) {
if !self.terminated {
let producer = self.shared.lock().producer.clone();
if let (Some(consumer), Some(producer)) = (&self.actor_ref, producer) {
let _ = cast_actor(
&producer,
ProducerCommand::Cancel {
consumer: consumer.clone(),
},
);
}
}
self.stop_actor();
drop(self.source_ref_keep_alive.take());
}
}
fn next_demand<T>(
inner: &mut ConsumerInner<T>,
settings: StreamRefSettings,
) -> Option<(ActorRef<ProducerCommand<T>>, u64)> {
if inner.terminal.is_some() {
return None;
}
let target = inner
.delivered
.saturating_add(settings.buffer_capacity as u64);
if inner.cumulative_demand >= target {
return None;
}
inner.cumulative_demand = target;
inner
.producer
.as_ref()
.map(|producer| (producer.clone(), inner.cumulative_demand))
}
fn send_demand<T>(
consumer: &Option<ActorRef<ConsumerCommand<T>>>,
producer: &ActorRef<ProducerCommand<T>>,
cumulative: u64,
) -> StreamResult<()>
where
T: Send + 'static,
{
let Some(consumer) = consumer else {
return Err(StreamError::ActorTerminated);
};
cast_actor(
producer,
ProducerCommand::Demand {
consumer: consumer.clone(),
cumulative,
},
)
}
fn stream_ref_source_sink<T>(settings: StreamRefSettings) -> Sink<T, SourceRef<T>>
where
T: Send + 'static,
{
Sink::from_runner(move |input, materializer| {
let shared = Arc::new(ProducerShared::new());
let (producer_ref, _handle) = spawn_producer_actor(None, Arc::clone(&shared))?;
let producer_for_task = producer_ref.clone();
let completion = materializer.spawn_stream(move |cancelled| {
run_producer_endpoint(input, shared, producer_for_task, settings, cancelled)
});
Ok(SourceRef {
inner: Arc::new(SourceRefInner {
producer: producer_ref,
settings,
subscribed: AtomicBool::new(false),
_keep_alive: Mutex::new(Some(completion)),
}),
})
})
}
fn stream_ref_sink_source<T>(settings: StreamRefSettings) -> Source<T, SinkRef<T>>
where
T: Send + 'static,
{
Source::from_materialized_factory(move |_materializer| {
let shared = Arc::new(ConsumerShared::new(settings));
let (consumer_ref, _handle) = spawn_consumer_actor(None, Arc::clone(&shared))?;
let sink_ref = SinkRef {
inner: Arc::new(SinkRefInner {
consumer: consumer_ref.clone(),
settings,
subscribed: AtomicBool::new(false),
}),
};
Ok((
Box::new(ConsumerStream {
shared,
actor_ref: Some(consumer_ref),
settings,
terminated: false,
source_ref_keep_alive: None,
}) as BoxStream<T>,
sink_ref,
))
})
}
fn run_producer_endpoint<T>(
mut input: BoxStream<T>,
shared: Arc<ProducerShared<T>>,
producer_ref: ActorRef<ProducerCommand<T>>,
settings: StreamRefSettings,
cancelled: Arc<AtomicBool>,
) -> StreamResult<NotUsed>
where
T: Send + 'static,
{
let deadline = Instant::now()
.checked_add(settings.subscription_timeout)
.unwrap_or_else(far_future);
loop {
let consumer = match wait_for_remote_demand(&shared, deadline, &cancelled) {
Ok(consumer) => consumer,
Err(error) => {
producer_ref.stop(None);
return Err(error);
}
};
if cancelled.load(Ordering::SeqCst) {
return Err(StreamError::Cancelled);
}
match input.next() {
Some(Ok(item)) => {
let seq = {
let mut inner = shared.lock();
let seq = inner.sent;
inner.sent = inner.sent.saturating_add(1);
seq
};
if let Err(error) = cast_actor(
&consumer,
ConsumerCommand::Element {
producer: producer_ref.clone(),
seq,
item,
},
) {
return Err(match error {
StreamError::ActorTerminated => StreamError::Cancelled,
other => other,
});
}
}
Some(Err(error)) => {
let _ = cast_actor(
&consumer,
ConsumerCommand::Failure {
producer: producer_ref.clone(),
error: error.clone(),
},
);
return Err(error);
}
None => {
let seq = shared.lock().sent;
let _ = cast_actor(
&consumer,
ConsumerCommand::Complete {
producer: producer_ref.clone(),
seq,
},
);
return Ok(NotUsed);
}
}
}
}
fn wait_for_remote_demand<T>(
shared: &Arc<ProducerShared<T>>,
deadline: Instant,
cancelled: &Arc<AtomicBool>,
) -> StreamResult<ActorRef<ConsumerCommand<T>>>
where
T: Send + 'static,
{
let mut inner = shared.lock();
loop {
if cancelled.load(Ordering::SeqCst) {
return Err(StreamError::Cancelled);
}
if let Some(error) = inner.stopped.clone() {
return Err(error);
}
if let Some(consumer) = &inner.consumer
&& inner.sent < inner.cumulative_demand
{
return Ok(consumer.clone());
}
let now = Instant::now();
if inner.consumer.is_none() && now >= deadline {
return Err(subscription_timeout_error("stream ref sink"));
}
let remaining = deadline.saturating_duration_since(now);
let timeout = if inner.consumer.is_none() {
remaining.min(STREAM_REF_WAIT_POLL)
} else {
STREAM_REF_WAIT_POLL
};
let (next, _) = wait_timeout_unpoison(&shared.changed, inner, timeout);
inner = next;
}
}
fn spawn_producer_actor<T>(
initial_consumer: Option<ActorRef<ConsumerCommand<T>>>,
shared: Arc<ProducerShared<T>>,
) -> StreamResult<(
ActorRef<ProducerCommand<T>>,
ractor::concurrency::JoinHandle<()>,
)>
where
T: Send + 'static,
{
block_on_ractor_runtime(Actor::spawn(
None,
ProducerActor {
shared,
initial_consumer,
},
(),
))?
.map_err(|error| {
StreamError::Failed(format!(
"stream ref producer actor failed to spawn: {error}"
))
})
}
fn spawn_consumer_actor<T>(
initial_producer: Option<ActorRef<ProducerCommand<T>>>,
shared: Arc<ConsumerShared<T>>,
) -> StreamResult<(
ActorRef<ConsumerCommand<T>>,
ractor::concurrency::JoinHandle<()>,
)>
where
T: Send + 'static,
{
block_on_ractor_runtime(Actor::spawn(
None,
ConsumerActor {
shared,
initial_producer,
},
(),
))?
.map_err(|error| {
StreamError::Failed(format!(
"stream ref consumer actor failed to spawn: {error}"
))
})
}
fn cast_actor<Msg>(actor_ref: &ActorRef<Msg>, message: Msg) -> StreamResult<()>
where
Msg: Message,
{
match actor_ref.cast(message) {
Ok(()) => Ok(()),
Err(ractor::MessagingErr::SendErr(_)) | Err(ractor::MessagingErr::ChannelClosed) => {
Err(StreamError::ActorTerminated)
}
Err(error) => Err(StreamError::ActorAskSendFailed {
reason: error.to_string(),
}),
}
}
fn same_actor<MsgA, MsgB>(left: &ActorRef<MsgA>, right: &ActorRef<MsgB>) -> bool
where
MsgA: Message,
MsgB: Message,
{
left.get_cell().get_id() == right.get_cell().get_id()
}
fn failed_once<T>(reason: &str) -> BoxStream<T>
where
T: Send + 'static,
{
let error = StreamError::Failed(reason.to_owned());
Box::new(std::iter::once(Err(error)))
}
fn subscription_timeout_error(side: &str) -> StreamError {
StreamError::Failed(format!(
"{side} remote side did not subscribe within subscription timeout"
))
}
fn invalid_sequence_error(expected: u64, got: u64, context: &str) -> StreamError {
StreamError::Failed(format!(
"{context}: expected sequence {expected}, got {got}"
))
}
fn next_redelivery_deadline(settings: StreamRefSettings) -> Instant {
Instant::now()
.checked_add(settings.demand_redelivery_interval)
.unwrap_or_else(far_future)
}
fn far_future() -> Instant {
Instant::now() + Duration::from_secs(60 * 60 * 24 * 365)
}
fn wait_timeout_unpoison<'a, T>(
condvar: &Condvar,
guard: MutexGuard<'a, T>,
timeout: Duration,
) -> (MutexGuard<'a, T>, std::sync::WaitTimeoutResult) {
condvar
.wait_timeout(guard, timeout)
.unwrap_or_else(|poison| poison.into_inner())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
stream::{Keep, Source},
testkit::TestSink,
};
use std::sync::{
Arc as StdArc,
atomic::{AtomicBool, AtomicUsize, Ordering},
};
fn wait_until(timeout: Duration, mut condition: impl FnMut() -> bool) -> bool {
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
if condition() {
return true;
}
std::thread::park_timeout(Duration::from_millis(1));
}
condition()
}
fn assert_condition_holds(timeout: Duration, mut condition: impl FnMut() -> bool) {
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
assert!(condition());
std::thread::park_timeout(Duration::from_millis(1));
}
assert!(condition());
}
fn short_settings() -> StreamRefSettings {
StreamRefSettings::default()
.with_buffer_capacity(1)
.with_subscription_timeout(Duration::from_millis(50))
.with_demand_redelivery_interval(Duration::from_millis(10))
}
#[test]
fn source_ref_streams_elements_and_completion() {
let source_ref = Source::from_iter(1_u64..=3)
.run_with(StreamRefs::source_ref())
.unwrap();
assert_eq!(source_ref.source().run_collect().unwrap(), vec![1, 2, 3]);
}
#[test]
fn sink_ref_streams_elements_and_completion() {
let (sink_ref, completion) = StreamRefs::sink_ref::<u64>()
.to_mat(Sink::collect(), Keep::both)
.run()
.unwrap();
Source::from_iter(1_u64..=3)
.run_with(sink_ref.sink())
.unwrap()
.wait()
.unwrap();
assert_eq!(completion.wait().unwrap(), vec![1, 2, 3]);
}
#[test]
fn source_ref_propagates_upstream_failure() {
let source_ref = Source::<u64>::failed(StreamError::Failed("boom".to_owned()))
.run_with(StreamRefs::source_ref())
.unwrap();
assert_eq!(
source_ref.source().run_collect(),
Err(StreamError::Failed("boom".to_owned()))
);
}
#[test]
fn sink_ref_propagates_upstream_failure() {
let (sink_ref, completion) = StreamRefs::sink_ref::<u64>()
.to_mat(Sink::collect(), Keep::both)
.run()
.unwrap();
let failure = Source::<u64>::failed(StreamError::Failed("remote boom".to_owned()))
.run_with(sink_ref.sink())
.unwrap()
.wait();
assert_eq!(failure, Err(StreamError::Failed("remote boom".to_owned())));
assert_eq!(
completion.wait(),
Err(StreamError::Failed("remote boom".to_owned()))
);
}
#[test]
fn source_ref_cancellation_reaches_origin() {
let closed = StdArc::new(AtomicBool::new(false));
let close_flag = StdArc::clone(&closed);
let source = Source::unfold_resource(
|| Ok(()),
|_state| Ok(Some(1_u64)),
move |_state| {
close_flag.store(true, Ordering::SeqCst);
Ok(())
},
);
let source_ref = source.run_with(StreamRefs::source_ref()).unwrap();
assert_eq!(source_ref.source().take(1).run_collect().unwrap(), vec![1]);
assert!(wait_until(Duration::from_secs(1), || {
closed.load(Ordering::SeqCst)
}));
}
#[test]
fn sink_ref_cancellation_stops_remote_producer() {
let (sink_ref, completion) = StreamRefs::sink_ref::<u64>()
.take(1)
.to_mat(Sink::collect(), Keep::both)
.run()
.unwrap();
let producer = Source::repeat(1_u64)
.run_with(sink_ref.sink())
.unwrap()
.wait();
assert_eq!(completion.wait().unwrap(), vec![1]);
assert_eq!(producer, Err(StreamError::Cancelled));
}
#[test]
fn source_ref_backpressures_across_ref() {
let pulled = StdArc::new(AtomicUsize::new(0));
let pulled_for_source = StdArc::clone(&pulled);
let source = Source::unfold(0_u64, move |next| {
pulled_for_source.fetch_add(1, Ordering::SeqCst);
Some((next + 1, next))
});
let source_ref = source
.run_with(StreamRefs::source_ref_with_settings(short_settings()))
.unwrap();
let mut probe = source_ref.source().run_with(TestSink::probe()).unwrap();
probe.request(1);
probe.assert_next(0);
assert!(wait_until(Duration::from_secs(1), || {
pulled.load(Ordering::SeqCst) >= 2
}));
assert_condition_holds(Duration::from_millis(50), || {
pulled.load(Ordering::SeqCst) <= 2
});
probe.request(1);
probe.assert_next(1);
assert!(wait_until(Duration::from_secs(1), || {
pulled.load(Ordering::SeqCst) >= 3
}));
assert_condition_holds(Duration::from_millis(50), || {
pulled.load(Ordering::SeqCst) <= 3
});
probe.cancel();
}
#[test]
fn source_ref_late_subscription_observes_timeout() {
let source_ref = Source::repeat(1_u64)
.run_with(StreamRefs::source_ref_with_settings(short_settings()))
.unwrap();
assert!(wait_until(Duration::from_secs(1), || {
source_ref.inner.producer.get_status() == ractor::ActorStatus::Stopped
}));
let result = source_ref.source().run_collect();
assert!(matches!(
result,
Err(StreamError::ActorTerminated) | Err(StreamError::Failed(_))
));
}
#[test]
fn sink_ref_subscription_timeout_fails_local_source() {
let (_sink_ref, probe) = StreamRefs::sink_ref_with_settings::<u64>(short_settings())
.to_mat(TestSink::probe(), Keep::both)
.run()
.unwrap();
probe.request(1);
let error = probe.expect_error();
assert!(
matches!(error, StreamError::Failed(message) if message.contains("did not subscribe"))
);
}
#[test]
fn stream_refs_are_one_shot() {
let source_ref = Source::from_iter([1_u64])
.run_with(StreamRefs::source_ref())
.unwrap();
assert_eq!(source_ref.source().run_collect().unwrap(), vec![1]);
assert!(matches!(
source_ref.source().run_collect(),
Err(StreamError::Failed(message)) if message.contains("already")
));
let (sink_ref, completion) = StreamRefs::sink_ref::<u64>()
.to_mat(Sink::collect(), Keep::both)
.run()
.unwrap();
Source::single(1_u64)
.run_with(sink_ref.sink())
.unwrap()
.wait()
.unwrap();
assert!(matches!(
Source::single(2_u64).run_with(sink_ref.sink()).unwrap().wait(),
Err(StreamError::Failed(message)) if message.contains("already")
));
assert_eq!(completion.wait().unwrap(), vec![1]);
}
}