use std::fmt;
use std::io;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::handles::{session, Rmux, Session};
use crate::transport::TransportClient;
use crate::{CollectError, ProcessSpec, Result, RmuxError, SessionName, TerminalSizeSpec};
use rmux_proto::{NewSessionExtRequest, Request, Response};
const REDACTED_ENVIRONMENT: &str = "[redacted process environment]";
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum EnsureSessionPolicy {
CreateOnly,
#[default]
CreateOrReuse,
ReuseOnly,
}
#[derive(Clone, PartialEq, Eq)]
pub struct EnsureSession {
session_name: Option<SessionName>,
working_directory: Option<String>,
detached: bool,
size: Option<TerminalSizeSpec>,
process: ProcessSpec,
group_target: Option<SessionName>,
policy: EnsureSessionPolicy,
window_name: Option<String>,
creation_tags: Option<Vec<String>>,
timeout: Option<Duration>,
}
impl EnsureSession {
#[must_use]
pub fn named(session_name: SessionName) -> Self {
Self {
session_name: Some(session_name),
..Self::default()
}
}
pub fn try_named(session_name: impl AsRef<str>) -> Result<Self> {
SessionName::new(session_name.as_ref())
.map(Self::named)
.map_err(RmuxError::protocol)
}
#[must_use]
pub fn auto_named() -> Self {
Self::default()
}
#[must_use]
pub fn session_name(mut self, session_name: SessionName) -> Self {
self.session_name = Some(session_name);
self
}
#[must_use]
pub fn automatic_name(mut self) -> Self {
self.session_name = None;
self
}
#[must_use]
pub const fn configured_session_name(&self) -> Option<&SessionName> {
self.session_name.as_ref()
}
#[must_use]
pub fn policy(mut self, policy: EnsureSessionPolicy) -> Self {
self.policy = policy;
self
}
#[must_use]
pub fn create_only(self) -> Self {
self.policy(EnsureSessionPolicy::CreateOnly)
}
#[must_use]
pub fn create_or_reuse(self) -> Self {
self.policy(EnsureSessionPolicy::CreateOrReuse)
}
#[must_use]
pub fn reuse_only(self) -> Self {
self.policy(EnsureSessionPolicy::ReuseOnly)
}
#[must_use]
pub const fn configured_policy(&self) -> EnsureSessionPolicy {
self.policy
}
#[must_use]
pub fn working_directory(mut self, working_directory: impl Into<String>) -> Self {
self.working_directory = Some(working_directory.into());
self
}
#[must_use]
pub fn detached(mut self, detached: bool) -> Self {
self.detached = detached;
self
}
#[must_use]
pub fn size(mut self, size: TerminalSizeSpec) -> Self {
self.size = Some(size);
self
}
#[must_use]
pub fn command<I, S>(mut self, command: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.process.command = Some(command.into_iter().map(Into::into).collect());
self
}
#[must_use]
pub fn environment<I, S>(mut self, environment: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.process.environment = Some(environment.into_iter().map(Into::into).collect());
self
}
#[must_use]
pub fn process(mut self, process: ProcessSpec) -> Self {
self.process = process;
self
}
#[must_use]
pub fn group_target(mut self, group_target: SessionName) -> Self {
self.group_target = Some(group_target);
self
}
#[must_use]
pub fn window_name(mut self, window_name: impl Into<String>) -> Self {
self.window_name = Some(window_name.into());
self
}
#[must_use]
pub fn tags<I, S>(mut self, tags: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.creation_tags = Some(tags.into_iter().map(Into::into).collect());
self
}
#[must_use]
pub fn empty_tags(self) -> Self {
self.tags(Vec::<String>::new())
}
#[must_use]
pub fn tag(mut self, tag: impl Into<String>) -> Self {
self.creation_tags
.get_or_insert_with(Vec::new)
.push(tag.into());
self
}
#[must_use]
pub fn configured_tags(&self) -> Option<&[String]> {
self.creation_tags.as_deref()
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub const fn configured_timeout(&self) -> Option<Duration> {
self.timeout
}
#[must_use]
pub fn resolved_timeout(&self, rmux: &Rmux) -> Option<Duration> {
rmux.resolved_timeout(self.timeout)
}
pub async fn ensure(self, rmux: &Rmux) -> Result<Session> {
ensure_session(rmux, self).await
}
pub(crate) fn to_new_session_request(&self, attach_if_exists: bool) -> NewSessionExtRequest {
NewSessionExtRequest {
session_name: self.session_name.clone(),
working_directory: self.working_directory.clone(),
detached: self.detached,
size: self.size.map(Into::into),
environment: self.process.environment.clone(),
group_target: self.group_target.clone(),
attach_if_exists,
detach_other_clients: false,
kill_other_clients: false,
flags: None,
window_name: self.window_name.clone(),
print_session_info: false,
print_format: None,
command: self.process.command.clone(),
}
}
}
impl Default for EnsureSession {
fn default() -> Self {
Self {
session_name: None,
working_directory: None,
detached: true,
size: None,
process: ProcessSpec::default(),
group_target: None,
policy: EnsureSessionPolicy::default(),
window_name: None,
creation_tags: None,
timeout: None,
}
}
}
impl fmt::Debug for EnsureSession {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("EnsureSession")
.field("session_name", &self.session_name)
.field("working_directory", &self.working_directory)
.field("detached", &self.detached)
.field("size", &self.size)
.field("process", &self.process)
.field("group_target", &self.group_target)
.field("policy", &self.policy)
.field("window_name", &self.window_name)
.field("creation_tags", &self.creation_tags)
.field("timeout", &self.timeout)
.finish_non_exhaustive()
}
}
async fn ensure_session(rmux: &Rmux, builder: EnsureSession) -> Result<Session> {
let endpoint = rmux.resolved_endpoint()?;
let timeout = builder.resolved_timeout(rmux);
let transport = rmux
.connect_resolved_transport_for_operation(&endpoint, timeout)
.await?;
let (session_name, created) = match builder.policy {
EnsureSessionPolicy::CreateOnly => {
(create_session(&transport, &builder, false).await?, true)
}
EnsureSessionPolicy::CreateOrReuse => create_or_reuse_session(&transport, &builder).await?,
EnsureSessionPolicy::ReuseOnly => reuse_session(&transport, &builder).await?,
};
Ok(Session::new(
session_name,
endpoint,
rmux.configured_default_timeout(),
transport,
created,
builder.creation_tags,
))
}
async fn create_or_reuse_session(
transport: &TransportClient,
builder: &EnsureSession,
) -> Result<(SessionName, bool)> {
let Some(configured_name) = builder.session_name.as_ref() else {
let session_name = create_session(transport, builder, true).await?;
return Ok((session_name, true));
};
let existed_before = session::has_session(transport, configured_name.clone())
.await
.map_err(|error| redact_builder_environment_error(error, builder))?;
if existed_before {
let session_name = create_session(transport, builder, true).await?;
return Ok((session_name, false));
}
match create_session(transport, builder, false).await {
Ok(session_name) => Ok((session_name, true)),
Err(error) if builder.group_target.is_none() && is_duplicate_session_error(&error) => {
let session_name = create_session(transport, builder, true).await?;
Ok((session_name, false))
}
Err(error) => Err(error),
}
}
async fn create_session(
transport: &TransportClient,
builder: &EnsureSession,
attach_if_exists: bool,
) -> Result<SessionName> {
let request = builder.to_new_session_request(attach_if_exists);
match transport
.request(Request::NewSessionExt(request))
.await
.map_err(|error| redact_builder_environment_error(error, builder))?
{
Response::NewSession(response) => Ok(response.session_name),
response => Err(session::unexpected_response("new-session", response)),
}
}
async fn reuse_session(
transport: &TransportClient,
builder: &EnsureSession,
) -> Result<(SessionName, bool)> {
let Some(session_name) = builder.session_name.clone() else {
return Err(RmuxError::protocol(rmux_proto::RmuxError::Server(
"reuse-only ensure-session requires an explicit session name".to_owned(),
)));
};
if session::has_session(transport, session_name.clone())
.await
.map_err(|error| redact_builder_environment_error(error, builder))?
{
Ok((session_name, false))
} else {
Err(RmuxError::protocol(rmux_proto::RmuxError::SessionNotFound(
session_name.to_string(),
)))
}
}
fn redact_builder_environment_error(error: RmuxError, builder: &EnsureSession) -> RmuxError {
redact_environment_error(error, builder.process.environment.as_deref())
}
fn is_duplicate_session_error(error: &RmuxError) -> bool {
matches!(
error,
RmuxError::Protocol {
source: rmux_proto::RmuxError::DuplicateSession(_),
}
)
}
fn redact_environment_error(error: RmuxError, environment: Option<&[String]>) -> RmuxError {
let Some(environment) = environment.filter(|environment| !environment.is_empty()) else {
return error;
};
match error {
RmuxError::Unsupported { feature, hint } => RmuxError::Unsupported {
feature,
hint: redact_environment_message(&hint, environment),
},
RmuxError::Protocol { source } => {
RmuxError::protocol(redact_environment_protocol_error(source, environment))
}
RmuxError::Transport { operation, source } => {
let kind = source.kind();
RmuxError::transport(
operation,
io::Error::new(
kind,
redact_environment_message(&source.to_string(), environment),
),
)
}
RmuxError::Collect { source } => RmuxError::collect(
source
.into_errors()
.into_iter()
.map(|error| redact_environment_error(error, Some(environment)))
.collect::<CollectError>(),
),
}
}
fn redact_environment_protocol_error(
error: rmux_proto::RmuxError,
environment: &[String],
) -> rmux_proto::RmuxError {
match error {
rmux_proto::RmuxError::InvalidTarget { value, reason } => {
rmux_proto::RmuxError::InvalidTarget {
value: redact_environment_message(&value, environment),
reason: redact_environment_message(&reason, environment),
}
}
rmux_proto::RmuxError::UnknownCommand(command) => {
rmux_proto::RmuxError::UnknownCommand(redact_environment_message(&command, environment))
}
rmux_proto::RmuxError::DuplicateSession(session_name) => {
rmux_proto::RmuxError::DuplicateSession(session_name)
}
rmux_proto::RmuxError::SessionNotFound(session_name) => {
rmux_proto::RmuxError::SessionNotFound(session_name)
}
rmux_proto::RmuxError::Server(message) => {
rmux_proto::RmuxError::Server(redact_environment_message(&message, environment))
}
rmux_proto::RmuxError::Message(message) => {
rmux_proto::RmuxError::Message(redact_environment_message(&message, environment))
}
rmux_proto::RmuxError::InvalidSetOption(message) => {
rmux_proto::RmuxError::InvalidSetOption(redact_environment_message(
&message,
environment,
))
}
rmux_proto::RmuxError::UnsupportedCapability { feature, supported } => {
rmux_proto::RmuxError::UnsupportedCapability {
feature: redact_environment_message(&feature, environment),
supported: supported
.into_iter()
.map(|value| redact_environment_message(&value, environment))
.collect(),
}
}
rmux_proto::RmuxError::Encode(message) => {
rmux_proto::RmuxError::Encode(redact_environment_message(&message, environment))
}
rmux_proto::RmuxError::Decode(message) => {
rmux_proto::RmuxError::Decode(redact_environment_message(&message, environment))
}
error => error,
}
}
fn redact_environment_message(message: &str, environment: &[String]) -> String {
let mut redacted = message.to_owned();
for binding in environment {
replace_environment_secret(&mut redacted, binding);
if let Some((name, value)) = binding.split_once('=') {
replace_environment_name(&mut redacted, name);
if value.len() >= 4 {
replace_environment_secret(&mut redacted, value);
}
} else {
replace_environment_name(&mut redacted, binding);
}
}
redacted
}
fn replace_environment_secret(message: &mut String, secret: &str) {
if !secret.is_empty() && message.contains(secret) {
*message = message.replace(secret, REDACTED_ENVIRONMENT);
}
}
fn replace_environment_name(message: &mut String, name: &str) {
if !is_environment_name(name) {
return;
}
let mut redacted = String::with_capacity(message.len());
let mut copied_until = 0;
for (start, _) in message.match_indices(name) {
let end = start + name.len();
if is_environment_name_match(message.as_bytes(), start, end) {
redacted.push_str(&message[copied_until..start]);
redacted.push_str(REDACTED_ENVIRONMENT);
copied_until = end;
}
}
if copied_until != 0 {
redacted.push_str(&message[copied_until..]);
*message = redacted;
}
}
fn is_environment_name(name: &str) -> bool {
let mut bytes = name.bytes();
let Some(first) = bytes.next() else {
return false;
};
(first.is_ascii_alphabetic() || first == b'_')
&& bytes.all(|byte| byte.is_ascii_alphanumeric() || byte == b'_')
}
fn is_environment_name_match(message: &[u8], start: usize, end: usize) -> bool {
!start
.checked_sub(1)
.and_then(|index| message.get(index))
.is_some_and(|byte| byte.is_ascii_alphanumeric() || *byte == b'_')
&& !message
.get(end)
.is_some_and(|byte| byte.is_ascii_alphanumeric() || *byte == b'_')
}
#[cfg(test)]
mod tests {
use super::{redact_environment_message, REDACTED_ENVIRONMENT};
#[test]
fn environment_redaction_scrubs_binding_value_and_name() {
let environment = [String::from("SDK_SECRET_NAME=hidden-value")];
let rendered = redact_environment_message(
"SDK_SECRET_NAME failed after SDK_SECRET_NAME=hidden-value exposed hidden-value",
&environment,
);
assert!(!rendered.contains("SDK_SECRET_NAME"), "{rendered}");
assert!(!rendered.contains("hidden-value"), "{rendered}");
assert!(rendered.contains(REDACTED_ENVIRONMENT), "{rendered}");
}
#[test]
fn environment_name_redaction_respects_identifier_boundaries() {
let environment = [String::from("SDK_SECRET_NAME=hidden")];
let rendered = redact_environment_message("prefix_SDK_SECRET_NAME_suffix", &environment);
assert_eq!(rendered, "prefix_SDK_SECRET_NAME_suffix");
}
}