use clap::{Args, FromArgMatches};
use pin_project_lite::pin_project;
use std::error::Error;
use std::future::{Future, IntoFuture};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::ResourceDependencies;
use crate::assembly::sealed::{ResourceBase, TraitRegisterContext};
use crate::assembly::{ProduceContext, RegisterContext, ResourceFut};
use crate::shutdown::{
ShutdownSignalParticipant, ShutdownSignalParticipantCreator, TaskRunningSentinel,
};
pub struct StopSignal(ShutdownSignalParticipant);
impl Future for StopSignal {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
Pin::new(&mut self.get_mut().0).poll(cx).map(|_| ())
}
}
pub struct AssemblyRuntime<'a> {
stoppers: Option<&'a mut ShutdownSignalParticipantCreator>,
task: Option<Box<dyn Task>>,
}
impl AssemblyRuntime<'_> {
pub fn self_stop(&mut self) -> StopSignal {
StopSignal(
self.stoppers
.take()
.expect("self_stop called more than once")
.next()
.unwrap(),
)
}
pub fn set_task<F>(&mut self, task: F)
where
F: IntoFuture<Output = Result<(), Box<dyn Error>>> + Send + 'static,
F::IntoFuture: Send,
{
self.task = Some(Box::new(TaskImpl(task)));
}
}
#[doc(hidden)]
pub struct TraitInstallerProduce<'a, 'b, 'c, R> {
cx: &'a mut ProduceContext<'c>,
shared: &'b Arc<R>,
resource: &'b TaskRunningSentinel,
}
#[doc(hidden)]
pub enum TraitInstaller<'a, 'b, 'c, R> {
Register(TraitRegisterContext<'b>),
Produce(TraitInstallerProduce<'a, 'b, 'c, R>),
}
impl<R> TraitInstaller<'_, '_, '_, R> {
pub fn offer<T, F>(&mut self, factory: F)
where
T: std::any::Any + ?Sized,
F: FnOnce(&Arc<R>) -> Arc<T>,
{
match self {
Self::Register(cx) => cx.register_as_trait::<T>(),
Self::Produce(installer) => {
if let Some(trait_i) = installer.cx.get_trait_i::<T>() {
if installer.resource.is_dependent_of(trait_i) {
installer
.cx
.provide_as_trait(trait_i, factory(installer.shared));
}
}
}
}
}
}
pub trait Resource: Send + Sync + Sized + 'static {
type Args: clap::Args;
type Dependencies: ResourceDependencies;
type CreationError: Into<Box<(dyn Error + 'static)>>;
const NAME: &str;
fn new(
deps: Self::Dependencies,
args: Self::Args,
api: &mut AssemblyRuntime<'_>,
) -> Result<Arc<Self>, Self::CreationError>;
fn provide_as_trait<'a>(_: &'a mut TraitInstaller<'_, 'a, '_, Self>) {}
}
pub use comprehensive_macros::v1resource as resource;
pin_project! {
struct TaskInner<F> {
#[pin] fut: F,
keepalive: TaskRunningSentinel,
}
}
pin_project! {
struct AutoStopTask<F> {
#[pin] stopper: ShutdownSignalParticipant,
#[pin] inner: Option<TaskInner<F>>,
}
}
pin_project! {
struct SelfStopTask<F> {
#[pin] stopper: ShutdownSignalParticipant,
#[pin] inner: Option<TaskInner<F>>,
}
}
impl<F> AutoStopTask<F> {
fn new<T>(task: T, stopper: ShutdownSignalParticipant, keepalive: TaskRunningSentinel) -> Self
where
T: IntoFuture<IntoFuture = F>,
{
Self {
inner: Some(TaskInner {
fut: task.into_future(),
keepalive,
}),
stopper,
}
}
}
impl<F> Future for AutoStopTask<F>
where
F: Future<Output = Result<(), Box<dyn Error>>>,
{
type Output = Result<(), Box<dyn Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if let Poll::Ready(forwarder) = this.stopper.poll(cx) {
forwarder.propagate();
this.inner.set(None);
return Poll::Ready(Ok(()));
}
if let Some(inner) = this.inner.as_mut().as_pin_mut() {
if let Poll::Ready(r) = inner.project().fut.poll(cx) {
this.inner.set(None);
if r.is_err() {
return Poll::Ready(r);
}
}
}
Poll::Pending
}
}
impl<F> SelfStopTask<F> {
fn new<T>(task: T, stopper: ShutdownSignalParticipant, keepalive: TaskRunningSentinel) -> Self
where
T: IntoFuture<IntoFuture = F>,
{
Self {
inner: Some(TaskInner {
fut: task.into_future(),
keepalive,
}),
stopper,
}
}
}
impl<F> Future for SelfStopTask<F>
where
F: Future<Output = Result<(), Box<dyn Error>>>,
{
type Output = Result<(), Box<dyn Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if let Some(inner) = this.inner.as_mut().as_pin_mut() {
if let Poll::Ready(r) = inner.project().fut.poll(cx) {
this.inner.set(None);
if r.is_err() {
return Poll::Ready(r);
}
} else {
return Poll::Pending;
}
}
if let Poll::Ready(forwarder) = this.stopper.poll(cx) {
forwarder.propagate();
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
trait Task: Send {
fn into_task(
self: Box<Self>,
stopper: ShutdownSignalParticipant,
keepalive: TaskRunningSentinel,
auto_stop: bool,
) -> ResourceFut;
}
struct TaskImpl<T>(T);
impl<T> Task for TaskImpl<T>
where
T: IntoFuture<Output = Result<(), Box<dyn Error>>> + Send,
T::IntoFuture: Send + 'static,
{
fn into_task(
self: Box<Self>,
stopper: ShutdownSignalParticipant,
keepalive: TaskRunningSentinel,
auto_stop: bool,
) -> ResourceFut {
if auto_stop {
Box::pin(AutoStopTask::new(self.0, stopper, keepalive))
} else {
Box::pin(SelfStopTask::new(self.0, stopper, keepalive))
}
}
}
mod private {
pub struct ResourceProduction<T> {
pub(super) shared: std::sync::Arc<T>,
pub(super) task: Option<Box<dyn super::Task>>,
pub(super) stopper: super::ShutdownSignalParticipant,
pub(super) keepalive: super::TaskRunningSentinel,
pub(super) auto_stop: bool,
}
}
impl<T: Resource> ResourceBase<{ crate::ResourceVariety::V1 as usize }> for T {
const NAME: &str = T::NAME;
type Production = private::ResourceProduction<T>;
fn register_recursive(cx: &mut RegisterContext<'_>) {
T::Dependencies::register(cx);
}
fn augment_args(c: clap::Command) -> clap::Command {
T::Args::augment_args(c)
}
fn register_as_traits(cx: TraitRegisterContext<'_>) {
let mut installer = TraitInstaller::Register(cx);
T::provide_as_trait(&mut installer);
}
fn make(
cx: &mut ProduceContext<'_>,
arg_matches: &mut clap::ArgMatches,
mut stoppers: ShutdownSignalParticipantCreator,
keepalive: TaskRunningSentinel,
) -> Result<Self::Production, Box<dyn Error>> {
let deps = T::Dependencies::produce(cx)?;
let args = T::Args::from_arg_matches(arg_matches)?;
let mut api = AssemblyRuntime {
stoppers: Some(&mut stoppers),
task: None,
};
let shared = T::new(deps, args, &mut api).map_err(Into::into)?;
let mut installer = TraitInstaller::Produce(TraitInstallerProduce {
cx,
shared: &shared,
resource: &keepalive,
});
T::provide_as_trait(&mut installer);
Ok(private::ResourceProduction {
shared,
task: api.task,
auto_stop: api.stoppers.is_some(),
stopper: stoppers.into_inner().unwrap(),
keepalive,
})
}
fn shared(p: &Self::Production) -> Arc<T> {
Arc::clone(&p.shared)
}
fn task(
p: Self::Production,
) -> Pin<Box<dyn Future<Output = Result<(), Box<dyn Error>>> + Send>> {
match p.task {
Some(t) => t.into_task(p.stopper, p.keepalive, p.auto_stop),
None => Box::pin(async move {
p.stopper.await.propagate();
Ok(())
}),
}
}
}
impl<T: Resource> crate::AnyResource<{ crate::ResourceVariety::V1 as usize }> for T {
const NAME: &str = T::NAME;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testutil::TestExecutor;
use crate::{Assembly, NoArgs, NoDependencies};
use atomic_take::AtomicTake;
use futures::TryFutureExt;
use std::pin::pin;
use std::sync::atomic::{AtomicBool, Ordering};
use try_lock::TryLock;
const EMPTY: &[std::ffi::OsString] = &[];
struct Fails;
#[resource]
impl Resource for Fails {
fn new(
_: NoDependencies,
_: NoArgs,
api: &mut AssemblyRuntime<'_>,
) -> Result<Arc<Self>, std::convert::Infallible> {
api.set_task(std::future::ready(Err("no good")).err_into());
Ok(Arc::new(Self))
}
}
#[derive(ResourceDependencies)]
struct FailDependencies {
_f: Arc<Fails>,
}
#[test]
fn assembly_fails() {
let mut r = pin!(
Assembly::<FailDependencies>::new_from_argv(EMPTY)
.unwrap()
.run_with_termination_signal(futures::stream::pending())
);
let mut e = TestExecutor::default();
match e.poll(&mut r) {
Poll::Ready(Err(e)) => {
assert_eq!(e.to_string(), "no good");
}
other => {
panic!("assembly await result: want error, got {:?}", other);
}
}
}
struct QuitMonitor(AtomicBool);
#[resource]
impl Resource for QuitMonitor {
fn new(
_: NoDependencies,
_: NoArgs,
api: &mut AssemblyRuntime<'_>,
) -> Result<Arc<Self>, std::convert::Infallible> {
let shared = Arc::new(Self(AtomicBool::default()));
let sentinel = Arc::clone(&shared);
let stop = api.self_stop();
api.set_task(async move {
stop.await;
sentinel.0.store(true, Ordering::Release);
Ok(())
});
Ok(shared)
}
}
struct TestAutoStop {
skip_task: bool,
leaf: Arc<QuitMonitor>,
}
#[derive(ResourceDependencies)]
struct TestAutoStopDependencies(Arc<QuitMonitor>);
#[derive(clap::Args)]
#[group(skip)]
struct TestAutoStopArgs {
#[arg(long)]
skip_task: bool,
}
#[resource]
impl Resource for TestAutoStop {
fn new(
d: TestAutoStopDependencies,
a: TestAutoStopArgs,
api: &mut AssemblyRuntime<'_>,
) -> Result<Arc<Self>, std::convert::Infallible> {
if !a.skip_task {
api.set_task(std::future::pending());
}
Ok(Arc::new(Self {
leaf: d.0,
skip_task: a.skip_task,
}))
}
}
#[derive(ResourceDependencies)]
struct TestAutoStopTopDependencies(Arc<TestAutoStop>);
#[test]
fn no_task() {
let argv: Vec<std::ffi::OsString> = vec!["cmd".into(), "--skip-task".into()];
let (tx, rx) = tokio::sync::mpsc::channel(1);
let assembly = Assembly::<TestAutoStopTopDependencies>::new_from_argv(argv).unwrap();
let tas = Arc::clone(&assembly.top.0);
assert!(tas.skip_task);
let mut r = pin!(
assembly.run_with_termination_signal(tokio_stream::wrappers::ReceiverStream::new(rx))
);
let mut e = TestExecutor::default();
assert!(e.poll(&mut r).is_pending());
assert!(!tas.leaf.0.load(Ordering::Acquire));
let _ = tx.try_send(()).unwrap();
assert!(e.poll(&mut r).is_ready());
assert!(tas.leaf.0.load(Ordering::Acquire));
}
#[test]
fn auto_stop() {
let (tx, rx) = tokio::sync::mpsc::channel(1);
let assembly = Assembly::<TestAutoStopTopDependencies>::new_from_argv(EMPTY).unwrap();
let tas = Arc::clone(&assembly.top.0);
assert!(!tas.skip_task);
let mut r = pin!(
assembly.run_with_termination_signal(tokio_stream::wrappers::ReceiverStream::new(rx))
);
let mut e = TestExecutor::default();
assert!(e.poll(&mut r).is_pending());
assert!(!tas.leaf.0.load(Ordering::Acquire));
let _ = tx.try_send(()).unwrap();
assert!(e.poll(&mut r).is_ready());
assert!(tas.leaf.0.load(Ordering::Acquire));
}
struct TestSelfStop {
quit_requested: TryLock<Option<tokio::sync::oneshot::Sender<()>>>,
leaf: Arc<QuitMonitor>,
}
#[derive(ResourceDependencies)]
struct TestSelfStopDependencies(Arc<QuitMonitor>);
#[resource]
impl Resource for TestSelfStop {
fn new(
d: TestSelfStopDependencies,
_: NoArgs,
api: &mut AssemblyRuntime<'_>,
) -> Result<Arc<Self>, std::convert::Infallible> {
let shared = Arc::new(Self {
quit_requested: TryLock::new(None),
leaf: d.0,
});
let stop = api.self_stop();
let shared2 = Arc::clone(&shared);
api.set_task(async move {
stop.await;
let (tx, rx) = tokio::sync::oneshot::channel();
*shared2.quit_requested.try_lock().unwrap() = Some(tx);
let _ = rx.await;
Ok(())
});
Ok(shared)
}
}
#[derive(ResourceDependencies)]
struct TestSelfStopTopDependencies(Arc<TestSelfStop>);
#[test]
fn self_stop() {
let (tx, rx) = tokio::sync::mpsc::channel(1);
let assembly = Assembly::<TestSelfStopTopDependencies>::new_from_argv(EMPTY).unwrap();
let tss = Arc::clone(&assembly.top.0);
let mut r = pin!(
assembly.run_with_termination_signal(tokio_stream::wrappers::ReceiverStream::new(rx))
);
let mut e = TestExecutor::default();
assert!(e.poll(&mut r).is_pending());
assert!(tss.quit_requested.try_lock().unwrap().is_none());
assert!(!tss.leaf.0.load(Ordering::Acquire));
let _ = tx.try_send(()).unwrap();
assert!(e.poll(&mut r).is_pending());
let next_step = tss.quit_requested.try_lock().unwrap().take().unwrap();
assert!(!tss.leaf.0.load(Ordering::Acquire));
std::mem::drop(next_step);
assert!(e.poll(&mut r).is_ready());
assert!(tss.leaf.0.load(Ordering::Acquire));
}
struct RunUntilSignaled(AtomicTake<tokio::sync::oneshot::Sender<()>>);
#[resource]
impl Resource for RunUntilSignaled {
fn new(
_: NoDependencies,
_: NoArgs,
api: &mut AssemblyRuntime<'_>,
) -> Result<Arc<Self>, std::convert::Infallible> {
let (tx, rx) = tokio::sync::oneshot::channel();
api.set_task(async move {
let _ = rx.await;
Ok(())
});
Ok(Arc::new(Self(AtomicTake::new(tx))))
}
}
#[derive(ResourceDependencies)]
struct RunUntilSignaledTop(Arc<RunUntilSignaled>);
#[test]
fn runs_until_resource_quits() {
let assembly = Assembly::<RunUntilSignaledTop>::new_from_argv(EMPTY).unwrap();
let notify = assembly.top.0.0.take().unwrap();
let mut r = pin!(assembly.run_with_termination_signal(futures::stream::pending()));
let mut e = TestExecutor::default();
assert!(e.poll(&mut r).is_pending());
let _ = notify.send(());
assert!(e.poll(&mut r).is_ready());
}
struct RunStubbornly;
#[resource]
impl Resource for RunStubbornly {
fn new(
_: NoDependencies,
_: NoArgs,
api: &mut AssemblyRuntime<'_>,
) -> Result<Arc<Self>, std::convert::Infallible> {
let _ = api.self_stop();
api.set_task(std::future::pending());
Ok(Arc::new(Self))
}
}
#[derive(ResourceDependencies)]
struct RunStubbornlyTop(#[allow(dead_code)] Arc<RunStubbornly>);
#[test]
fn needs_2_sigterms() {
let assembly = Assembly::<RunStubbornlyTop>::new_from_argv(EMPTY).unwrap();
let (tx, rx) = tokio::sync::mpsc::channel(2);
let mut r = pin!(
assembly.run_with_termination_signal(tokio_stream::wrappers::ReceiverStream::new(rx))
);
let mut e = TestExecutor::default();
assert!(e.poll(&mut r).is_pending());
let _ = tx.try_send(()).unwrap();
assert!(e.poll(&mut r).is_pending());
let _ = tx.try_send(()).unwrap();
assert!(e.poll(&mut r).is_ready());
}
trait TestTrait1: Send + Sync {}
trait TestTrait2: Send + Sync {}
#[derive(ResourceDependencies)]
struct RequiresDynDependencies(Vec<Arc<dyn TestTrait1>>, Vec<Arc<dyn TestTrait2>>);
struct RequiresDyn(Vec<Arc<dyn TestTrait1>>, Vec<Arc<dyn TestTrait2>>);
#[resource]
impl Resource for RequiresDyn {
fn new(
d: RequiresDynDependencies,
_: NoArgs,
_: &mut AssemblyRuntime<'_>,
) -> Result<Arc<Self>, std::convert::Infallible> {
Ok(Arc::new(Self(d.0, d.1)))
}
}
struct ProvidesDyn;
impl TestTrait1 for ProvidesDyn {}
impl TestTrait2 for ProvidesDyn {}
#[resource]
#[export(dyn TestTrait1)]
#[export(dyn TestTrait2)]
impl Resource for ProvidesDyn {
fn new(
_: NoDependencies,
_: NoArgs,
_: &mut AssemblyRuntime<'_>,
) -> Result<Arc<Self>, std::convert::Infallible> {
Ok(Arc::new(Self))
}
}
#[derive(ResourceDependencies)]
struct RequiresDynTop(Arc<RequiresDyn>, Arc<ProvidesDyn>);
#[test]
fn dyn_resource() {
let assembly = Assembly::<RequiresDynTop>::new_from_argv(EMPTY).unwrap();
assert_eq!(assembly.top.0.0.len(), 1);
assert_eq!(assembly.top.0.1.len(), 1);
let _ = Arc::clone(&assembly.top.1);
}
}