mod nosvc;
mod rttype;
mod signals;
#[cfg(all(target_os = "linux", feature = "systemd"))]
#[cfg_attr(docsrs, doc(cfg(feature = "systemd")))]
mod systemd;
#[cfg(windows)]
pub mod winsvc;
use std::{
any::{Any, TypeId},
sync::{
Arc, OnceLock,
atomic::{AtomicU32, Ordering}
}
};
use hashbrown::HashMap;
#[cfg(any(feature = "tokio", feature = "rocket"))]
use async_trait::async_trait;
#[cfg(feature = "tokio")]
use tokio::runtime;
use tokio::sync::broadcast;
#[cfg(all(target_os = "linux", feature = "systemd"))]
use sd_notify::NotifyState;
use crate::{err::CbErr, lumberjack::LumberJack};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RunAs {
Foreground,
SvcSubsys
}
static RUNAS: OnceLock<RunAs> = OnceLock::new();
pub fn runas() -> Option<RunAs> {
RUNAS.get().copied()
}
#[derive(Debug, Clone)]
pub enum RunEnv {
Foreground,
Service(Option<String>)
}
pub(crate) trait StateReporter {
fn starting(&self, checkpoint: u32, msg: &str);
fn started(&self);
fn stopping(&self, checkpoint: u32, msg: &str);
fn stopped(&self);
}
#[derive(Clone)]
#[repr(transparent)]
pub(crate) struct ServiceReporter(Arc<dyn StateReporter + Send + Sync>);
impl ServiceReporter {
pub(crate) fn new<SR>(sr: SR) -> Self
where
SR: StateReporter + Send + Sync + 'static
{
Self(Arc::new(sr))
}
pub(crate) fn starting(&self, checkpoint: u32, status: Option<&str>) {
let text = status.as_ref().map_or_else(
|| format!("Starting[{checkpoint}]"),
|msg| format!("Starting[{checkpoint}] {msg}")
);
self.0.starting(checkpoint, &text);
log::debug!("{text}");
}
pub(crate) fn started(&self) {
self.0.started();
log::debug!(
"Service initialization has finished and is entering running state"
);
}
pub(crate) fn stopping(&self, checkpoint: u32, status: Option<&str>) {
let text = status.as_ref().map_or_else(
|| format!("Stopping[{checkpoint}]"),
|msg| format!("Stopping[{checkpoint}] {msg}")
);
self.0.stopping(checkpoint, &text);
log::debug!("{text}");
}
pub(crate) fn stopped(&self) {
self.0.stopped();
log::debug!("Service termination has finished");
}
}
pub struct InitCtx {
re: RunEnv,
sr: ServiceReporter,
cnt: Arc<AtomicU32>,
passthrough: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
passthrough_term: HashMap<TypeId, Box<dyn Any + Send + Sync>>
}
impl InitCtx {
#[must_use]
pub fn runenv(&self) -> RunEnv {
self.re.clone()
}
pub fn report(&self, status: Option<&str>) {
let checkpoint = self.cnt.fetch_add(1, Ordering::SeqCst);
if let Some(msg) = status {
tracing::trace!("Reached init checkpoint {checkpoint}; {}", msg);
} else {
tracing::trace!("Reached init checkpoint {checkpoint}");
}
self.sr.starting(checkpoint, status);
}
#[must_use]
pub fn get<T>(&self) -> Option<&T>
where
T: Send + 'static
{
self
.passthrough
.get(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_ref::<T>())
}
pub fn take<T>(&mut self) -> Option<T>
where
T: Send + 'static
{
self
.passthrough
.remove(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast::<T>().ok())
.map(|v| *v)
}
pub fn term_passthrough<T>(&mut self, value: T)
where
T: Send + Sync + 'static
{
self
.passthrough_term
.insert(TypeId::of::<T>(), Box::new(value));
}
}
impl Drop for InitCtx {
fn drop(&mut self) {
let checkpoint = self.cnt.fetch_add(1, Ordering::SeqCst);
self
.sr
.starting(checkpoint, Some("Initialization phase finished"));
}
}
pub struct TermCtx {
re: RunEnv,
sr: ServiceReporter,
cnt: Arc<AtomicU32>,
passthrough: HashMap<TypeId, Box<dyn Any + Send + Sync>>
}
impl TermCtx {
#[must_use]
pub fn runenv(&self) -> RunEnv {
self.re.clone()
}
pub fn report(&self, status: Option<&str>) {
let checkpoint = self.cnt.fetch_add(1, Ordering::SeqCst);
if let Some(msg) = status {
tracing::trace!("Reached term checkpoint {checkpoint}; {msg}");
} else {
tracing::trace!("Reached term checkpoint {checkpoint}");
}
self.sr.stopping(checkpoint, status);
}
#[must_use]
pub fn get<T: 'static>(&self) -> Option<&T> {
self
.passthrough
.get(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_ref::<T>())
}
pub fn take<T: 'static>(&mut self) -> Option<T> {
self
.passthrough
.remove(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast::<T>().ok())
.map(|v| *v)
}
}
impl Drop for TermCtx {
fn drop(&mut self) {
let checkpoint = self.cnt.fetch_add(1, Ordering::SeqCst);
self
.sr
.stopping(checkpoint, Some("Termination phase finished"));
}
}
pub trait ServiceHandler {
type AppErr;
fn init(&mut self, ictx: &mut InitCtx) -> Result<(), Self::AppErr>;
fn run(&mut self, re: &RunEnv) -> Result<(), Self::AppErr>;
fn shutdown(&mut self, tctx: &mut TermCtx) -> Result<(), Self::AppErr>;
}
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[async_trait]
pub trait TokioServiceHandler {
type AppErr;
async fn init(&mut self, ictx: &mut InitCtx) -> Result<(), Self::AppErr>;
async fn run(&mut self, re: &RunEnv) -> Result<(), Self::AppErr>;
async fn shutdown(&mut self, tctx: &mut TermCtx)
-> Result<(), Self::AppErr>;
}
#[cfg(feature = "rocket")]
#[cfg_attr(docsrs, doc(cfg(feature = "rocket")))]
#[async_trait]
pub trait RocketServiceHandler {
type AppErr;
async fn init(
&mut self,
ictx: &mut InitCtx
) -> Result<Vec<rocket::Rocket<rocket::Build>>, Self::AppErr>;
async fn run(
&mut self,
rockets: Vec<rocket::Rocket<rocket::Ignite>>,
re: &RunEnv
) -> Result<(), Self::AppErr>;
async fn shutdown(&mut self, tctx: &mut TermCtx)
-> Result<(), Self::AppErr>;
}
#[derive(Copy, Clone, Debug)]
pub enum Demise {
Interrupted,
Terminated,
ReachedEnd
}
#[derive(Copy, Clone, Debug)]
pub enum UserSig {
Sig1,
Sig2
}
#[derive(Copy, Clone, Debug)]
pub enum SvcEvt {
User(UserSig),
Pause,
Resume,
ReloadConf,
Shutdown(Demise)
}
#[allow(clippy::large_enum_variant, clippy::module_name_repetitions)]
pub enum SrvAppRt<ApEr> {
Sync {
svcevt_handler: Box<dyn FnMut(SvcEvt) + Send>,
rt_handler: Box<dyn ServiceHandler<AppErr = ApEr> + Send>
},
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
Tokio {
rtbldr: Option<runtime::Builder>,
svcevt_handler: Box<dyn FnMut(SvcEvt) + Send>,
rt_handler: Box<dyn TokioServiceHandler<AppErr = ApEr> + Send>
},
#[cfg(feature = "rocket")]
#[cfg_attr(docsrs, doc(cfg(feature = "rocket")))]
Rocket {
svcevt_handler: Box<dyn FnMut(SvcEvt) + Send>,
rt_handler: Box<dyn RocketServiceHandler<AppErr = ApEr> + Send>
}
}
pub struct RunCtx {
service: bool,
svcname: String,
log_init: bool,
passthrough_init: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
passthrough_term: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
test_mode: bool
}
impl RunCtx {
#[cfg(all(target_os = "linux", feature = "systemd"))]
fn systemd<ApEr>(self, st: SrvAppRt<ApEr>) -> Result<(), CbErr<ApEr>>
where
ApEr: Send + std::fmt::Debug
{
LumberJack::default()
.set_init(self.log_init)
.service()
.init()?;
tracing::debug!("Running service '{}'", self.svcname);
let sr = systemd::ServiceReporter {};
let sr = ServiceReporter::new(sr);
let re = RunEnv::Service(Some(self.svcname.clone()));
match st {
SrvAppRt::Sync {
svcevt_handler,
rt_handler
} => rttype::sync_main(rttype::SyncMainParams {
re,
svcevt_handler,
rt_handler,
sr,
svcevt_ch: None,
passthrough_init: self.passthrough_init,
passthrough_term: self.passthrough_term,
test_mode: self.test_mode
}),
SrvAppRt::Tokio {
rtbldr,
svcevt_handler,
rt_handler
} => rttype::tokio_main(
rtbldr,
rttype::TokioMainParams {
re,
svcevt_handler,
rt_handler,
sr,
svcevt_ch: None,
passthrough_init: self.passthrough_init,
passthrough_term: self.passthrough_term
}
),
#[cfg(feature = "rocket")]
SrvAppRt::Rocket {
svcevt_handler,
rt_handler
} => rttype::rocket_main(rttype::RocketMainParams {
re,
svcevt_handler,
rt_handler,
sr,
svcevt_ch: None,
passthrough_init: self.passthrough_init,
passthrough_term: self.passthrough_term
})
}
}
#[cfg(windows)]
fn winsvc<ApEr>(self, st: SrvAppRt<ApEr>) -> Result<(), CbErr<ApEr>>
where
ApEr: Send + 'static + std::fmt::Debug
{
winsvc::run(
&self.svcname,
st,
self.passthrough_init,
self.passthrough_term
)?;
Ok(())
}
fn foreground<ApEr>(self, st: SrvAppRt<ApEr>) -> Result<(), CbErr<ApEr>>
where
ApEr: Send + std::fmt::Debug
{
LumberJack::default().set_init(self.log_init).init()?;
tracing::debug!("Running service '{}'", self.svcname);
let sr = nosvc::ServiceReporter {};
let sr = ServiceReporter::new(sr);
match st {
SrvAppRt::Sync {
svcevt_handler,
rt_handler
} => rttype::sync_main(rttype::SyncMainParams {
re: RunEnv::Foreground,
svcevt_handler,
rt_handler,
sr,
svcevt_ch: None,
passthrough_init: self.passthrough_init,
passthrough_term: self.passthrough_term,
test_mode: self.test_mode
}),
#[cfg(feature = "tokio")]
SrvAppRt::Tokio {
rtbldr,
svcevt_handler,
rt_handler
} => rttype::tokio_main(
rtbldr,
rttype::TokioMainParams {
re: RunEnv::Foreground,
svcevt_handler,
rt_handler,
sr,
svcevt_ch: None,
passthrough_init: self.passthrough_init,
passthrough_term: self.passthrough_term
}
),
#[cfg(feature = "rocket")]
SrvAppRt::Rocket {
svcevt_handler,
rt_handler
} => rttype::rocket_main(rttype::RocketMainParams {
re: RunEnv::Foreground,
svcevt_handler,
rt_handler,
sr,
svcevt_ch: None,
passthrough_init: self.passthrough_init,
passthrough_term: self.passthrough_term
})
}
}
}
impl RunCtx {
#[must_use]
pub fn new(name: &str) -> Self {
Self {
service: false,
svcname: name.into(),
log_init: true,
passthrough_init: HashMap::new(),
passthrough_term: HashMap::new(),
test_mode: false
}
}
#[must_use]
pub fn init_passthrough<T>(mut self, data: T) -> Self
where
T: Send + Sync + 'static
{
self.init_passthrough_r(data);
self
}
pub fn init_passthrough_r<T>(&mut self, data: T) -> &mut Self
where
T: Send + Sync + 'static
{
self
.passthrough_init
.insert(TypeId::of::<T>(), Box::new(data));
self
}
#[must_use]
pub fn term_passthrough<T>(mut self, data: T) -> Self
where
T: Send + Sync + 'static
{
self.term_passthrough_r(data);
self
}
pub fn term_passthrough_r<T>(&mut self, data: T) -> &mut Self
where
T: Send + Sync + 'static
{
self
.passthrough_term
.insert(TypeId::of::<T>(), Box::new(data));
self
}
#[doc(hidden)]
#[must_use]
pub const fn test_mode(mut self) -> Self {
self.log_init = false;
self.test_mode = true;
self
}
#[doc(hidden)]
#[must_use]
pub const fn log_init(mut self, flag: bool) -> Self {
self.log_init = flag;
self
}
#[doc(hidden)]
pub const fn log_init_ref(&mut self, flag: bool) -> &mut Self {
self.log_init = flag;
self
}
#[must_use]
pub const fn service(mut self) -> Self {
self.service = true;
self
}
pub const fn service_ref(&mut self) -> &mut Self {
self.service = true;
self
}
#[must_use]
pub const fn is_service(&self) -> bool {
self.service
}
pub fn run<ApEr>(self, st: SrvAppRt<ApEr>) -> Result<(), CbErr<ApEr>>
where
ApEr: Send + 'static + std::fmt::Debug
{
if self.service {
let _ = RUNAS.set(RunAs::SvcSubsys);
#[cfg(all(target_os = "linux", feature = "systemd"))]
self.systemd(st)?;
#[cfg(windows)]
self.winsvc(st)?;
} else {
let _ = RUNAS.set(RunAs::Foreground);
self.foreground(st)?;
}
Ok(())
}
#[allow(clippy::missing_errors_doc)]
pub fn run_sync<ApEr>(
self,
svcevt_handler: Box<dyn FnMut(SvcEvt) + Send>,
rt_handler: Box<dyn ServiceHandler<AppErr = ApEr> + Send>
) -> Result<(), CbErr<ApEr>>
where
ApEr: Send + 'static + std::fmt::Debug
{
self.run(SrvAppRt::Sync {
svcevt_handler,
rt_handler
})
}
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
#[allow(clippy::missing_errors_doc)]
pub fn run_tokio<ApEr>(
self,
rtbldr: Option<runtime::Builder>,
svcevt_handler: Box<dyn FnMut(SvcEvt) + Send>,
rt_handler: Box<dyn TokioServiceHandler<AppErr = ApEr> + Send>
) -> Result<(), CbErr<ApEr>>
where
ApEr: Send + 'static + std::fmt::Debug
{
self.run(SrvAppRt::Tokio {
rtbldr,
svcevt_handler,
rt_handler
})
}
#[cfg(feature = "rocket")]
#[cfg_attr(docsrs, doc(cfg(feature = "rocket")))]
#[allow(clippy::missing_errors_doc)]
pub fn run_rocket<ApEr>(
self,
svcevt_handler: Box<dyn FnMut(SvcEvt) + Send>,
rt_handler: Box<dyn RocketServiceHandler<AppErr = ApEr> + Send>
) -> Result<(), CbErr<ApEr>>
where
ApEr: Send + 'static + std::fmt::Debug
{
self.run(SrvAppRt::Rocket {
svcevt_handler,
rt_handler
})
}
}
fn svcevt_thread(
mut rx: broadcast::Receiver<SvcEvt>,
mut evt_handler: Box<dyn FnMut(SvcEvt) + Send>
) {
while let Ok(msg) = rx.blocking_recv() {
tracing::debug!("Received {:?}", msg);
#[cfg(all(target_os = "linux", feature = "systemd"))]
if matches!(msg, SvcEvt::ReloadConf) {
let ts =
nix::time::clock_gettime(nix::time::ClockId::CLOCK_MONOTONIC).unwrap();
let s = format!(
"RELOADING=1\nMONOTONIC_USEC={}{:06}",
ts.tv_sec(),
ts.tv_nsec() / 1000
);
tracing::trace!("Sending notification to systemd: {}", s);
let custom = NotifyState::Custom(&s);
if let Err(e) = sd_notify::notify(&[custom]) {
log::error!("Unable to send RELOADING=1 notification to systemd; {e}");
}
}
evt_handler(msg);
#[cfg(all(target_os = "linux", feature = "systemd"))]
if matches!(msg, SvcEvt::ReloadConf) {
tracing::trace!("Sending notification to systemd: READY=1");
if let Err(e) = sd_notify::notify(&[NotifyState::Ready]) {
log::error!("Unable to send READY=1 notification to systemd; {e}");
}
}
if let SvcEvt::Shutdown(_) = msg {
tracing::debug!("Terminating thread");
break;
}
}
}