use std::error::Error;
use std::fmt;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::Mutex;
use tokio_tungstenite::WebSocketStream;
use crate::connect_socket;
use crate::os::{detect_os_from_env, OsDetectionError, OsKind};
use crate::socket_async_read;
use crate::socket_async_write;
use crate::target::{detect_target_from_env, TargetDetectionError, TargetKind};
#[derive(Debug)]
pub enum SocketError {
Io(io::Error),
Ws(tokio_tungstenite::tungstenite::Error),
UnsupportedCombination(TargetKind, OsKind),
Detection(DetectionError),
RecipientAckTimeout,
}
#[derive(Debug)]
pub enum DetectionError {
Os(OsDetectionError),
Target(TargetDetectionError),
}
impl fmt::Display for DetectionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DetectionError::Os(e) => e.fmt(f),
DetectionError::Target(e) => e.fmt(f),
}
}
}
impl Error for DetectionError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
DetectionError::Os(e) => Some(e),
DetectionError::Target(e) => Some(e),
}
}
}
impl fmt::Display for SocketError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SocketError::Io(e) => write!(f, "socket I/O error: {}", e),
SocketError::Ws(e) => write!(f, "WebSocket error: {}", e),
SocketError::UnsupportedCombination(t, o) => {
write!(
f,
"unsupported combination target={:?} os={:?} on this platform",
t, o
)
}
SocketError::Detection(e) => write!(f, "endpoint detection failed: {}", e),
SocketError::RecipientAckTimeout => {
write!(f, "recipient did not acknowledge message within the time limit")
}
}
}
}
impl Error for SocketError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
SocketError::Io(e) => Some(e),
SocketError::Ws(e) => Some(e),
SocketError::UnsupportedCombination(_, _) => None,
SocketError::Detection(e) => e.source(),
SocketError::RecipientAckTimeout => None,
}
}
}
impl From<io::Error> for SocketError {
fn from(e: io::Error) -> Self {
SocketError::Io(e)
}
}
impl From<tokio_tungstenite::tungstenite::Error> for SocketError {
fn from(e: tokio_tungstenite::tungstenite::Error) -> Self {
SocketError::Ws(e)
}
}
pub(crate) const MAX_MESSAGE_LEN: u32 = 4 * 1024 * 1024;
pub(crate) fn connection_lost_error() -> io::Error {
io::Error::new(io::ErrorKind::ConnectionReset, "connection lost")
}
#[cfg(unix)]
pub(crate) struct PolledUnixStream {
pub(crate) inner: tokio::net::UnixStream,
pub(crate) last_check: Instant,
pub(crate) interval: Duration,
pub(crate) disconnected: bool,
pub(crate) peek: Option<u8>,
}
#[cfg(windows)]
pub(crate) struct PolledNamedPipe {
pub(crate) inner: tokio::net::windows::named_pipe::NamedPipeClient,
pub(crate) last_check: Instant,
pub(crate) interval: Duration,
pub(crate) disconnected: bool,
pub(crate) peek: Option<u8>,
}
#[cfg(unix)]
impl PolledUnixStream {
pub(crate) fn run_liveness_check(
&mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let mut one = [0u8; 1];
let mut read_buf = ReadBuf::new(&mut one);
match tokio::io::AsyncRead::poll_read(
Pin::new(&mut self.inner),
cx,
&mut read_buf,
) {
Poll::Ready(Ok(())) => {
if read_buf.filled().is_empty() {
self.disconnected = true;
Poll::Ready(Err(connection_lost_error()))
} else {
self.peek = Some(one[0]);
self.last_check = Instant::now();
Poll::Ready(Ok(()))
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(windows)]
impl PolledNamedPipe {
pub(crate) fn run_liveness_check(
&mut self,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let mut one = [0u8; 1];
let mut read_buf = ReadBuf::new(&mut one);
match tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, &mut read_buf) {
Poll::Ready(Ok(())) => {
if read_buf.filled().is_empty() {
self.disconnected = true;
Poll::Ready(Err(connection_lost_error()))
} else {
self.peek = Some(one[0]);
self.last_check = Instant::now();
Poll::Ready(Ok(()))
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(any(unix, windows))]
pub(crate) struct MessageFramed<T> {
pub(crate) inner: T,
pub(crate) length_buf: [u8; 4],
pub(crate) length_filled: usize,
pub(crate) read_buf: Option<Vec<u8>>,
pub(crate) payload_filled: usize,
pub(crate) read_cursor: usize,
pub(crate) write_buf: Vec<u8>,
}
#[cfg(any(unix, windows))]
impl<T> MessageFramed<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(inner: T) -> Self {
Self {
inner,
length_buf: [0; 4],
length_filled: 0,
read_buf: None,
payload_filled: 0,
read_cursor: 0,
write_buf: Vec::new(),
}
}
pub(crate) fn poll_read_fill_length(
inner: &mut Pin<&mut T>,
cx: &mut Context<'_>,
length_buf: &mut [u8; 4],
length_filled: &mut usize,
) -> Poll<io::Result<u32>> {
while *length_filled < 4 {
let mut read_buf = ReadBuf::new(&mut length_buf[*length_filled..]);
match inner.as_mut().poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {
let n = read_buf.filled().len();
if n == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed while reading frame length",
)));
}
*length_filled += n;
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
let len = u32::from_be_bytes(*length_buf);
if len > MAX_MESSAGE_LEN {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("frame length {} exceeds max {}", len, MAX_MESSAGE_LEN),
)));
}
Poll::Ready(Ok(len))
}
}
#[allow(private_interfaces)]
pub(crate) enum InnerSocket {
WebSocket(WebSocketAdapter),
#[cfg(unix)]
Unix(MessageFramed<PolledUnixStream>, String),
#[cfg(windows)]
NamedPipe(MessageFramed<PolledNamedPipe>, String),
Closed,
}
pub struct Socket(pub(crate) Arc<Mutex<InnerSocket>>);
#[doc(hidden)]
pub struct WebSocketAdapter {
pub(crate) stream: Pin<Box<WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>>>,
pub(crate) read_buf: Option<Vec<u8>>,
pub(crate) read_cursor: usize,
pub(crate) write_buf: Vec<u8>,
}
impl WebSocketAdapter {
pub(crate) fn new(
stream: WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
) -> Self {
Self {
stream: Box::pin(stream),
read_buf: None,
read_cursor: 0,
write_buf: Vec::new(),
}
}
}
fn log_detection_error_to_stderr(de: &DetectionError) {
let (problem, what_to_do) = match de {
DetectionError::Target(TargetDetectionError::NotFound) => (
"the environment variable `target` is not set".to_string(),
"For local connections set `target=local` and `os=windows` or `os=linux`. \
For remote connections use a WebSocket URL (e.g. ws://...) as the endpoint; \
then these variables are not needed.",
),
DetectionError::Target(TargetDetectionError::InvalidValue(v)) => (
format!("the environment variable `target` has an invalid value (\"{}\")", v),
"Set `target` to either `local` or `remote`. For local IPC also set `os=windows` or \
`os=linux`, then try again.",
),
DetectionError::Os(OsDetectionError::NotFound) => (
"the environment variable `os` is not set".to_string(),
"For local connections set `os=windows` or `os=linux` (depending on the peer's \
operating system) and `target=local`, then try again.",
),
DetectionError::Os(OsDetectionError::InvalidValue(v)) => (
format!("the environment variable `os` has an invalid value (\"{}\")", v),
"Set `os` to either `linux` or `windows`, and ensure `target=local` for local \
connections, then try again.",
),
};
let msg = format!(
"ipcez: The connection could not be started because {}.\nWhat to do: {}",
problem, what_to_do
);
crate::logger::log_to_stderr(&msg);
}
fn resolve_endpoint(endpoint: &str) -> Result<(String, TargetKind, OsKind), SocketError> {
if endpoint.starts_with("ws://") || endpoint.starts_with("wss://") {
return Ok((endpoint.to_string(), TargetKind::Remote, OsKind::Linux));
}
let target = detect_target_from_env().map_err(DetectionError::Target).map_err(SocketError::Detection)?;
let os = detect_os_from_env().map_err(DetectionError::Os).map_err(SocketError::Detection)?;
let addr = endpoint.to_string();
Ok((addr, target, os))
}
pub async fn socket_init(endpoint: &str) -> Result<Socket, SocketError> {
let (addr, target, os) = match resolve_endpoint(endpoint) {
Ok(t) => t,
Err(e) => {
if let SocketError::Detection(ref de) = e {
log_detection_error_to_stderr(de);
}
return Err(e);
}
};
let inner = match connect_socket::connect_socket(&addr, target, os).await {
Ok(inner) => inner,
Err(e) => {
crate::logger::log_error(&e);
return Err(e);
}
};
Ok(Socket(Arc::new(Mutex::new(inner))))
}
impl Socket {
pub async fn send_message(&self, msg: &[u8]) -> Result<(), SocketError> {
socket_async_write::send_message(&self.0, msg).await
}
pub async fn disconnect(&self) {
let mut inner = self.0.lock().await;
*inner = InnerSocket::Closed;
}
pub fn message_handler<F, Fut>(&self, callback: F)
where
F: FnMut(Result<Vec<u8>, SocketError>) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send,
{
socket_async_read::spawn_message_handler(self.0.clone(), callback);
}
}