pub mod raw_mode;
pub mod server_launcher;
use crate::protocol::{
self, read_one_message, ClientMsg, FrameReader, ServerMsg, PROTOCOL_VERSION,
};
use std::io::{self, BufWriter, Read, Write};
use tokio::io::AsyncWriteExt;
use tokio::net::UnixStream;
use raw_mode::RawMode;
use server_launcher::ensure_server_running;
const DETACH_KEY: u8 = 0x1c;
const SIGINT: i32 = 2;
const SIGTERM: i32 = 15;
fn signal_exit_code(signo: i32) -> i32 {
128 + signo
}
const FOCUS_IN: u8 = b'I';
const FOCUS_OUT: u8 = b'O';
struct PanicHookGuard;
impl Drop for PanicHookGuard {
fn drop(&mut self) {
if !std::thread::panicking() {
let _ = std::panic::take_hook();
}
}
}
struct TerminalModeGuard;
impl Drop for TerminalModeGuard {
fn drop(&mut self) {
cleanup_terminal();
}
}
enum DispatchResult {
Continue,
Done {
exit_code: Option<i32>,
},
}
fn dispatch_server_msg(msg: &ServerMsg, stdout: &mut impl Write) -> io::Result<DispatchResult> {
match msg {
ServerMsg::ScreenUpdate(data) => {
stdout.write_all(data)?;
}
ServerMsg::Passthrough(data) => {
stdout.write_all(data)?;
stdout.flush()?;
}
ServerMsg::History(lines) => {
for line in lines {
stdout.write_all(line)?;
stdout.write_all(b"\r\n")?;
}
}
ServerMsg::SessionEnded { exit_code } => {
stdout.flush()?;
eprint!("[retach: session ended]\r\n");
return Ok(DispatchResult::Done {
exit_code: *exit_code,
});
}
ServerMsg::Error(e) => {
stdout.flush()?;
eprint!("[retach error: {}]\r\n", e);
return Ok(DispatchResult::Done { exit_code: None });
}
other => {
tracing::debug!(
"ignoring unexpected server message: {:?}",
std::mem::discriminant(other)
);
}
}
Ok(DispatchResult::Continue)
}
fn get_terminal_size() -> (u16, u16) {
if let Some(size) = terminal_size::terminal_size() {
(size.0 .0, size.1 .0)
} else {
(crate::session::DEFAULT_COLS, crate::session::DEFAULT_ROWS)
}
}
type SocketWriter = std::sync::Arc<tokio::sync::Mutex<tokio::net::unix::OwnedWriteHalf>>;
#[derive(Debug, PartialEq)]
enum FilterAction {
Forward(Vec<u8>),
Detach,
FocusIn,
}
struct InputFilter {
carry: Vec<u8>,
}
impl InputFilter {
fn new() -> Self {
Self {
carry: Vec::with_capacity(2),
}
}
fn flush_filtered(actions: &mut Vec<FilterAction>, filtered: &mut Vec<u8>) {
if !filtered.is_empty() {
actions.push(FilterAction::Forward(std::mem::take(filtered)));
}
}
fn process(&mut self, input: &[u8]) -> Vec<FilterAction> {
let raw: Vec<u8> = if self.carry.is_empty() {
input.to_vec()
} else {
let mut combined = std::mem::take(&mut self.carry);
combined.extend_from_slice(input);
combined
};
if raw.first() == Some(&DETACH_KEY) && raw.len() == 1 {
return vec![FilterAction::Detach];
}
let mut actions = Vec::new();
let mut filtered = Vec::with_capacity(raw.len());
let mut i = 0;
while i < raw.len() {
if raw[i] != 0x1b {
filtered.push(raw[i]);
i += 1;
continue;
}
let remaining = raw.len() - i;
if remaining < 2 {
Self::flush_filtered(&mut actions, &mut filtered);
self.carry.extend_from_slice(&raw[i..]);
return actions;
}
if raw[i + 1] != b'[' {
filtered.push(raw[i]);
filtered.push(raw[i + 1]);
i += 2;
continue;
}
if remaining < 3 {
Self::flush_filtered(&mut actions, &mut filtered);
self.carry.extend_from_slice(&raw[i..]);
return actions;
}
if raw[i + 2] == FOCUS_IN {
Self::flush_filtered(&mut actions, &mut filtered);
actions.push(FilterAction::FocusIn);
i += 3;
continue;
}
if raw[i + 2] == FOCUS_OUT {
i += 3;
continue;
}
filtered.push(raw[i]);
i += 1;
}
Self::flush_filtered(&mut actions, &mut filtered);
actions
}
fn flush(&mut self) -> Option<FilterAction> {
if self.carry.is_empty() {
None
} else {
Some(FilterAction::Forward(std::mem::take(&mut self.carry)))
}
}
}
async fn run_stdin_to_socket(sw: SocketWriter) -> anyhow::Result<()> {
let mut filter = InputFilter::new();
'stdin: loop {
let result = tokio::task::spawn_blocking(|| {
let mut buf = [0u8; 1024];
let n = io::stdin().read(&mut buf)?;
Ok::<_, io::Error>((buf, n))
})
.await;
match result {
Ok(Ok((_buf, 0))) => {
if let Some(FilterAction::Forward(data)) = filter.flush() {
let msg = protocol::encode(&ClientMsg::Input(data))?;
let mut w = sw.lock().await;
w.write_all(&msg).await?;
}
break;
}
Ok(Ok((buf, n))) => {
for action in filter.process(&buf[..n]) {
match action {
FilterAction::Forward(data) => {
let msg = protocol::encode(&ClientMsg::Input(data))?;
let mut w = sw.lock().await;
w.write_all(&msg).await?;
}
FilterAction::Detach => {
let mut w = sw.lock().await;
if let Ok(msg) = protocol::encode(&ClientMsg::Detach) {
w.write_all(&msg).await?;
}
drop(w);
return Ok(());
}
FilterAction::FocusIn => {
if let Ok(msg) = protocol::encode(&ClientMsg::RefreshScreen) {
let mut w = sw.lock().await;
if let Err(e) = w.write_all(&msg).await {
tracing::debug!(error = %e, "failed to send focus-in refresh");
break 'stdin;
}
}
}
}
}
}
Ok(Err(e)) => return Err(anyhow::Error::from(e)),
Err(e) => return Err(anyhow::Error::from(e)),
}
}
Ok(())
}
fn spawn_sigwinch_handler(
sock_writer: SocketWriter,
) -> anyhow::Result<tokio::task::JoinHandle<()>> {
let mut sigwinch =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::window_change())?;
let sw = sock_writer;
Ok(tokio::spawn(async move {
while sigwinch.recv().await.is_some() {
let (cols, rows) = get_terminal_size();
let mut w = sw.lock().await;
if let Ok(msg) = protocol::encode(&ClientMsg::Resize { cols, rows }) {
if let Err(e) = w.write_all(&msg).await {
tracing::debug!(error = %e, "failed to send resize");
break;
}
}
if let Ok(msg) = protocol::encode(&ClientMsg::RefreshScreen) {
if let Err(e) = w.write_all(&msg).await {
tracing::debug!(error = %e, "failed to send refresh after resize");
break;
}
}
}
}))
}
async fn run_socket_to_stdout(
mut sock_reader: tokio::net::unix::OwnedReadHalf,
leftover: Vec<u8>,
) -> anyhow::Result<Option<i32>> {
let mut frames = FrameReader::with_leftover(leftover);
let mut stdout = BufWriter::new(io::stdout());
while let Some(msg) = frames.decode_next::<ServerMsg>()? {
if let DispatchResult::Done { exit_code } = dispatch_server_msg(&msg, &mut stdout)? {
return Ok(exit_code);
}
}
stdout.flush()?;
loop {
if !frames.fill_from(&mut sock_reader).await? {
eprint!("[retach: detached]\r\n");
break;
}
while let Some(msg) = frames.decode_next::<ServerMsg>()? {
if let DispatchResult::Done { exit_code } = dispatch_server_msg(&msg, &mut stdout)? {
return Ok(exit_code);
}
}
stdout.flush()?;
}
Ok(None)
}
pub async fn connect(
name: &str,
history: usize,
mode: crate::protocol::ConnectMode,
) -> anyhow::Result<Option<i32>> {
ensure_server_running().await?;
let mut stream = UnixStream::connect(crate::server::socket_path()?).await?;
let (cols, rows) = get_terminal_size();
let msg = protocol::encode(&ClientMsg::Connect {
version: PROTOCOL_VERSION,
name: name.to_string(),
history,
cols,
rows,
mode,
})?;
stream.write_all(&msg).await?;
let mut frames = FrameReader::new();
loop {
if !frames.fill_from(&mut stream).await? {
anyhow::bail!("server closed connection before handshake completed");
}
if let Some(msg) = frames.decode_next::<ServerMsg>()? {
match msg {
ServerMsg::Connected {
name: ref session_name,
new_session,
} => {
if new_session {
eprintln!("[retach: new session '{}' (detach: Ctrl+\\)]", session_name);
} else {
eprintln!(
"[retach: reattached to '{}' (detach: Ctrl+\\)]",
session_name
);
}
break;
}
ServerMsg::Error(e) => {
anyhow::bail!("{}", e);
}
_ => {
anyhow::bail!("unexpected response from server");
}
}
}
}
let leftover = frames.into_leftover();
let prev_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
raw_mode::emergency_restore();
cleanup_terminal();
prev_hook(info);
}));
let _hook_guard = PanicHookGuard;
let _raw = RawMode::enter()?;
if let Err(e) = io::stdout().write_all(b"\x1b[?1004h") {
tracing::debug!(error = %e, "failed to enable focus reporting");
}
if let Err(e) = io::stdout().flush() {
tracing::debug!(error = %e, "failed to flush stdout after enabling focus reporting");
}
let _mode_guard = TerminalModeGuard;
let (sock_reader, sock_writer) = stream.into_split();
let sock_writer = std::sync::Arc::new(tokio::sync::Mutex::new(sock_writer));
let sigwinch_handle = spawn_sigwinch_handler(sock_writer.clone())?;
let mut stdin_task = tokio::spawn(run_stdin_to_socket(sock_writer.clone()));
let mut socket_task = tokio::spawn(run_socket_to_stdout(sock_reader, leftover));
let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?;
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?;
enum Completed {
Stdin,
Socket,
Neither,
}
let mut session_exit_code: Option<i32> = None;
let completed = tokio::select! {
r = &mut stdin_task => {
if let Ok(Err(e)) = r {
tracing::debug!(error = %e, "stdin task error");
}
Completed::Stdin
}
r = &mut socket_task => {
match r {
Ok(Ok(code)) => session_exit_code = code,
Ok(Err(e)) => {
tracing::warn!(error = %e, "socket task error");
eprint!("[retach error: {}]\r\n", e);
}
Err(e) => {
tracing::warn!(error = %e, "socket task join error");
}
}
Completed::Socket
}
_ = sigint.recv() => {
tracing::debug!("received SIGINT, detaching");
if let Ok(msg) = protocol::encode(&ClientMsg::Detach) {
let mut w = sock_writer.lock().await;
if let Err(e) = w.write_all(&msg).await {
tracing::debug!(error = %e, "failed to send detach on SIGINT");
}
}
session_exit_code = Some(signal_exit_code(SIGINT));
Completed::Neither
}
_ = sigterm.recv() => {
tracing::debug!("received SIGTERM, detaching");
if let Ok(msg) = protocol::encode(&ClientMsg::Detach) {
let mut w = sock_writer.lock().await;
if let Err(e) = w.write_all(&msg).await {
tracing::debug!(error = %e, "failed to send detach on SIGTERM");
}
}
session_exit_code = Some(signal_exit_code(SIGTERM));
Completed::Neither
}
};
match completed {
Completed::Stdin => {
socket_task.abort();
let _ = socket_task.await;
}
Completed::Socket => {
stdin_task.abort();
let _ = stdin_task.await;
}
Completed::Neither => {
stdin_task.abort();
socket_task.abort();
let _ = tokio::join!(stdin_task, socket_task);
}
}
sigwinch_handle.abort();
drop(_mode_guard);
drop(_raw);
drop(_hook_guard);
Ok(session_exit_code)
}
fn cleanup_terminal() {
let mut stdout = io::stdout();
let _ = stdout.write_all(
concat!(
"\x1b[r", "\x1b[2J", "\x1b[H", "\x1b[?25h", "\x1b[?7h", "\x1b[?1l", "\x1b[?2004l", "\x1b[?1000l", "\x1b[?1002l", "\x1b[?1003l", "\x1b[?1005l", "\x1b[?1006l", "\x1b[?1004l", "\x1b[?2026l", "\x1b>", "\x1b[0 q", "\x1b[0m", )
.as_bytes(),
);
let _ = stdout.flush();
}
pub async fn list_sessions() -> anyhow::Result<()> {
let path = crate::server::socket_path()?;
let mut stream = match UnixStream::connect(&path).await {
Ok(s) => s,
Err(e)
if e.kind() == std::io::ErrorKind::ConnectionRefused
|| e.kind() == std::io::ErrorKind::NotFound =>
{
println!("No active sessions");
return Ok(());
}
Err(e) => return Err(e.into()),
};
let msg = protocol::encode(&ClientMsg::ListSessions {
version: PROTOCOL_VERSION,
})?;
stream.write_all(&msg).await?;
let resp: ServerMsg = read_one_message(&mut stream).await?;
match resp {
ServerMsg::SessionList(sessions) => {
if sessions.is_empty() {
println!("No active sessions");
} else {
for s in sessions {
println!("{} ({}x{})", s.name, s.cols, s.rows);
}
}
}
ServerMsg::Error(e) => anyhow::bail!("{}", e),
other => anyhow::bail!(
"unexpected server response: {:?}",
std::mem::discriminant(&other)
),
}
Ok(())
}
pub async fn kill_session(name: &str) -> anyhow::Result<()> {
let path = crate::server::socket_path()?;
let mut stream = match UnixStream::connect(&path).await {
Ok(s) => s,
Err(_) => anyhow::bail!("server not running"),
};
let msg = protocol::encode(&ClientMsg::KillSession {
version: PROTOCOL_VERSION,
name: name.to_string(),
})?;
stream.write_all(&msg).await?;
let resp: ServerMsg = read_one_message(&mut stream).await?;
match resp {
ServerMsg::SessionKilled { name } => println!("killed session '{}'", name),
ServerMsg::Error(e) => anyhow::bail!("{}", e),
other => anyhow::bail!(
"unexpected server response: {:?}",
std::mem::discriminant(&other)
),
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dispatch_server_msg_error_returns_done() {
let msg = ServerMsg::Error("test error".into());
let mut buf = Vec::new();
let result = dispatch_server_msg(&msg, &mut buf).unwrap();
assert!(matches!(result, DispatchResult::Done { exit_code: None }));
}
#[test]
fn dispatch_server_msg_session_ended_returns_done() {
let msg = ServerMsg::SessionEnded { exit_code: Some(0) };
let mut buf = Vec::new();
let result = dispatch_server_msg(&msg, &mut buf).unwrap();
assert!(matches!(
result,
DispatchResult::Done { exit_code: Some(0) }
));
}
#[test]
fn dispatch_server_msg_session_ended_propagates_exit_code() {
let msg = ServerMsg::SessionEnded {
exit_code: Some(42),
};
let mut buf = Vec::new();
match dispatch_server_msg(&msg, &mut buf).unwrap() {
DispatchResult::Done { exit_code } => assert_eq!(exit_code, Some(42)),
DispatchResult::Continue => panic!("expected Done"),
}
}
#[test]
fn dispatch_server_msg_screen_update_continues() {
let msg = ServerMsg::ScreenUpdate(b"hello".to_vec());
let mut buf = Vec::new();
let result = dispatch_server_msg(&msg, &mut buf).unwrap();
assert!(matches!(result, DispatchResult::Continue));
assert_eq!(buf, b"hello");
}
#[test]
fn signal_exit_code_matches_shell_convention() {
assert_eq!(signal_exit_code(SIGINT), 130);
assert_eq!(signal_exit_code(SIGTERM), 143);
}
#[test]
fn input_filter_passthrough() {
let mut filter = InputFilter::new();
let result = filter.process(b"hello");
assert_eq!(result, vec![FilterAction::Forward(b"hello".to_vec())]);
}
#[test]
fn input_filter_detach_only_on_lone_keypress() {
let mut filter = InputFilter::new();
let result = filter.process(b"\x1c");
assert_eq!(result, vec![FilterAction::Detach]);
}
#[test]
fn input_filter_detach_byte_mid_buffer_passes_through() {
let mut filter = InputFilter::new();
let result = filter.process(b"abc\x1cdef");
assert_eq!(result, vec![FilterAction::Forward(b"abc\x1cdef".to_vec())]);
}
#[test]
fn input_filter_detach_byte_at_start_of_multibyte_passes_through() {
let mut filter = InputFilter::new();
let result = filter.process(b"\x1cabc");
assert_eq!(result, vec![FilterAction::Forward(b"\x1cabc".to_vec())]);
}
#[test]
fn input_filter_bracketed_paste_with_detach_byte_intact() {
let mut filter = InputFilter::new();
let paste = b"\x1b[200~hello\x1cworld\x1b[201~";
let result = filter.process(paste);
assert_eq!(result, vec![FilterAction::Forward(paste.to_vec())]);
}
#[test]
fn input_filter_focus_in() {
let mut filter = InputFilter::new();
let result = filter.process(b"\x1b[I");
assert_eq!(result, vec![FilterAction::FocusIn]);
}
#[test]
fn input_filter_focus_out_dropped() {
let mut filter = InputFilter::new();
let result = filter.process(b"\x1b[O");
assert!(result.is_empty());
}
#[test]
fn input_filter_carry_lone_esc() {
let mut filter = InputFilter::new();
let result = filter.process(b"abc\x1b");
assert_eq!(result, vec![FilterAction::Forward(b"abc".to_vec())]);
let result = filter.process(b"[Amore");
assert_eq!(result, vec![FilterAction::Forward(b"\x1b[Amore".to_vec())]);
}
#[test]
fn input_filter_carry_esc_bracket() {
let mut filter = InputFilter::new();
let result = filter.process(b"\x1b[");
assert!(result.is_empty());
let result = filter.process(b"Irest");
assert_eq!(
result,
vec![
FilterAction::FocusIn,
FilterAction::Forward(b"rest".to_vec()),
]
);
}
#[test]
fn input_filter_flush() {
let mut filter = InputFilter::new();
let _ = filter.process(b"\x1b");
let flushed = filter.flush();
assert_eq!(flushed, Some(FilterAction::Forward(vec![0x1b])));
assert_eq!(filter.flush(), None);
}
#[test]
fn input_filter_esc_non_bracket_passthrough() {
let mut filter = InputFilter::new();
let result = filter.process(b"\x1bOA"); assert_eq!(result, vec![FilterAction::Forward(b"\x1bOA".to_vec())]);
}
#[test]
fn input_filter_mixed() {
let mut filter = InputFilter::new();
let result = filter.process(b"text\x1b[Imore");
assert_eq!(
result,
vec![
FilterAction::Forward(b"text".to_vec()),
FilterAction::FocusIn,
FilterAction::Forward(b"more".to_vec()),
]
);
}
#[test]
fn input_filter_esc_bracket_at_boundary() {
let mut f = InputFilter::new();
let a1 = f.process(b"\x1b");
assert!(a1.is_empty(), "lone ESC should be carried");
let a2 = f.process(b"[I");
assert_eq!(a2, vec![FilterAction::FocusIn]);
}
#[test]
fn input_filter_multiple_focus_events() {
let mut f = InputFilter::new();
let actions = f.process(b"a\x1b[Ib\x1b[Oc");
assert_eq!(
actions,
vec![
FilterAction::Forward(b"a".to_vec()),
FilterAction::FocusIn,
FilterAction::Forward(b"bc".to_vec()),
]
);
}
}