use std::{
fmt, ops,
pin::Pin,
task::{ready, Context, Poll},
};
use futures_core::Stream;
use imbl::Vector;
use tokio::sync::broadcast::{
self,
error::{RecvError, TryRecvError},
Receiver, Sender,
};
use tokio_util::sync::ReusableBoxFuture;
#[cfg(feature = "tracing")]
use tracing::info;
mod entry;
pub use self::entry::{ObservableVectorEntries, ObservableVectorEntry};
pub struct ObservableVector<T> {
values: Vector<T>,
sender: Sender<BroadcastMessage<T>>,
}
impl<T: Clone + Send + Sync + 'static> ObservableVector<T> {
pub fn new() -> Self {
Self::with_capacity(16)
}
pub fn with_capacity(capacity: usize) -> Self {
let (sender, _) = broadcast::channel(capacity);
Self { values: Vector::new(), sender }
}
pub fn into_inner(self) -> Vector<T> {
self.values
}
pub fn subscribe(&self) -> VectorSubscriber<T> {
let rx = self.sender.subscribe();
VectorSubscriber::new(rx)
}
pub fn append(&mut self, values: Vector<T>) {
self.values.append(values.clone());
self.broadcast_diff(VectorDiff::Append { values });
}
pub fn clear(&mut self) {
self.values.clear();
self.broadcast_diff(VectorDiff::Clear);
}
pub fn push_front(&mut self, value: T) {
self.values.push_front(value.clone());
self.broadcast_diff(VectorDiff::PushFront { value });
}
pub fn push_back(&mut self, value: T) {
self.values.push_back(value.clone());
self.broadcast_diff(VectorDiff::PushBack { value });
}
pub fn pop_front(&mut self) -> Option<T> {
let value = self.values.pop_front();
if value.is_some() {
self.broadcast_diff(VectorDiff::PopFront);
}
value
}
pub fn pop_back(&mut self) -> Option<T> {
let value = self.values.pop_back();
if value.is_some() {
self.broadcast_diff(VectorDiff::PopBack);
}
value
}
#[track_caller]
pub fn insert(&mut self, index: usize, value: T) {
let len = self.values.len();
if index <= len {
self.values.insert(index, value.clone());
self.broadcast_diff(VectorDiff::Insert { index, value });
} else {
panic!("index out of bounds: the length is {len} but the index is {index}");
}
}
#[track_caller]
pub fn set(&mut self, index: usize, value: T) -> T {
let len = self.values.len();
if index < len {
let old_value = self.values.set(index, value.clone());
self.broadcast_diff(VectorDiff::Set { index, value });
old_value
} else {
panic!("index out of bounds: the length is {len} but the index is {index}");
}
}
#[track_caller]
pub fn remove(&mut self, index: usize) -> T {
let len = self.values.len();
if index < len {
let value = self.values.remove(index);
self.broadcast_diff(VectorDiff::Remove { index });
value
} else {
panic!("index out of bounds: the length is {len} but the index is {index}");
}
}
#[track_caller]
pub fn entry(&mut self, index: usize) -> ObservableVectorEntry<'_, T> {
let len = self.values.len();
if index < len {
ObservableVectorEntry::new(self, index)
} else {
panic!("index out of bounds: the length is {len} but the index is {index}");
}
}
pub fn for_each(&mut self, mut f: impl FnMut(ObservableVectorEntry<'_, T>)) {
let mut entries = self.entries();
while let Some(entry) = entries.next() {
f(entry);
}
}
pub fn entries(&mut self) -> ObservableVectorEntries<'_, T> {
ObservableVectorEntries::new(self)
}
fn broadcast_diff(&self, diff: VectorDiff<T>) {
if self.sender.receiver_count() != 0 {
let msg = BroadcastMessage { diff, state: self.values.clone() };
let _num_receivers = self.sender.send(msg).unwrap_or(0);
#[cfg(feature = "tracing")]
tracing::debug!("New observable value broadcast to {_num_receivers} receivers");
}
}
}
impl<T: Clone + Send + Sync + 'static> Default for ObservableVector<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> fmt::Debug for ObservableVector<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ObservableVector").field("values", &self.values).finish_non_exhaustive()
}
}
impl<T> ops::Deref for ObservableVector<T> {
type Target = Vector<T>;
fn deref(&self) -> &Self::Target {
&self.values
}
}
impl<T: Clone + Send + Sync + 'static> From<Vector<T>> for ObservableVector<T> {
fn from(values: Vector<T>) -> Self {
let mut this = Self::new();
this.append(values);
this
}
}
#[derive(Clone)]
struct BroadcastMessage<T> {
diff: VectorDiff<T>,
state: Vector<T>,
}
#[derive(Debug)]
pub struct VectorSubscriber<T> {
inner: ReusableBoxFuture<'static, SubscriberFutureReturn<BroadcastMessage<T>>>,
}
impl<T: Clone + Send + Sync + 'static> VectorSubscriber<T> {
fn new(rx: Receiver<BroadcastMessage<T>>) -> Self {
Self { inner: ReusableBoxFuture::new(make_future(rx)) }
}
}
impl<T: Clone + Send + Sync + 'static> Stream for VectorSubscriber<T> {
type Item = VectorDiff<T>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (result, mut rx) = ready!(self.inner.poll(cx));
let poll = match result {
Ok(msg) => Poll::Ready(Some(msg.diff)),
Err(RecvError::Closed) => Poll::Ready(None),
Err(RecvError::Lagged(_)) => {
let mut msg = None;
loop {
match rx.try_recv() {
Ok(m) => {
msg = Some(m);
}
Err(TryRecvError::Closed) => {
#[cfg(feature = "tracing")]
info!("Channel closed after lag, can't return last state");
break Poll::Ready(None);
}
Err(TryRecvError::Lagged(_)) => {}
Err(TryRecvError::Empty) => match msg {
Some(msg) => {
break Poll::Ready(Some(VectorDiff::Reset { values: msg.state }));
}
None => unreachable!("got no new message via try_recv after lag"),
},
}
}
}
};
self.inner.set(make_future(rx));
poll
}
}
type SubscriberFutureReturn<T> = (Result<T, RecvError>, Receiver<T>);
async fn make_future<T: Clone>(mut rx: Receiver<T>) -> SubscriberFutureReturn<T> {
let result = rx.recv().await;
(result, rx)
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum VectorDiff<T> {
Append {
values: Vector<T>,
},
Clear,
PushFront {
value: T,
},
PushBack {
value: T,
},
PopFront,
PopBack,
Insert {
index: usize,
value: T,
},
Set {
index: usize,
value: T,
},
Remove {
index: usize,
},
Reset {
values: Vector<T>,
},
}
impl<T: Clone> VectorDiff<T> {
pub fn map<U: Clone>(self, mut f: impl FnMut(T) -> U) -> VectorDiff<U> {
match self {
VectorDiff::Append { values } => VectorDiff::Append { values: vector_map(values, f) },
VectorDiff::Clear => VectorDiff::Clear,
VectorDiff::PushFront { value } => VectorDiff::PushFront { value: f(value) },
VectorDiff::PushBack { value } => VectorDiff::PushBack { value: f(value) },
VectorDiff::PopFront => VectorDiff::PopFront,
VectorDiff::PopBack => VectorDiff::PopBack,
VectorDiff::Insert { index, value } => VectorDiff::Insert { index, value: f(value) },
VectorDiff::Set { index, value } => VectorDiff::Set { index, value: f(value) },
VectorDiff::Remove { index } => VectorDiff::Remove { index },
VectorDiff::Reset { values } => VectorDiff::Reset { values: vector_map(values, f) },
}
}
}
fn vector_map<T: Clone, U: Clone>(v: Vector<T>, f: impl FnMut(T) -> U) -> Vector<U> {
v.into_iter().map(f).collect()
}