use std::{io::ErrorKind, pin::Pin};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::{UnixListener, UnixSocket, UnixStream},
};
#[cfg(not(target_os = "macos"))]
use tokio_vsock::{VsockListener, VsockStream};
use super::{IOError, SocketAddress};
const MIB: usize = 1024 * 1024;
pub const MAX_PAYLOAD_SIZE: usize = 128 * MIB;
#[derive(Debug)]
enum InnerListener {
Unix(UnixListener),
#[cfg(not(target_os = "macos"))]
Vsock(VsockListener),
}
#[derive(Debug)]
enum InnerStream {
Unix(UnixStream),
#[cfg(not(target_os = "macos"))]
Vsock(VsockStream),
}
#[derive(Debug)]
pub struct Stream {
address: Option<SocketAddress>,
inner: Option<InnerStream>,
}
impl From<&Stream> for Stream {
fn from(other: &Stream) -> Self {
Self { address: other.address.clone(), inner: None }
}
}
impl Stream {
fn unix_accepted(stream: UnixStream) -> Self {
Self { address: None, inner: Some(InnerStream::Unix(stream)) }
}
#[cfg(not(target_os = "macos"))]
fn vsock_accepted(stream: VsockStream) -> Self {
Self { address: None, inner: Some(InnerStream::Vsock(stream)) }
}
#[must_use]
pub fn new(address: &SocketAddress) -> Self {
Self { address: Some(address.clone()), inner: None }
}
pub async fn connect(&mut self) -> Result<(), IOError> {
let addr = self.address()?;
match self.address()? {
SocketAddress::Unix(_uaddr) => {
let inner = unix_connect(addr).await?;
self.inner = Some(InnerStream::Unix(inner));
}
#[cfg(not(target_os = "macos"))]
SocketAddress::Vsock(_vaddr) => {
let inner = vsock_connect(addr).await?;
self.inner = Some(InnerStream::Vsock(inner));
}
}
Ok(())
}
pub async fn reconnect(&mut self) -> Result<(), IOError> {
let addr = self.address()?.clone();
match &mut self.inner_mut()? {
InnerStream::Unix(s) => {
*s = unix_connect(&addr).await?;
}
#[cfg(not(target_os = "macos"))]
InnerStream::Vsock(s) => {
*s = vsock_connect(&addr).await?;
}
}
Ok(())
}
pub async fn send(&mut self, buf: &[u8]) -> Result<(), IOError> {
match &mut self.inner_mut()? {
InnerStream::Unix(s) => send(s, buf).await,
#[cfg(not(target_os = "macos"))]
InnerStream::Vsock(s) => send(s, buf).await,
}
}
pub async fn recv(&mut self) -> Result<Vec<u8>, IOError> {
match &mut self.inner_mut()? {
InnerStream::Unix(s) => recv(s).await,
#[cfg(not(target_os = "macos"))]
InnerStream::Vsock(s) => recv(s).await,
}
}
pub async fn call(&mut self, req_buf: &[u8]) -> Result<Vec<u8>, IOError> {
if self.inner.is_none() {
self.connect().await?;
} else {
eprintln!("SocketStream already connected, call proceeding");
}
self.send(req_buf).await?;
self.recv().await
}
pub fn address(&self) -> Result<&SocketAddress, IOError> {
self.address.as_ref().ok_or(IOError::ConnectAddressInvalid)
}
fn inner_mut(&mut self) -> Result<&mut InnerStream, IOError> {
self.inner.as_mut().ok_or(IOError::DisconnectedStream)
}
pub fn reset(&mut self) {
self.inner = None;
}
pub fn is_connected(&self) -> bool {
self.inner.is_some()
}
}
async fn send<S: AsyncWriteExt + Unpin>(
stream: &mut S,
buf: &[u8],
) -> Result<(), IOError> {
const MAX_WRITE_SIZE: usize = 31234;
let length = buf.len();
if length > MAX_PAYLOAD_SIZE {
return Err(IOError::OversizedPayload(length));
}
let len_buf: [u8; size_of::<u64>()] = (length as u64).to_le_bytes();
stream.write_all(&len_buf).await?;
let mut total = 0;
while total < length {
total += stream
.write(
&buf[total..std::cmp::min(buf.len(), total + MAX_WRITE_SIZE)],
)
.await?;
}
Ok(())
}
async fn recv<S: AsyncReadExt + Unpin>(
stream: &mut S,
) -> Result<Vec<u8>, IOError> {
let length: usize = {
let mut buf = [0u8; size_of::<u64>()];
stream.read_exact(&mut buf).await.map_err(|e| match e.kind() {
ErrorKind::UnexpectedEof => IOError::RecvConnectionClosed,
_ => IOError::StdIoError(e),
})?;
u64::from_le_bytes(buf)
.try_into()
.map_err(|_| IOError::ArithmeticSaturation)?
};
if length > MAX_PAYLOAD_SIZE {
return Err(IOError::OversizedPayload(length));
}
let mut buf = vec![0; length];
stream.read_exact(&mut buf).await.map_err(|e| match e.kind() {
ErrorKind::UnexpectedEof => IOError::RecvConnectionClosed,
_ => IOError::StdIoError(e),
})?;
Ok(buf)
}
impl From<IOError> for std::io::Error {
fn from(value: IOError) -> Self {
match value {
IOError::DisconnectedStream => std::io::Error::new(
std::io::ErrorKind::NotFound,
"connection not found",
),
_ => std::io::Error::other("unknown error"),
}
}
}
impl AsyncRead for Stream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match &mut self.inner_mut()? {
InnerStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(not(target_os = "macos"))]
InnerStream::Vsock(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for Stream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match &mut self.inner_mut()? {
InnerStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(not(target_os = "macos"))]
InnerStream::Vsock(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.inner_mut()? {
InnerStream::Unix(s) => Pin::new(s).poll_flush(cx),
#[cfg(not(target_os = "macos"))]
InnerStream::Vsock(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match &mut self.inner_mut()? {
InnerStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(not(target_os = "macos"))]
InnerStream::Vsock(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
pub struct Listener {
inner: InnerListener,
addr: SocketAddress, }
impl Listener {
pub(crate) fn listen(addr: &SocketAddress) -> Result<Self, IOError> {
let listener = match *addr {
SocketAddress::Unix(uaddr) => {
let path =
uaddr.path().ok_or(IOError::ConnectAddressInvalid)?;
if path.exists() {
_ = std::fs::remove_file(path);
}
let inner = InnerListener::Unix(UnixListener::bind(path)?);
Self { inner, addr: addr.clone() }
}
#[cfg(not(target_os = "macos"))]
SocketAddress::Vsock(vaddr) => {
let inner = InnerListener::Vsock(VsockListener::bind(vaddr)?);
Self { inner, addr: addr.clone() }
}
};
Ok(listener)
}
pub async fn accept(&self) -> Result<Stream, IOError> {
let stream = match &self.inner {
InnerListener::Unix(l) => {
let (s, _) = l.accept().await?;
Stream::unix_accepted(s)
}
#[cfg(not(target_os = "macos"))]
InnerListener::Vsock(l) => {
let (s, _) = l.accept().await?;
Stream::vsock_accepted(s)
}
};
Ok(stream)
}
pub fn addr(&self) -> &SocketAddress {
&self.addr
}
}
impl Drop for Listener {
fn drop(&mut self) {
match &mut self.inner {
InnerListener::Unix(usock) => match usock.local_addr() {
Ok(addr) => {
if let Some(path) = addr.as_pathname() {
_ = std::fs::remove_file(path);
} else {
eprintln!("unable to path the usock"); }
}
Err(e) => eprintln!("{e}"), },
#[cfg(not(target_os = "macos"))]
InnerListener::Vsock(_vsock) => {} }
}
}
async fn unix_connect(
addr: &SocketAddress,
) -> Result<UnixStream, std::io::Error> {
let addr = addr.usock();
let path = addr.path().ok_or(IOError::ConnectAddressInvalid)?;
let socket = UnixSocket::new_stream()?;
socket.connect(path).await
}
#[cfg(not(target_os = "macos"))]
async fn vsock_connect(
addr: &SocketAddress,
) -> Result<VsockStream, std::io::Error> {
let addr = addr.vsock();
VsockStream::connect(*addr).await
}
#[cfg(test)]
mod test {
use super::*;
use crate::{client::SocketClient, io::StreamPool};
use std::{str::from_utf8, time::Duration};
pub async fn wait_for_usock(path: &str) {
let addr = SocketAddress::new_unix(path);
let pool = StreamPool::new(addr, 1).unwrap().shared();
let client = SocketClient::new(pool, Duration::from_millis(50));
for _ in 0..50 {
if std::fs::exists(path).unwrap()
&& client.try_connect().await.is_ok()
{
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
pub struct HarakiriPongServer {
path: String,
}
impl Drop for HarakiriPongServer {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.path);
}
}
impl HarakiriPongServer {
pub fn new(path: String) -> Self {
let _ = std::fs::remove_file(&path);
Self { path }
}
pub async fn start(&mut self) {
let listener = UnixListener::bind(&self.path).unwrap();
let (_stream, _peer_addr) = listener.accept().await.unwrap();
let (mut stream, _peer_addr) = listener.accept().await.unwrap();
let mut buf = [0u8; 4];
let r = stream.read_exact(&mut buf).await;
eprintln!("BYTES: {buf:?}");
r.unwrap();
if from_utf8(&buf).unwrap() == "PING" {
let _ = stream.write(b"PONG").await.unwrap();
}
}
}
#[tokio::test]
async fn stream_integration_test() {
let addr: SocketAddress =
SocketAddress::new_unix("/tmp/stream_integration_test.sock");
let listener: Listener = Listener::listen(&addr).unwrap();
let mut client = Stream::new(&addr);
client.connect().await.unwrap();
let mut server = listener.accept().await.unwrap();
let data = vec![1, 2, 3, 4, 5, 6, 6, 6];
client.send(&data).await.unwrap();
let resp = server.recv().await.unwrap();
assert_eq!(data, resp);
}
#[tokio::test]
async fn stream_implements_read_write_traits() {
let socket_server_path =
"/tmp/stream_implements_read_write_traits.sock";
let mut server =
HarakiriPongServer::new(socket_server_path.to_string());
tokio::spawn(async move {
server.start().await;
});
wait_for_usock(socket_server_path).await;
let addr = SocketAddress::new_unix(socket_server_path);
let mut pong_stream = Stream::new(&addr);
pong_stream.connect().await.unwrap();
let written = pong_stream.write(b"PING").await.unwrap();
assert_eq!(written, 4);
let mut resp = [0u8; 4];
let res = pong_stream.read(&mut resp).await.unwrap();
assert_eq!(res, 4);
assert_eq!(from_utf8(&resp).unwrap(), "PONG");
}
#[tokio::test]
async fn listener_accept_test() {
let addr = SocketAddress::new_unix("./listener_iterator_test.sock");
let listener = Listener::listen(&addr).unwrap();
let handler = tokio::spawn(async move {
if let Ok(mut stream) = listener.accept().await {
let req = stream.recv().await.unwrap();
stream.send(&req).await.unwrap();
}
});
let mut client = Stream::new(&addr);
client.connect().await.unwrap();
let data = vec![1, 2, 3, 4, 5, 6, 6, 6];
client.send(&data).await.unwrap();
let resp = client.recv().await.unwrap();
assert_eq!(data, resp);
handler.await.unwrap();
}
#[tokio::test]
async fn limit_sized_payload() {
let unix_addr =
nix::sys::socket::UnixAddr::new("./limit_sized_payload.sock")
.unwrap();
let addr = SocketAddress::Unix(unix_addr);
let listener = Listener::listen(&addr).unwrap();
let handler = tokio::spawn(async move {
if let Ok(mut stream) = listener.accept().await {
let req = stream.recv().await.unwrap();
stream.send(&req.clone()).await.unwrap();
}
});
let mut client = Stream::new(&addr);
client.connect().await.unwrap();
let req = vec![1u8; MAX_PAYLOAD_SIZE];
client.send(&req).await.unwrap();
let resp = client.recv().await.unwrap();
assert_eq!(resp.len(), MAX_PAYLOAD_SIZE);
handler.await.unwrap();
}
#[tokio::test]
async fn oversized_payload() {
let addr = SocketAddress::new_unix("./oversized_payload.sock");
let listener = Listener::listen(&addr).unwrap();
let handler = tokio::spawn(async move {
if let Err(err) = listener.accept().await {
panic!("{err:?}");
}
});
let mut client = Stream::new(&addr);
client.connect().await.unwrap();
let req = vec![1u8; MAX_PAYLOAD_SIZE + 1];
match client.send(&req).await.unwrap_err() {
IOError::OversizedPayload(size) => {
assert_eq!(size, MAX_PAYLOAD_SIZE + 1);
}
other => {
panic!("test failed: unexpected error variant ({other:?})");
}
}
handler.await.unwrap();
}
}