use std::{
collections::HashMap,
fmt::Debug,
marker::PhantomData,
path::{Path, PathBuf},
pin::Pin,
task::{Context, Poll},
};
use futures::Stream;
use tokio::{
net::{UnixListener, UnixStream},
select,
sync::{
broadcast::Sender as BroadcastSender,
mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
oneshot::{Sender as OneshotSender, channel as oneshot_channel},
},
task::JoinHandle,
};
use tracing::{debug, warn};
use crate::{codecs::Codec, error::Error, handle::Handle, unix::UnixSocketId};
pub struct UnixServer<C: Codec<OUT, IN>, OUT: Debug + Send + 'static, IN: Debug + Send + 'static> {
local_addr: PathBuf,
out_tx: UnboundedSender<OutgoingMessage<OUT>>,
in_rx: UnboundedReceiver<(IN, UnixSocketId)>,
exit_tx: BroadcastSender<()>,
_listener_handle: JoinHandle<()>,
_router_handle: JoinHandle<()>,
_c: PhantomData<C>,
}
type OutgoingMessage<OUT> = (OUT, UnixSocketId, OneshotSender<Result<(), Error>>);
enum UnixConnection<C: Codec<OUT, IN>, OUT: Debug + Send + 'static, IN: Debug + Send + 'static> {
Connected {
out_tx: UnboundedSender<OUT>,
handle: Handle<C, OUT, IN, UnixSocketId>,
},
Disconnected {
id: UnixSocketId,
},
}
impl<C: Codec<OUT, IN>, OUT: Debug + Send + 'static, IN: Debug + Send + 'static> Unpin
for UnixServer<C, OUT, IN>
{
}
impl<C: Codec<OUT, IN>, OUT: Debug + Send + 'static, IN: Debug + Send + 'static>
UnixServer<C, OUT, IN>
{
pub async fn bind(path: &Path) -> Result<Self, Error> {
std::fs::remove_file(&path).ok();
let listener = UnixListener::bind(&path)?;
debug!("Unix server listening on {}", path.display());
let (exit_tx, _) = tokio::sync::broadcast::channel::<()>(1);
let (in_tx, in_rx) = unbounded_channel();
let (out_tx, mut out_rx) = unbounded_channel::<OutgoingMessage<OUT>>();
let (conn_tx, mut conn_rx) = unbounded_channel();
let mut socket_id_count = 0u64;
let exit_tx_ = exit_tx.clone();
let _listener_handle = tokio::task::spawn(async move {
let mut exit_rx = exit_tx_.subscribe();
loop {
select! {
Ok((socket, _addr)) = listener.accept() => {
let socket_id = UnixSocketId(socket_id_count);
socket_id_count += 1;
match Self::handle_connection(socket_id, socket, in_tx.clone()).await {
Ok((conn_out_tx, handle)) => {
let conn_tx_ = conn_tx.clone();
handle.on_closed(move || {
conn_tx_.send(UnixConnection::Disconnected { id: socket_id }).ok();
});
conn_tx.send(UnixConnection::Connected { out_tx: conn_out_tx, handle }).unwrap_or_else(|e| {
warn!("Failed to forward connection from {} to router {e:?}", socket_id);
});
}
Err(e) => {
warn!("Failed to handle connection from {}: {e:?}", socket_id);
}
}
}
_ = exit_rx.recv() => {
break;
}
}
}
debug!("Shutting down Unix server listener");
});
let mut exit_rx = exit_tx.subscribe();
let _router_handle = tokio::task::spawn(async move {
let mut clients: HashMap<
UnixSocketId,
(UnboundedSender<OUT>, Handle<C, OUT, IN, UnixSocketId>),
> = HashMap::new();
loop {
select! {
Some((msg, target, resp_tx)) = out_rx.recv() => {
if let Some((out_tx, _handle)) = clients.get(&target) {
debug!("Routing message to {}: {:?}", target, msg);
match out_tx.send(msg) {
Ok(_) => {
let _ = resp_tx.send(Ok(()));
}
Err(e) => {
warn!("Failed to send message to {target}: {e:?}");
let _ = resp_tx.send(Err(Error::Send));
if let Some((_out_rx, handle)) = clients.remove(&target) {
debug!("Shutting down client {}", target);
let _ = handle.close();
}
}
}
} else {
warn!("No client found for target {target}");
let _ = resp_tx.send(Err(Error::Send));
}
},
Some(evt) = conn_rx.recv() => match evt {
UnixConnection::Connected { out_tx, handle } => {
debug!("Client connected: {}", handle.addr());
clients.insert(handle.addr(), (out_tx, handle));
}
UnixConnection::Disconnected { id } => {
debug!("Client disconnected: {}", id);
if let Some((_out_tx, handle)) = clients.remove(&id) {
let _ = handle.close();
}
}
},
_ = exit_rx.recv() => {
break;
}
}
}
debug!("Shutting down Unix server router");
for (_addr, (_out_tx, handle)) in clients.drain() {
let _ = handle.close();
}
});
Ok(Self {
local_addr: path.to_path_buf(),
out_tx,
in_rx,
exit_tx,
_listener_handle,
_router_handle,
_c: PhantomData,
})
}
pub fn local_path(&self) -> &Path {
&self.local_addr
}
pub async fn send(&mut self, msg: OUT, target: UnixSocketId) -> Result<(), Error> {
debug!("Sending message to {target}: {msg:?}");
let (resp_tx, resp_rx) = oneshot_channel();
self.out_tx.send((msg, target, resp_tx)).map_err(|_e| Error::Send)?;
resp_rx.await.map_err(|_e| Error::Send)?
}
pub async fn shutdown(&self) {
let _ = self.exit_tx.send(());
}
async fn handle_connection(
id: UnixSocketId,
socket: UnixStream,
in_rx: UnboundedSender<(IN, UnixSocketId)>,
) -> Result<(UnboundedSender<OUT>, Handle<C, OUT, IN, UnixSocketId>), Error> {
debug!("New connection from {id}");
let (out_tx, out_rx) = unbounded_channel();
let (stream_rx, stream_tx) = socket.into_split();
let handle = Handle::new(id, stream_rx, stream_tx, out_rx, in_rx).await?;
Ok((out_tx, handle))
}
}
impl<C: Codec<OUT, IN>, OUT: Debug + Send + 'static, IN: Debug + Send + 'static> Stream
for UnixServer<C, OUT, IN>
{
type Item = (IN, UnixSocketId);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().in_rx.poll_recv(cx)
}
}