use std::{net::Shutdown, path::Path};
use super::IoResult;
use async_trait::async_trait;
use blocking::{unblock, Unblock};
#[async_trait(?Send)]
pub trait UnixSocketInterface {
type UnixStream;
type UnixListener;
type SocketAddr;
async fn unix_stream_connect(socket_path: impl AsRef<Path>) -> IoResult<Self::UnixStream>;
async fn unix_stream_shutdown(s: &mut Self::UnixStream) -> IoResult<()>;
async fn unix_stream_write(s: &mut Self::UnixStream, buf: &[u8]) -> IoResult<usize>;
async fn unix_stream_write_all(s: &mut Self::UnixStream, buf: &[u8]) -> IoResult<()>;
async fn unix_stream_read(s: &mut Self::UnixStream, buf: &mut [u8]) -> IoResult<usize>;
async fn unix_stream_read_exact(s: &mut Self::UnixStream, buf: &mut [u8]) -> IoResult<()>;
async fn unix_listener_bind(path: impl AsRef<Path>) -> IoResult<Self::UnixListener>;
async fn unix_listener_accept(
s: &mut Self::UnixListener,
) -> IoResult<(Self::UnixStream, Self::SocketAddr)>;
}
#[cfg(feature = "async-std")]
#[derive(Debug, Clone, Copy)]
pub struct AsyncStdUSocks;
#[cfg(feature = "async-std")]
use async_std::os::unix::net as async_std_us;
#[cfg(feature = "async-std")]
#[async_trait(?Send)]
impl UnixSocketInterface for AsyncStdUSocks {
type UnixStream = async_std_us::UnixStream;
type UnixListener = async_std_us::UnixListener;
type SocketAddr = async_std_us::SocketAddr;
async fn unix_stream_connect(socket_path: impl AsRef<Path>) -> IoResult<Self::UnixStream> {
Self::UnixStream::connect(socket_path.as_ref()).await
}
async fn unix_stream_shutdown(s: &mut Self::UnixStream) -> IoResult<()> {
s.shutdown(Shutdown::Both)
}
async fn unix_stream_write(s: &mut Self::UnixStream, buf: &[u8]) -> IoResult<usize> {
use async_std::io::WriteExt;
s.write(buf).await
}
async fn unix_stream_write_all(s: &mut Self::UnixStream, buf: &[u8]) -> IoResult<()> {
use async_std::io::WriteExt;
s.write_all(buf).await
}
async fn unix_stream_read(s: &mut Self::UnixStream, buf: &mut [u8]) -> IoResult<usize> {
use async_std::io::ReadExt;
s.read(buf).await
}
async fn unix_stream_read_exact(s: &mut Self::UnixStream, buf: &mut [u8]) -> IoResult<()> {
use async_std::io::ReadExt;
s.read_exact(buf).await
}
async fn unix_listener_bind(path: impl AsRef<Path>) -> IoResult<Self::UnixListener> {
Self::UnixListener::bind(path.as_ref()).await
}
async fn unix_listener_accept(
s: &mut Self::UnixListener,
) -> IoResult<(Self::UnixStream, Self::SocketAddr)> {
s.accept().await
}
}
#[cfg(feature = "tokio")]
pub struct TokioUSocks;
#[cfg(feature = "tokio")]
use tokio::net as tokio_us;
#[cfg(feature = "tokio")]
#[async_trait(?Send)]
impl UnixSocketInterface for TokioUSocks {
type UnixStream = tokio_us::UnixStream;
type UnixListener = tokio_us::UnixListener;
type SocketAddr = tokio_us::unix::SocketAddr;
async fn unix_stream_connect(socket_path: impl AsRef<Path>) -> IoResult<Self::UnixStream> {
Self::UnixStream::connect(socket_path.as_ref()).await
}
async fn unix_stream_shutdown(s: &mut Self::UnixStream) -> IoResult<()> {
use tokio::io::AsyncWriteExt;
s.shutdown().await
}
async fn unix_stream_write(s: &mut Self::UnixStream, buf: &[u8]) -> IoResult<usize> {
use tokio::io::AsyncWriteExt;
s.write(buf).await
}
async fn unix_stream_write_all(s: &mut Self::UnixStream, buf: &[u8]) -> IoResult<()> {
use tokio::io::AsyncWriteExt;
s.write_all(buf).await
}
async fn unix_stream_read(s: &mut Self::UnixStream, buf: &mut [u8]) -> IoResult<usize> {
use tokio::io::AsyncReadExt;
s.read(buf).await
}
async fn unix_stream_read_exact(s: &mut Self::UnixStream, buf: &mut [u8]) -> IoResult<()> {
use tokio::io::AsyncReadExt;
s.read_exact(buf).await.map(|_| ())
}
async fn unix_listener_bind(path: impl AsRef<Path>) -> IoResult<Self::UnixListener> {
Self::UnixListener::bind(path.as_ref())
}
async fn unix_listener_accept(
s: &mut Self::UnixListener,
) -> IoResult<(Self::UnixStream, Self::SocketAddr)> {
s.accept().await
}
}
pub struct StdThreadpoolUSocks;
use std::os::unix::net as std_us;
#[async_trait(?Send)]
impl UnixSocketInterface for StdThreadpoolUSocks {
type UnixStream = Unblock<std_us::UnixStream>;
type UnixListener = Unblock<std_us::UnixListener>;
type SocketAddr = std_us::SocketAddr;
async fn unix_stream_connect(socket_path: impl AsRef<Path>) -> IoResult<Self::UnixStream> {
let pathref_for_thread_sharing = socket_path.as_ref().to_owned();
unblock(move || std_us::UnixStream::connect(pathref_for_thread_sharing))
.await
.map(Unblock::new)
}
async fn unix_stream_shutdown(s: &mut Self::UnixStream) -> IoResult<()> {
s.with_mut(|inner_sock| inner_sock.shutdown(Shutdown::Both))
.await
}
async fn unix_stream_write(s: &mut Self::UnixStream, buf: &[u8]) -> IoResult<usize> {
use futures_lite::AsyncWriteExt;
s.write(buf).await
}
async fn unix_stream_write_all(s: &mut Self::UnixStream, buf: &[u8]) -> IoResult<()> {
use futures_lite::AsyncWriteExt;
s.write_all(buf).await
}
async fn unix_stream_read(s: &mut Self::UnixStream, buf: &mut [u8]) -> IoResult<usize> {
use futures_lite::AsyncReadExt;
s.read(buf).await
}
async fn unix_stream_read_exact(s: &mut Self::UnixStream, buf: &mut [u8]) -> IoResult<()> {
use futures_lite::AsyncReadExt;
s.read_exact(buf).await
}
async fn unix_listener_bind(path: impl AsRef<Path>) -> IoResult<Self::UnixListener> {
let pathref_for_thread_sharing = path.as_ref().to_owned();
unblock(move || std_us::UnixListener::bind(pathref_for_thread_sharing))
.await
.map(Unblock::new)
}
async fn unix_listener_accept(
s: &mut Self::UnixListener,
) -> IoResult<(Self::UnixStream, Self::SocketAddr)> {
s.with_mut(|listener| listener.accept())
.await
.map(|(connection, addr)| (Unblock::new(connection), addr))
}
}
#[cfg(feature = "async-std")]
pub type DefaultUnixSocks = AsyncStdUSocks;
#[cfg(all(feature = "tokio", not(feature = "async-std")))]
pub type DefaultUnixSocks = TokioUSocks;
#[cfg(not(any(feature = "tokio", feature = "async-std")))]
pub type DefaultUnixSocks = StdThreadpoolUSocks;