use std::{
net::SocketAddr,
path::{Path, PathBuf},
};
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(windows)]
use tokio::net::windows::named_pipe::{
ClientOptions, NamedPipeClient, NamedPipeServer, ServerOptions,
};
#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};
use crate::error::Error;
#[async_trait::async_trait]
pub trait Connect<S>
where
S: AsyncRead + AsyncWrite,
{
async fn connect(&self) -> Result<S, Error>;
}
pub struct LocalConnector {
session_id: String,
service_id: String,
}
impl LocalConnector {
pub fn new() -> Self {
Self {
session_id: default_session_id(),
service_id: String::new(),
}
}
pub fn with_service_id<S: Into<String>>(mut self, service_id: S) -> Self {
self.service_id = service_id.into();
self
}
}
#[cfg(unix)]
#[async_trait::async_trait]
impl Connect<UnixStream> for LocalConnector {
async fn connect(&self) -> Result<UnixStream, Error> {
let path = get_unix_socket_path(&self.session_id, &self.service_id);
let stream = UnixStream::connect(path).await?;
Ok(stream)
}
}
#[cfg(windows)]
#[async_trait::async_trait]
impl Connect<NamedPipeClient> for LocalConnector {
async fn connect(&self) -> Result<NamedPipeClient, Error> {
use winapi::shared::winerror;
let path = get_windows_named_pipe_path(&self.session_id, &self.service_id);
loop {
match ClientOptions::new().open(&path) {
Ok(client) => return Ok(client),
Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (),
Err(e) => return Err(e.into()),
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
}
}
impl Default for LocalConnector {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
pub trait Listen<S>
where
S: AsyncRead + AsyncWrite,
{
fn listen(&mut self) -> Result<Option<SocketAddr>, Error>;
async fn accept(&mut self) -> Result<(S, Option<SocketAddr>), Error>;
}
pub struct LocalListener {
session_id: String,
service_id: String,
#[cfg(unix)]
listener: Option<UnixListener>,
#[cfg(unix)]
path: Option<PathBuf>,
#[cfg(windows)]
server: Option<NamedPipeServer>,
}
impl LocalListener {
pub fn new() -> Self {
Self {
session_id: default_session_id(),
service_id: String::new(),
#[cfg(unix)]
listener: None,
#[cfg(unix)]
path: None,
#[cfg(windows)]
server: None,
}
}
pub fn with_service_id<S: Into<String>>(mut self, service_id: S) -> Self {
self.service_id = service_id.into();
self
}
}
#[cfg(unix)]
#[async_trait::async_trait]
impl Listen<UnixStream> for LocalListener {
fn listen(&mut self) -> Result<Option<SocketAddr>, Error> {
let path = get_unix_socket_path(&self.session_id, &self.service_id);
let _ = std::fs::remove_file(&path);
let listener = UnixListener::bind(&path)?;
self.listener = Some(listener);
self.path = Some(path);
Ok(None)
}
async fn accept(&mut self) -> Result<(UnixStream, Option<SocketAddr>), Error> {
loop {
match self.listener.as_ref().unwrap().accept().await {
Ok((stream, _addr)) => return Ok((stream, None)),
Err(error) if is_fatal_accept(&error) => return Err(error.into()),
_ => continue,
}
}
}
}
#[cfg(unix)]
impl Drop for LocalListener {
fn drop(&mut self) {
if let Some(path) = self.path.take() {
let _ = std::fs::remove_file(path);
}
}
}
#[cfg(windows)]
#[async_trait::async_trait]
impl Listen<NamedPipeServer> for LocalListener {
fn listen(&mut self) -> Result<Option<SocketAddr>, Error> {
let path = get_windows_named_pipe_path(&self.session_id, &self.service_id);
let server = ServerOptions::new()
.first_pipe_instance(true)
.create(path)?;
self.server = Some(server);
Ok(None)
}
async fn accept(&mut self) -> Result<(NamedPipeServer, Option<SocketAddr>), Error> {
let path = get_windows_named_pipe_path(&self.session_id, &self.service_id);
let server = self.server.take().unwrap();
server.connect().await?;
let new_server = ServerOptions::new().create(path)?;
self.server = Some(new_server);
Ok((server, None))
}
}
impl Default for LocalListener {
fn default() -> Self {
Self::new()
}
}
fn default_session_id() -> String {
match std::env::current_dir() {
Ok(path) => path_to_session_id(&path),
Err(_) => whoami::username(),
}
}
#[cfg(unix)]
fn path_to_session_id(path: &Path) -> String {
use std::os::unix::ffi::OsStrExt;
let hash = mx3::v3::hash(path.as_os_str().as_bytes(), 1);
format!("{:016x}", hash)
}
#[cfg(windows)]
fn path_to_session_id(path: &Path) -> String {
use std::os::windows::ffi::OsStrExt;
let mut bytes = Vec::with_capacity(path.as_os_str().len() * 2);
for unit in path.as_os_str().encode_wide() {
bytes.push((unit >> 8) as u8);
bytes.push(unit as u8);
}
let hash = mx3::v3::hash(&bytes, 1);
format!("{:016x}", hash)
}
#[allow(dead_code)]
fn get_runtime_dir() -> PathBuf {
let mut runtime_dir = dirs::runtime_dir();
if runtime_dir.is_none() {
runtime_dir = Some(std::env::temp_dir());
}
runtime_dir.unwrap()
}
fn get_filename(session_id: &str, service_id: &str) -> String {
let username = whoami::username();
let username =
percent_encoding::utf8_percent_encode(&username, percent_encoding::NON_ALPHANUMERIC);
let session_id =
percent_encoding::utf8_percent_encode(session_id, percent_encoding::NON_ALPHANUMERIC);
let service_id =
percent_encoding::utf8_percent_encode(service_id, percent_encoding::NON_ALPHANUMERIC);
format!("webaves-{}-{}-{}", username, session_id, service_id)
}
#[cfg(unix)]
fn get_unix_socket_path(session_id: &str, service_id: &str) -> PathBuf {
let mut path = get_runtime_dir();
path.push(get_filename(session_id, service_id));
path.set_extension("sock");
path
}
#[cfg(windows)]
fn get_windows_named_pipe_path(session_id: &str, service_id: &str) -> PathBuf {
let mut path = PathBuf::from(r"\\.\pipe\");
path.push(get_filename(session_id, service_id));
path
}
#[allow(dead_code)]
fn is_fatal_accept(error: &std::io::Error) -> bool {
!matches!(
error.kind(),
std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::BrokenPipe
)
}