use futures::{
stream::{Fuse, FusedStream},
Stream, StreamExt as _,
};
use pin_project::pin_project;
use std::{
collections::VecDeque,
pin::Pin,
sync::{Arc, Mutex},
task::{Poll, Waker},
};
#[cfg(test)]
mod tests;
pub trait StreamExt: Stream {
fn fork(self) -> Forked<Self>
where
Self: Sized,
Self::Item: Clone,
{
ForkedInner::init(self)
}
}
impl<S> StreamExt for S where S: Stream {}
#[must_use = "streams do nothing unless polled"]
pub struct Forked<S: Stream>
where
S::Item: Clone,
{
subscriber: Option<BufferSubscriber>,
inner: Arc<Mutex<Pin<Box<ForkedInner<S>>>>>,
}
impl<S: Stream> Clone for Forked<S>
where
S::Item: Clone,
{
fn clone(&self) -> Self {
ForkedInner::create_fork(
self.inner.clone(),
Some(self.subscriber.as_ref().expect("only unset during drop")),
)
}
}
impl<S: Stream> Stream for Forked<S>
where
S::Item: Clone,
{
type Item = S::Item;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let this = Pin::get_mut(self);
let mut inner = this.inner.lock().unwrap();
inner.as_mut().poll(
this.subscriber.as_mut().expect("only unset during drop"),
cx.waker(),
)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let inner = self.inner.lock().unwrap();
inner.size_hint(self.subscriber.as_ref().expect("only unset during drop"))
}
}
impl<S: Stream> FusedStream for Forked<S>
where
S::Item: Clone,
{
fn is_terminated(&self) -> bool {
let inner = self.inner.lock().unwrap();
inner.is_terminated(self.subscriber.as_ref().expect("only unset during drop"))
}
}
impl<S: Stream> Drop for Forked<S>
where
S::Item: Clone,
{
fn drop(&mut self) {
let mut inner = self.inner.lock().unwrap();
inner
.as_mut()
.project()
.buffer
.dispose_of_subscriber(self.subscriber.take().expect("only unset during drop"));
}
}
#[pin_project]
struct ForkedInner<S: Stream>
where
S::Item: Clone,
{
#[pin]
source: Fuse<S>,
buffer: Buffer<S::Item>,
waiting_for_source: Option<Waker>,
waiting_for_buffer: VecDeque<Waker>,
}
impl<S: Stream> ForkedInner<S>
where
S::Item: Clone,
{
fn init(source: S) -> Forked<S> {
let inner = Self {
source: source.fuse(),
buffer: Buffer::new(),
waiting_for_source: None,
waiting_for_buffer: VecDeque::new(),
};
let arc = Arc::new(Mutex::new(Box::pin(inner)));
Self::create_fork(arc, None)
}
fn poll(
self: Pin<&mut Self>,
subscriber: &mut BufferSubscriber,
waker: &'_ Waker,
) -> Poll<Option<S::Item>> {
let this = self.project();
let result = if let Some(item) = this.buffer.read(subscriber) {
Poll::Ready(Some(item))
} else {
if let Some(prev_waker) = this.waiting_for_source.replace(waker.clone()) {
this.waiting_for_buffer.push_back(prev_waker);
}
match Stream::poll_next(this.source, &mut std::task::Context::from_waker(&waker)) {
Poll::Ready(Some(item)) => {
this.buffer.push(item);
let item = this
.buffer
.read(subscriber)
.expect("the item was just pushed into the buffer");
Poll::Ready(Some(item))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
};
if result.is_ready() {
if let Some(waiting_waker) = this.waiting_for_buffer.pop_front() {
waiting_waker.wake();
}
}
result
}
fn is_terminated(&self, subscriber: &BufferSubscriber) -> bool {
self.source.is_terminated() && !self.buffer.has_items_for(subscriber)
}
fn size_hint(&self, subscriber: &BufferSubscriber) -> (usize, Option<usize>) {
let source_size_hint = self.source.size_hint();
let num_buffered_items = self.buffer.num_items_for(subscriber);
(
num_buffered_items + source_size_hint.0,
source_size_hint
.1
.map(|source_upper_bound| num_buffered_items + source_upper_bound),
)
}
fn create_fork(
this: Arc<Mutex<Pin<Box<Self>>>>,
with_offset_of: Option<&BufferSubscriber>,
) -> Forked<S> {
let subscriber = this
.lock()
.unwrap()
.as_mut()
.project()
.buffer
.create_subscriber(with_offset_of);
Forked {
subscriber: Some(subscriber),
inner: this,
}
}
}
impl<S: Stream> Forked<S>
where
S::Item: Clone,
{
pub fn downgrade(&self) -> Weak<S> {
Weak(Arc::downgrade(&self.inner))
}
}
pub struct Weak<S: Stream>(std::sync::Weak<Mutex<Pin<Box<ForkedInner<S>>>>>)
where
S::Item: Clone;
impl<S: Stream> Weak<S>
where
S::Item: Clone,
{
pub fn upgrade(&self) -> Option<Forked<S>> {
self.0
.upgrade()
.map(|inner| ForkedInner::create_fork(inner, None))
}
}
struct Buffer<T: Clone> {
buf: VecDeque<BufferEntry<T>>,
buffer_start_offset: usize,
num_subscribers: usize,
}
struct BufferSubscriber {
offset: usize,
}
#[cfg_attr(test, derive(Debug))]
struct BufferEntry<T> {
item: T,
refs: usize,
}
impl<T: Clone> Buffer<T> {
fn new() -> Self {
Self {
buf: VecDeque::new(),
buffer_start_offset: 0,
num_subscribers: 0,
}
}
fn create_subscriber(&mut self, with_offset_of: Option<&BufferSubscriber>) -> BufferSubscriber {
self.num_subscribers += 1;
let offset = if let Some(with_offset_of) = with_offset_of {
for entry in self
.buf
.iter_mut()
.skip(with_offset_of.offset - self.buffer_start_offset)
{
entry.refs += 1;
}
with_offset_of.offset
} else {
self.buffer_start_offset + self.buf.len()
};
BufferSubscriber { offset }
}
fn push(&mut self, item: T) {
self.buf.push_back(BufferEntry {
item,
refs: self.num_subscribers,
});
}
fn read(&mut self, subscriber: &mut BufferSubscriber) -> Option<T> {
if let Some(entry) = self
.buf
.get_mut(subscriber.offset - self.buffer_start_offset)
{
entry.refs -= 1;
let item = if entry.refs == 0 {
let item = self
.buf
.pop_front()
.expect("there is at least one item in the buffer")
.item;
self.buffer_start_offset += 1;
item
} else {
entry.item.clone()
};
subscriber.offset += 1;
Some(item)
} else {
None
}
}
fn has_items_for(&self, subscriber: &BufferSubscriber) -> bool {
subscriber.offset - self.buffer_start_offset < self.buf.len()
}
fn num_items_for(&self, subscriber: &BufferSubscriber) -> usize {
self.buf.len() - (subscriber.offset - self.buffer_start_offset)
}
fn dispose_of_subscriber(&mut self, subscriber: BufferSubscriber) {
self.num_subscribers -= 1;
let mut entries_to_remove = 0usize;
for entry in self
.buf
.iter_mut()
.skip(subscriber.offset - self.buffer_start_offset)
{
entry.refs -= 1;
if entry.refs == 0 {
entries_to_remove += 1;
}
}
for _ in 0..entries_to_remove {
self.buf.pop_front();
}
self.buffer_start_offset += entries_to_remove;
}
}