#![allow(unused)]
use {
super::{
endpoint::Endpoint,
memory::{
MemoryListener,
MemoryReadHalf,
MemoryStream,
MemoryWriteHalf,
},
},
futures_lite::io::{
AsyncRead,
AsyncWrite,
ReadHalf,
WriteHalf,
split,
},
pin_project::pin_project,
std::{
io,
pin::Pin,
task::{
Context,
Poll,
},
},
};
#[pin_project(project = IpcStreamProj)]
pub enum IpcStream {
#[cfg(unix)]
Unix(#[pin] smol::net::unix::UnixStream),
Memory(#[pin] MemoryStream),
}
impl IpcStream {
pub async fn connect(endpoint: &Endpoint) -> io::Result<Self> {
match endpoint {
#[cfg(unix)]
| Endpoint::UnixSocket(path) => {
let stream = smol::net::unix::UnixStream::connect(path).await?;
Ok(IpcStream::Unix(stream))
},
#[cfg(target_os = "linux")]
| Endpoint::AbstractSocket(name) => {
use std::os::linux::net::SocketAddrExt;
let addr =
std::os::unix::net::SocketAddr::from_abstract_name(name.as_bytes())?;
let stream =
smol::net::unix::UnixStream::connect_addr(addr.into()).await?;
Ok(IpcStream::Unix(stream))
},
#[cfg(windows)]
| Endpoint::NamedPipe(_name) => {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"Windows named pipes not yet implemented",
))
},
| Endpoint::Memory { transport, name } => {
let stream = transport.connect(name).await?;
Ok(IpcStream::Memory(stream))
},
}
}
pub fn into_split(self) -> (IpcReadHalf, IpcWriteHalf) {
match self {
#[cfg(unix)]
| IpcStream::Unix(stream) => {
let (reader, writer) = split(stream);
(IpcReadHalf::Unix(reader), IpcWriteHalf::Unix(writer))
},
| IpcStream::Memory(stream) => {
let (reader, writer) = stream.into_split();
(IpcReadHalf::Memory(reader), IpcWriteHalf::Memory(writer))
},
}
}
}
impl AsyncRead for IpcStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.project() {
#[cfg(unix)]
| IpcStreamProj::Unix(stream) => stream.poll_read(cx, buf),
| IpcStreamProj::Memory(stream) => stream.poll_read(cx, buf),
}
}
}
impl AsyncWrite for IpcStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.project() {
#[cfg(unix)]
| IpcStreamProj::Unix(stream) => stream.poll_write(cx, buf),
| IpcStreamProj::Memory(stream) => stream.poll_write(cx, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match self.project() {
#[cfg(unix)]
| IpcStreamProj::Unix(stream) => stream.poll_flush(cx),
| IpcStreamProj::Memory(stream) => stream.poll_flush(cx),
}
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match self.project() {
#[cfg(unix)]
| IpcStreamProj::Unix(stream) => stream.poll_close(cx),
| IpcStreamProj::Memory(stream) => stream.poll_close(cx),
}
}
}
impl std::fmt::Debug for IpcStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(unix)]
| IpcStream::Unix(_) => {
f.debug_struct("IpcStream::Unix").finish_non_exhaustive()
},
| IpcStream::Memory(s) => {
f.debug_tuple("IpcStream::Memory").field(s).finish()
},
}
}
}
pub enum IpcListener {
#[cfg(unix)]
Unix(smol::net::unix::UnixListener),
Memory(MemoryListener),
}
impl IpcListener {
pub async fn bind(endpoint: &Endpoint) -> io::Result<Self> {
match endpoint {
#[cfg(unix)]
| Endpoint::UnixSocket(path) => {
if path.exists() {
std::fs::remove_file(path)?;
}
let listener = smol::net::unix::UnixListener::bind(path)?;
Ok(IpcListener::Unix(listener))
},
#[cfg(target_os = "linux")]
| Endpoint::AbstractSocket(name) => {
use std::os::linux::net::SocketAddrExt;
let addr =
std::os::unix::net::SocketAddr::from_abstract_name(name.as_bytes())?;
let listener = smol::net::unix::UnixListener::bind_addr(&addr.into())?;
Ok(IpcListener::Unix(listener))
},
#[cfg(windows)]
| Endpoint::NamedPipe(_name) => {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"Windows named pipes not yet implemented",
))
},
| Endpoint::Memory { transport, name } => {
let listener = transport.bind(name).await?;
Ok(IpcListener::Memory(listener))
},
}
}
pub async fn accept(&mut self) -> io::Result<IpcStream> {
match self {
#[cfg(unix)]
| IpcListener::Unix(listener) => {
let (stream, _addr) = listener.accept().await?;
Ok(IpcStream::Unix(stream))
},
| IpcListener::Memory(listener) => {
let stream = listener.accept().await?;
Ok(IpcStream::Memory(stream))
},
}
}
}
impl std::fmt::Debug for IpcListener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(unix)]
| IpcListener::Unix(_) => {
f.debug_struct("IpcListener::Unix").finish_non_exhaustive()
},
| IpcListener::Memory(l) => {
f.debug_tuple("IpcListener::Memory").field(l).finish()
},
}
}
}
pub enum IpcReadHalf {
#[cfg(unix)]
Unix(ReadHalf<smol::net::unix::UnixStream>),
Memory(MemoryReadHalf),
}
impl AsyncRead for IpcReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
#[cfg(unix)]
| IpcReadHalf::Unix(reader) => Pin::new(reader).poll_read(cx, buf),
| IpcReadHalf::Memory(reader) => Pin::new(reader).poll_read(cx, buf),
}
}
}
impl std::fmt::Debug for IpcReadHalf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(unix)]
| IpcReadHalf::Unix(_) => {
f.debug_struct("IpcReadHalf::Unix").finish_non_exhaustive()
},
| IpcReadHalf::Memory(r) => {
f.debug_tuple("IpcReadHalf::Memory").field(r).finish()
},
}
}
}
pub enum IpcWriteHalf {
#[cfg(unix)]
Unix(WriteHalf<smol::net::unix::UnixStream>),
Memory(MemoryWriteHalf),
}
impl AsyncWrite for IpcWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
#[cfg(unix)]
| IpcWriteHalf::Unix(writer) => Pin::new(writer).poll_write(cx, buf),
| IpcWriteHalf::Memory(writer) => Pin::new(writer).poll_write(cx, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
#[cfg(unix)]
| IpcWriteHalf::Unix(writer) => Pin::new(writer).poll_flush(cx),
| IpcWriteHalf::Memory(writer) => Pin::new(writer).poll_flush(cx),
}
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
#[cfg(unix)]
| IpcWriteHalf::Unix(writer) => Pin::new(writer).poll_close(cx),
| IpcWriteHalf::Memory(writer) => Pin::new(writer).poll_close(cx),
}
}
}
impl std::fmt::Debug for IpcWriteHalf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(unix)]
| IpcWriteHalf::Unix(_) => {
f.debug_struct("IpcWriteHalf::Unix").finish_non_exhaustive()
},
| IpcWriteHalf::Memory(w) => {
f.debug_tuple("IpcWriteHalf::Memory").field(w).finish()
},
}
}
}
#[cfg(test)]
mod tests {
use {
super::{
super::memory::MemoryTransport,
*,
},
futures_lite::io::{
AsyncReadExt,
AsyncWriteExt,
},
};
#[test]
fn test_ipc_stream_connect_memory() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "test-connect");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
let client_handle = smol::spawn(async move {
let ep = Endpoint::memory(t, "test-connect");
let stream = IpcStream::connect(&ep).await.unwrap();
assert!(matches!(stream, IpcStream::Memory(_)));
});
let stream = listener.accept().await.unwrap();
assert!(matches!(stream, IpcStream::Memory(_)));
client_handle.await;
});
}
#[test]
fn test_ipc_stream_async_read_write() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "test-rw");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
let client_handle = smol::spawn(async move {
let ep = Endpoint::memory(t, "test-rw");
let mut stream = IpcStream::connect(&ep).await.unwrap();
stream.write_all(b"hello from client").await.unwrap();
let mut buf = [0u8; 17];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello from server");
});
let mut server = listener.accept().await.unwrap();
let mut buf = [0u8; 17];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello from client");
server.write_all(b"hello from server").await.unwrap();
client_handle.await;
});
}
#[test]
fn test_ipc_stream_into_split() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "test-split");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
let client_handle = smol::spawn(async move {
let ep = Endpoint::memory(t, "test-split");
let stream = IpcStream::connect(&ep).await.unwrap();
let (mut reader, mut writer) = stream.into_split();
writer.write_all(b"ping").await.unwrap();
let mut buf = [0u8; 4];
reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"pong");
});
let server = listener.accept().await.unwrap();
let (mut reader, mut writer) = server.into_split();
let mut buf = [0u8; 4];
reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
writer.write_all(b"pong").await.unwrap();
client_handle.await;
});
}
#[test]
fn test_ipc_listener_bind_memory() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "test-bind");
let listener = IpcListener::bind(&endpoint).await.unwrap();
assert!(matches!(listener, IpcListener::Memory(_)));
});
}
#[test]
fn test_ipc_listener_accept() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "test-accept");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let t = transport.clone();
smol::spawn(async move {
let ep = Endpoint::memory(t, "test-accept");
let _stream = IpcStream::connect(&ep).await.unwrap();
})
.detach();
let stream = listener.accept().await.unwrap();
assert!(matches!(stream, IpcStream::Memory(_)));
});
}
#[test]
fn test_ipc_stream_read_after_close() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "test-close-read");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let (done_tx, done_rx) = async_channel::bounded::<()>(1);
let t = transport.clone();
smol::spawn(async move {
let ep = Endpoint::memory(t, "test-close-read");
let stream = IpcStream::connect(&ep).await.unwrap();
drop(stream);
let _ = done_tx.send(()).await;
})
.detach();
let mut server = listener.accept().await.unwrap();
done_rx.recv().await.unwrap();
let mut buf = [0u8; 10];
let n = server.read(&mut buf).await.unwrap();
assert_eq!(n, 0);
});
}
#[test]
fn test_ipc_stream_write_after_close() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "test-close-write");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let (done_tx, done_rx) = async_channel::bounded::<()>(1);
let t = transport.clone();
smol::spawn(async move {
let ep = Endpoint::memory(t, "test-close-write");
let stream = IpcStream::connect(&ep).await.unwrap();
drop(stream);
let _ = done_tx.send(()).await;
})
.detach();
let mut server = listener.accept().await.unwrap();
done_rx.recv().await.unwrap();
let result = server.write_all(b"data").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::BrokenPipe);
});
}
#[test]
fn test_ipc_multiple_connections() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "test-multi");
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let num_clients = 5;
let handles: Vec<_> = (0..num_clients)
.map(|i| {
let t = transport.clone();
smol::spawn(async move {
let ep = Endpoint::memory(t, "test-multi");
let mut stream = IpcStream::connect(&ep).await.unwrap();
let msg = format!("client-{}", i);
stream.write_all(msg.as_bytes()).await.unwrap();
})
})
.collect();
for _ in 0..num_clients {
let mut server = listener.accept().await.unwrap();
let mut buf = [0u8; 8];
let n = server.read(&mut buf).await.unwrap();
assert!(n > 0);
}
for handle in handles {
handle.await;
}
});
}
#[cfg(unix)]
#[test]
fn test_unix_socket_connect() {
smol::block_on(async {
let dir = tempfile::tempdir().unwrap();
let socket_path = dir.path().join("test.sock");
let endpoint = Endpoint::UnixSocket(socket_path.clone());
let mut listener = IpcListener::bind(&endpoint).await.unwrap();
let path = socket_path.clone();
let client_handle = smol::spawn(async move {
let ep = Endpoint::UnixSocket(path);
let mut stream = IpcStream::connect(&ep).await.unwrap();
assert!(matches!(stream, IpcStream::Unix(_)));
stream.write_all(b"unix hello").await.unwrap();
let mut buf = [0u8; 12];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"unix goodbye");
});
let mut server = listener.accept().await.unwrap();
assert!(matches!(server, IpcStream::Unix(_)));
let mut buf = [0u8; 10];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"unix hello");
server.write_all(b"unix goodbye").await.unwrap();
client_handle.await;
});
}
#[cfg(unix)]
#[test]
fn test_unix_socket_cleanup() {
smol::block_on(async {
let dir = tempfile::tempdir().unwrap();
let socket_path = dir.path().join("cleanup.sock");
let endpoint = Endpoint::UnixSocket(socket_path.clone());
{
let _listener = IpcListener::bind(&endpoint).await.unwrap();
assert!(socket_path.exists());
}
let _listener2 = IpcListener::bind(&endpoint).await.unwrap();
assert!(socket_path.exists());
});
}
}