use std::{
collections::{HashMap, VecDeque},
fmt::Debug,
ops::Deref,
os::fd::{FromRawFd, IntoRawFd, OwnedFd, RawFd},
sync::atomic::AtomicU64,
};
use sendfd::{RecvWithFd, SendWithFd};
use serde::{de::DeserializeOwned, Serialize};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{UnixListener, UnixStream},
sync::{mpsc, oneshot},
};
use super::{framed::frame, types::{FdId, ChannelId}};
#[derive(PartialEq, Eq)]
enum ProcessType {
Primary,
Secondary,
}
pub(crate) struct Comms {
#[allow(dead_code)]
fd_id_gen: AtomicU64,
process_type: ProcessType,
outgoing_tx: mpsc::Sender<(FdId, RawFd)>,
register_callbacks_tx: mpsc::Sender<(FdId, oneshot::Sender<RawFd>)>,
}
impl Comms {
pub async fn connect(path: &std::path::Path) -> Self {
let stream = UnixStream::connect(path).await.unwrap();
let (outgoing_tx, outgoing_rx) = mpsc::channel(32);
let (register_callbacks_tx, register_callbacks_rx) = mpsc::channel(32);
tokio::spawn(Comms::handle_stream(
stream,
outgoing_rx,
register_callbacks_rx,
));
Self {
process_type: ProcessType::Secondary,
fd_id_gen: (ChannelId::NUM_RESERVED_IDS as u64).into(), outgoing_tx,
register_callbacks_tx,
}
}
pub async fn bind(bind_path: &std::path::Path) -> Self {
let listener = UnixListener::bind(bind_path).unwrap();
let (outgoing_tx, outgoing_rx) = mpsc::channel(32);
let (register_callbacks_tx, register_callbacks_rx) = mpsc::channel(32);
tokio::spawn(async move {
let stream = match listener.accept().await {
Ok((stream, _)) => stream,
Err(e) => panic!("Error when connecting: {}", e),
};
Comms::handle_stream(stream, outgoing_rx, register_callbacks_rx).await;
});
Self {
process_type: ProcessType::Primary,
fd_id_gen: (ChannelId::NUM_RESERVED_IDS as u64).into(), outgoing_tx,
register_callbacks_tx,
}
}
async fn handle_stream(
stream: UnixStream,
mut outgoing_rx: mpsc::Receiver<(FdId, RawFd)>,
mut register_callbacks_rx: mpsc::Receiver<(FdId, oneshot::Sender<RawFd>)>,
) {
let (read_stream, write_stream) = stream.into_split();
tokio::spawn(async move {
let ws: &UnixStream = write_stream.as_ref();
while let Some((id, fd)) = outgoing_rx.recv().await {
let fds = [fd];
let bytes = id.0.to_le_bytes();
ws.writable().await.unwrap();
ws.send_with_fd(&bytes, &fds).unwrap();
}
});
let (incoming_tx, mut incoming_rx) = mpsc::channel(32);
tokio::spawn(async move {
let rs: &UnixStream = read_stream.as_ref();
let mut fd_queue: VecDeque<RawFd> = VecDeque::new();
let mut id_queue: VecDeque<FdId> = VecDeque::new();
loop {
let mut bytes = [0u8; 8];
let mut fds = [0; 1];
rs.readable().await.unwrap();
let (num_bytes, num_fds) = match rs.recv_with_fd(&mut bytes, &mut fds) {
Ok(v) => v,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue
},
Err(e) => panic!("Got an error: {:#?}", e)
};
if num_bytes != 0 {
assert_eq!(
num_bytes, 8,
"Got an unexpected number of bytes in FD recv code: {num_bytes}"
);
id_queue.push_back(FdId(u64::from_le_bytes(bytes)));
}
if num_fds != 0 {
assert_eq!(
num_fds, 1,
"Got an unexpected number of fds in FD recv code: {num_fds}"
);
fd_queue.push_back(fds[0]);
}
if !fd_queue.is_empty() && !id_queue.is_empty() {
if incoming_tx
.send((id_queue.pop_front().unwrap(), fd_queue.pop_front().unwrap()))
.await
.is_err()
{
break;
}
} else if num_bytes == 0 && num_fds == 0 {
log::trace!("Got empty fd message");
break;
}
}
});
let mut waiting = HashMap::new();
let mut received = HashMap::new();
loop {
let mut callback_req = None;
let mut incoming_data = None;
tokio::select! {
cr = register_callbacks_rx.recv() => callback_req = cr,
ic = incoming_rx.recv() => incoming_data = ic,
}
if callback_req.is_none() && incoming_data.is_none() {
break;
}
if let Some((requested_id, callback)) = callback_req {
if let Some(fd) = received.remove(&requested_id) {
callback.send(fd).unwrap();
} else {
waiting.insert(requested_id, callback);
}
}
if let Some((fd_id, fd)) = incoming_data {
if let Some(callback) = waiting.remove(&fd_id) {
callback.send(fd).unwrap();
} else {
received.insert(fd_id, fd);
}
}
}
}
#[allow(dead_code)]
pub(crate) async fn send_fd(&self, fd: RawFd) -> FdId {
let mut id = self
.fd_id_gen
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
* 2;
if self.process_type == ProcessType::Secondary {
id += 1;
}
let id = FdId(id);
self.outgoing_tx.send((id, fd)).await.unwrap();
id
}
pub(crate) async fn wait_for_fd(&self, fd_id: FdId) -> RawFd {
let (tx, rx) = oneshot::channel();
self.register_callbacks_tx.send((fd_id, tx)).await.unwrap();
rx.await.unwrap()
}
async fn create_bidi_stream(&self, id: FdId) -> UnixStream {
let (one, two) = UnixStream::pair().unwrap();
let fd = two.into_std().unwrap().into_raw_fd();
self.outgoing_tx.send((id, fd)).await.unwrap();
one
}
async fn get_raw_channel(
&self,
channel_id: ChannelId,
) -> (impl AsyncRead, impl AsyncWrite) {
let id = FdId(channel_id as u64);
let stream = if self.process_type == ProcessType::Primary {
self.create_bidi_stream(id).await
} else {
let fd = self.wait_for_fd(id).await;
let owned = unsafe { OwnedFd::from_raw_fd(fd) };
let std_stream = std::os::unix::net::UnixStream::from(owned);
UnixStream::from_std(std_stream).unwrap()
};
let (read_stream, write_stream) = stream.into_split();
(read_stream, write_stream)
}
pub async fn get_channel<T, U>(
&self,
channel_id: ChannelId,
) -> (mpsc::Sender<T>, mpsc::Receiver<U>)
where
T: Debug + Serialize + Send + 'static,
U: Debug + DeserializeOwned + Send + 'static,
{
let (read_stream, write_stream) = self.get_raw_channel(channel_id).await;
frame(read_stream, write_stream).await
}
}
pub(crate) struct OwnedComms {
_tempdir: tempfile::TempDir,
comms: Comms,
}
impl OwnedComms {
pub(crate) async fn new() -> (Self, std::path::PathBuf) {
let tempdir = tempfile::tempdir().unwrap();
let bind_path = tempdir.path().join("bootstrap");
(
Self {
_tempdir: tempdir,
comms: Comms::bind(bind_path.as_path()).await,
},
bind_path,
)
}
}
impl Deref for OwnedComms {
type Target = Comms;
fn deref(&self) -> &Self::Target {
&self.comms
}
}