#![allow(clippy::mutex_atomic)]
use std::future::Future;
use std::sync::{Arc, Condvar, Mutex};
use err_context::AnyError;
use log::trace;
use serde::de::DeserializeOwned;
use spirit::extension::Extensible;
use spirit::fragment::Installer;
use structopt::StructOpt;
use tokio::select;
use tokio::sync::oneshot::{self, Sender};
use crate::runtime::{self, ShutGuard, Tokio};
#[derive(Default, Debug)]
struct Wakeup {
wakeup: Mutex<bool>,
condvar: Condvar,
}
impl Wakeup {
fn wait(&self) {
trace!("Waiting on wakeup on {:p}/{:?}", self, self);
let g = self.wakeup.lock().unwrap();
let _g = self.condvar.wait_while(g, |w| !*w).unwrap();
}
fn wakeup(&self) {
trace!("Waking up {:p}/{:?}", self, self);
*self.wakeup.lock().unwrap() = true;
self.condvar.notify_all();
}
}
pub struct RemoteDrop {
name: &'static str,
request_drop: Option<Sender<()>>,
wakeup: Arc<Wakeup>,
_shut_guard: Option<ShutGuard>,
}
impl Drop for RemoteDrop {
fn drop(&mut self) {
trace!("Requesting remote drop on {}", self.name);
let _ = self.request_drop.take().unwrap().send(());
self.wakeup.wait();
trace!("Remote drop done on {}", self.name);
}
}
struct SendOnDrop(Arc<Wakeup>);
impl Drop for SendOnDrop {
fn drop(&mut self) {
self.0.wakeup();
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct FutureInstaller;
impl<F, O, C> Installer<F, O, C> for FutureInstaller
where
F: Future<Output = ()> + Send + 'static,
{
type UninstallHandle = RemoteDrop;
fn install(&mut self, fut: F, name: &'static str) -> RemoteDrop {
let (request_send, request_recv) = oneshot::channel();
let wakeup = Default::default();
let guard = SendOnDrop(Arc::clone(&wakeup));
let cancellable_future = async move {
let _guard = guard;
select! {
_ = request_recv => trace!("Future {} requested to terminate", name),
_ = fut => trace!("Future {} terminated on its own", name),
};
};
trace!("Installing future {}", name);
tokio::spawn(cancellable_future);
RemoteDrop {
name,
request_drop: Some(request_send),
wakeup,
_shut_guard: runtime::shut_guard(),
}
}
fn init<E>(&mut self, ext: E, _name: &'static str) -> Result<E, AnyError>
where
E: Extensible<Opts = O, Config = C, Ok = E>,
E::Config: DeserializeOwned + Send + Sync + 'static,
E::Opts: StructOpt + Send + Sync + 'static,
{
#[cfg(feature = "multithreaded")]
{
ext.with_singleton(Tokio::Default)
}
#[cfg(not(feature = "multithreaded"))]
{
ext.with_singleton(Tokio::SingleThreaded)
}
}
}