use std::ffi::{OsStr, OsString};
use std::fmt;
use std::io;
use std::process::Command as ProcessCommand;
use std::time::Duration;
use super::builder::RmuxBuilder;
use super::owned_session::OwnedSessionBuilder;
use crate::command::CommandRun;
use crate::diagnostics::FEATURE_PROTOCOL_CAPABILITIES;
use crate::transport::{DropGuard, TransportClient};
use crate::{
bootstrap::discovery, broadcast::BroadcastResult, ensure::EnsureSession, handles::Session,
Input, Pane, PaneId, PaneRef, Result, RmuxEndpoint, RmuxError, SessionName, Window, WindowRef,
};
use rmux_proto::{
HandshakeRequest, KillServerRequest, Request, Response, CAPABILITY_DAEMON_SHUTDOWN,
CAPABILITY_HANDSHAKE, RMUX_WIRE_VERSION,
};
#[path = "rmux/connect.rs"]
mod connect;
use connect::connect_transport;
pub(crate) use connect::{connect_or_start_transport, connect_transport_to_endpoint};
pub struct Rmux {
endpoint: RmuxEndpoint,
default_timeout: Option<Duration>,
transport: Option<TransportClient>,
drop_guard: DropGuard,
}
impl Rmux {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn builder() -> RmuxBuilder {
RmuxBuilder::new()
}
pub async fn connect(endpoint: RmuxEndpoint) -> Result<Self> {
RmuxBuilder::new().endpoint(endpoint).connect().await
}
pub async fn connect_or_start() -> Result<Self> {
RmuxBuilder::new().connect_or_start().await
}
pub async fn connect_or_start_at(endpoint: RmuxEndpoint) -> Result<Self> {
RmuxBuilder::new()
.endpoint(endpoint)
.connect_or_start()
.await
}
#[must_use]
pub fn endpoint(&self) -> &RmuxEndpoint {
&self.endpoint
}
#[must_use]
pub const fn configured_default_timeout(&self) -> Option<Duration> {
self.default_timeout
}
pub fn set_default_timeout(&mut self, timeout: Duration) {
self.default_timeout = Some(timeout);
}
pub fn resolved_endpoint(&self) -> Result<RmuxEndpoint> {
discovery::resolve_endpoint(&self.endpoint)
}
#[must_use]
pub fn resolved_timeout(&self, per_operation_timeout: Option<Duration>) -> Option<Duration> {
discovery::resolve_timeout(per_operation_timeout, self.default_timeout)
}
pub async fn ensure_session(&self, ensure: EnsureSession) -> Result<Session> {
ensure.ensure(self).await
}
pub async fn session(&self, session_name: SessionName) -> Result<Session> {
self.ensure_session(EnsureSession::named(session_name).reuse_only())
.await
}
pub fn owned_session(&self, session_name: SessionName) -> OwnedSessionBuilder<'_> {
OwnedSessionBuilder::new(self, session_name)
}
pub async fn window(&self, target: WindowRef) -> Result<Window> {
let endpoint = self.resolved_endpoint()?;
let timeout = self.resolved_timeout(None);
let transport = self
.connect_resolved_transport_for_operation(&endpoint, timeout)
.await?;
Ok(Window::new(
target,
endpoint,
self.configured_default_timeout(),
transport,
))
}
pub async fn pane(&self, target: PaneRef) -> Result<Pane> {
let endpoint = self.resolved_endpoint()?;
let timeout = self.resolved_timeout(None);
let transport = self
.connect_resolved_transport_for_operation(&endpoint, timeout)
.await?;
Ok(Pane::new(
target,
endpoint,
self.configured_default_timeout(),
transport,
))
}
pub async fn pane_by_id(&self, session_name: SessionName, pane_id: PaneId) -> Result<Pane> {
let endpoint = self.resolved_endpoint()?;
let timeout = self.resolved_timeout(None);
let transport = self
.connect_resolved_transport_for_operation(&endpoint, timeout)
.await?;
let target = super::pane::resolve_pane_ref_for_id(&transport, &session_name, pane_id)
.await?
.ok_or_else(|| pane_not_found(&session_name, pane_id))?;
Ok(Pane::new_by_id(
target,
pane_id,
endpoint,
self.configured_default_timeout(),
transport,
))
}
pub async fn get_pane_by_id(
&self,
session_name: impl AsRef<str>,
pane_id: PaneId,
) -> Result<Pane> {
let session_name = SessionName::new(session_name.as_ref()).map_err(RmuxError::protocol)?;
self.pane_by_id(session_name, pane_id).await
}
pub fn find_panes(&self) -> crate::PaneFinder<'_> {
crate::PaneFinder::new(self)
}
pub const fn find_sessions(&self) -> crate::SessionFinder<'_> {
crate::SessionFinder::new(self)
}
pub async fn get_pane_by_title(&self, title: impl AsRef<str>) -> Result<Pane> {
self.find_panes().title(title.as_ref()).one().await
}
#[must_use]
pub const fn tracing(&self) -> crate::RmuxTraceBuilder<'_> {
crate::RmuxTraceBuilder::new(self)
}
pub async fn broadcast(&self, panes: &[Pane], input: Input<'_>) -> Result<BroadcastResult> {
crate::broadcast::broadcast(panes, input).await
}
pub async fn has_session(&self, session_name: SessionName) -> Result<bool> {
let client = self
.connect_transport_for_operation(self.resolved_timeout(None))
.await?;
super::session::has_session(&client, session_name).await
}
pub async fn list_sessions(&self) -> Result<Vec<SessionName>> {
let client = self
.connect_transport_for_operation(self.resolved_timeout(None))
.await?;
super::session::list_session_names(&client).await
}
pub async fn capabilities(&self) -> Result<Vec<String>> {
let client = self
.connect_transport_for_operation(self.resolved_timeout(None))
.await?;
Ok(crate::capabilities::negotiated_capabilities(&client)
.await?
.as_ref()
.to_vec())
}
pub async fn has_capability(&self, capability: &str) -> Result<bool> {
Ok(self
.capabilities()
.await?
.iter()
.any(|advertised| advertised == capability))
}
pub async fn cmd<I, S>(&self, args: I) -> Result<CommandRun>
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
let endpoint = self.resolved_endpoint()?;
let args = args
.into_iter()
.map(|arg| arg.as_ref().to_owned())
.collect::<Vec<OsString>>();
tokio::task::spawn_blocking(move || run_binary_command(endpoint, args))
.await
.map_err(|error| {
RmuxError::transport(
"join rmux command task",
io::Error::other(error.to_string()),
)
})?
}
pub async fn shutdown(mut self) -> Result<()> {
self.drop_guard.disarm();
let client = match self.transport.take() {
Some(client) => client,
None => self.connect_transport().await?,
};
negotiate_shutdown_capability(&client).await?;
let response = client
.request(Request::KillServer(KillServerRequest))
.await?;
match response {
Response::KillServer(_) => {
if let Err(error) = client.shutdown().await {
if !is_clean_shutdown_close(&error) {
return Err(error);
}
}
Ok(())
}
Response::Error(error) => Err(error.into()),
response => Err(RmuxError::protocol(rmux_proto::RmuxError::Server(format!(
"rmux daemon sent `{}` response for shutdown request",
response.command_name()
)))),
}
}
pub(crate) fn from_config(endpoint: RmuxEndpoint, default_timeout: Option<Duration>) -> Self {
Self {
endpoint,
default_timeout,
transport: None,
drop_guard: DropGuard::noop(),
}
}
pub(crate) fn from_connected_transport(
endpoint: RmuxEndpoint,
default_timeout: Option<Duration>,
transport: TransportClient,
) -> Self {
Self {
endpoint,
default_timeout,
transport: Some(transport),
drop_guard: DropGuard::noop(),
}
}
#[cfg(test)]
pub(crate) fn from_transport_for_test(
client: TransportClient,
drop_request: Option<Request>,
) -> Self {
let drop_guard = drop_request
.map(|request| DropGuard::best_effort(client.clone(), request))
.unwrap_or_else(DropGuard::noop);
Self {
endpoint: RmuxEndpoint::Default,
default_timeout: None,
transport: Some(client),
drop_guard,
}
}
async fn connect_transport(&self) -> Result<TransportClient> {
let endpoint = self.resolved_endpoint()?;
connect_transport(&endpoint, self.resolved_timeout(None)).await
}
pub(crate) async fn connect_transport_for_operation(
&self,
timeout: Option<Duration>,
) -> Result<TransportClient> {
if let Some(client) = self.transport.as_ref() {
return Ok(client.clone());
}
let endpoint = self.resolved_endpoint()?;
connect_transport(&endpoint, timeout).await
}
pub(crate) async fn connect_resolved_transport_for_operation(
&self,
endpoint: &RmuxEndpoint,
timeout: Option<Duration>,
) -> Result<TransportClient> {
if let Some(client) = self.transport.as_ref() {
return Ok(client.clone());
}
connect_transport(endpoint, timeout).await
}
}
fn pane_not_found(session_name: &SessionName, pane_id: PaneId) -> RmuxError {
RmuxError::protocol(rmux_proto::RmuxError::pane_not_found(
session_name.clone(),
pane_id,
))
}
fn run_binary_command(endpoint: RmuxEndpoint, args: Vec<OsString>) -> Result<CommandRun> {
let mut command = ProcessCommand::new(connect::daemon_binary());
append_endpoint_args(&mut command, endpoint);
command.args(args);
let output = command
.output()
.map_err(|error| RmuxError::transport("run rmux command", error))?;
Ok(CommandRun {
stdout: output.stdout,
stderr: output.stderr,
exit: output.status.code(),
})
}
fn append_endpoint_args(command: &mut ProcessCommand, endpoint: RmuxEndpoint) {
command.args(endpoint_args(endpoint));
}
fn endpoint_args(endpoint: RmuxEndpoint) -> Vec<OsString> {
match endpoint {
RmuxEndpoint::Default => Vec::new(),
RmuxEndpoint::UnixSocket(path) => vec![OsString::from("-S"), path.into_os_string()],
RmuxEndpoint::WindowsPipe(pipe) => vec![OsString::from("-S"), pipe.into()],
}
}
fn is_clean_shutdown_close(error: &RmuxError) -> bool {
matches!(
error,
RmuxError::Transport { source, .. }
if matches!(
source.kind(),
io::ErrorKind::UnexpectedEof
| io::ErrorKind::ConnectionReset
| io::ErrorKind::BrokenPipe
| io::ErrorKind::NotConnected
)
)
}
impl Default for Rmux {
fn default() -> Self {
RmuxBuilder::default().build()
}
}
impl fmt::Debug for Rmux {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.debug_struct("Rmux").finish_non_exhaustive()
}
}
async fn negotiate_shutdown_capability(client: &TransportClient) -> Result<()> {
let response = client
.request(Request::Handshake(HandshakeRequest::requiring([
CAPABILITY_HANDSHAKE,
CAPABILITY_DAEMON_SHUTDOWN,
])))
.await
.map_err(normalize_handshake_error)?;
match response {
Response::Handshake(response) => {
ensure_selected_wire_version(response.wire_version)?;
ensure_capability(&response.capabilities, CAPABILITY_HANDSHAKE)?;
ensure_capability(&response.capabilities, CAPABILITY_DAEMON_SHUTDOWN)
}
Response::Error(error) => Err(error.into()),
response => Err(RmuxError::protocol(rmux_proto::RmuxError::Server(format!(
"rmux daemon sent `{}` response for capability handshake",
response.command_name()
)))),
}
}
fn normalize_handshake_error(error: RmuxError) -> RmuxError {
match error {
RmuxError::Protocol {
source: rmux_proto::RmuxError::Decode(message),
} => unsupported_handshake_error(&message),
RmuxError::Unsupported { feature, .. }
if feature == crate::diagnostics::command_feature_id("handshake") =>
{
unsupported_handshake_error("daemon did not recognize the handshake request")
}
error => error,
}
}
fn unsupported_handshake_error(detail: &str) -> RmuxError {
RmuxError::unsupported(
FEATURE_PROTOCOL_CAPABILITIES,
format!(
"upgrade the rmux daemon to one that advertises `{CAPABILITY_HANDSHAKE}` before using SDK daemon shutdown; {detail}"
),
)
}
fn ensure_selected_wire_version(wire_version: u32) -> Result<()> {
if wire_version == RMUX_WIRE_VERSION {
return Ok(());
}
Err(RmuxError::protocol(
rmux_proto::RmuxError::UnsupportedWireVersion {
got: wire_version,
minimum: RMUX_WIRE_VERSION,
maximum: RMUX_WIRE_VERSION,
},
))
}
fn ensure_capability(capabilities: &[String], feature: &str) -> Result<()> {
if capabilities
.iter()
.any(|capability| capability.as_str() == feature)
{
return Ok(());
}
Err(RmuxError::protocol(
rmux_proto::RmuxError::UnsupportedCapability {
feature: feature.to_owned(),
supported: capabilities.to_vec(),
},
))
}
#[cfg(test)]
#[path = "rmux/tests.rs"]
mod tests;