#[cfg(feature = "systemd_notify")]
use std::time::Duration;
use std::{
future::Future,
io, mem,
net::SocketAddr,
ops::Deref,
path::{Path, PathBuf},
sync::Arc,
task::{Context, Poll},
thread,
};
#[cfg(feature = "systemd_sockets")]
pub(super) mod systemd_sockets;
#[cfg(feature = "systemd_sockets")]
use {
crate::{listener::TryFromRawFd, registry::SockInfo},
systemd_sockets::{SystemdSocketError, SystemdSockets, SystemdSocketsReadError},
};
#[cfg(target_os = "linux")]
#[cfg(feature = "systemd_sockets")]
use tokio_seqpacket::UnixSeqpacketListener;
use bytes::Bytes;
use futures::{stream::SplitStream, StreamExt};
use parking_lot::RwLock;
use socket2::Socket;
use tokio::{
net::{TcpListener, UdpSocket, UnixDatagram, UnixListener},
signal::unix::signal,
sync::oneshot::channel,
};
use tokio_stream::wrappers::{SignalStream, TcpListenerStream, UnixListenerStream};
use tokio_util::{codec::BytesCodec, udp::UdpFramed};
#[cfg(target_os = "linux")]
use crate::seqpacket::UnixSeqpacketListenerStream;
use super::{executioner::upgrade, Ecdysis, UpgradeFinished};
use supervisor::Supervisor;
#[cfg(feature = "systemd_notify")]
use systemd_notify::{SystemdNotifier, SystemdNotifierError};
use trigger::{Trigger, TriggerReason};
pub mod supervisor;
#[cfg(feature = "systemd_notify")]
pub(super) mod systemd_notify;
mod trigger;
pub type UdpStream = SplitStream<UdpFramed<BytesCodec>>;
pub use supervisor::{StopOnShutdown, Stoppable, StoppableStream};
pub use tokio::signal::unix::SignalKind;
#[derive(Debug)]
pub enum ExitMode {
Upgrade,
FullStop,
PartialStop,
#[cfg(feature = "systemd_notify")]
PartialStopWithSystemd(SystemdNotifier),
}
#[derive(Debug, Clone)]
pub enum ExitReason {
Signal(SignalKind),
UnixListener(PathBuf),
}
#[derive(Debug, Copy, Clone)]
pub(crate) enum ExitCondition {
Upgrade,
Stop,
PartialStop,
}
pub struct TokioEcdysisBuilder {
tokio_ecdysis: TokioEcdysis,
triggers: Vec<(Trigger, ExitCondition)>,
#[cfg(feature = "systemd_notify")]
systemd_notifier: Option<SystemdNotifier>,
}
impl Deref for TokioEcdysisBuilder {
type Target = TokioEcdysis;
fn deref(&self) -> &Self::Target {
&self.tokio_ecdysis
}
}
impl TokioEcdysisBuilder {
pub fn new(upgrade_signal_kind: SignalKind) -> io::Result<Self> {
let triggers: Vec<(Trigger, ExitCondition)> = vec![(
Trigger::Signal(
upgrade_signal_kind,
SignalStream::new(signal(upgrade_signal_kind)?),
),
ExitCondition::Upgrade,
)];
Ok(Self {
tokio_ecdysis: TokioEcdysis::new(),
triggers,
#[cfg(feature = "systemd_notify")]
systemd_notifier: None,
})
}
fn trigger_on_signal(
&mut self,
signal_kind: SignalKind,
trigger_action: ExitCondition,
) -> io::Result<()> {
self.triggers.push((
Trigger::Signal(signal_kind, SignalStream::new(signal(signal_kind)?)),
trigger_action,
));
Ok(())
}
pub fn stop_on_signal(&mut self, signal_kind: SignalKind) -> io::Result<()> {
self.trigger_on_signal(signal_kind, ExitCondition::Stop)
}
pub fn partial_stop_on_signal(&mut self, signal_kind: SignalKind) -> io::Result<()> {
self.trigger_on_signal(signal_kind, ExitCondition::PartialStop)
}
fn trigger_on_socket<P>(
&mut self,
listen_path: P,
trigger_action: ExitCondition,
) -> io::Result<()>
where
P: AsRef<Path> + std::fmt::Debug,
{
self.triggers.push((
Trigger::Uds(
listen_path.as_ref().to_path_buf(),
self.tokio_ecdysis
.listen_unix(StopOnShutdown::Yes, listen_path)?,
),
trigger_action,
));
Ok(())
}
pub fn upgrade_on_socket<P>(&mut self, listen_path: P) -> io::Result<()>
where
P: AsRef<Path> + std::fmt::Debug,
{
self.trigger_on_socket(listen_path, ExitCondition::Upgrade)
}
pub fn stop_on_socket<P>(&mut self, listen_path: P) -> io::Result<()>
where
P: AsRef<Path> + std::fmt::Debug,
{
self.trigger_on_socket(listen_path, ExitCondition::Stop)
}
pub fn partial_stop_on_socket<P>(&mut self, listen_path: P) -> io::Result<()>
where
P: AsRef<Path> + std::fmt::Debug,
{
self.trigger_on_socket(listen_path, ExitCondition::PartialStop)
}
pub fn set_pid_file<P: AsRef<Path>>(&mut self, pid_file: P) {
self.tokio_ecdysis.inner.set_pid_file(pid_file)
}
#[cfg(feature = "systemd_notify")]
pub fn enable_systemd_notifications(&mut self) -> Result<(), SystemdNotifierError> {
self.systemd_notifier = Some(SystemdNotifier::new()?);
Ok(())
}
#[cfg(feature = "systemd_notify")]
pub async fn extend_systemd_timeouts(&mut self, extension: Duration) -> io::Result<()> {
if let Some(systemd_notifier) = &mut self.systemd_notifier {
systemd_notifier.notify_extend_timeouts(extension).await?;
}
Ok(())
}
pub fn ready(
self,
) -> io::Result<(
Arc<TokioEcdysis>,
impl Future<Output = TokioEcdysisUpgradeResult>,
)> {
let Self {
mut tokio_ecdysis,
triggers,
#[cfg(feature = "systemd_notify")]
systemd_notifier,
} = self;
tokio_ecdysis.inner.ready()?;
let tokio_ecdysis_arc = Arc::new(tokio_ecdysis);
let upgrader = TokioEcdysisUpgrader {
tokio_ecdysis: tokio_ecdysis_arc.clone(),
triggers,
#[cfg(feature = "systemd_notify")]
systemd_notifier,
};
Ok((tokio_ecdysis_arc, upgrader.monitor()))
}
#[cfg(feature = "systemd_sockets")]
pub fn read_systemd_sockets(&mut self) -> Result<(), SystemdSocketsReadError> {
self.tokio_ecdysis.read_systemd_sockets()
}
}
pub struct TokioEcdysis {
supervisor: RwLock<Supervisor>,
inner: Ecdysis,
#[cfg(feature = "systemd_sockets")]
systemd_sockets: Option<SystemdSockets>,
}
impl TokioEcdysis {
fn new() -> Self {
Self {
supervisor: RwLock::new(Supervisor::new()),
inner: Ecdysis::new(),
#[cfg(feature = "systemd_sockets")]
systemd_sockets: None,
}
}
pub fn is_child(&self) -> bool {
self.inner.is_child()
}
pub fn std_ecdysis(&self) -> &Ecdysis {
&self.inner
}
pub fn listen_unix<P>(
&self,
stop_on_shutdown: StopOnShutdown,
path: P,
) -> io::Result<StoppableStream<UnixListenerStream>>
where
P: AsRef<Path> + std::fmt::Debug,
{
let listener = self.inner.listen_unix(path)?;
listener.set_nonblocking(true)?;
let listener = UnixListener::from_std(listener)?;
let listener = UnixListenerStream::new(listener);
Ok(self
.supervisor
.read()
.supervise_stream(listener, stop_on_shutdown))
}
#[cfg(target_os = "linux")]
pub fn listen_unix_seqpacket<P>(
&self,
stop_on_shutdown: StopOnShutdown,
path: P,
) -> io::Result<StoppableStream<UnixSeqpacketListenerStream>>
where
P: AsRef<Path> + std::fmt::Debug,
{
let listener = self.inner.listen_unix_seqpacket(path)?;
let listener = UnixSeqpacketListenerStream::new(listener);
Ok(self
.supervisor
.read()
.supervise_stream(listener, stop_on_shutdown))
}
pub fn listen_tcp(
&self,
stop_on_shutdown: StopOnShutdown,
addr: SocketAddr,
) -> io::Result<StoppableStream<TcpListenerStream>> {
self.build_listen_tcp(stop_on_shutdown, addr, |b, addr| {
b.bind(&addr.into())?;
b.listen(128)?;
Ok(b.into())
})
}
pub fn build_listen_tcp<F>(
&self,
stop_on_shutdown: StopOnShutdown,
addr: SocketAddr,
sock_build: F,
) -> io::Result<StoppableStream<TcpListenerStream>>
where
F: FnOnce(Socket, SocketAddr) -> io::Result<std::net::TcpListener>,
{
let listener = self.inner.build_listen_tcp(addr, sock_build)?;
listener.set_nonblocking(true)?;
let listener = TcpListener::from_std(listener)?;
let listener = TcpListenerStream::new(listener);
Ok(self
.supervisor
.read()
.supervise_stream(listener, stop_on_shutdown))
}
pub fn build_stream_udp<F>(
&self,
stop_on_shutdown: StopOnShutdown,
addr: SocketAddr,
sock_build: F,
) -> io::Result<StoppableStream<UdpStream>>
where
F: FnOnce(Socket, SocketAddr) -> io::Result<std::net::UdpSocket>,
{
let socket = self.build_socket_udp(addr, sock_build)?;
let (_s, reader) = UdpFramed::new(socket, BytesCodec::new()).split::<(Bytes, _)>();
Ok(self
.supervisor
.read()
.supervise_stream(reader, stop_on_shutdown))
}
pub fn build_stoppable_socket_udp<F>(
&self,
stop_on_shutdown: StopOnShutdown,
addr: SocketAddr,
sock_build: F,
) -> io::Result<Stoppable<UdpSocket>>
where
F: FnOnce(Socket, SocketAddr) -> io::Result<std::net::UdpSocket>,
{
let socket = self.build_socket_udp(addr, sock_build)?;
Ok(self.supervisor.read().supervise(socket, stop_on_shutdown))
}
fn build_socket_udp<F>(&self, addr: SocketAddr, sock_build: F) -> io::Result<UdpSocket>
where
F: FnOnce(Socket, SocketAddr) -> io::Result<std::net::UdpSocket>,
{
let socket = self.inner.build_socket_udp(addr, sock_build)?;
socket.set_nonblocking(true)?;
UdpSocket::from_std(socket)
}
pub fn unix_datagram_pair(
&self,
name: String,
) -> (Option<io::Result<UnixDatagram>>, io::Result<UnixDatagram>) {
let (unix_datagram_to_parent_option, unix_datagram_to_child_result) =
self.inner.unix_datagram_pair(name);
let from_std = |s: std::os::unix::net::UnixDatagram| {
s.set_nonblocking(true)?;
UnixDatagram::from_std(s)
};
(
unix_datagram_to_parent_option.map(from_std),
unix_datagram_to_child_result.and_then(from_std),
)
}
#[cfg(feature = "systemd_sockets")]
pub fn read_systemd_sockets(&mut self) -> Result<(), SystemdSocketsReadError> {
if self.is_child() {
return Ok(());
}
if self.systemd_sockets.is_some() {
return Err(SystemdSocketsReadError::DuplicateSystemdSocketsRead);
}
let systemd_sockets = SystemdSockets::new()?;
self.systemd_sockets = Some(systemd_sockets);
Ok(())
}
#[cfg(feature = "systemd_sockets")]
async fn systemd_sock_of_proto<P>(
&self,
name: String,
sock_info: SockInfo,
) -> Result<P, SystemdSocketError>
where
P: TryFromRawFd + crate::listener::Listener,
{
if self.is_child() {
return self.read_sock_from_registry_of_proto(sock_info);
}
log::debug!(
"parent: find systemd sock with name {:?} in systemd_sockets",
name
);
let listener = match &self.systemd_sockets {
None => {
return Err(SystemdSocketError::SystemdSocketsNotInitialized);
}
Some(s) => {
let fd = s.find(name).await?;
unsafe { P::try_from_raw_fd(fd)? }
}
};
let listener_info = listener.info()?;
if listener_info.sock_info != sock_info {
return Err(SystemdSocketError::SockInfoIncorrect(format!(
"systemd sock info {:?} does not match expected sock info {:?}",
listener_info.sock_info, sock_info
)));
}
self.inner.registry.add(listener.info()?)?;
Ok(listener)
}
#[cfg(feature = "systemd_sockets")]
fn read_sock_from_registry_of_proto<P>(
&self,
sock_info: SockInfo,
) -> Result<P, SystemdSocketError>
where
P: TryFromRawFd + crate::listener::Listener,
{
log::debug!(
"child: find systemd sock with SockInfo {:?} in registry",
sock_info
);
debug_assert!(self.is_child());
match self.inner.registry.inherit(sock_info.clone()) {
Some(fd) => {
log::debug!("Found existing fd in registry");
let sock = unsafe { P::try_from_raw_fd(fd)? };
Ok(sock)
}
None => {
log::debug!("fd does not exist");
Err(SystemdSocketError::SocketNotFoundInChildRegistry(format!(
"socket not found with SockInfo {:?}",
sock_info
)))
}
}
}
#[cfg(feature = "systemd_sockets")]
pub async fn systemd_listen_unix<P>(
&self,
stop_on_shutdown: StopOnShutdown,
name: String,
path: P,
) -> Result<StoppableStream<UnixListenerStream>, SystemdSocketError>
where
P: AsRef<Path> + std::fmt::Debug,
{
let std_listener: std::os::unix::net::UnixListener = self
.systemd_sock_of_proto(name, SockInfo::Unix(Some(path.as_ref().into())))
.await?;
std_listener.set_nonblocking(true)?;
let tokio_listener = UnixListener::from_std(std_listener)?;
let listener = UnixListenerStream::new(tokio_listener);
Ok(self
.supervisor
.read()
.supervise_stream(listener, stop_on_shutdown))
}
#[cfg(target_os = "linux")]
#[cfg(feature = "systemd_sockets")]
pub async fn systemd_listen_unix_seqpacket<P>(
&self,
stop_on_shutdown: StopOnShutdown,
name: String,
path: P,
) -> Result<StoppableStream<UnixSeqpacketListenerStream>, SystemdSocketError>
where
P: AsRef<Path> + std::fmt::Debug,
{
let listener: UnixSeqpacketListener = self
.systemd_sock_of_proto(name, SockInfo::UnixSeqpacket(Some(path.as_ref().into())))
.await?;
crate::seqpacket::set_nonblocking(listener.as_raw_fd(), true)?;
let stream = UnixSeqpacketListenerStream::new(listener);
Ok(self
.supervisor
.read()
.supervise_stream(stream, stop_on_shutdown))
}
#[cfg(feature = "systemd_sockets")]
pub async fn systemd_listen_tcp(
&self,
stop_on_shutdown: StopOnShutdown,
name: String,
addr: SocketAddr,
) -> Result<StoppableStream<TcpListenerStream>, SystemdSocketError> {
let listener: std::net::TcpListener = self
.systemd_sock_of_proto(name, SockInfo::Tcp(addr))
.await?;
listener.set_nonblocking(true)?;
let listener = TcpListener::from_std(listener)?;
let listener = TcpListenerStream::new(listener);
Ok(self
.supervisor
.read()
.supervise_stream(listener, stop_on_shutdown))
}
#[cfg(feature = "systemd_sockets")]
pub async fn systemd_socket_udp(
&self,
name: String,
addr: SocketAddr,
) -> Result<UdpSocket, SystemdSocketError> {
let socket: std::net::UdpSocket = self
.systemd_sock_of_proto(name, SockInfo::Udp(addr))
.await?;
socket.set_nonblocking(true)?;
Ok(UdpSocket::from_std(socket)?)
}
#[cfg(feature = "systemd_sockets")]
pub async fn systemd_stream_udp(
&self,
stop_on_shutdown: StopOnShutdown,
name: String,
addr: SocketAddr,
) -> Result<StoppableStream<UdpStream>, SystemdSocketError> {
let socket = self.systemd_socket_udp(name, addr).await?;
let (_s, reader) = UdpFramed::new(socket, BytesCodec::new()).split::<(Bytes, _)>();
Ok(self
.supervisor
.read()
.supervise_stream(reader, stop_on_shutdown))
}
}
pub type TokioEcdysisUpgradeResult = Result<(ExitMode, ExitReason), String>;
struct TokioEcdysisUpgrader {
tokio_ecdysis: Arc<TokioEcdysis>,
triggers: Vec<(Trigger, ExitCondition)>,
#[cfg(feature = "systemd_notify")]
systemd_notifier: Option<SystemdNotifier>,
}
impl TokioEcdysisUpgrader {
async fn initialize(&mut self) -> Result<(), String> {
#[cfg(feature = "systemd_notify")]
if let Some(systemd_notifier) = &mut self.systemd_notifier {
systemd_notifier
.notify_ready()
.await
.map_err(|e| e.to_string())?;
}
Ok(())
}
async fn upgrade(&mut self) -> Result<UpgradeFinished, String> {
#[cfg(feature = "systemd_notify")]
if let Some(systemd_notifier) = &mut self.systemd_notifier {
systemd_notifier
.notify_reloading()
.await
.map_err(|e| e.to_string())?;
}
let (tx, rx) = channel();
let fds = self.tokio_ecdysis.inner.registry.get_fds_for_child();
log::warn!("Ecdysis starting upgrade");
thread::spawn(move || {
if let Err(e) = tx.send(upgrade(fds)) {
panic!(
"Cannot send upgrade result{}",
e.map_or_else(|e| format!(": {e}"), |()| "!".into())
);
}
});
rx.await.map_err(|e| e.to_string())
}
async fn monitor_triggers(&mut self) -> io::Result<(TriggerReason, ExitCondition)> {
fn poll_triggers(
triggers: &mut [(Trigger, ExitCondition)],
cx: &mut Context,
) -> Poll<io::Result<(TriggerReason, ExitCondition)>> {
for (trigger, ec) in triggers {
match trigger.poll_next_unpin(cx) {
Poll::Ready(Some(result)) => return Poll::Ready(result.map(|o| (o, *ec))),
Poll::Ready(None) => unreachable!(), Poll::Pending => (),
}
}
Poll::Pending
}
let polling_fut = std::future::poll_fn(|cx| poll_triggers(&mut self.triggers, cx));
polling_fut.await
}
async fn clean_up_triggers(&mut self) -> Result<(), String> {
for (trigger, _ec) in self.triggers.drain(..) {
trigger.cleanup().await?
}
Ok(())
}
async fn on_shutdown(mut self) -> Result<(), String> {
self.clean_up_triggers().await?;
#[cfg(feature = "systemd_notify")]
{
if let Some(systemd_notifier) = &mut self.systemd_notifier {
systemd_notifier
.notify_stopping()
.await
.map_err(|e| e.to_string())?;
}
}
Ok(())
}
fn quit(&mut self, ec: ExitCondition) -> Result<(), String> {
self.tokio_ecdysis.inner.quit();
self.tokio_ecdysis
.supervisor
.write()
.stop_all(ec)
.map_err(|_| "Cannot stop supervised listeners!".into())
}
async fn monitor(mut self) -> TokioEcdysisUpgradeResult {
loop {
self.initialize().await?;
let reason_condition = match self.monitor_triggers().await {
Ok(reason_condition) => reason_condition,
Err(e) => {
self.on_shutdown().await?;
return Err(format!("Encountered error while polling triggers: {e}"));
}
};
let upgrade_reason = match reason_condition {
(reason, ExitCondition::Upgrade) => reason,
(reason, not_upgrade_condition) => {
log::warn!("Ecdysis stopping (reason: {reason:?})");
let exit_reason = match reason {
TriggerReason::Signal(kind) => ExitReason::Signal(kind),
TriggerReason::UnixStream(path, stream) => {
mem::forget(stream);
ExitReason::UnixListener(path)
}
};
self.quit(not_upgrade_condition)?;
if let ExitCondition::PartialStop = not_upgrade_condition {
self.clean_up_triggers().await?;
#[cfg(feature = "systemd_notify")]
if let Some(systemd_notifier) = self.systemd_notifier {
return Ok((
ExitMode::PartialStopWithSystemd(systemd_notifier),
exit_reason,
));
}
return Ok((ExitMode::PartialStop, exit_reason));
}
self.on_shutdown().await?;
return Ok((ExitMode::FullStop, exit_reason));
}
};
log::warn!("Ecdysis starting upgrade (reason: {upgrade_reason:?})");
return match self.upgrade().await {
Ok(upgrade_finished) => match upgrade_finished {
Ok(()) => {
log::info!("Upgrade successful");
self.quit(ExitCondition::Upgrade)?;
Ok((
ExitMode::Upgrade,
match upgrade_reason {
TriggerReason::Signal(kind) => ExitReason::Signal(kind),
TriggerReason::UnixStream(path, _stream) => {
ExitReason::UnixListener(path)
}
},
))
}
Err(e) => {
log::warn!("Upgrade failed: {e}");
log::warn!("Ecdysis returning to listening state!");
let _ = self.tokio_ecdysis.inner.write_pidfile();
continue;
}
},
Err(err_str) => {
self.on_shutdown().await?;
Err(format!("Encountered a problem during upgrade: {err_str}"))
}
};
}
}
}