#![allow(missing_docs)]
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};
use crate::{ChildInfo as SupervisorChildInfo, ChildType, RestartPolicy, SupervisorHandle, Worker};
#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
#[rkyv(derive(Debug))]
pub enum SupervisorAddress {
Tcp(String),
Unix(String),
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
#[rkyv(derive(Debug))]
#[rkyv(attr(allow(missing_docs)))]
pub enum RemoteCommand {
Shutdown,
WhichChildren,
TerminateChild {
id: String,
},
Status,
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
#[rkyv(derive(Debug))]
#[rkyv(attr(allow(missing_docs)))]
pub enum RemoteResponse {
Ok,
Children(Vec<ChildInfo>),
Status(SupervisorStatus),
Error(String),
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
#[rkyv(derive(Debug))]
#[rkyv(attr(allow(missing_docs)))]
pub struct ChildInfo {
pub id: String,
pub child_type: ChildType,
pub restart_policy: Option<RestartPolicy>,
}
impl From<SupervisorChildInfo> for ChildInfo {
fn from(info: SupervisorChildInfo) -> Self {
Self {
id: info.id,
child_type: info.child_type,
restart_policy: info.restart_policy,
}
}
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Serialize, Deserialize, Archive, RkyvSerialize, RkyvDeserialize)]
#[rkyv(derive(Debug))]
#[rkyv(attr(allow(missing_docs)))]
pub struct SupervisorStatus {
pub name: String,
pub children_count: usize,
pub restart_strategy: String,
pub uptime_secs: u64,
}
#[derive(Clone)]
pub struct RemoteSupervisorHandle {
address: SupervisorAddress,
}
impl RemoteSupervisorHandle {
#[must_use]
pub fn new(address: SupervisorAddress) -> Self {
Self { address }
}
#[allow(clippy::unused_async)]
pub async fn connect_tcp(addr: impl Into<String>) -> Result<Self, DistributedError> {
let address = SupervisorAddress::Tcp(addr.into());
Ok(Self { address })
}
#[allow(clippy::unused_async)]
pub async fn connect_unix(path: impl Into<String>) -> Result<Self, DistributedError> {
let address = SupervisorAddress::Unix(path.into());
Ok(Self { address })
}
pub async fn send_command(
&self,
cmd: RemoteCommand,
) -> Result<RemoteResponse, DistributedError> {
match &self.address {
SupervisorAddress::Tcp(addr) => {
let mut stream = TcpStream::connect(addr).await?;
send_message(&mut stream, &cmd).await?;
receive_message(&mut stream).await
}
#[cfg(unix)]
SupervisorAddress::Unix(path) => {
let mut stream = UnixStream::connect(path).await?;
send_message(&mut stream, &cmd).await?;
receive_message(&mut stream).await
}
#[cfg(not(unix))]
SupervisorAddress::Unix(_) => Err(DistributedError::Io(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Unix sockets are not supported on this platform",
))),
}
}
pub async fn shutdown(&self) -> Result<(), DistributedError> {
self.send_command(RemoteCommand::Shutdown).await?;
Ok(())
}
pub async fn which_children(&self) -> Result<Vec<ChildInfo>, DistributedError> {
match self.send_command(RemoteCommand::WhichChildren).await? {
RemoteResponse::Children(children) => Ok(children),
RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
_ => Err(DistributedError::UnexpectedResponse),
}
}
pub async fn terminate_child(&self, id: &str) -> Result<(), DistributedError> {
match self
.send_command(RemoteCommand::TerminateChild { id: id.to_owned() })
.await?
{
RemoteResponse::Ok => Ok(()),
RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
_ => Err(DistributedError::UnexpectedResponse),
}
}
pub async fn status(&self) -> Result<SupervisorStatus, DistributedError> {
match self.send_command(RemoteCommand::Status).await? {
RemoteResponse::Status(status) => Ok(status),
RemoteResponse::Error(e) => Err(DistributedError::RemoteError(e)),
_ => Err(DistributedError::UnexpectedResponse),
}
}
}
pub struct SupervisorServer<W: Worker> {
handle: Arc<SupervisorHandle<W>>,
}
impl<W: Worker> SupervisorServer<W> {
#[must_use]
pub fn new(handle: SupervisorHandle<W>) -> Self {
Self {
handle: Arc::new(handle),
}
}
#[cfg(unix)]
pub async fn listen_unix(
self,
path: impl AsRef<std::path::Path>,
) -> Result<(), DistributedError> {
let socket_path = path.as_ref();
let _remove_result = std::fs::remove_file(socket_path);
let listener = UnixListener::bind(socket_path)?;
tracing::info!(path = %socket_path.display(), "server listening on unix socket");
loop {
let (mut stream, _) = listener.accept().await?;
let handle = Arc::clone(&self.handle);
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(&mut stream, handle).await {
tracing::error!(error = %e, "connection error");
}
});
}
}
pub async fn listen_tcp(self, addr: impl AsRef<str>) -> Result<(), DistributedError> {
let listener = TcpListener::bind(addr.as_ref()).await?;
tracing::info!(address = addr.as_ref(), "server listening on tcp");
loop {
let (mut stream, peer) = listener.accept().await?;
tracing::debug!(peer = ?peer, "new connection");
let handle = Arc::clone(&self.handle);
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(&mut stream, handle).await {
tracing::error!(error = %e, "Connection error");
}
});
}
}
async fn handle_connection<S>(
stream: &mut S,
handle: Arc<SupervisorHandle<W>>,
) -> Result<(), DistributedError>
where
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
let command: RemoteCommand = receive_message(stream).await?;
let response = Self::process_command(command, &handle).await;
send_message(stream, &response).await?;
Ok(())
}
async fn process_command(
command: RemoteCommand,
handle: &SupervisorHandle<W>,
) -> RemoteResponse {
match command {
RemoteCommand::Shutdown => match handle.shutdown().await {
Ok(()) => RemoteResponse::Ok,
Err(e) => RemoteResponse::Error(e.to_string()),
},
RemoteCommand::WhichChildren => match handle.which_children().await {
Ok(children) => {
let child_list: Vec<ChildInfo> = children.into_iter().map(Into::into).collect();
RemoteResponse::Children(child_list)
}
Err(e) => RemoteResponse::Error(e.to_string()),
},
RemoteCommand::TerminateChild { id } => match handle.terminate_child(&id).await {
Ok(()) => RemoteResponse::Ok,
Err(e) => RemoteResponse::Error(e.to_string()),
},
RemoteCommand::Status => {
let restart_strategy = handle
.restart_strategy()
.await
.map_or_else(|_| "Unknown".to_owned(), |s| format!("{s:?}"));
let uptime_secs = handle.uptime().await.unwrap_or(0);
RemoteResponse::Status(SupervisorStatus {
name: handle.name().to_owned(),
children_count: handle.which_children().await.map(|c| c.len()).unwrap_or(0),
restart_strategy,
uptime_secs,
})
}
}
}
}
async fn send_message<S, T>(stream: &mut S, msg: &T) -> Result<(), DistributedError>
where
S: AsyncWriteExt + Unpin,
T: Serialize,
for<'a> T: RkyvSerialize<
rkyv::api::high::HighSerializer<
rkyv::util::AlignedVec,
rkyv::ser::allocator::ArenaHandle<'a>,
rkyv::rancor::Error,
>,
>,
{
let encoded = rkyv::to_bytes::<rkyv::rancor::Error>(msg)?;
let len = u32::try_from(encoded.len())
.map_err(|_| DistributedError::MessageTooLarge(encoded.len()))?;
stream.write_all(&len.to_be_bytes()).await?;
stream.write_all(&encoded).await?;
stream.flush().await?;
Ok(())
}
#[allow(clippy::as_conversions)]
async fn receive_message<S, T>(stream: &mut S) -> Result<T, DistributedError>
where
S: AsyncReadExt + Unpin,
T: Archive,
for<'a> T::Archived: RkyvDeserialize<T, rkyv::api::high::HighDeserializer<rkyv::rancor::Error>>,
{
let mut len_bytes = [0u8; 4];
stream.read_exact(&mut len_bytes).await?;
let len = u32::from_be_bytes(len_bytes) as usize;
if len > 10_000_000 {
return Err(DistributedError::MessageTooLarge(len));
}
let mut buffer = vec![0u8; len];
stream.read_exact(&mut buffer).await?;
let decoded: T = unsafe { rkyv::from_bytes_unchecked::<T, rkyv::rancor::Error>(&buffer)? };
Ok(decoded)
}
#[derive(Debug)]
pub enum DistributedError {
Io(std::io::Error),
Encode(rkyv::rancor::Error),
Decode(rkyv::rancor::Error),
RemoteError(String),
UnexpectedResponse,
MessageTooLarge(usize),
}
impl fmt::Display for DistributedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DistributedError::Io(e) => write!(f, "IO error: {e}"),
DistributedError::Encode(e) => write!(f, "Encode error: {e}"),
DistributedError::Decode(e) => write!(f, "Decode error: {e}"),
DistributedError::RemoteError(e) => write!(f, "Remote error: {e}"),
DistributedError::UnexpectedResponse => write!(f, "Unexpected response from remote"),
DistributedError::MessageTooLarge(size) => {
write!(f, "Message too large: {size} bytes")
}
}
}
}
impl std::error::Error for DistributedError {}
impl From<std::io::Error> for DistributedError {
fn from(e: std::io::Error) -> Self {
DistributedError::Io(e)
}
}
impl From<rkyv::rancor::Error> for DistributedError {
fn from(e: rkyv::rancor::Error) -> Self {
DistributedError::Encode(e)
}
}