use std::io::{self, BufReader, Write};
use std::os::unix::net::UnixStream;
use std::sync::mpsc;
use std::time::Duration;
use crossterm::{
cursor,
event::{
self, Event, KeyboardEnhancementFlags, PopKeyboardEnhancementFlags,
PushKeyboardEnhancementFlags,
},
execute,
terminal::{self, EnterAlternateScreen, LeaveAlternateScreen},
};
use crate::protocol;
pub enum ExitReason {
Detached,
ServerExit,
ConnectionLost,
}
pub fn run(socket_path: &std::path::Path, session_name: &str) -> anyhow::Result<()> {
run_with_mode(socket_path, session_name, protocol::AttachMode::Steal)
}
#[derive(Debug)]
pub struct IncompatibleServerError {
#[allow(dead_code)]
pub server_proto: String,
#[allow(dead_code)]
pub client_proto: String,
pub message: String,
}
impl std::fmt::Display for IncompatibleServerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for IncompatibleServerError {}
pub fn run_with_mode(
socket_path: &std::path::Path,
session_name: &str,
attach_mode: protocol::AttachMode,
) -> anyhow::Result<()> {
let stream = UnixStream::connect(socket_path)?;
stream.set_nonblocking(false)?;
stream.set_read_timeout(Some(Duration::from_secs(2)))?;
let server_hello = perform_handshake(&stream, session_name)?;
stream.set_read_timeout(None)?;
if std::env::var("EZPN_DEBUG").is_ok() {
eprintln!(
"ezpn: handshake ok — server {}.{} ({})",
server_hello.proto_major, server_hello.proto_minor, server_hello.build
);
}
let write_stream = stream.try_clone()?;
let read_stream = stream;
let (server_tx, server_rx) = mpsc::channel::<(u8, Vec<u8>)>();
std::thread::spawn(move || {
let mut reader = BufReader::new(read_stream);
while let Ok(msg) = protocol::read_msg(&mut reader) {
if server_tx.send(msg).is_err() {
break;
}
}
});
terminal::enable_raw_mode()?;
let mut stdout = io::stdout();
execute!(
stdout,
EnterAlternateScreen,
event::EnableMouseCapture,
event::EnableFocusChange,
event::EnableBracketedPaste,
PushKeyboardEnhancementFlags(KeyboardEnhancementFlags::DISAMBIGUATE_ESCAPE_CODES),
cursor::Hide
)?;
let _ = write!(stdout, "\x1b]0;ezpn: {}\x07", session_name);
let _ = stdout.flush();
let reason = client_loop(&mut stdout, write_stream, &server_rx, attach_mode);
{
let mut out = io::stdout();
let _ = write!(out, "\x1b]0;\x07"); let _ = execute!(
out,
PopKeyboardEnhancementFlags,
event::DisableBracketedPaste,
event::DisableFocusChange,
cursor::Show,
event::DisableMouseCapture,
LeaveAlternateScreen
);
}
let _ = terminal::disable_raw_mode();
match reason {
Ok(ExitReason::Detached) => {
println!("[detached from session {}]", session_name);
}
Ok(ExitReason::ServerExit) => {
println!("[session {} ended]", session_name);
}
Ok(ExitReason::ConnectionLost) => {
return Err(anyhow::anyhow!("server connection lost"));
}
Err(e) => return Err(e),
}
Ok(())
}
fn perform_handshake(
stream: &UnixStream,
session_name: &str,
) -> anyhow::Result<protocol::ServerHello> {
let mut reader = stream;
let mut writer = stream;
match protocol::client_handshake(&mut reader, &mut writer)? {
protocol::HandshakeOutcome::Ok(hello) => Ok(hello),
protocol::HandshakeOutcome::Incompat(notice) => {
eprintln!("Error: {}", notice.message);
if !notice.message.contains("ezpn kill") {
eprintln!("hint: ezpn kill {}", session_name);
}
Err(anyhow::Error::new(IncompatibleServerError {
server_proto: notice.server_proto,
client_proto: notice.client_proto,
message: notice.message,
}))
}
}
}
fn client_loop(
stdout: &mut io::Stdout,
mut writer: UnixStream,
server_rx: &mpsc::Receiver<(u8, Vec<u8>)>,
attach_mode: protocol::AttachMode,
) -> anyhow::Result<ExitReason> {
let (cols, rows) = terminal::size()?;
if attach_mode == protocol::AttachMode::Steal {
let resize_data = protocol::encode_resize(cols, rows);
protocol::write_msg(&mut writer, protocol::C_RESIZE, &resize_data)?;
} else {
let req = protocol::AttachRequest {
cols,
rows,
mode: attach_mode,
};
let json = serde_json::to_vec(&req)?;
protocol::write_msg(&mut writer, protocol::C_ATTACH, &json)?;
}
loop {
let mut got_output = false;
loop {
match server_rx.try_recv() {
Ok((tag, payload)) => match tag {
protocol::S_OUTPUT => {
stdout.write_all(&payload)?;
got_output = true;
}
protocol::S_DETACHED => {
return Ok(ExitReason::Detached);
}
protocol::S_EXIT => {
return Ok(ExitReason::ServerExit);
}
_ => {}
},
Err(mpsc::TryRecvError::Empty) => break,
Err(mpsc::TryRecvError::Disconnected) => {
return Ok(ExitReason::ConnectionLost);
}
}
}
if got_output {
stdout.flush()?;
}
let poll_ms = if got_output { 1 } else { 4 };
while event::poll(Duration::from_millis(poll_ms))? {
let ev = event::read()?;
match &ev {
Event::Resize(w, h) => {
let data = protocol::encode_resize(*w, *h);
if protocol::write_msg(&mut writer, protocol::C_RESIZE, &data).is_err() {
return Ok(ExitReason::ConnectionLost);
}
}
_ => {
let json = serde_json::to_vec(&ev)?;
if protocol::write_msg(&mut writer, protocol::C_EVENT, &json).is_err() {
return Ok(ExitReason::ConnectionLost);
}
}
}
}
}
}