use std::{
any::Any,
error::Error,
future::Future,
pin::Pin,
sync::{Arc, Mutex, RwLock},
thread::JoinHandle as ThreadJoinHandle,
};
use tokio::runtime;
use tokio::task::{self, JoinHandle};
use crate::observer::Observer;
#[doc(hidden)]
pub trait Fuse {
fn set_fused(&mut self, _: bool, _: bool) {}
fn get_fused(&self) -> (bool, bool) {
(false, false)
}
}
pub trait Subscribeable: Fuse {
type ObsType;
fn subscribe(&mut self, s: Subscriber<Self::ObsType>) -> Subscription;
fn is_subject(&self) -> bool {
true
}
#[doc(hidden)]
fn set_subject_indicator(&mut self, _: bool) {}
}
pub trait Unsubscribeable {
fn unsubscribe(self);
}
type NextFn<T> = Box<dyn FnMut(T) + Send>;
type CompleteFn = Box<dyn FnMut() + Send + Sync>;
type ErrorFn = Box<dyn FnMut(Arc<dyn Error + Send + Sync>) + Send + Sync>;
#[allow(clippy::struct_excessive_bools)]
pub struct Subscriber<NextFnType> {
next_fn: NextFn<NextFnType>,
complete_fn: Option<CompleteFn>,
error_fn: Option<ErrorFn>,
completed: bool,
pub(crate) fused: bool,
pub(crate) defused: bool,
pub(crate) take_wrapped: bool,
errored: bool,
}
impl<NextFnType> Subscriber<NextFnType> {
pub fn new(
next_fn: impl FnMut(NextFnType) + 'static + Send,
error_fn: impl FnMut(Arc<dyn Error + Send + Sync>) + 'static + Send + Sync,
complete_fn: impl FnMut() + 'static + Send + Sync,
) -> Self {
Subscriber {
next_fn: Box::new(next_fn),
complete_fn: Some(Box::new(complete_fn)),
error_fn: Some(Box::new(error_fn)),
completed: false,
fused: false,
defused: false,
take_wrapped: false,
errored: false,
}
}
pub fn on_next(next_fn: impl FnMut(NextFnType) + 'static + Send) -> Self {
Subscriber {
next_fn: Box::new(next_fn),
complete_fn: None,
error_fn: None,
completed: false,
fused: false,
defused: false,
take_wrapped: false,
errored: false,
}
}
pub fn on_complete(&mut self, complete_fn: impl FnMut() + 'static + Send + Sync) {
self.complete_fn = Some(Box::new(complete_fn));
}
pub fn on_error(
&mut self,
error_fn: impl FnMut(Arc<dyn Error + Send + Sync>) + 'static + Send + Sync,
) {
self.error_fn = Some(Box::new(error_fn));
}
#[must_use]
pub fn is_fused(&self) -> bool {
self.fused && !self.defused
}
}
impl<T> Fuse for Subscriber<T> {
fn set_fused(&mut self, fused: bool, defused: bool) {
self.fused = fused;
self.defused = defused;
}
fn get_fused(&self) -> (bool, bool) {
(self.fused, self.defused)
}
}
impl<T> Observer for Subscriber<T> {
type NextFnType = T;
fn next(&mut self, v: Self::NextFnType) {
if (!self.take_wrapped && self.errored)
|| (!self.take_wrapped && self.fused && self.completed)
{
return;
}
(self.next_fn)(v);
}
fn complete(&mut self) {
if self.errored || (self.fused && self.completed) {
return;
}
if let Some(cfn) = &mut self.complete_fn {
(cfn)();
self.completed = true;
}
}
fn error(&mut self, observable_error: Arc<dyn Error + Send + Sync>) {
if self.errored || (self.fused && self.completed) {
return;
}
if let Some(efn) = &mut self.error_fn {
(efn)(observable_error);
self.errored = true;
}
}
}
type AwaitResult<T> = Result<T, Box<dyn Any + Send>>;
pub struct SubscriptionCollection {
subscriptions: Arc<Mutex<Option<Vec<Subscription>>>>,
signal_sent: Arc<Mutex<bool>>,
pub(crate) use_task: bool,
}
impl SubscriptionCollection {
pub(crate) fn new(s: Arc<Mutex<Option<Vec<Subscription>>>>, use_task: bool) -> Self {
SubscriptionCollection {
subscriptions: s,
signal_sent: Arc::new(Mutex::new(false)),
use_task,
}
}
pub(crate) fn join_all(self) -> AwaitResult<()> {
use std::thread::{sleep, spawn};
let mut subscriptionsh = self.subscriptions.lock().unwrap();
let subscriptions = subscriptionsh.take();
let (tx, rx) = std::sync::mpsc::channel();
*subscriptionsh = Some(vec![Subscription::new(
UnsubscribeLogic::Logic(Box::new(move || {
let _ = tx.send(true);
})),
SubscriptionHandle::Nil,
)]);
let cl = Arc::clone(&self.signal_sent);
std::thread::spawn(move || {
if let Ok(true) = rx.recv() {
*cl.lock().unwrap() = true;
}
});
drop(subscriptionsh);
if subscriptions.is_none() {
return Ok(());
}
let mut stored = Vec::with_capacity(subscriptions.as_ref().unwrap().len());
let mut stored_tasks = Vec::with_capacity(subscriptions.as_ref().unwrap().len());
for mut s in subscriptions.unwrap() {
if *self.signal_sent.lock().unwrap() {
s.unsubscribe();
continue;
}
match s.subscription_future {
SubscriptionHandle::Nil => (),
SubscriptionHandle::JoinThread(thread_handle) => {
s.subscription_future = SubscriptionHandle::Nil;
stored.push(s);
let h = spawn(|| thread_handle.join());
stored_tasks.push(h);
}
SubscriptionHandle::JoinSubscriptions(collection) => {
s.subscription_future = SubscriptionHandle::Nil;
stored.push(s);
let h = spawn(|| collection.join_all());
stored_tasks.push(h);
}
SubscriptionHandle::JoinTask(..) => {
panic!("Handle should be OS thread handle but it is Tokio task handle instead. When working with Tokio, use `join_concurrent().await` to await the completion of observables.");
}
}
}
let local_signal = Arc::new(RwLock::new(false));
let local_signal_cloned = Arc::clone(&local_signal);
spawn(move || loop {
if *self.signal_sent.lock().unwrap() {
let r = stored.pop();
if let Some(s) = r {
s.unsubscribe();
}
}
if stored.is_empty() || *local_signal_cloned.read().unwrap() {
break;
}
sleep(std::time::Duration::from_millis(5));
});
for s in stored_tasks {
s.join()
.unwrap_or(Err(Box::new("failed to await merged observable")))?;
}
*local_signal.write().unwrap() = true;
Ok(())
}
#[allow(clippy::await_holding_lock)]
#[allow(clippy::too_many_lines)]
pub(crate) fn join_all_async(self) -> Pin<Box<dyn Future<Output = AwaitResult<()>> + 'static>> {
Box::pin(async move {
let mut subscriptionsh = self.subscriptions.lock().unwrap();
let subscriptions = subscriptionsh.take();
let signal_sent = Arc::clone(&self.signal_sent);
let signal_sent_cl = Arc::clone(&signal_sent);
let signal_sent_cl2 = Arc::clone(&signal_sent);
if self.use_task {
let (tx, mut rx) = tokio::sync::mpsc::channel(5);
*subscriptionsh = Some(vec![Subscription::new(
UnsubscribeLogic::Future(Box::pin(async move {
let _ = tx.send(true).await;
})),
SubscriptionHandle::Nil,
)]);
task::spawn(async move {
if let Some(true) = rx.recv().await {
*signal_sent.lock().unwrap() = true;
}
});
} else {
let (tx, rx) = std::sync::mpsc::channel();
*subscriptionsh = Some(vec![Subscription::new(
UnsubscribeLogic::Logic(Box::new(move || {
let _ = tx.send(true);
})),
SubscriptionHandle::Nil,
)]);
std::thread::spawn(move || {
if let Ok(true) = rx.recv() {
*signal_sent.lock().unwrap() = true;
}
});
}
drop(subscriptionsh);
if subscriptions.is_none() {
return Ok(());
}
let mut stored = Vec::with_capacity(subscriptions.as_ref().unwrap().len());
let mut stored_tasks = Vec::with_capacity(subscriptions.as_ref().unwrap().len());
let mut stored_threads = Vec::with_capacity(subscriptions.as_ref().unwrap().len());
for mut s in subscriptions.unwrap() {
if *self.signal_sent.lock().unwrap() {
s.unsubscribe();
continue;
}
match s.subscription_future {
SubscriptionHandle::Nil => (),
SubscriptionHandle::JoinTask(task_handle) => {
s.subscription_future = SubscriptionHandle::Nil;
stored.push(s);
let h = task::spawn(async {
let r = task_handle.await;
if r.is_err() {
return r.map_err(|e| Box::new(e) as Box<dyn Any + Send>);
}
Ok(())
});
stored_tasks.push(h);
}
SubscriptionHandle::JoinThread(thread_handle) => {
s.subscription_future = SubscriptionHandle::Nil;
stored.push(s);
stored_threads.push(thread_handle);
}
SubscriptionHandle::JoinSubscriptions(collection) => {
s.subscription_future = SubscriptionHandle::Nil;
stored.push(s);
let h = task::spawn_local(async { collection.join_all_async().await });
stored_tasks.push(h);
}
}
}
let mut tokio_current_thread = false;
if let runtime::RuntimeFlavor::CurrentThread =
runtime::Handle::current().runtime_flavor()
{
tokio_current_thread = true;
}
let stop_thread_event_loop = Arc::new(RwLock::new(false));
let stop_thread_event_loop_cl = Arc::clone(&stop_thread_event_loop);
if tokio_current_thread {
let mut stored_threads_subscriptions = Vec::with_capacity(8);
let mut i = 0;
while i < stored.len() {
if let UnsubscribeLogic::Logic(_) = &mut stored[i].unsubscribe_logic {
stored_threads_subscriptions.push(stored.remove(i));
} else {
i += 1;
}
}
std::thread::spawn(move || loop {
if *signal_sent_cl2.lock().unwrap() {
let r = stored_threads_subscriptions.pop();
if let Some(s) = r {
s.unsubscribe();
}
}
if stored_threads_subscriptions.is_empty()
|| *stop_thread_event_loop_cl.read().unwrap()
{
break;
}
std::thread::sleep(std::time::Duration::from_millis(5));
});
}
let h = task::spawn(async move {
loop {
if *signal_sent_cl.lock().unwrap() {
let r = stored.pop();
if let Some(s) = r {
s.unsubscribe();
}
}
if stored.is_empty() {
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
}
});
for s in stored_tasks {
s.await
.unwrap_or(Err(Box::new("failed to await merged observable")))?;
}
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
for s in stored_threads {
s.join()?;
}
*stop_thread_event_loop.write().unwrap() = true;
h.abort();
Ok(())
})
}
}
pub enum SubscriptionHandle {
Nil,
JoinTask(JoinHandle<()>),
JoinThread(ThreadJoinHandle<()>),
JoinSubscriptions(SubscriptionCollection),
}
#[allow(clippy::used_underscore_binding)]
pub struct Subscription {
pub(crate) unsubscribe_logic: UnsubscribeLogic,
pub(crate) subscription_future: SubscriptionHandle,
pub(crate) runtime_handle: Result<runtime::Handle, runtime::TryCurrentError>,
_is_subject: bool,
}
impl Subscription {
#[must_use]
pub fn new(
unsubscribe_logic: UnsubscribeLogic,
subscription_future: SubscriptionHandle,
) -> Self {
let runtime_handle = tokio::runtime::Handle::try_current();
Subscription {
unsubscribe_logic,
subscription_future,
runtime_handle,
_is_subject: false,
}
}
pub(crate) fn subject_subscription(
unsubscribe_logic: UnsubscribeLogic,
subscription_future: SubscriptionHandle,
) -> Self {
let runtime_handle = tokio::runtime::Handle::try_current();
Subscription {
unsubscribe_logic,
subscription_future,
runtime_handle,
_is_subject: true,
}
}
pub(crate) fn _is_subject(&self) -> bool {
self._is_subject
}
pub async fn join_concurrent(self) -> Result<(), Box<dyn Any + Send>> {
match self.subscription_future {
SubscriptionHandle::JoinTask(task_handle) => {
let r = task_handle.await;
r.map_err(|e| Box::new(e) as Box<dyn Any + Send>)
}
SubscriptionHandle::JoinThread(thread_handle) => thread_handle.join(),
SubscriptionHandle::JoinSubscriptions(s) => s.join_all_async().await,
SubscriptionHandle::Nil => Ok(()),
}
}
pub fn join(self) -> Result<(), Box<dyn Any + Send>> {
match self.subscription_future {
SubscriptionHandle::JoinThread(thread_handle) => thread_handle.join(),
SubscriptionHandle::Nil => Ok(()),
SubscriptionHandle::JoinSubscriptions(s) => s.join_all(),
SubscriptionHandle::JoinTask(_) => {
panic!("Handle should be OS thread handle but it is Tokio task handle instead. When working with Tokio, use `join_concurrent().await` to await the completion of observables.")
}
}
}
}
impl Unsubscribeable for Subscription {
fn unsubscribe(self) {
self.unsubscribe_logic.unsubscribe(self.runtime_handle);
}
}
pub enum UnsubscribeLogic {
Nil,
Wrapped(Box<Subscription>),
Logic(Box<dyn FnOnce() + Send>),
Future(Pin<Box<dyn Future<Output = ()> + Send>>),
}
impl UnsubscribeLogic {
fn unsubscribe(
mut self,
runtime_handle: Result<runtime::Handle, runtime::TryCurrentError>,
) -> Self {
match self {
UnsubscribeLogic::Nil => (),
UnsubscribeLogic::Logic(fnc) => {
fnc();
self = Self::Nil;
}
UnsubscribeLogic::Wrapped(subscription) => {
subscription.unsubscribe();
self = Self::Nil;
}
UnsubscribeLogic::Future(future) => {
match runtime_handle {
Ok(handle) => {
handle.spawn(async {
future.await;
});
}
e @ Err(_) => {
e.expect(
"Observable that uses Tokio tasks is called outside of Tokio runtime",
);
}
}
self = Self::Nil;
}
}
self
}
}