use std::collections::HashMap;
use std::io;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use futures::channel::mpsc;
use futures::prelude::*;
use futures::select;
use futures::task::{Context, Poll};
use async_std::os::unix::net::{UnixListener, UnixStream};
use async_std::task::{self, JoinHandle};
use tracing::{span, Level};
use tracing_futures::Instrument;
use bytes::Bytes;
pub const UNIX_BUFF_LEN: usize = 10 * 1024;
#[derive(Debug, PartialEq)]
pub enum UnixError {
Io(io::ErrorKind),
Sender(mpsc::SendError),
NoMatchingConnection,
}
impl From<io::Error> for UnixError {
fn from(e: io::Error) -> Self {
Self::Io(e.kind())
}
}
impl From<mpsc::SendError> for UnixError {
fn from(e: mpsc::SendError) -> Self {
Self::Sender(e)
}
}
#[derive(Debug, Clone)]
pub struct UnixMessage {
pub connection_id: u32,
pub data: Bytes,
sink: Option<mpsc::Sender<UnixMessage>>,
exit: Option<mpsc::Sender<()>>,
}
impl UnixMessage {
fn new(connection_id: u32, data: Bytes) -> Self {
Self {
connection_id,
data,
sink: None,
exit: None,
}
}
pub fn response(&self, data: Bytes) -> Self {
Self {
connection_id: self.connection_id,
data,
sink: self.sink.clone(),
exit: self.exit.clone(),
}
}
pub(crate) async fn send(&self) -> Result<(), mpsc::SendError> {
let mut ch = self.sink.as_ref().unwrap().clone();
ch.send(self.clone()).await?;
Ok(())
}
pub(crate) async fn close(&self) -> Result<(), mpsc::SendError> {
let mut ch = self.exit.as_ref().unwrap().clone();
ch.send(()).await?;
Ok(())
}
}
impl PartialEq for UnixMessage {
fn eq(&self, o: &Self) -> bool {
self.connection_id == o.connection_id && self.data == o.data
}
}
pub struct Unix {
connections: Arc<Mutex<HashMap<u32, Connection>>>,
rx_stream: mpsc::Receiver<UnixMessage>,
handle: JoinHandle<Result<(), UnixError>>,
}
struct Connection {
index: u32,
sink: mpsc::Sender<UnixMessage>,
exit_sink: mpsc::Sender<()>,
handle: JoinHandle<Result<(), UnixError>>,
}
impl Unix {
pub async fn new(path: &str) -> Result<Self, UnixError> {
debug!("Creating UnixActor with path: {}", path);
let _ = std::fs::remove_file(&path);
let listener = UnixListener::bind(&path).await?;
let mut index = 0;
let connections = Arc::new(Mutex::new(HashMap::new()));
let c = connections.clone();
let (rx_sink, rx_stream) = mpsc::channel::<UnixMessage>(0);
let handle = task::spawn(
async move {
let mut incoming = listener.incoming();
while let Some(stream) = incoming.next().await {
let stream = stream?;
let conn = Connection::new(stream, index, rx_sink.clone());
trace!("connections lock");
c.lock().unwrap().insert(index, conn);
index += 1;
}
Ok(())
}
.instrument(span!(Level::TRACE, "UNIX", path)),
);
Ok(Self {
connections,
rx_stream,
handle,
})
}
pub async fn send(&mut self, msg: UnixMessage, close: bool) -> Result<(), UnixError> {
let connection_id = msg.connection_id;
let (mut tx_sink, mut exit_sink) = {
trace!("response lock");
let mut connections = self.connections.lock().unwrap();
let interface = match connections.get_mut(&connection_id) {
Some(v) => v,
None => return Err(UnixError::NoMatchingConnection),
};
debug!("send on interface: {}", interface.index);
(interface.sink.clone(), interface.exit_sink.clone())
};
tx_sink.send(msg).await?;
if close {
exit_sink.send(()).await?;
}
Ok(())
}
}
impl Stream for Unix {
type Item = UnixMessage;
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.rx_stream).poll_next(ctx)
}
}
impl Drop for Connection {
fn drop(&mut self) {
}
}
impl Connection {
fn new(unix_stream: UnixStream, index: u32, rx_sink: mpsc::Sender<UnixMessage>) -> Connection {
let mut rx_sink = rx_sink;
let (tx_sink, tx_stream) = mpsc::channel::<UnixMessage>(0);
let tx = Some(tx_sink.clone());
let (exit_sink, mut exit_stream) = mpsc::channel::<()>(0);
let exit = Some(exit_sink.clone());
let (mut unix_rx, mut unix_tx) = unix_stream.split();
let handle: JoinHandle<Result<(), UnixError>> = task::spawn(async move {
debug!("new UNIX task {}", index);
let mut buff = vec![0u8; UNIX_BUFF_LEN];
let mut tx_stream = tx_stream.fuse();
loop {
#[cfg(feature = "profile")]
let _fg = ::flame::start_guard("unix::tick");
select!{
tx = tx_stream.next() => {
if let Some(tx) = tx {
trace!("unix tx: {:?}", tx.data);
unix_tx.write(&tx.data).await?;
}
},
res = unix_rx.read(&mut buff).fuse() => {
match res {
Ok(n) => {
if n == 0 {
break
}
let mut u = UnixMessage::new(index, Bytes::copy_from_slice(&buff[..n]));
u.sink = tx.clone();
u.exit = exit.clone();
trace!("unix rx: {:?}", &u.data);
rx_sink.send(u).await?;
},
Err(e) => {
error!("rx error: {:?}", e);
break;
},
}
},
res = exit_stream.next() => {
if let Some(r) = res {
debug!("Received exit");
break;
}
},
}
}
debug!("task UNIX closed {}", index);
Ok(())
}.instrument(span!(Level::TRACE, "UNIX", index)) );
Connection {
index,
handle,
sink: tx_sink,
exit_sink,
}
}
}
#[cfg(test)]
mod test {
use super::*;
use tracing_subscriber::FmtSubscriber;
#[test]
fn test_unix() {
let _ = FmtSubscriber::builder()
.with_max_level(Level::DEBUG)
.try_init();
task::block_on(async {
let mut unix = Unix::new("/tmp/dsf-unix-test")
.await
.expect("Error creating unix socket listener");
let mut stream = UnixStream::connect("/tmp/dsf-unix-test")
.await
.expect("Error connecting to unix socket");
let data = Bytes::copy_from_slice(&[0x11, 0x22, 0x33, 0x44]);
let mut buff = vec![0u8; UNIX_BUFF_LEN];
stream.write(&data).await.expect("Error writing data");
let res = unix.next().await.expect("Error awaiting unix message");
assert_eq!(res, UnixMessage::new(0, data.clone()));
unix.send(UnixMessage::new(0, data.clone()), false)
.await
.expect("Error sending message to client");
let n = stream
.read(&mut buff)
.await
.expect("Error reading from client");
assert_eq!(&buff[..n], &data);
})
}
}