use std::{
collections::{HashMap, VecDeque},
hint,
marker::PhantomData,
sync::{
Arc, Condvar, Mutex, MutexGuard,
atomic::{AtomicBool, AtomicU8, AtomicU64, AtomicUsize, Ordering, fence},
mpsc,
},
thread,
time::Duration,
};
use arc_swap::{ArcSwap, ArcSwapOption};
use ractor::{Actor, ActorProcessingErr, ActorRef};
use tokio::sync::Notify;
use crate::{
StreamError, StreamResult,
actor::block_on_ractor_runtime,
stream::{BoxStream, NotUsed, Source, current_stream_cancelled},
};
const SLOT_WAIT_BACKSTOP: Duration = Duration::from_millis(10);
const SUBSCRIPTION_DRAIN_BATCH: usize = 256;
const STATE_OPEN: u8 = 0;
const STATE_CLOSING: u8 = 1;
const STATE_CLOSED: u8 = 2;
const UNSEEDED_CURSOR: u64 = u64::MAX;
const NO_DROP_FROM: u64 = u64::MAX;
const NO_TERMINAL_FROM: u64 = u64::MAX;
type Ack = mpsc::Sender<StreamResult<()>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SubscriptionOverflow {
Backpressure,
DropNew,
Fail,
}
pub struct Subscription<T: Send + Sync + 'static> {
inner: Arc<SubscriptionInner<T>>,
}
struct SubscriptionInner<T: Send + Sync + 'static> {
actor: ActorRef<SubscriptionMessage<T>>,
shared: Arc<SubscriptionShared<T>>,
next_subscriber_id: Arc<AtomicU64>,
}
struct SubscriptionShared<T: Send + Sync + 'static> {
mirror: Arc<ArcSwap<T>>,
published: Arc<ArcSwap<PublishedValue<T>>>,
subscribers: Arc<ArcSwap<SubscriptionSlotTable<T>>>,
ring: SubscriptionRing<T>,
overflow: SubscriptionOverflow,
lifecycle: AtomicU8,
active_writers: AtomicUsize,
next_sequence: AtomicU64,
published_sequence: AtomicU64,
parked_slots: Arc<AtomicUsize>,
}
struct PublishedValue<T: Send + Sync + 'static> {
sequence: u64,
value: Arc<T>,
}
struct SubscriptionSlotTable<T: Send + Sync + 'static> {
slots: Vec<Arc<SubscriptionSlot<T>>>,
}
struct SubscriptionRing<T: Send + Sync + 'static> {
logical_capacity: u64,
physical_capacity: usize,
slots: Vec<SubscriptionRingSlot<T>>,
space_lock: Mutex<()>,
space_available: Condvar,
space_waiters: AtomicUsize,
}
struct SubscriptionRingSlot<T: Send + Sync + 'static> {
sequence: AtomicU64,
value: ArcSwapOption<T>,
}
impl<T: Send + Sync + 'static> Clone for Subscription<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<T: Send + Sync + 'static> Subscription<T> {
pub fn new(initial: T, capacity: usize, overflow: SubscriptionOverflow) -> StreamResult<Self> {
assert!(
capacity > 0,
"subscription capacity must be greater than zero"
);
let value = Arc::new(initial);
let shared = Arc::new(SubscriptionShared {
mirror: Arc::new(ArcSwap::from(Arc::clone(&value))),
published: Arc::new(ArcSwap::from_pointee(PublishedValue {
sequence: 0,
value: Arc::clone(&value),
})),
subscribers: Arc::new(ArcSwap::from_pointee(SubscriptionSlotTable {
slots: Vec::new(),
})),
ring: SubscriptionRing::new(capacity),
overflow,
lifecycle: AtomicU8::new(STATE_OPEN),
active_writers: AtomicUsize::new(0),
next_sequence: AtomicU64::new(0),
published_sequence: AtomicU64::new(0),
parked_slots: Arc::new(AtomicUsize::new(0)),
});
let state = SubscriptionActorState {
shared: Arc::clone(&shared),
subscribers: HashMap::new(),
closed: false,
};
let (actor, _handle) =
block_on_ractor_runtime(Actor::spawn(None, SubscriptionActor::<T>::default(), state))?
.map_err(|error| {
StreamError::Failed(format!("subscription actor failed to spawn: {error}"))
})?;
Ok(Self {
inner: Arc::new(SubscriptionInner {
actor,
shared,
next_subscriber_id: Arc::new(AtomicU64::new(1)),
}),
})
}
#[must_use]
pub fn get(&self) -> Arc<T> {
self.inner.shared.mirror.load_full()
}
#[must_use]
pub fn get_cloned(&self) -> T
where
T: Clone,
{
self.inner.shared.mirror.load().as_ref().clone()
}
pub fn set(&self, value: T) -> StreamResult<()> {
self.publish_set(Arc::new(value))
}
pub fn set_eventually(&self, value: T) -> StreamResult<()> {
self.publish_set(Arc::new(value))
}
pub fn update<F>(&self, update: F) -> StreamResult<()>
where
F: FnMut(&T) -> T + Send + 'static,
{
self.publish_update(update)
}
pub fn update_eventually<F>(&self, update: F) -> StreamResult<()>
where
F: FnMut(&T) -> T + Send + 'static,
{
self.publish_update(update)
}
pub fn close(&self) -> StreamResult<()> {
self.send_close(None)
}
pub fn close_with(&self, final_value: T) -> StreamResult<()> {
self.send_close(Some(final_value))
}
fn publish_set(&self, value: Arc<T>) -> StreamResult<()> {
let _permit = self.inner.shared.begin_write()?;
let sequence = self.inner.shared.claim_sequence();
self.inner.shared.wait_publish_turn(sequence);
self.inner.shared.wait_for_ring_capacity(sequence);
let overflow = self.inner.shared.apply_overflow_policy(sequence);
self.inner.shared.finish_publish(sequence, value);
overflow
}
fn publish_update<F>(&self, mut update: F) -> StreamResult<()>
where
F: FnMut(&T) -> T + Send + 'static,
{
let _permit = self.inner.shared.begin_write()?;
let sequence = self.inner.shared.claim_sequence();
self.inner.shared.wait_publish_turn(sequence);
self.inner.shared.wait_for_ring_capacity(sequence);
let value = loop {
let current = self.inner.shared.mirror.load();
let next = Arc::new(update(current.as_ref()));
let previous = self
.inner
.shared
.mirror
.compare_and_swap(&*current, Arc::clone(&next));
if std::ptr::eq(current.as_ref(), previous.as_ref()) {
break next;
}
};
let overflow = self.inner.shared.apply_overflow_policy(sequence);
self.inner.shared.ring.store(sequence, Arc::clone(&value));
self.inner
.shared
.finish_publish_after_mirror(sequence, value);
overflow
}
fn send_close(&self, final_value: Option<T>) -> StreamResult<()> {
let (reply, receiver) = mpsc::channel();
self.inner
.actor
.send_message(SubscriptionMessage::Close { final_value, reply })
.map_err(|error| StreamError::ActorAskSendFailed {
reason: error.to_string(),
})?;
receiver.recv().unwrap_or(Err(StreamError::ActorTerminated))
}
fn register_slot(&self, slot: Arc<SubscriptionSlot<T>>, id: u64) -> StreamResult<()> {
let (reply, receiver) = mpsc::channel();
self.inner
.actor
.send_message(SubscriptionMessage::Subscribe { id, slot, reply })
.map_err(|error| StreamError::ActorAskSendFailed {
reason: error.to_string(),
})?;
receiver.recv().unwrap_or(Err(StreamError::ActorTerminated))
}
}
impl<T: Clone + Send + Sync + 'static> Subscription<T> {
#[must_use]
pub fn changes(&self) -> Source<T> {
let actor = self.inner.actor.clone();
let subscription = self.clone();
let shared = Arc::clone(&self.inner.shared);
let next_subscriber_id = Arc::clone(&self.inner.next_subscriber_id);
Source::from_materialized_factory(move |_materializer| {
let id = next_subscriber_id.fetch_add(1, Ordering::Relaxed);
let slot = SubscriptionSlot::new(id, actor.clone(), Arc::clone(&shared.parked_slots));
subscription.register_slot(Arc::clone(&slot), id)?;
let stream: BoxStream<T> = Box::new(SubscriptionChangesStream {
shared: Arc::clone(&shared),
slot,
pending: VecDeque::new(),
terminated: false,
});
Ok((stream, NotUsed))
})
}
#[doc(hidden)]
pub fn __benchmark_changes(&self) -> StreamResult<SubscriptionBenchmarkStream<T>> {
let id = self
.inner
.next_subscriber_id
.fetch_add(1, Ordering::Relaxed);
let slot = SubscriptionSlot::new(
id,
self.inner.actor.clone(),
Arc::clone(&self.inner.shared.parked_slots),
);
self.register_slot(Arc::clone(&slot), id)?;
Ok(SubscriptionBenchmarkStream {
shared: Arc::clone(&self.inner.shared),
slot,
pending: VecDeque::new(),
terminated: false,
})
}
}
impl<T: Send + Sync + 'static> SubscriptionShared<T> {
fn begin_write(&self) -> StreamResult<WritePermit<'_>> {
if self.lifecycle.load(Ordering::Acquire) != STATE_OPEN {
return Err(closed_error());
}
self.active_writers.fetch_add(1, Ordering::AcqRel);
if self.lifecycle.load(Ordering::Acquire) == STATE_OPEN {
Ok(WritePermit {
active_writers: &self.active_writers,
})
} else {
self.active_writers.fetch_sub(1, Ordering::AcqRel);
Err(closed_error())
}
}
fn claim_sequence(&self) -> u64 {
self.next_sequence.fetch_add(1, Ordering::AcqRel) + 1
}
fn wait_publish_turn(&self, sequence: u64) {
let mut spins = 0_u32;
while self.published_sequence.load(Ordering::Acquire) + 1 != sequence {
spins = spins.wrapping_add(1);
if spins < 64 {
hint::spin_loop();
} else {
thread::yield_now();
}
}
}
fn wait_for_ring_capacity(&self, sequence: u64) {
if self.overflow != SubscriptionOverflow::Backpressure {
return;
}
let mut guard = self
.ring
.space_lock
.lock()
.unwrap_or_else(|poison| poison.into_inner());
while self.sequence_would_overflow(sequence) {
self.ring.space_waiters.fetch_add(1, Ordering::AcqRel);
if !self.sequence_would_overflow(sequence) {
self.ring.space_waiters.fetch_sub(1, Ordering::AcqRel);
break;
}
guard = self
.ring
.space_available
.wait_timeout(guard, SLOT_WAIT_BACKSTOP)
.unwrap_or_else(|poison| poison.into_inner())
.0;
self.ring.space_waiters.fetch_sub(1, Ordering::AcqRel);
}
}
fn sequence_would_overflow(&self, sequence: u64) -> bool {
let Some(cursor) = self.min_active_cursor() else {
return false;
};
sequence >= cursor.saturating_add(self.ring.logical_capacity)
}
fn min_active_cursor(&self) -> Option<u64> {
let table = self.subscribers.load();
table
.slots
.iter()
.filter_map(|slot| slot.backpressure_cursor())
.min()
}
fn apply_overflow_policy(&self, sequence: u64) -> StreamResult<()> {
match self.overflow {
SubscriptionOverflow::Backpressure => Ok(()),
SubscriptionOverflow::DropNew => {
let table = self.subscribers.load();
for slot in &table.slots {
if slot.is_full_for(sequence, self.ring.logical_capacity) {
slot.drop_new(sequence, self.ring.logical_capacity);
}
}
Ok(())
}
SubscriptionOverflow::Fail => {
let table = self.subscribers.load();
let mut overflowed = false;
let error = overflow_error(self.ring.logical_capacity);
for slot in &table.slots {
if slot.is_full_for(sequence, self.ring.logical_capacity) {
overflowed = true;
slot.fail_after(sequence, error.clone());
}
}
if overflowed { Err(error) } else { Ok(()) }
}
}
}
fn finish_publish(&self, sequence: u64, value: Arc<T>) {
self.ring.store(sequence, Arc::clone(&value));
self.mirror.store(Arc::clone(&value));
self.finish_publish_after_mirror(sequence, value);
}
fn finish_publish_after_mirror(&self, sequence: u64, value: Arc<T>) {
self.published.store(Arc::new(PublishedValue {
sequence,
value: Arc::clone(&value),
}));
self.published_sequence.store(sequence, Ordering::Release);
if self.parked_slots.load(Ordering::Acquire) != 0 {
let table = self.subscribers.load();
for slot in &table.slots {
slot.wake_for_sequence(sequence);
}
}
}
fn wait_for_writers_to_drain(&self) {
while self.active_writers.load(Ordering::Acquire) != 0 {
thread::yield_now();
}
}
}
impl<T: Send + Sync + 'static> SubscriptionRing<T> {
fn new(logical_capacity: usize) -> Self {
let physical_capacity = logical_capacity.max(1_024).next_power_of_two();
let mut slots = Vec::with_capacity(physical_capacity);
for _ in 0..physical_capacity {
slots.push(SubscriptionRingSlot {
sequence: AtomicU64::new(0),
value: ArcSwapOption::empty(),
});
}
Self {
logical_capacity: logical_capacity as u64,
physical_capacity,
slots,
space_lock: Mutex::new(()),
space_available: Condvar::new(),
space_waiters: AtomicUsize::new(0),
}
}
fn store(&self, sequence: u64, value: Arc<T>) {
let slot = &self.slots[self.index(sequence)];
slot.value.store(Some(value));
slot.sequence.store(sequence, Ordering::Release);
}
fn load(&self, sequence: u64) -> Option<Arc<T>> {
let slot = &self.slots[self.index(sequence)];
if slot.sequence.load(Ordering::Acquire) == sequence {
slot.value.load_full()
} else {
None
}
}
fn has(&self, sequence: u64) -> bool {
let slot = &self.slots[self.index(sequence)];
slot.sequence.load(Ordering::Acquire) == sequence
}
fn oldest_available(&self, published_sequence: u64) -> u64 {
published_sequence
.saturating_sub(self.physical_capacity as u64)
.saturating_add(1)
.max(1)
}
fn notify_space(&self) {
if self.space_waiters.load(Ordering::Acquire) == 0 {
return;
}
let _guard = self
.space_lock
.lock()
.unwrap_or_else(|poison| poison.into_inner());
self.space_available.notify_all();
}
fn index(&self, sequence: u64) -> usize {
sequence as usize & (self.physical_capacity - 1)
}
}
struct WritePermit<'a> {
active_writers: &'a AtomicUsize,
}
impl Drop for WritePermit<'_> {
fn drop(&mut self) {
self.active_writers.fetch_sub(1, Ordering::AcqRel);
}
}
impl<T: Send + Sync + 'static> Drop for SubscriptionInner<T> {
fn drop(&mut self) {
self.actor.stop(None);
}
}
enum SubscriptionMessage<T: Send + Sync + 'static> {
Close {
final_value: Option<T>,
reply: Ack,
},
Subscribe {
id: u64,
slot: Arc<SubscriptionSlot<T>>,
reply: Ack,
},
Unsubscribe {
id: u64,
},
}
#[cfg(feature = "cluster")]
impl<T: Send + Sync + 'static> ractor::Message for SubscriptionMessage<T> {}
struct SubscriptionActor<T> {
_marker: PhantomData<fn() -> T>,
}
impl<T> Default for SubscriptionActor<T> {
fn default() -> Self {
Self {
_marker: PhantomData,
}
}
}
struct SubscriptionActorState<T: Send + Sync + 'static> {
shared: Arc<SubscriptionShared<T>>,
subscribers: HashMap<u64, Arc<SubscriptionSlot<T>>>,
closed: bool,
}
impl<T: Send + Sync + 'static> Actor for SubscriptionActor<T> {
type Msg = SubscriptionMessage<T>;
type State = SubscriptionActorState<T>;
type Arguments = SubscriptionActorState<T>;
async fn pre_start(
&self,
_myself: ActorRef<Self::Msg>,
args: Self::Arguments,
) -> Result<Self::State, ActorProcessingErr> {
Ok(args)
}
async fn handle(
&self,
_myself: ActorRef<Self::Msg>,
message: Self::Msg,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
match message {
SubscriptionMessage::Close { final_value, reply } => {
close_subscription(state, final_value);
let _ = reply.send(Ok(()));
}
SubscriptionMessage::Subscribe { id, slot, reply } => {
if state.closed || state.shared.lifecycle.load(Ordering::Acquire) == STATE_CLOSED {
let published = state.shared.published.load_full();
slot.complete_post_close(Arc::clone(&published.value));
} else {
state.subscribers.insert(id, Arc::clone(&slot));
publish_subscription_slot_table(state);
let published = state.shared.published.load_full();
slot.seed(
published.sequence.saturating_add(1),
Arc::clone(&published.value),
);
}
let _ = reply.send(Ok(()));
}
SubscriptionMessage::Unsubscribe { id } => {
state.subscribers.remove(&id);
publish_subscription_slot_table(state);
state.shared.ring.notify_space();
}
}
Ok(())
}
async fn post_stop(
&self,
_myself: ActorRef<Self::Msg>,
state: &mut Self::State,
) -> Result<(), ActorProcessingErr> {
if !state.closed {
for slot in state.subscribers.values() {
slot.fail_now(StreamError::ActorTerminated);
}
state.subscribers.clear();
publish_subscription_slot_table(state);
state.shared.ring.notify_space();
}
Ok(())
}
}
fn close_subscription<T: Send + Sync + 'static>(
state: &mut SubscriptionActorState<T>,
final_value: Option<T>,
) {
if state.closed {
return;
}
match state.shared.lifecycle.compare_exchange(
STATE_OPEN,
STATE_CLOSING,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {}
Err(STATE_CLOSED) => {
state.closed = true;
return;
}
Err(_) => {}
}
state.shared.wait_for_writers_to_drain();
let sequence = state.shared.claim_sequence();
state.shared.wait_publish_turn(sequence);
let value = final_value
.map(Arc::new)
.unwrap_or_else(|| state.shared.mirror.load_full());
state.shared.mirror.store(Arc::clone(&value));
state.shared.published.store(Arc::new(PublishedValue {
sequence,
value: Arc::clone(&value),
}));
state
.shared
.published_sequence
.store(sequence, Ordering::Release);
state
.shared
.lifecycle
.store(STATE_CLOSED, Ordering::Release);
for slot in state.subscribers.values() {
slot.complete_with_final(sequence, Arc::clone(&value));
}
state.subscribers.clear();
publish_subscription_slot_table(state);
state.shared.ring.notify_space();
state.closed = true;
}
fn publish_subscription_slot_table<T: Send + Sync + 'static>(state: &SubscriptionActorState<T>) {
let slots = state.subscribers.values().cloned().collect::<Vec<_>>();
state
.shared
.subscribers
.store(Arc::new(SubscriptionSlotTable { slots }));
}
fn closed_error() -> StreamError {
StreamError::Failed("subscription is closed".into())
}
fn overflow_error(capacity: u64) -> StreamError {
StreamError::Failed(format!(
"subscription buffer overflow (max capacity was: {capacity})"
))
}
fn atomic_fetch_min(target: &AtomicU64, value: u64) {
let mut current = target.load(Ordering::Acquire);
while value < current {
match target.compare_exchange(current, value, Ordering::AcqRel, Ordering::Acquire) {
Ok(_) => return,
Err(observed) => current = observed,
}
}
}
fn atomic_fetch_max(target: &AtomicU64, value: u64) {
let mut current = target.load(Ordering::Acquire);
while value > current {
match target.compare_exchange(current, value, Ordering::AcqRel, Ordering::Acquire) {
Ok(_) => return,
Err(observed) => current = observed,
}
}
}
struct SubscriptionSlot<T: Send + Sync + 'static> {
id: u64,
actor: ActorRef<SubscriptionMessage<T>>,
parked_count: Arc<AtomicUsize>,
cursor: AtomicU64,
active: AtomicBool,
parked: AtomicBool,
drop_from: AtomicU64,
drop_through: AtomicU64,
terminal_from: AtomicU64,
state: Mutex<SubscriptionSlotState<T>>,
available: Condvar,
async_available: Notify,
}
struct SubscriptionSlotState<T: Send + Sync + 'static> {
seed: Option<Arc<T>>,
terminal: Option<SubscriptionSlotTerminal<T>>,
}
#[derive(Clone)]
enum SubscriptionSlotTerminal<T: Send + Sync + 'static> {
Complete {
final_sequence: u64,
final_value: Option<Arc<T>>,
},
Error {
after_sequence: u64,
error: StreamError,
},
}
impl<T: Send + Sync + 'static> SubscriptionSlot<T> {
fn new(
id: u64,
actor: ActorRef<SubscriptionMessage<T>>,
parked_count: Arc<AtomicUsize>,
) -> Arc<Self> {
Arc::new(Self {
id,
actor,
parked_count,
cursor: AtomicU64::new(UNSEEDED_CURSOR),
active: AtomicBool::new(true),
parked: AtomicBool::new(false),
drop_from: AtomicU64::new(NO_DROP_FROM),
drop_through: AtomicU64::new(0),
terminal_from: AtomicU64::new(NO_TERMINAL_FROM),
state: Mutex::new(SubscriptionSlotState {
seed: None,
terminal: None,
}),
available: Condvar::new(),
async_available: Notify::new(),
})
}
fn lock(&self) -> MutexGuard<'_, SubscriptionSlotState<T>> {
self.state
.lock()
.unwrap_or_else(|poison| poison.into_inner())
}
fn seed(&self, next_sequence: u64, value: Arc<T>) {
self.cursor.store(next_sequence, Ordering::Release);
let mut state = self.lock();
state.seed = Some(value);
drop(state);
self.wake();
}
fn complete_post_close(&self, value: Arc<T>) {
self.cursor.store(0, Ordering::Release);
let mut state = self.lock();
state.seed = Some(value);
state.terminal = Some(SubscriptionSlotTerminal::Complete {
final_sequence: 0,
final_value: None,
});
drop(state);
self.terminal_from.store(0, Ordering::Release);
self.active.store(false, Ordering::Release);
self.wake();
}
fn complete_with_final(&self, final_sequence: u64, value: Arc<T>) {
let mut state = self.lock();
if state.terminal.is_none() {
state.terminal = Some(SubscriptionSlotTerminal::Complete {
final_sequence,
final_value: Some(value),
});
}
drop(state);
self.terminal_from
.fetch_min(final_sequence, Ordering::AcqRel);
self.wake();
}
fn fail_after(&self, after_sequence: u64, error: StreamError) {
let mut state = self.lock();
if state.terminal.is_none() {
state.terminal = Some(SubscriptionSlotTerminal::Error {
after_sequence,
error,
});
}
drop(state);
self.terminal_from
.fetch_min(after_sequence, Ordering::AcqRel);
self.active.store(false, Ordering::Release);
self.wake();
}
fn fail_now(&self, error: StreamError) {
let cursor = self.cursor.load(Ordering::Acquire);
let after_sequence = if cursor == UNSEEDED_CURSOR { 0 } else { cursor };
self.fail_after(after_sequence, error);
}
fn is_full_for(&self, sequence: u64, capacity: u64) -> bool {
if !self.active.load(Ordering::Acquire) {
return false;
}
let cursor = self.cursor.load(Ordering::Acquire);
cursor != UNSEEDED_CURSOR && sequence >= cursor.saturating_add(capacity)
}
fn backpressure_cursor(&self) -> Option<u64> {
if !self.active.load(Ordering::Acquire) {
return None;
}
let cursor = self.cursor.load(Ordering::Acquire);
(cursor != UNSEEDED_CURSOR).then_some(cursor)
}
fn drop_new(&self, sequence: u64, capacity: u64) {
let cursor = self.cursor.load(Ordering::Acquire);
if cursor == UNSEEDED_CURSOR {
return;
}
let from = cursor.saturating_add(capacity);
if sequence >= from {
atomic_fetch_min(&self.drop_from, from);
atomic_fetch_max(&self.drop_through, sequence);
self.wake();
}
}
fn skip_dropped(&self, cursor: u64) -> Option<u64> {
let from = self.drop_from.load(Ordering::Acquire);
let through = self.drop_through.load(Ordering::Acquire);
if from != NO_DROP_FROM && cursor >= from && cursor <= through {
self.drop_from.store(NO_DROP_FROM, Ordering::Release);
self.drop_through.store(0, Ordering::Release);
Some(through.saturating_add(1))
} else {
None
}
}
fn has_dropped(&self, cursor: u64) -> bool {
let from = self.drop_from.load(Ordering::Acquire);
let through = self.drop_through.load(Ordering::Acquire);
from != NO_DROP_FROM && cursor >= from && cursor <= through
}
fn terminal_blocks(&self, cursor: u64) -> bool {
cursor >= self.terminal_from.load(Ordering::Acquire)
}
fn wake(&self) {
if self.parked.swap(false, Ordering::AcqRel) {
self.parked_count.fetch_sub(1, Ordering::AcqRel);
self.available.notify_all();
self.async_available.notify_waiters();
}
}
fn park(&self) {
if !self.parked.swap(true, Ordering::AcqRel) {
self.parked_count.fetch_add(1, Ordering::AcqRel);
}
}
fn unpark(&self) {
if self.parked.swap(false, Ordering::AcqRel) {
self.parked_count.fetch_sub(1, Ordering::AcqRel);
}
}
fn wake_for_sequence(&self, sequence: u64) {
if self.cursor.load(Ordering::Acquire) == sequence {
self.wake();
}
}
fn unsubscribe(&self) {
self.active.store(false, Ordering::Release);
let _ = self
.actor
.send_message(SubscriptionMessage::Unsubscribe { id: self.id });
}
}
struct SubscriptionChangesStream<T: Clone + Send + Sync + 'static> {
shared: Arc<SubscriptionShared<T>>,
slot: Arc<SubscriptionSlot<T>>,
pending: VecDeque<Arc<T>>,
terminated: bool,
}
#[doc(hidden)]
pub struct SubscriptionBenchmarkStream<T: Clone + Send + Sync + 'static> {
shared: Arc<SubscriptionShared<T>>,
slot: Arc<SubscriptionSlot<T>>,
pending: VecDeque<Arc<T>>,
terminated: bool,
}
impl<T: Clone + Send + Sync + 'static> Iterator for SubscriptionChangesStream<T> {
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.terminated {
return None;
}
loop {
if let Some(value) = self.pending.pop_front() {
return Some(Ok(value.as_ref().clone()));
}
if let Some(item) = self.poll_seed_or_terminal() {
return item;
}
let cursor = self.slot.cursor.load(Ordering::Acquire);
if cursor == UNSEEDED_CURSOR {
self.wait_for_wake();
continue;
}
if let Some(next_cursor) = self.slot.skip_dropped(cursor) {
self.slot.cursor.store(next_cursor, Ordering::Release);
self.shared.ring.notify_space();
continue;
}
if let Some(value) = self.drain_available(cursor) {
return Some(Ok(value.as_ref().clone()));
}
let published = self.shared.published_sequence.load(Ordering::Acquire);
if cursor <= published {
let oldest = self.shared.ring.oldest_available(published);
if cursor < oldest {
match self.shared.overflow {
SubscriptionOverflow::DropNew => {
self.slot.cursor.store(oldest, Ordering::Release);
self.shared.ring.notify_space();
continue;
}
SubscriptionOverflow::Backpressure | SubscriptionOverflow::Fail => {
self.terminated = true;
return Some(Err(overflow_error(self.shared.ring.logical_capacity)));
}
}
}
}
if current_stream_cancelled()
.as_ref()
.is_some_and(|cancelled| cancelled.load(Ordering::SeqCst))
{
self.terminated = true;
return Some(Err(StreamError::Cancelled));
}
self.wait_for_wake();
}
}
}
impl<T: Clone + Send + Sync + 'static> SubscriptionChangesStream<T> {
fn drain_available(&mut self, start_cursor: u64) -> Option<Arc<T>> {
let mut cursor = start_cursor;
let first = self.shared.ring.load(cursor)?;
cursor = cursor.saturating_add(1);
let mut drained = 1_usize;
while drained < SUBSCRIPTION_DRAIN_BATCH {
if self.slot.has_dropped(cursor) || self.slot.terminal_blocks(cursor) {
break;
}
let Some(value) = self.shared.ring.load(cursor) else {
break;
};
self.pending.push_back(value);
cursor = cursor.saturating_add(1);
drained += 1;
}
self.slot.cursor.store(cursor, Ordering::Release);
self.shared.ring.notify_space();
Some(first)
}
fn poll_seed_or_terminal(&mut self) -> Option<Option<StreamResult<T>>> {
let mut state = self.slot.lock();
if let Some(seed) = state.seed.take() {
return Some(Some(Ok(seed.as_ref().clone())));
}
let cursor = self.slot.cursor.load(Ordering::Acquire);
if let Some(terminal) = &mut state.terminal {
match terminal {
SubscriptionSlotTerminal::Complete {
final_sequence,
final_value,
} => {
if cursor >= *final_sequence {
if let Some(value) = final_value.take() {
return Some(Some(Ok(value.as_ref().clone())));
}
self.terminated = true;
return Some(None);
}
}
SubscriptionSlotTerminal::Error {
after_sequence,
error,
} => {
if cursor >= *after_sequence {
self.terminated = true;
return Some(Some(Err(error.clone())));
}
}
}
}
None
}
fn wait_for_wake(&self) {
let state = self.slot.lock();
self.slot.park();
fence(Ordering::SeqCst);
let cursor = self.slot.cursor.load(Ordering::Acquire);
if state.seed.is_some()
|| state.terminal.is_some()
|| self.slot.has_dropped(cursor)
|| (cursor != UNSEEDED_CURSOR && self.shared.ring.has(cursor))
{
self.slot.unpark();
return;
}
let _guard = self
.slot
.available
.wait_timeout(state, SLOT_WAIT_BACKSTOP)
.unwrap_or_else(|poison| poison.into_inner())
.0;
self.slot.unpark();
}
}
impl<T: Clone + Send + Sync + 'static> Drop for SubscriptionChangesStream<T> {
fn drop(&mut self) {
self.slot.unsubscribe();
self.shared.ring.notify_space();
}
}
impl<T: Clone + Send + Sync + 'static> SubscriptionBenchmarkStream<T> {
#[doc(hidden)]
pub async fn next(&mut self) -> Option<StreamResult<T>> {
if self.terminated {
return None;
}
loop {
if let Some(value) = self.pending.pop_front() {
return Some(Ok(value.as_ref().clone()));
}
if let Some(item) = self.poll_seed_or_terminal() {
return item;
}
let cursor = self.slot.cursor.load(Ordering::Acquire);
if cursor == UNSEEDED_CURSOR {
self.wait_for_wake().await;
continue;
}
if let Some(next_cursor) = self.slot.skip_dropped(cursor) {
self.slot.cursor.store(next_cursor, Ordering::Release);
self.shared.ring.notify_space();
continue;
}
if let Some(value) = self.drain_available(cursor) {
return Some(Ok(value.as_ref().clone()));
}
let published = self.shared.published_sequence.load(Ordering::Acquire);
if cursor <= published {
let oldest = self.shared.ring.oldest_available(published);
if cursor < oldest {
match self.shared.overflow {
SubscriptionOverflow::DropNew => {
self.slot.cursor.store(oldest, Ordering::Release);
self.shared.ring.notify_space();
continue;
}
SubscriptionOverflow::Backpressure | SubscriptionOverflow::Fail => {
self.terminated = true;
return Some(Err(overflow_error(self.shared.ring.logical_capacity)));
}
}
}
}
self.wait_for_wake().await;
}
}
#[doc(hidden)]
pub async fn count_changes(&mut self, target: u64) -> StreamResult<u64> {
let mut count = 0_u64;
while count < target {
if self.terminated {
return Err(StreamError::Failed(
"subscription stream ended before requested count".into(),
));
}
if !self.pending.is_empty() {
let drained = self.pending.len().min((target - count) as usize);
self.pending.drain(..drained);
count += drained as u64;
continue;
}
if let Some(item) = self.poll_seed_or_terminal() {
match item {
Some(Ok(_)) => {
count += 1;
continue;
}
Some(Err(error)) => return Err(error),
None => {
return Err(StreamError::Failed(
"subscription stream completed before requested count".into(),
));
}
}
}
let cursor = self.slot.cursor.load(Ordering::Acquire);
if cursor == UNSEEDED_CURSOR {
self.wait_for_wake().await;
continue;
}
if let Some(next_cursor) = self.slot.skip_dropped(cursor) {
self.slot.cursor.store(next_cursor, Ordering::Release);
self.shared.ring.notify_space();
continue;
}
if let Some(drained) = self.drain_available_count(cursor, (target - count) as usize) {
count += drained as u64;
continue;
}
let published = self.shared.published_sequence.load(Ordering::Acquire);
if cursor <= published {
let oldest = self.shared.ring.oldest_available(published);
if cursor < oldest {
return Err(overflow_error(self.shared.ring.logical_capacity));
}
}
self.wait_for_wake().await;
}
Ok(count)
}
fn drain_available(&mut self, start_cursor: u64) -> Option<Arc<T>> {
let mut cursor = start_cursor;
let first = self.shared.ring.load(cursor)?;
cursor = cursor.saturating_add(1);
let mut drained = 1_usize;
while drained < SUBSCRIPTION_DRAIN_BATCH {
if self.slot.has_dropped(cursor) || self.slot.terminal_blocks(cursor) {
break;
}
let Some(value) = self.shared.ring.load(cursor) else {
break;
};
self.pending.push_back(value);
cursor = cursor.saturating_add(1);
drained += 1;
}
self.slot.cursor.store(cursor, Ordering::Release);
self.shared.ring.notify_space();
Some(first)
}
fn drain_available_count(&mut self, start_cursor: u64, limit: usize) -> Option<usize> {
if self.slot.has_dropped(start_cursor) || self.slot.terminal_blocks(start_cursor) {
return None;
}
let published = self.shared.published_sequence.load(Ordering::Acquire);
if start_cursor > published {
return None;
}
let oldest = self.shared.ring.oldest_available(published);
if start_cursor < oldest {
return None;
}
let available = published.saturating_sub(start_cursor).saturating_add(1) as usize;
let limit = limit.min(SUBSCRIPTION_DRAIN_BATCH);
let drained = available.min(limit);
if drained == 0 {
return None;
}
self.slot.cursor.store(
start_cursor.saturating_add(drained as u64),
Ordering::Release,
);
self.shared.ring.notify_space();
Some(drained)
}
fn poll_seed_or_terminal(&mut self) -> Option<Option<StreamResult<T>>> {
let mut state = self.slot.lock();
if let Some(seed) = state.seed.take() {
return Some(Some(Ok(seed.as_ref().clone())));
}
let cursor = self.slot.cursor.load(Ordering::Acquire);
if let Some(terminal) = &mut state.terminal {
match terminal {
SubscriptionSlotTerminal::Complete {
final_sequence,
final_value,
} => {
if cursor >= *final_sequence {
if let Some(value) = final_value.take() {
return Some(Some(Ok(value.as_ref().clone())));
}
self.terminated = true;
return Some(None);
}
}
SubscriptionSlotTerminal::Error {
after_sequence,
error,
} => {
if cursor >= *after_sequence {
self.terminated = true;
return Some(Some(Err(error.clone())));
}
}
}
}
None
}
async fn wait_for_wake(&self) {
let notified = self.slot.async_available.notified();
tokio::pin!(notified);
notified.as_mut().enable();
{
let state = self.slot.lock();
self.slot.park();
fence(Ordering::SeqCst);
let cursor = self.slot.cursor.load(Ordering::Acquire);
if state.seed.is_some()
|| state.terminal.is_some()
|| self.slot.has_dropped(cursor)
|| (cursor != UNSEEDED_CURSOR && self.shared.ring.has(cursor))
{
self.slot.unpark();
return;
}
}
notified.await;
self.slot.unpark();
}
}
impl<T: Clone + Send + Sync + 'static> Drop for SubscriptionBenchmarkStream<T> {
fn drop(&mut self) {}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Sink, stream::Materializer};
use std::{
sync::{
Arc,
atomic::{AtomicBool, AtomicUsize},
},
thread,
time::{Duration, Instant},
};
fn wait<T>(completion: crate::StreamCompletion<T>) -> T {
completion.wait().unwrap()
}
fn wait_until<F>(timeout: Duration, mut condition: F) -> bool
where
F: FnMut() -> bool,
{
let deadline = Instant::now() + timeout;
while Instant::now() < deadline {
if condition() {
return true;
}
thread::yield_now();
}
condition()
}
#[test]
fn get_snapshot_and_acked_set_read_your_writes() {
let subscription = Subscription::new(1_u64, 8, SubscriptionOverflow::Backpressure).unwrap();
assert_eq!(*subscription.get(), 1);
assert_eq!(subscription.get_cloned(), 1);
subscription.set(2).unwrap();
assert_eq!(*subscription.get(), 2);
assert_eq!(subscription.get_cloned(), 2);
subscription.update(|value| *value + 1).unwrap();
assert_eq!(*subscription.get(), 3);
assert_eq!(subscription.get_cloned(), 3);
}
#[test]
fn lossless_backpressure_subscribers_see_all_changes() {
const SUBSCRIBERS: usize = 4;
const WRITES: u64 = 128;
let subscription =
Subscription::new(0_u64, 256, SubscriptionOverflow::Backpressure).unwrap();
let completions = (0..SUBSCRIBERS)
.map(|_| subscription.changes().run_with(Sink::collect()).unwrap())
.collect::<Vec<_>>();
for value in 1..=WRITES {
subscription.set(value).unwrap();
}
subscription.close_with(WRITES + 1).unwrap();
for completion in completions {
let values = wait(completion);
let expected = (0..=WRITES + 1).collect::<Vec<_>>();
assert_eq!(values, expected);
}
}
#[test]
fn backpressure_parks_producer_ack_until_capacity_returns() {
let subscription = Subscription::new(0_u64, 1, SubscriptionOverflow::Backpressure).unwrap();
let seen = Arc::new(Mutex::new(Vec::new()));
let gate = Arc::new(AtomicBool::new(false));
let sink_seen = Arc::clone(&seen);
let sink_gate = Arc::clone(&gate);
let completion = subscription
.changes()
.run_with(Sink::foreach(move |item| {
sink_seen.lock().unwrap().push(item);
while !sink_gate.load(Ordering::SeqCst) {
thread::yield_now();
}
}))
.unwrap();
assert!(wait_until(Duration::from_secs(1), || {
seen.lock().unwrap().as_slice() == [0]
}));
subscription.set(1).unwrap();
let producer_subscription = subscription.clone();
let completed = Arc::new(AtomicBool::new(false));
let producer_completed = Arc::clone(&completed);
let producer = thread::spawn(move || {
producer_subscription.set(2).unwrap();
producer_completed.store(true, Ordering::SeqCst);
});
assert!(!wait_until(Duration::from_millis(25), || completed
.load(Ordering::SeqCst)));
assert_eq!(*subscription.get(), 1);
gate.store(true, Ordering::SeqCst);
assert!(wait_until(Duration::from_secs(1), || completed.load(Ordering::SeqCst)));
producer.join().unwrap();
assert_eq!(*subscription.get(), 2);
subscription.close_with(3).unwrap();
wait(completion);
assert_eq!(seen.lock().unwrap().as_slice(), [0, 1, 2, 3]);
}
#[test]
fn drop_new_policy_drops_only_full_subscribers() {
let subscription = Subscription::new(0_u64, 1, SubscriptionOverflow::DropNew).unwrap();
let seen = Arc::new(Mutex::new(Vec::new()));
let gate = Arc::new(AtomicBool::new(false));
let sink_seen = Arc::clone(&seen);
let sink_gate = Arc::clone(&gate);
let completion = subscription
.changes()
.run_with(Sink::foreach(move |item| {
sink_seen.lock().unwrap().push(item);
while !sink_gate.load(Ordering::SeqCst) {
thread::yield_now();
}
}))
.unwrap();
assert!(wait_until(Duration::from_secs(1), || {
seen.lock().unwrap().as_slice() == [0]
}));
subscription.set(1).unwrap();
subscription.set(2).unwrap();
subscription.close_with(3).unwrap();
gate.store(true, Ordering::SeqCst);
wait(completion);
assert_eq!(seen.lock().unwrap().as_slice(), [0, 1, 3]);
assert_eq!(*subscription.get(), 3);
}
#[test]
fn fail_policy_fails_full_subscriber_and_reports_overflow() {
let subscription = Subscription::new(0_u64, 1, SubscriptionOverflow::Fail).unwrap();
let seen = Arc::new(Mutex::new(Vec::new()));
let gate = Arc::new(AtomicBool::new(false));
let sink_seen = Arc::clone(&seen);
let sink_gate = Arc::clone(&gate);
let completion = subscription
.changes()
.run_with(Sink::foreach(move |item| {
sink_seen.lock().unwrap().push(item);
while !sink_gate.load(Ordering::SeqCst) {
thread::yield_now();
}
}))
.unwrap();
assert!(wait_until(Duration::from_secs(1), || {
seen.lock().unwrap().as_slice() == [0]
}));
subscription.set(1).unwrap();
assert!(matches!(
subscription.set(2),
Err(StreamError::Failed(message)) if message.contains("subscription buffer overflow")
));
gate.store(true, Ordering::SeqCst);
assert!(matches!(
completion.wait(),
Err(StreamError::Failed(message)) if message.contains("subscription buffer overflow")
));
assert_eq!(seen.lock().unwrap().as_slice(), [0, 1]);
assert_eq!(*subscription.get(), 2);
}
#[test]
fn terminal_ordering_and_post_close_subscribe() {
let subscription = Subscription::new(0_u64, 8, SubscriptionOverflow::Backpressure).unwrap();
let completion = subscription.changes().run_with(Sink::collect()).unwrap();
subscription.set(1).unwrap();
subscription.close_with(9).unwrap();
assert_eq!(wait(completion), vec![0, 1, 9]);
let post_close = subscription.changes().run_collect().unwrap();
assert_eq!(post_close, vec![9]);
}
#[test]
fn dropping_feed_source_cancels_and_unsubscribes() {
let subscription = Subscription::new(0_u64, 1, SubscriptionOverflow::Backpressure).unwrap();
let pulled = Arc::new(AtomicUsize::new(0));
let sink_pulled = Arc::clone(&pulled);
let completion = subscription
.changes()
.run_with(Sink::foreach(move |_| {
sink_pulled.fetch_add(1, Ordering::SeqCst);
}))
.unwrap();
assert!(wait_until(Duration::from_secs(1), || {
pulled.load(Ordering::SeqCst) == 1
}));
drop(completion);
assert!(wait_until(Duration::from_secs(1), || subscription
.set(1)
.is_ok()));
}
#[test]
fn actor_death_fails_feed() {
let subscription = Subscription::new(0_u64, 8, SubscriptionOverflow::Backpressure).unwrap();
let materializer = Materializer::new();
let completion = subscription
.changes()
.drop(1)
.run_with_materializer(Sink::head(), &materializer)
.unwrap();
drop(subscription);
match completion.wait() {
Err(StreamError::ActorTerminated) => {}
other => panic!("expected actor termination, got {other:?}"),
}
}
}