hakuban 0.8.5

Data-object sharing library
Documentation
use std::{
	collections::HashMap,
	future::Future,
	pin::Pin,
	sync::{Arc, Mutex},
};

use futures::{
	channel::oneshot,
	stream::{AbortHandle, Abortable},
	Stream, StreamExt,
};
use log::{error, info};

use super::diff::DiffError;
use crate::utils::BoolUtils;

#[derive(Debug, Clone)]
pub enum ConnectionTerminationReason {
	Shutdown(String),
	SeriousError(String),
}

#[cfg(feature = "downstream")]
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum DownstreamBehaviourError {
	DoubleObjectObserve,
	DoubleObjectExpose,
	ObjectUnobserveWithoutMatchingObserve,
	ObjectUnexposeWithoutMatchingExpose,
	DoubleTagObserve,
	DoubleTagExpose,
	TagUnobserveWithoutMatchingObserve,
	TagUexposeWithoutMatchingExpose,
	ChangeExposeCapacityWhileNotExposing,
}

impl ConnectionTerminationReason {
	pub fn log(&self, context: &str) {
		match self {
			ConnectionTerminationReason::Shutdown(error) => info!("{}: {}", context, error),
			ConnectionTerminationReason::SeriousError(error) => error!("{}: {}", context, error),
		}
	}
}

impl From<serde_json::Error> for ConnectionTerminationReason {
	fn from(value: serde_json::Error) -> Self {
		ConnectionTerminationReason::SeriousError(format!("JSON parsing failed: {:?}", value))
	}
}

#[cfg(feature = "downstream")]
impl From<DownstreamBehaviourError> for ConnectionTerminationReason {
	fn from(value: DownstreamBehaviourError) -> Self {
		ConnectionTerminationReason::SeriousError(format!("Connection logic issue: {:?}", value))
	}
}

impl From<DiffError> for ConnectionTerminationReason {
	fn from(value: DiffError) -> Self {
		match value {
			DiffError::InvalidScanShift => ConnectionTerminationReason::SeriousError("Diffing error: InvalidScanShift".to_string()),
			DiffError::TooBigInput => ConnectionTerminationReason::SeriousError("Diffing error: TooBigInput".to_string()),
			DiffError::TooBigOutput => ConnectionTerminationReason::SeriousError("Diffing error: TooBigOutput".to_string()),
			DiffError::TooManyForks => ConnectionTerminationReason::SeriousError("Diffing error: TooManyForks".to_string()),
			DiffError::TooManyNodes => ConnectionTerminationReason::SeriousError("Diffing error: TooManyNodes".to_string()),
		}
	}
}

impl From<futures::channel::mpsc::SendError> for ConnectionTerminationReason {
	fn from(value: futures::channel::mpsc::SendError) -> Self {
		ConnectionTerminationReason::SeriousError(format!("MPSC error {:?}", value))
	}
}

impl std::fmt::Display for ConnectionTerminationReason {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		match self {
			ConnectionTerminationReason::Shutdown(text) => write!(f, "Connection shutdown: {:?}", text),
			ConnectionTerminationReason::SeriousError(text) => write!(f, "Serious error: {:?}", text),
		}
	}
}

#[derive(Clone)]
pub struct CoTerminatingSet {
	inner: Arc<Mutex<CoTerminatingSetInner>>,
}

struct CoTerminatingSetInner {
	state: CoTerminatingSetState,
	sequence: u64,
}

enum CoTerminatingSetState {
	NotTerminated { abort_handles: HashMap<u64, AbortHandle>, notify_senders: HashMap<u64, oneshot::Sender<ConnectionTerminationReason>> },
	Terminated(ConnectionTerminationReason),
}

impl CoTerminatingSet {
	pub fn new() -> CoTerminatingSet {
		CoTerminatingSet {
			inner: Arc::new(Mutex::new(CoTerminatingSetInner {
				state: CoTerminatingSetState::NotTerminated { abort_handles: HashMap::new(), notify_senders: HashMap::new() },
				sequence: 0,
			})),
		}
	}

	pub fn is_terminated(&self) -> bool {
		matches!(self.inner.lock().unwrap().state, CoTerminatingSetState::Terminated(_))
	}

	pub fn reason(&self) -> impl Future<Output = ConnectionTerminationReason> {
		let inner = self.inner.clone();
		async move {
			let (notify_sender, notify_receiver) = oneshot::channel();
			let id = {
				let mut locked_inner = inner.lock().unwrap();
				locked_inner.sequence += 1;
				let id = locked_inner.sequence;
				match &mut locked_inner.state {
					CoTerminatingSetState::NotTerminated { abort_handles: _, notify_senders } => {
						notify_senders.insert(id, notify_sender).is_none().assert_true()
					}
					CoTerminatingSetState::Terminated(connection_termination_reason) => return connection_termination_reason.clone(),
				}
				id
			};

			let _release = drop_guard::guard(id, |id| {
				let mut locked_inner = inner.lock().unwrap();
				if let CoTerminatingSetState::NotTerminated { abort_handles: _, notify_senders } = &mut locked_inner.state {
					notify_senders.remove(&id).is_some().assert_true();
				};
			});

			notify_receiver.await.unwrap()
		}
	}

	pub fn terminate(&self, termination: impl Into<ConnectionTerminationReason>) -> ConnectionTerminationReason {
		CoTerminatingSetInner::terminate(&self.inner, termination)
	}

	pub fn terminate_on_drop<T>(self, reason: ConnectionTerminationReason, guarded: T) -> CoTerminatingSetDropGuard<T> {
		CoTerminatingSetDropGuard { guarded, reason: Some(reason), termination: self }
	}

	pub fn abort_on_termination(&self, f: impl Future<Output = Result<(), ConnectionTerminationReason>> + Send + Sync + 'static) -> impl Future<Output = ()> {
		let inner = self.inner.clone();
		async move {
			let (abort_handle, abort_registration) = AbortHandle::new_pair();
			let id = {
				let mut locked_inner = inner.lock().unwrap();
				locked_inner.sequence += 1;
				let id = locked_inner.sequence;
				match &mut locked_inner.state {
					CoTerminatingSetState::NotTerminated { abort_handles, notify_senders: _ } => abort_handles.insert(id, abort_handle).is_none().assert_true(),
					CoTerminatingSetState::Terminated(_connection_termination_reason) => return,
				}
				id
			};

			let _release = drop_guard::guard(id, |id| {
				let mut locked_inner = inner.lock().unwrap();
				if let CoTerminatingSetState::NotTerminated { abort_handles, notify_senders: _ } = &mut locked_inner.state {
					abort_handles.remove(&id).is_some().assert_true();
				};
			});

			if let Ok(Err(error)) = Abortable::new(f, abort_registration).await {
				CoTerminatingSetInner::terminate(&inner, error);
			};
		}
	}
}

impl CoTerminatingSetInner {
	pub fn terminate(this: &Arc<Mutex<CoTerminatingSetInner>>, termination: impl Into<ConnectionTerminationReason>) -> ConnectionTerminationReason {
		let mut locked = this.lock().unwrap();
		if let CoTerminatingSetState::Terminated(ref connection_termination_reason) = locked.state {
			return connection_termination_reason.clone();
		}
		let termination = termination.into();
		let CoTerminatingSetState::NotTerminated { mut abort_handles, mut notify_senders } =
			std::mem::replace(&mut locked.state, CoTerminatingSetState::Terminated(termination.clone()))
		else {
			panic!()
		};
		drop(locked);
		for (_id, abort_handle) in abort_handles.drain() {
			abort_handle.abort();
		}
		for (_id, notify_sender) in notify_senders.drain() {
			notify_sender.send(termination.clone()).ok();
		}
		termination
	}
}

pub struct CoTerminatingSetDropGuard<T> {
	guarded: T,
	reason: Option<ConnectionTerminationReason>,
	termination: CoTerminatingSet,
}

impl<T> Drop for CoTerminatingSetDropGuard<T> {
	fn drop(&mut self) {
		self.termination.terminate(self.reason.take().unwrap());
	}
}

impl<I, T: Stream<Item = I> + Unpin> Stream for CoTerminatingSetDropGuard<T> {
	type Item = I;

	fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
		self.guarded.poll_next_unpin(cx)
	}
}

impl<I, T: Iterator<Item = I>> Iterator for CoTerminatingSetDropGuard<T> {
	type Item = I;

	fn next(&mut self) -> Option<Self::Item> {
		self.guarded.next()
	}
}

/*
impl<T: Future + Unpin> Future for CoTerminatingSetDropGuard<T> {
	type Output = T::Output;

	fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
		self.guarded.poll_unpin(cx)
	}
}
*/