#[cfg(all(test, unix))]
use std::fs;
use std::io;
#[cfg(windows)]
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex as StdMutex};
#[cfg(windows)]
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use rmux_core::events::SubscriptionLimits;
#[cfg(windows)]
use rmux_ipc::connect_blocking;
use rmux_ipc::LocalEndpoint;
#[cfg(windows)]
use rmux_ipc::LocalListener;
#[cfg(windows)]
use rmux_proto::{
encode_frame, FrameDecoder, HasSessionRequest, Request, Response, RmuxError, SessionName,
};
use crate::listener;
use crate::listener_options::ServeOptions;
#[cfg(windows)]
use crate::server_access::current_owner_uid;
#[cfg(unix)]
use crate::unix_socket::bind_unix_listener_at;
#[cfg(unix)]
use crate::unix_socket::real_user_id;
#[cfg(all(test, unix))]
use crate::unix_socket::{
ensure_parent_directory, indicates_stale_socket, remove_stale_socket_if_needed,
};
#[cfg(all(test, unix))]
const FALLBACK_SOCKET_ROOT: &str = "/tmp";
pub fn default_socket_path() -> io::Result<PathBuf> {
rmux_ipc::default_endpoint().map(LocalEndpoint::into_path)
}
#[cfg(all(test, unix))]
fn socket_root_from_env(tmpdir: Option<&std::ffi::OsStr>) -> io::Result<PathBuf> {
let tmpdir = tmpdir
.filter(|value| !value.is_empty())
.map(PathBuf::from)
.into_iter();
let candidates = tmpdir.chain(std::iter::once(PathBuf::from(FALLBACK_SOCKET_ROOT)));
for candidate in candidates {
if let Ok(resolved) = fs::canonicalize(&candidate) {
return Ok(resolved);
}
}
Err(io::Error::new(
io::ErrorKind::NotFound,
"no suitable rmux socket directory",
))
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DaemonConfig {
socket_path: PathBuf,
config_load: ConfigLoadOptions,
subscription_limits: SubscriptionLimits,
}
impl DaemonConfig {
#[must_use]
pub fn new(socket_path: PathBuf) -> Self {
Self {
socket_path,
config_load: ConfigLoadOptions::disabled(),
subscription_limits: SubscriptionLimits::default(),
}
}
pub fn with_default_socket_path() -> io::Result<Self> {
Ok(Self::new(default_socket_path()?))
}
#[must_use]
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
#[must_use]
pub const fn config_load(&self) -> &ConfigLoadOptions {
&self.config_load
}
#[must_use]
pub fn subscription_limits(&self) -> SubscriptionLimits {
self.subscription_limits
}
#[must_use]
pub fn with_default_config_load(mut self, quiet: bool, cwd: Option<PathBuf>) -> Self {
self.config_load = ConfigLoadOptions {
selection: ConfigFileSelection::Default,
quiet,
cwd,
};
self
}
#[must_use]
pub fn with_subscription_limits(mut self, subscription_limits: SubscriptionLimits) -> Self {
self.subscription_limits = subscription_limits;
self
}
#[must_use]
pub fn with_config_files(
mut self,
files: Vec<PathBuf>,
quiet: bool,
cwd: Option<PathBuf>,
) -> Self {
self.config_load = ConfigLoadOptions {
selection: ConfigFileSelection::Files(files),
quiet,
cwd,
};
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConfigLoadOptions {
selection: ConfigFileSelection,
quiet: bool,
cwd: Option<PathBuf>,
}
impl ConfigLoadOptions {
#[must_use]
pub const fn disabled() -> Self {
Self {
selection: ConfigFileSelection::Disabled,
quiet: true,
cwd: None,
}
}
#[must_use]
pub const fn selection(&self) -> &ConfigFileSelection {
&self.selection
}
#[must_use]
pub const fn quiet(&self) -> bool {
self.quiet
}
#[must_use]
pub fn cwd(&self) -> Option<&Path> {
self.cwd.as_deref()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConfigFileSelection {
Disabled,
Default,
Files(Vec<PathBuf>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ServerDaemon {
config: DaemonConfig,
}
#[derive(Debug, Clone)]
pub(crate) struct ShutdownHandle {
sender: Arc<StdMutex<Option<oneshot::Sender<()>>>>,
}
impl ShutdownHandle {
pub(crate) fn new() -> (Self, oneshot::Receiver<()>) {
let (sender, receiver) = oneshot::channel();
(
Self {
sender: Arc::new(StdMutex::new(Some(sender))),
},
receiver,
)
}
pub(crate) fn request_shutdown(&self) {
if let Some(sender) = self.sender.lock().expect("shutdown sender").take() {
let _ = sender.send(());
}
}
}
impl ServerDaemon {
#[must_use]
pub fn new(config: DaemonConfig) -> Self {
Self { config }
}
pub async fn bind(self) -> io::Result<ServerHandle> {
#[cfg(unix)]
{
let bound_listener = bind_unix_listener_at(self.config.socket_path())?;
let (shutdown_handle, shutdown_receiver) = ShutdownHandle::new();
let (server_signal_tx, server_signal_rx) = tokio::sync::mpsc::unbounded_channel();
let signal_watcher =
crate::signals::SignalWatcher::install(shutdown_handle.clone(), server_signal_tx)?;
let socket_path = self.config.socket_path().to_path_buf();
let owner_uid = real_user_id()?;
let serve_options = ServeOptions::new(
self.config.config_load().clone(),
self.config.subscription_limits(),
owner_uid,
)
.with_socket_identity(bound_listener.identity)
.with_server_signals(server_signal_rx);
let task = tokio::spawn(listener::serve(
bound_listener.listener,
socket_path.clone(),
shutdown_handle.clone(),
shutdown_receiver,
serve_options,
));
Ok(ServerHandle {
socket_path,
shutdown_handle,
task: Some(task),
signal_watcher: Some(signal_watcher),
})
}
#[cfg(windows)]
{
let endpoint = LocalEndpoint::from_path(self.config.socket_path().to_path_buf());
let listener = bind_windows_listener(&endpoint)?;
let (shutdown_handle, shutdown_receiver) = ShutdownHandle::new();
let socket_path = self.config.socket_path().to_path_buf();
let owner_uid = current_owner_uid();
let serve_options = ServeOptions::new(
self.config.config_load().clone(),
self.config.subscription_limits(),
owner_uid,
);
let task = tokio::spawn(listener::serve(
listener,
socket_path.clone(),
shutdown_handle.clone(),
shutdown_receiver,
serve_options,
));
Ok(ServerHandle {
socket_path,
shutdown_handle,
task: Some(task),
})
}
}
}
#[cfg(windows)]
fn bind_windows_listener(endpoint: &LocalEndpoint) -> io::Result<LocalListener> {
match LocalListener::bind(endpoint) {
Ok(listener) => Ok(listener),
Err(bind_error) => Err(windows_bind_error(endpoint, bind_error)),
}
}
#[cfg(windows)]
fn windows_bind_error(endpoint: &LocalEndpoint, bind_error: io::Error) -> io::Error {
if windows_pipe_responds(endpoint) {
return io::Error::new(
io::ErrorKind::AddrInUse,
format!(
"Windows named pipe '{}' is already held by a responsive rmux-compatible server",
endpoint.as_path().display()
),
);
}
io::Error::new(
bind_error.kind(),
format!(
"failed to bind Windows named pipe '{}': {bind_error}. Another process may still be holding this endpoint",
endpoint.as_path().display()
),
)
}
#[cfg(windows)]
fn windows_pipe_responds(endpoint: &LocalEndpoint) -> bool {
let endpoint = endpoint.clone();
std::thread::spawn(move || windows_protocol_probe(&endpoint).unwrap_or(false))
.join()
.unwrap_or(false)
}
#[cfg(windows)]
fn windows_protocol_probe(endpoint: &LocalEndpoint) -> io::Result<bool> {
let mut stream = connect_blocking(endpoint, Duration::from_millis(100))?;
stream.set_write_timeout(Some(Duration::from_millis(100)))?;
stream.set_read_timeout(Some(Duration::from_millis(100)))?;
let request = Request::HasSession(HasSessionRequest {
target: SessionName::new("__rmux_probe__").map_err(io::Error::other)?,
});
let frame = encode_frame(&request).map_err(io::Error::other)?;
stream.write_all(&frame)?;
stream.flush()?;
let mut decoder = FrameDecoder::new();
let mut buffer = [0_u8; 512];
loop {
let bytes_read = match stream.read(&mut buffer) {
Ok(0) => return Ok(false),
Ok(bytes_read) => bytes_read,
Err(error) if error.kind() == io::ErrorKind::TimedOut => return Ok(false),
Err(error) => return Err(error),
};
decoder.push_bytes(&buffer[..bytes_read]);
match decoder.next_frame::<Response>() {
Ok(Some(Response::HasSession(_))) => return Ok(true),
Ok(Some(_response)) => return Ok(false),
Ok(None) => continue,
Err(RmuxError::IncompleteFrame { .. }) => continue,
Err(_error) => return Ok(false),
}
}
}
#[derive(Debug)]
pub struct ServerHandle {
socket_path: PathBuf,
shutdown_handle: ShutdownHandle,
task: Option<JoinHandle<io::Result<()>>>,
#[cfg(unix)]
signal_watcher: Option<crate::signals::SignalWatcher>,
}
impl ServerHandle {
#[must_use]
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
pub async fn wait(mut self) -> io::Result<()> {
if let Some(task) = self.task.take() {
return task.await.map_err(io::Error::other)?;
}
Ok(())
}
pub async fn shutdown(mut self) -> io::Result<()> {
self.request_shutdown();
if let Some(task) = self.task.take() {
return task.await.map_err(io::Error::other)?;
}
Ok(())
}
fn request_shutdown(&mut self) {
#[cfg(unix)]
{
let _ = self.signal_watcher.take();
}
self.shutdown_handle.request_shutdown();
}
}
impl Drop for ServerHandle {
fn drop(&mut self) {
self.request_shutdown();
}
}
#[cfg(all(test, unix))]
#[path = "daemon_tests/unix.rs"]
mod tests;
#[cfg(all(test, windows))]
#[path = "daemon_tests/windows.rs"]
mod tests;