use std::ffi::OsStr;
#[cfg(all(test, unix))]
use std::ffi::OsString;
#[cfg(all(test, unix))]
use std::fs;
use std::io::{self, Read, Write};
#[cfg(all(test, unix))]
use std::os::unix::ffi::{OsStrExt, OsStringExt};
use std::path::{Path, PathBuf};
use std::time::Duration;
use crate::ClientError;
use rmux_ipc::{connect_blocking, BlockingLocalStream, LocalEndpoint};
use rmux_proto::{
encode_frame, AttachSessionResponse, ControlMode, ControlModeResponse, FrameDecoder, Request,
Response,
};
const READ_BUFFER_SIZE: usize = 8192;
const SOCKET_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const SOCKET_WRITE_TIMEOUT: Duration = Duration::from_secs(5);
const SOCKET_RESPONSE_TIMEOUT: Duration = Duration::from_secs(15);
#[cfg(all(test, unix))]
const FALLBACK_SOCKET_ROOT: &str = "/tmp";
#[cfg(all(test, unix))]
const SOCKET_DIR_PREFIX: &str = "rmux";
pub fn default_socket_path() -> Result<PathBuf, ClientError> {
rmux_ipc::default_endpoint()
.map(LocalEndpoint::into_path)
.map_err(ClientError::Io)
}
pub fn socket_path_for_label(label: impl AsRef<OsStr>) -> Result<PathBuf, ClientError> {
rmux_ipc::endpoint_for_label(label)
.map(LocalEndpoint::into_path)
.map_err(ClientError::Io)
}
pub fn resolve_socket_path(
socket_name: Option<&OsStr>,
socket_path: Option<&Path>,
) -> Result<PathBuf, ClientError> {
rmux_ipc::resolve_endpoint(socket_name, socket_path)
.map(LocalEndpoint::into_path)
.map_err(ClientError::Io)
}
#[derive(Debug)]
pub enum ConnectResult {
Connected(Connection),
Absent,
}
pub fn connect_or_absent(socket_path: &Path) -> Result<ConnectResult, ClientError> {
connect_or_absent_with_timeout_using(
socket_path,
SOCKET_CONNECT_TIMEOUT,
connect_stream_with_timeout,
)
}
pub fn connect(socket_path: &Path) -> Result<Connection, ClientError> {
connect_with_timeout_using(
socket_path,
SOCKET_CONNECT_TIMEOUT,
connect_stream_with_timeout,
)
}
#[derive(Debug)]
pub struct Connection {
stream: BlockingLocalStream,
decoder: FrameDecoder,
}
#[derive(Debug)]
pub enum AttachTransition {
Upgraded(AttachSessionUpgrade),
Rejected(Response),
}
#[derive(Debug)]
pub enum ControlTransition {
Upgraded(ControlModeUpgrade),
Rejected(Response),
}
#[derive(Debug)]
pub struct AttachSessionUpgrade {
response: AttachSessionResponse,
stream: BlockingLocalStream,
initial_bytes: Vec<u8>,
}
#[derive(Debug)]
pub struct ControlModeUpgrade {
pub(crate) response: ControlModeResponse,
pub(crate) stream: BlockingLocalStream,
}
impl AttachSessionUpgrade {
#[must_use]
pub const fn response(&self) -> &AttachSessionResponse {
&self.response
}
#[must_use]
pub fn into_stream(self) -> BlockingLocalStream {
self.stream
}
#[must_use]
pub fn into_parts(self) -> (BlockingLocalStream, Vec<u8>) {
(self.stream, self.initial_bytes)
}
}
impl ControlModeUpgrade {
#[must_use]
pub const fn response(&self) -> &ControlModeResponse {
&self.response
}
#[must_use]
pub const fn mode(&self) -> ControlMode {
self.response.mode
}
#[must_use]
pub fn into_stream(self) -> BlockingLocalStream {
self.stream
}
}
impl Connection {
pub(crate) fn new(stream: BlockingLocalStream) -> Result<Self, ClientError> {
set_read_timeout(&stream, Some(SOCKET_RESPONSE_TIMEOUT)).map_err(ClientError::Io)?;
set_write_timeout(&stream, Some(SOCKET_WRITE_TIMEOUT)).map_err(ClientError::Io)?;
Ok(Self {
stream,
decoder: FrameDecoder::new(),
})
}
pub fn roundtrip(&mut self, request: &Request) -> Result<Response, ClientError> {
self.write_request(request)?;
self.read_response()
}
pub(crate) fn roundtrip_without_read_timeout(
&mut self,
request: &Request,
) -> Result<Response, ClientError> {
let previous_timeout = read_timeout(&self.stream).map_err(ClientError::Io)?;
set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
let result = self.roundtrip(request);
set_read_timeout(&self.stream, previous_timeout).map_err(ClientError::Io)?;
result
}
pub(crate) fn write_request(&mut self, request: &Request) -> Result<(), ClientError> {
let frame = encode_frame(request).map_err(ClientError::Protocol)?;
self.stream.write_all(&frame).map_err(ClientError::Io)
}
pub(crate) fn read_response(&mut self) -> Result<Response, ClientError> {
let mut buffer = [0u8; READ_BUFFER_SIZE];
loop {
match self.decoder.next_frame::<Response>() {
Ok(Some(response)) => return Ok(response),
Ok(None) => {}
Err(error) => return Err(ClientError::Protocol(error)),
}
let bytes_read = match self.stream.read(&mut buffer) {
Ok(bytes_read) => bytes_read,
Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
Err(error) => return Err(ClientError::Io(error)),
};
if bytes_read == 0 {
return Err(ClientError::UnexpectedEof);
}
self.decoder.push_bytes(&buffer[..bytes_read]);
}
}
pub(crate) fn stream_mut(&mut self) -> &mut BlockingLocalStream {
&mut self.stream
}
pub(crate) fn into_attach_upgrade(
self,
response: AttachSessionResponse,
) -> Result<AttachSessionUpgrade, ClientError> {
set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
let initial_bytes = self.decoder.remaining_bytes().to_vec();
Ok(AttachSessionUpgrade {
response,
stream: self.stream,
initial_bytes,
})
}
pub(crate) fn into_control_upgrade(
self,
response: ControlModeResponse,
) -> Result<ControlModeUpgrade, ClientError> {
set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
Ok(ControlModeUpgrade {
response,
stream: self.stream,
})
}
}
pub(crate) fn read_response_frame_exact(
stream: &mut BlockingLocalStream,
) -> Result<Response, ClientError> {
let mut decoder = FrameDecoder::new();
let mut byte = [0_u8; 1];
loop {
match decoder.next_frame::<Response>() {
Ok(Some(response)) => return Ok(response),
Ok(None) => {}
Err(error) => return Err(ClientError::Protocol(error)),
}
read_exact_or_eof(stream, &mut byte)?;
decoder.push_bytes(&byte);
}
}
fn read_exact_or_eof(
stream: &mut BlockingLocalStream,
buffer: &mut [u8],
) -> Result<(), ClientError> {
match stream.read_exact(buffer) {
Ok(()) => Ok(()),
Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => {
Err(ClientError::UnexpectedEof)
}
Err(error) => Err(ClientError::Io(error)),
}
}
#[cfg(all(test, unix))]
fn socket_path_from_parts(
rmux_tmpdir: Option<&OsStr>,
user_id: u32,
label: &OsStr,
) -> io::Result<PathBuf> {
let root = socket_root_from_parts(rmux_tmpdir)?;
let base = root.join(format!("{SOCKET_DIR_PREFIX}-{user_id}"));
let mut path = base.into_os_string().into_vec();
path.push(b'/');
path.extend_from_slice(label.as_bytes());
Ok(PathBuf::from(OsString::from_vec(path)))
}
#[cfg(all(test, unix))]
fn socket_root_from_parts(rmux_tmpdir: Option<&OsStr>) -> io::Result<PathBuf> {
let rmux_tmpdir = rmux_tmpdir
.filter(|value| !value.is_empty())
.map(PathBuf::from);
let candidates = rmux_tmpdir
.into_iter()
.chain(std::iter::once(PathBuf::from(FALLBACK_SOCKET_ROOT)));
for candidate in candidates {
if let Ok(resolved) = fs::canonicalize(&candidate) {
return Ok(resolved);
}
}
Err(io::Error::new(
io::ErrorKind::NotFound,
"no suitable rmux socket directory",
))
}
fn connect_or_absent_with_timeout_using<F>(
socket_path: &Path,
timeout: Duration,
connect_stream: F,
) -> Result<ConnectResult, ClientError>
where
F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
{
match connect_stream(socket_path, timeout) {
Ok(stream) => Ok(ConnectResult::Connected(Connection::new(stream)?)),
Err(error) if is_absent_error(&error) => Ok(ConnectResult::Absent),
Err(error) => Err(ClientError::Io(error)),
}
}
fn connect_with_timeout_using<F>(
socket_path: &Path,
timeout: Duration,
connect_stream: F,
) -> Result<Connection, ClientError>
where
F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
{
let stream = connect_stream(socket_path, timeout).map_err(ClientError::Io)?;
Connection::new(stream)
}
fn connect_stream_with_timeout(
socket_path: &Path,
timeout: Duration,
) -> io::Result<BlockingLocalStream> {
connect_blocking(
&LocalEndpoint::from_path(socket_path.to_path_buf()),
timeout,
)
}
fn read_timeout(stream: &BlockingLocalStream) -> io::Result<Option<Duration>> {
stream.read_timeout()
}
fn set_read_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
stream.set_read_timeout(timeout)
}
fn set_write_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
stream.set_write_timeout(timeout)
}
fn is_absent_error(error: &io::Error) -> bool {
matches!(
error.kind(),
io::ErrorKind::NotFound | io::ErrorKind::ConnectionRefused
)
}
#[cfg(all(test, unix))]
mod tests {
include!("connection/tests.rs");
}