#![cfg(unix)]
use std::error::Error;
use std::fmt;
use std::future::Future;
use std::io;
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::net::UnixStream;
use tokio::time::sleep;
use crate::bootstrap::deadline::StartupDeadline;
use rmux_os::identity::real_user_id;
#[path = "startup_unix/filesystem.rs"]
mod filesystem;
#[path = "startup_unix/lock.rs"]
mod lock;
use filesystem::{
ensure_owner_only_directory, prepare_socket_path_safe, reject_socket_symlink, startup_lock_path,
};
use lock::StartupLock;
pub const STARTUP_LOCK_MODE: u32 = 0o600;
pub const SOCKET_DIRECTORY_MODE: u32 = 0o700;
pub const UNSAFE_PERMISSION_MASK: u32 = 0o077;
pub const DEFAULT_STARTUP_DEADLINE: Duration = Duration::from_secs(5);
pub const STARTUP_POLL_INTERVAL: Duration = Duration::from_millis(25);
#[derive(Debug)]
pub enum StartupOutcome {
Started(UnixStream),
JoinedExisting(UnixStream),
}
impl StartupOutcome {
#[must_use]
pub fn stream(&self) -> &UnixStream {
match self {
Self::Started(stream) | Self::JoinedExisting(stream) => stream,
}
}
#[must_use]
pub fn into_stream(self) -> UnixStream {
match self {
Self::Started(stream) | Self::JoinedExisting(stream) => stream,
}
}
#[must_use]
pub const fn is_owner(&self) -> bool {
matches!(self, Self::Started(_))
}
}
#[derive(Debug)]
pub enum StartupError {
InvalidPath {
reason: String,
path: PathBuf,
},
SymlinkRejected {
path: PathBuf,
},
Filesystem {
operation: &'static str,
path: PathBuf,
source: io::Error,
},
Lock {
path: PathBuf,
source: io::Error,
},
UnsafeOwner {
path: PathBuf,
expected_uid: u32,
actual_uid: u32,
},
UnsafePermissions {
path: PathBuf,
mode: u32,
},
Launcher {
source: io::Error,
},
StartupTimeout {
socket_path: PathBuf,
waited: Duration,
},
PeerCredentialMismatch {
expected_uid: u32,
actual_uid: u32,
socket_path: PathBuf,
},
}
impl StartupError {
#[must_use]
pub const fn is_recoverable(&self) -> bool {
matches!(
self,
Self::Lock { .. }
| Self::Launcher { .. }
| Self::StartupTimeout { .. }
| Self::PeerCredentialMismatch { .. }
)
}
}
impl fmt::Display for StartupError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidPath { reason, path } => write!(
formatter,
"rmux startup rejected '{}': {reason}",
path.display()
),
Self::SymlinkRejected { path } => write!(
formatter,
"rmux startup refused to follow symlink at '{}'",
path.display()
),
Self::Filesystem {
operation,
path,
source,
} => write!(
formatter,
"rmux startup failed to {operation} '{}': {source}",
path.display()
),
Self::Lock { path, source } => write!(
formatter,
"rmux startup lock '{}' failed: {source}",
path.display()
),
Self::UnsafeOwner {
path,
expected_uid,
actual_uid,
} => write!(
formatter,
"rmux startup refused '{}': owned by uid {actual_uid} but expected uid {expected_uid}",
path.display()
),
Self::UnsafePermissions { path, mode } => write!(
formatter,
"rmux startup refused '{}': permissions 0o{mode:04o} grant access beyond the owner",
path.display()
),
Self::Launcher { source } => {
write!(formatter, "rmux startup launcher failed: {source}")
}
Self::StartupTimeout {
socket_path,
waited,
} => write!(
formatter,
"rmux startup timed out after {}ms waiting for '{}' to answer",
waited.as_millis(),
socket_path.display()
),
Self::PeerCredentialMismatch {
expected_uid,
actual_uid,
socket_path,
} => write!(
formatter,
"rmux daemon at '{}' reported peer uid {actual_uid} but expected {expected_uid}",
socket_path.display()
),
}
}
}
impl Error for StartupError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
Self::Filesystem { source, .. }
| Self::Lock { source, .. }
| Self::Launcher { source } => Some(source),
_ => None,
}
}
}
pub async fn connect_or_start<L, F>(
socket_path: &Path,
launcher: L,
) -> Result<StartupOutcome, StartupError>
where
L: FnOnce() -> F,
F: Future<Output = io::Result<()>>,
{
connect_or_start_with(
socket_path,
launcher,
DEFAULT_STARTUP_DEADLINE,
STARTUP_POLL_INTERVAL,
)
.await
}
pub async fn connect_or_start_with<L, F>(
socket_path: &Path,
launcher: L,
deadline: Duration,
poll_interval: Duration,
) -> Result<StartupOutcome, StartupError>
where
L: FnOnce() -> F,
F: Future<Output = io::Result<()>>,
{
connect_or_start_with_timeout(socket_path, launcher, Some(deadline), poll_interval).await
}
pub async fn connect_or_start_with_timeout<L, F>(
socket_path: &Path,
launcher: L,
deadline: Option<Duration>,
poll_interval: Duration,
) -> Result<StartupOutcome, StartupError>
where
L: FnOnce() -> F,
F: Future<Output = io::Result<()>>,
{
let deadline = StartupDeadline::from_timeout(deadline);
let owner_uid = real_user_id();
let parent = socket_path
.parent()
.ok_or_else(|| StartupError::InvalidPath {
reason: "socket path has no parent directory".to_owned(),
path: socket_path.to_path_buf(),
})?;
if parent.as_os_str().is_empty() {
return Err(StartupError::InvalidPath {
reason: "socket path has an empty parent directory".to_owned(),
path: socket_path.to_path_buf(),
});
}
if socket_path.file_name().is_none() {
return Err(StartupError::InvalidPath {
reason: "socket path has no file name component".to_owned(),
path: socket_path.to_path_buf(),
});
}
if let Some(stream) = try_connect_validated(socket_path, owner_uid).await? {
return Ok(StartupOutcome::JoinedExisting(stream));
}
ensure_owner_only_directory(parent, owner_uid)?;
let lock_path = startup_lock_path(socket_path);
let lock_guard = StartupLock::acquire(&lock_path, owner_uid, deadline, poll_interval).await?;
if let Some(stream) = try_connect_validated(socket_path, owner_uid).await? {
drop(lock_guard);
return Ok(StartupOutcome::JoinedExisting(stream));
}
prepare_socket_path_safe(socket_path, owner_uid)?;
launcher()
.await
.map_err(|error| StartupError::Launcher { source: error })?;
let stream = wait_for_daemon(socket_path, owner_uid, deadline, poll_interval).await?;
drop(lock_guard);
Ok(StartupOutcome::Started(stream))
}
async fn try_connect_validated(
socket_path: &Path,
owner_uid: u32,
) -> Result<Option<UnixStream>, StartupError> {
reject_socket_symlink(socket_path)?;
match UnixStream::connect(socket_path).await {
Ok(stream) => {
reject_socket_symlink(socket_path)?;
match validate_peer_credentials(&stream, owner_uid, socket_path) {
Ok(()) => Ok(Some(stream)),
Err(error) => Err(error),
}
}
Err(error)
if matches!(
error.kind(),
io::ErrorKind::NotFound | io::ErrorKind::ConnectionRefused
) =>
{
Ok(None)
}
Err(error) => Err(StartupError::Filesystem {
operation: "connect to daemon socket",
path: socket_path.to_path_buf(),
source: error,
}),
}
}
fn validate_peer_credentials(
stream: &UnixStream,
expected_uid: u32,
socket_path: &Path,
) -> Result<(), StartupError> {
let credentials = stream
.peer_cred()
.map_err(|error| StartupError::Filesystem {
operation: "read daemon peer credentials",
path: socket_path.to_path_buf(),
source: error,
})?;
let actual_uid = credentials.uid();
if actual_uid == expected_uid {
Ok(())
} else {
Err(StartupError::PeerCredentialMismatch {
expected_uid,
actual_uid,
socket_path: socket_path.to_path_buf(),
})
}
}
async fn wait_for_daemon(
socket_path: &Path,
owner_uid: u32,
deadline: StartupDeadline,
poll_interval: Duration,
) -> Result<UnixStream, StartupError> {
const MIN_POLL_INTERVAL: Duration = Duration::from_millis(1);
let effective_poll = poll_interval.max(MIN_POLL_INTERVAL);
loop {
match try_connect_validated(socket_path, owner_uid).await {
Ok(Some(stream)) => return Ok(stream),
Ok(None) => {}
Err(error) => return Err(error),
}
if deadline.is_elapsed() {
return Err(StartupError::StartupTimeout {
socket_path: socket_path.to_path_buf(),
waited: deadline.elapsed(),
});
}
sleep(deadline.sleep_for(effective_poll)).await;
}
}
#[cfg(test)]
#[path = "startup_unix/tests.rs"]
mod tests;