use std::{env, fmt, io, path::PathBuf, thread, time};
use anyhow::{anyhow, bail, Context};
use shpool_protocol::{
AttachHeader, AttachReplyHeader, ConnectHeader, DetachReply, DetachRequest, ResizeReply,
ResizeRequest, SessionMessageReply, SessionMessageRequest, SessionMessageRequestPayload,
TtySize,
};
use tracing::{debug, error, info, warn};
use super::{config, duration, protocol, protocol::ClientResult, test_hooks, tty::TtySizeExt as _};
const MAX_FORCE_RETRIES: usize = 20;
#[allow(clippy::too_many_arguments)]
pub fn run(
config_manager: config::Manager,
name: String,
force: bool,
background: bool,
ttl: Option<String>,
cmd: Option<String>,
dir: Option<String>,
socket: PathBuf,
) -> anyhow::Result<()> {
info!("\n\n======================== STARTING ATTACH ============================\n\n");
test_hooks::emit("attach-startup");
if name.is_empty() {
eprintln!("blank session names are not allowed");
return Ok(());
}
if name.contains(char::is_whitespace) {
eprintln!("whitespace is not allowed in session names");
return Ok(());
}
if !background {
SignalHandler::new(name.clone(), socket.clone()).spawn()?;
}
let ttl = match &ttl {
Some(src) => match duration::parse(src.as_str()) {
Ok(d) => Some(d),
Err(e) => {
bail!("could not parse ttl: {:?}", e);
}
},
None => None,
};
let mut detached = false;
let mut tries = 0;
let attach_client = loop {
match do_attach(&config_manager, name.as_str(), background, &ttl, &cmd, &dir, &socket) {
Ok(client) => break client,
Err(err) => match err.downcast() {
Ok(BusyError) if !force => {
eprintln!("session '{name}' already has a terminal attached");
return Ok(());
}
Ok(BusyError) => {
if !detached {
let mut client = dial_client(&socket, background)?;
client
.write_connect_header(ConnectHeader::Detach(DetachRequest {
sessions: vec![name.clone()],
}))
.context("writing detach request header")?;
let detach_reply: DetachReply =
client.read_reply().context("reading reply")?;
if !detach_reply.not_found_sessions.is_empty() {
warn!("could not find session '{}' to detach it", name);
}
detached = true;
}
thread::sleep(time::Duration::from_millis(100));
if tries > MAX_FORCE_RETRIES {
eprintln!("session '{name}' already has a terminal which remains attached even after attempting to detach it");
return Err(anyhow!("could not detach session, forced attach failed"));
}
tries += 1;
}
Err(err) => return Err(err),
},
}
};
if background {
drop(attach_client);
let mut client = dial_client(&socket, true)?;
client
.write_connect_header(ConnectHeader::Detach(DetachRequest {
sessions: vec![name.clone()],
}))
.context("writing detach request header")?;
let detach_reply: DetachReply = client.read_reply().context("reading reply")?;
if !detach_reply.not_found_sessions.is_empty() {
warn!("could not find session '{}' to detach it", name);
}
if !detach_reply.not_attached_sessions.is_empty() {
debug!(
"session '{}' was already detached while processing background detach request (expected)",
name
);
}
return Ok(());
}
match attach_client.pipe_bytes() {
Ok(exit_status) => std::process::exit(exit_status),
Err(e) => Err(e),
}
}
#[derive(Debug)]
struct BusyError;
impl fmt::Display for BusyError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BusyError")
}
}
impl std::error::Error for BusyError {}
fn do_attach(
config: &config::Manager,
name: &str,
background: bool,
ttl: &Option<time::Duration>,
cmd: &Option<String>,
dir: &Option<String>,
socket: &PathBuf,
) -> anyhow::Result<protocol::Client> {
let mut client = dial_client(socket, background)?;
let tty_size = match TtySize::from_fd(0) {
Ok(s) => s,
Err(e) => {
warn!("stdin is not a tty, using default size (err: {e:?})");
TtySize { rows: 24, cols: 80, xpixel: 0, ypixel: 0 }
}
};
let forward_env = config.get().forward_env.clone();
let mut local_env_keys = vec!["TERM", "DISPLAY", "LANG", "SSH_AUTH_SOCK"];
if let Some(fenv) = &forward_env {
for var in fenv.iter() {
local_env_keys.push(var);
}
}
info!("local env keys: {local_env_keys:?}");
let cwd = String::from(env::current_dir().context("getting cwd")?.to_string_lossy());
let default_dir = config.get().default_dir.clone().unwrap_or(String::from("$HOME"));
let start_dir = match (default_dir.as_str(), dir.as_deref()) {
(".", None) => Some(cwd),
("$HOME", None) => None,
(d, None) => Some(String::from(d)),
(_, Some(".")) => Some(cwd),
(_, Some(d)) => Some(String::from(d)),
};
client
.write_connect_header(ConnectHeader::Attach(AttachHeader {
name: String::from(name),
local_tty_size: tty_size,
local_env: local_env_keys
.into_iter()
.filter_map(|var| {
let val = env::var(var).context("resolving var").ok()?;
Some((String::from(var), val))
})
.collect::<Vec<_>>(),
ttl_secs: ttl.map(|d| d.as_secs()),
cmd: cmd.clone(),
dir: start_dir,
}))
.context("writing attach header")?;
let attach_resp: AttachReplyHeader = client.read_reply().context("reading attach reply")?;
info!("attach_resp.status={:?}", attach_resp.status);
{
use shpool_protocol::AttachStatus::*;
match attach_resp.status {
Busy => {
return Err(BusyError.into());
}
Forbidden(reason) => {
eprintln!("forbidden: {reason}");
return Err(anyhow!("forbidden: {reason}"));
}
Attached { warnings } => {
for warning in warnings.into_iter() {
eprintln!("shpool: warn: {warning}");
}
info!("attached to an existing session: '{}'", name);
}
Created { warnings } => {
for warning in warnings.into_iter() {
eprintln!("shpool: warn: {warning}");
}
info!("created a new session: '{}'", name);
}
UnexpectedError(err) => {
return Err(anyhow!("BUG: unexpected error attaching to '{}': {}", name, err));
}
}
}
Ok(client)
}
fn dial_client(socket: &PathBuf, background: bool) -> anyhow::Result<protocol::Client> {
match protocol::Client::new(socket) {
Ok(ClientResult::JustClient(c)) => Ok(c),
Ok(ClientResult::VersionMismatch { warning, client }) => {
if background {
eprintln!(
"warning: {warning}, proceeding in background mode; try restarting your daemon"
);
} else {
eprintln!("warning: {warning}, try restarting your daemon");
eprintln!("hit enter to continue anyway or ^C to exit");
let _ = io::stdin()
.lines()
.next()
.context("waiting for a continue through a version mismatch")?;
}
Ok(client)
}
Err(err) => {
let io_err = err.downcast::<io::Error>()?;
if io_err.kind() == io::ErrorKind::NotFound {
eprintln!("could not connect to daemon");
}
Err(io_err).context("connecting to daemon")
}
}
}
struct SignalHandler {
session_name: String,
socket: PathBuf,
}
impl SignalHandler {
fn new(session_name: String, socket: PathBuf) -> Self {
SignalHandler { session_name, socket }
}
fn spawn(self) -> anyhow::Result<()> {
use signal_hook::{consts::*, iterator::*};
let sigs = vec![SIGWINCH];
let mut signals = Signals::new(sigs).context("creating signal iterator")?;
thread::spawn(move || {
for signal in &mut signals {
let res = match signal {
SIGWINCH => self.handle_sigwinch(),
sig => {
error!("unknown signal: {}", sig);
panic!("unknown signal: {sig}");
}
};
if let Err(e) = res {
error!("signal handler error: {:?}", e);
}
}
});
Ok(())
}
fn handle_sigwinch(&self) -> anyhow::Result<()> {
info!("handle_sigwinch: enter");
let mut client = match protocol::Client::new(&self.socket)? {
ClientResult::JustClient(c) => c,
ClientResult::VersionMismatch { client, .. } => client,
};
let tty_size = TtySize::from_fd(0).context("getting tty size")?;
info!("handle_sigwinch: tty_size={:?}", tty_size);
client
.write_connect_header(ConnectHeader::SessionMessage(SessionMessageRequest {
session_name: self.session_name.clone(),
payload: SessionMessageRequestPayload::Resize(ResizeRequest {
tty_size: tty_size.clone(),
}),
}))
.context("writing resize request")?;
let reply: SessionMessageReply =
client.read_reply().context("reading session message reply")?;
match reply {
SessionMessageReply::NotFound => {
warn!(
"handle_sigwinch: sent resize for session '{}', but the daemon has no record of that session",
self.session_name
);
}
SessionMessageReply::Resize(ResizeReply::Ok) => {
info!("handle_sigwinch: resized session '{}' to {:?}", self.session_name, tty_size);
}
reply => {
warn!("handle_sigwinch: unexpected resize reply: {:?}", reply);
}
}
Ok(())
}
}