use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::{Arc, Mutex},
};
use futures::{
channel::oneshot,
stream::{AbortHandle, Abortable},
Stream, StreamExt,
};
use log::{error, info};
use super::diff::DiffError;
use crate::utils::BoolUtils;
#[derive(Debug, Clone)]
pub enum ConnectionTerminationReason {
Shutdown(String),
SeriousError(String),
}
#[cfg(feature = "downstream")]
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum DownstreamBehaviourError {
DoubleObjectObserve,
DoubleObjectExpose,
ObjectUnobserveWithoutMatchingObserve,
ObjectUnexposeWithoutMatchingExpose,
DoubleTagObserve,
DoubleTagExpose,
TagUnobserveWithoutMatchingObserve,
TagUexposeWithoutMatchingExpose,
ChangeExposeCapacityWhileNotExposing,
}
impl ConnectionTerminationReason {
pub fn log(&self, context: &str) {
match self {
ConnectionTerminationReason::Shutdown(error) => info!("{}: {}", context, error),
ConnectionTerminationReason::SeriousError(error) => error!("{}: {}", context, error),
}
}
}
impl From<serde_json::Error> for ConnectionTerminationReason {
fn from(value: serde_json::Error) -> Self {
ConnectionTerminationReason::SeriousError(format!("JSON parsing failed: {:?}", value))
}
}
#[cfg(feature = "downstream")]
impl From<DownstreamBehaviourError> for ConnectionTerminationReason {
fn from(value: DownstreamBehaviourError) -> Self {
ConnectionTerminationReason::SeriousError(format!("Connection logic issue: {:?}", value))
}
}
impl From<DiffError> for ConnectionTerminationReason {
fn from(value: DiffError) -> Self {
match value {
DiffError::InvalidScanShift => ConnectionTerminationReason::SeriousError("Diffing error: InvalidScanShift".to_string()),
DiffError::TooBigInput => ConnectionTerminationReason::SeriousError("Diffing error: TooBigInput".to_string()),
DiffError::TooBigOutput => ConnectionTerminationReason::SeriousError("Diffing error: TooBigOutput".to_string()),
DiffError::TooManyForks => ConnectionTerminationReason::SeriousError("Diffing error: TooManyForks".to_string()),
DiffError::TooManyNodes => ConnectionTerminationReason::SeriousError("Diffing error: TooManyNodes".to_string()),
}
}
}
impl From<futures::channel::mpsc::SendError> for ConnectionTerminationReason {
fn from(value: futures::channel::mpsc::SendError) -> Self {
ConnectionTerminationReason::SeriousError(format!("MPSC error {:?}", value))
}
}
impl std::fmt::Display for ConnectionTerminationReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectionTerminationReason::Shutdown(text) => write!(f, "Connection shutdown: {:?}", text),
ConnectionTerminationReason::SeriousError(text) => write!(f, "Serious error: {:?}", text),
}
}
}
#[derive(Clone)]
pub struct CoTerminatingSet {
inner: Arc<Mutex<CoTerminatingSetInner>>,
}
struct CoTerminatingSetInner {
state: CoTerminatingSetState,
sequence: u64,
}
enum CoTerminatingSetState {
NotTerminated { abort_handles: HashMap<u64, AbortHandle>, notify_senders: HashMap<u64, oneshot::Sender<ConnectionTerminationReason>> },
Terminated(ConnectionTerminationReason),
}
impl CoTerminatingSet {
pub fn new() -> CoTerminatingSet {
CoTerminatingSet {
inner: Arc::new(Mutex::new(CoTerminatingSetInner {
state: CoTerminatingSetState::NotTerminated { abort_handles: HashMap::new(), notify_senders: HashMap::new() },
sequence: 0,
})),
}
}
pub fn is_terminated(&self) -> bool {
matches!(self.inner.lock().unwrap().state, CoTerminatingSetState::Terminated(_))
}
pub fn reason(&self) -> impl Future<Output = ConnectionTerminationReason> {
let inner = self.inner.clone();
async move {
let (notify_sender, notify_receiver) = oneshot::channel();
let id = {
let mut locked_inner = inner.lock().unwrap();
locked_inner.sequence += 1;
let id = locked_inner.sequence;
match &mut locked_inner.state {
CoTerminatingSetState::NotTerminated { abort_handles: _, notify_senders } => {
notify_senders.insert(id, notify_sender).is_none().assert_true()
}
CoTerminatingSetState::Terminated(connection_termination_reason) => return connection_termination_reason.clone(),
}
id
};
let _release = drop_guard::guard(id, |id| {
let mut locked_inner = inner.lock().unwrap();
if let CoTerminatingSetState::NotTerminated { abort_handles: _, notify_senders } = &mut locked_inner.state {
notify_senders.remove(&id).is_some().assert_true();
};
});
notify_receiver.await.unwrap()
}
}
pub fn terminate(&self, termination: impl Into<ConnectionTerminationReason>) -> ConnectionTerminationReason {
CoTerminatingSetInner::terminate(&self.inner, termination)
}
pub fn terminate_on_drop<T>(self, reason: ConnectionTerminationReason, guarded: T) -> CoTerminatingSetDropGuard<T> {
CoTerminatingSetDropGuard { guarded, reason: Some(reason), termination: self }
}
pub fn abort_on_termination(&self, f: impl Future<Output = Result<(), ConnectionTerminationReason>> + Send + Sync + 'static) -> impl Future<Output = ()> {
let inner = self.inner.clone();
async move {
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let id = {
let mut locked_inner = inner.lock().unwrap();
locked_inner.sequence += 1;
let id = locked_inner.sequence;
match &mut locked_inner.state {
CoTerminatingSetState::NotTerminated { abort_handles, notify_senders: _ } => abort_handles.insert(id, abort_handle).is_none().assert_true(),
CoTerminatingSetState::Terminated(_connection_termination_reason) => return,
}
id
};
let _release = drop_guard::guard(id, |id| {
let mut locked_inner = inner.lock().unwrap();
if let CoTerminatingSetState::NotTerminated { abort_handles, notify_senders: _ } = &mut locked_inner.state {
abort_handles.remove(&id).is_some().assert_true();
};
});
if let Ok(Err(error)) = Abortable::new(f, abort_registration).await {
CoTerminatingSetInner::terminate(&inner, error);
};
}
}
}
impl CoTerminatingSetInner {
pub fn terminate(this: &Arc<Mutex<CoTerminatingSetInner>>, termination: impl Into<ConnectionTerminationReason>) -> ConnectionTerminationReason {
let mut locked = this.lock().unwrap();
if let CoTerminatingSetState::Terminated(ref connection_termination_reason) = locked.state {
return connection_termination_reason.clone();
}
let termination = termination.into();
let CoTerminatingSetState::NotTerminated { mut abort_handles, mut notify_senders } =
std::mem::replace(&mut locked.state, CoTerminatingSetState::Terminated(termination.clone()))
else {
panic!()
};
drop(locked);
for (_id, abort_handle) in abort_handles.drain() {
abort_handle.abort();
}
for (_id, notify_sender) in notify_senders.drain() {
notify_sender.send(termination.clone()).ok();
}
termination
}
}
pub struct CoTerminatingSetDropGuard<T> {
guarded: T,
reason: Option<ConnectionTerminationReason>,
termination: CoTerminatingSet,
}
impl<T> Drop for CoTerminatingSetDropGuard<T> {
fn drop(&mut self) {
self.termination.terminate(self.reason.take().unwrap());
}
}
impl<I, T: Stream<Item = I> + Unpin> Stream for CoTerminatingSetDropGuard<T> {
type Item = I;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
self.guarded.poll_next_unpin(cx)
}
}
impl<I, T: Iterator<Item = I>> Iterator for CoTerminatingSetDropGuard<T> {
type Item = I;
fn next(&mut self) -> Option<Self::Item> {
self.guarded.next()
}
}