use std::{
ffi::OsString,
fs::{DirBuilder, OpenOptions},
io::{Read as _, Write as _, stdin, stdout},
net::SocketAddr,
path::{Path, PathBuf},
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
thread,
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
};
#[cfg(target_family = "unix")]
use std::os::unix::fs::DirBuilderExt;
use anyhow::{Context as _, Result};
use clap::Parser as _;
use crossterm::terminal::{disable_raw_mode, enable_raw_mode};
use dialoguer::{Confirm, Password};
use libmoshpit::{
DiffMode, DisplayPreference, Emulator, EncryptedFrame, KEY_ALGORITHM_X25519, Kex,
KexConfig as _, KexMode, KeyPair, MoshpitError, PredictionEngine, Renderer, UdpReader,
UdpSender, UuidWrapper, init_tracing, load, paint_overlays_to_ansi, parse_server_destination,
run_key_exchange,
};
use terminal_size::terminal_size;
#[cfg(unix)]
use tokio::signal::unix::{SignalKind, signal};
use tokio::{
net::{TcpStream, UdpSocket},
select, spawn,
sync::{
Mutex,
mpsc::{Receiver, Sender, channel},
},
time,
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace};
use uuid::Uuid;
use crate::{cli::Cli, config::Config};
#[cfg_attr(coverage_nightly, coverage(off))]
pub(crate) async fn run<I, T>(args: Option<I>) -> Result<()>
where
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
let cli = if let Some(args) = args {
Cli::try_parse_from(args)?
} else {
Cli::try_parse()?
};
let mut config =
load::<Cli, Config, Cli>(&cli, &cli).with_context(|| MoshpitError::ConfigLoad)?;
init_tracing(&config, config.tracing().file(), &cli, None)
.with_context(|| MoshpitError::TracingInit)?;
maybe_generate_keypair(&config)?;
let (user, socket_addr) =
parse_server_destination(config.server_destination(), config.server_port())?;
let server_ip = socket_addr.ip().to_string();
let server_port = config.server_port();
let _ = config.set_user(user);
run_session_loop(config, socket_addr, server_ip, server_port).await
}
#[derive(Debug)]
enum PassCache {
Uncached,
NoPassphrase,
Passphrase(String),
}
impl PassCache {
fn is_cached(&self) -> bool {
!matches!(self, Self::Uncached)
}
fn passphrase(&self) -> Option<String> {
match self {
Self::Uncached => unreachable!("passphrase() called before caching"),
Self::NoPassphrase => None,
Self::Passphrase(s) => Some(s.clone()),
}
}
}
const KEX_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug)]
struct FatalKexError {
inner: MoshpitError,
key_path: PathBuf,
}
impl std::fmt::Display for FatalKexError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} (key: {})", self.inner, self.key_path.display())
}
}
impl std::error::Error for FatalKexError {}
#[derive(Clone, Copy, Default)]
enum EscapeState {
#[default]
Normal,
PendingDot,
}
async fn show_reconnect_banner(stdout_tx: &Sender<Vec<u8>>) {
let msg = b"\x1b[s\x1b[1;1H\x1b[44;97;1m [moshpit] server unreachable, reconnecting... (Ctrl-^ . to quit) \x1b[K\x1b[0m\x1b[u";
drop(stdout_tx.send(msg.to_vec()).await);
}
async fn clear_reconnect_banner(stdout_tx: &Sender<Vec<u8>>) {
let msg = b"\x1b[s\x1b[1;1H\x1b[0m\x1b[K\x1b[u";
drop(stdout_tx.send(msg.to_vec()).await);
}
async fn countdown_reconnect_banner(
stdout_tx: &Sender<Vec<u8>>,
total_secs: u64,
attempt: u32,
max_backoff_secs: u64,
exit_token: &CancellationToken,
) -> bool {
for remaining in (0..=total_secs).rev() {
let msg = format!(
"\x1b[s\x1b[1;1H\x1b[44;97;1m [moshpit] server unreachable, reconnecting \
(attempt #{attempt}, {remaining}s, max {max_backoff_secs}s, Ctrl-^ . to quit)... \x1b[K\x1b[0m\x1b[u"
);
drop(stdout_tx.send(msg.into_bytes()).await);
if remaining > 0 {
select! {
() = exit_token.cancelled() => return true,
() = time::sleep(Duration::from_secs(1)) => {}
}
}
}
exit_token.is_cancelled()
}
async fn run_escape_listener(
kb_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
exit_token: CancellationToken,
done_token: CancellationToken,
) {
let mut state = EscapeState::Normal;
let mut rx = kb_rx.lock().await;
loop {
select! {
() = done_token.cancelled() => break,
data = rx.recv() => match data {
None => break,
Some(data) => {
for &byte in &data {
state = match state {
EscapeState::Normal => {
if byte == 0x1E { EscapeState::PendingDot } else { EscapeState::Normal }
}
EscapeState::PendingDot => {
if byte == 0x2E {
exit_token.cancel();
return;
} else if byte == 0x1E {
EscapeState::PendingDot
} else {
EscapeState::Normal
}
}
};
}
}
}
}
}
}
fn encode_char_key(c: char, ctrl: bool, alt: bool) -> Vec<u8> {
let mut out = Vec::new();
if ctrl {
let byte = match c.to_ascii_lowercase() {
'@' => 0x00,
'a'..='z' => c.to_ascii_lowercase() as u8 - b'a' + 1,
'[' => 0x1b,
'\\' | '4' => 0x1c,
']' | '5' => 0x1d,
'^' | '6' => 0x1e,
'_' | '7' => 0x1f,
_ => {
let mut buf = [0u8; 4];
let s = c.encode_utf8(&mut buf);
if alt {
out.push(0x1b);
}
out.extend_from_slice(s.as_bytes());
return out;
}
};
if alt {
out.push(0x1b);
}
out.push(byte);
return out;
}
if alt {
out.push(0x1b);
}
let mut buf = [0u8; 4];
out.extend_from_slice(c.encode_utf8(&mut buf).as_bytes());
out
}
fn encode_nav_key(
code: crossterm::event::KeyCode,
has_mod: bool,
mod_param: u8,
) -> Option<Vec<u8>> {
use crossterm::event::KeyCode;
let bytes = match code {
KeyCode::Up => {
if has_mod {
format!("\x1b[1;{mod_param}A").into_bytes()
} else {
b"\x1b[A".to_vec()
}
}
KeyCode::Down => {
if has_mod {
format!("\x1b[1;{mod_param}B").into_bytes()
} else {
b"\x1b[B".to_vec()
}
}
KeyCode::Right => {
if has_mod {
format!("\x1b[1;{mod_param}C").into_bytes()
} else {
b"\x1b[C".to_vec()
}
}
KeyCode::Left => {
if has_mod {
format!("\x1b[1;{mod_param}D").into_bytes()
} else {
b"\x1b[D".to_vec()
}
}
KeyCode::Home => {
if has_mod {
format!("\x1b[1;{mod_param}H").into_bytes()
} else {
b"\x1b[H".to_vec()
}
}
KeyCode::End => {
if has_mod {
format!("\x1b[1;{mod_param}F").into_bytes()
} else {
b"\x1b[F".to_vec()
}
}
KeyCode::Insert => {
if has_mod {
format!("\x1b[2;{mod_param}~").into_bytes()
} else {
b"\x1b[2~".to_vec()
}
}
KeyCode::Delete => {
if has_mod {
format!("\x1b[3;{mod_param}~").into_bytes()
} else {
b"\x1b[3~".to_vec()
}
}
KeyCode::PageUp => {
if has_mod {
format!("\x1b[5;{mod_param}~").into_bytes()
} else {
b"\x1b[5~".to_vec()
}
}
KeyCode::PageDown => {
if has_mod {
format!("\x1b[6;{mod_param}~").into_bytes()
} else {
b"\x1b[6~".to_vec()
}
}
_ => return None,
};
Some(bytes)
}
fn encode_function_key(n: u8, has_mod: bool, mod_param: u8) -> Vec<u8> {
match n {
1 => {
if has_mod {
format!("\x1b[1;{mod_param}P").into_bytes()
} else {
b"\x1bOP".to_vec()
}
}
2 => {
if has_mod {
format!("\x1b[1;{mod_param}Q").into_bytes()
} else {
b"\x1bOQ".to_vec()
}
}
3 => {
if has_mod {
format!("\x1b[1;{mod_param}R").into_bytes()
} else {
b"\x1bOR".to_vec()
}
}
4 => {
if has_mod {
format!("\x1b[1;{mod_param}S").into_bytes()
} else {
b"\x1bOS".to_vec()
}
}
5 => {
if has_mod {
format!("\x1b[15;{mod_param}~").into_bytes()
} else {
b"\x1b[15~".to_vec()
}
}
6 => {
if has_mod {
format!("\x1b[17;{mod_param}~").into_bytes()
} else {
b"\x1b[17~".to_vec()
}
}
7 => {
if has_mod {
format!("\x1b[18;{mod_param}~").into_bytes()
} else {
b"\x1b[18~".to_vec()
}
}
8 => {
if has_mod {
format!("\x1b[19;{mod_param}~").into_bytes()
} else {
b"\x1b[19~".to_vec()
}
}
9 => {
if has_mod {
format!("\x1b[20;{mod_param}~").into_bytes()
} else {
b"\x1b[20~".to_vec()
}
}
10 => {
if has_mod {
format!("\x1b[21;{mod_param}~").into_bytes()
} else {
b"\x1b[21~".to_vec()
}
}
11 => {
if has_mod {
format!("\x1b[23;{mod_param}~").into_bytes()
} else {
b"\x1b[23~".to_vec()
}
}
12 => {
if has_mod {
format!("\x1b[24;{mod_param}~").into_bytes()
} else {
b"\x1b[24~".to_vec()
}
}
_ => Vec::new(),
}
}
fn key_event_to_bytes(event: crossterm::event::KeyEvent) -> Vec<u8> {
use crossterm::event::{KeyCode, KeyEventKind, KeyModifiers};
if event.kind != KeyEventKind::Press {
return Vec::new();
}
let mods = event.modifiers;
let ctrl = mods.contains(KeyModifiers::CONTROL);
let alt = mods.contains(KeyModifiers::ALT);
let shift = mods.contains(KeyModifiers::SHIFT);
let mod_param = 1u8 + u8::from(shift) + (u8::from(alt) * 2) + (u8::from(ctrl) * 4);
let has_mod = mod_param > 1;
match event.code {
KeyCode::Char(c) => encode_char_key(c, ctrl, alt),
KeyCode::Backspace => vec![0x7f],
KeyCode::Enter => vec![b'\r'],
KeyCode::Tab => vec![b'\t'],
KeyCode::BackTab => b"\x1b[Z".to_vec(),
KeyCode::Esc => vec![0x1b],
KeyCode::Null => vec![0x00],
KeyCode::F(n) => encode_function_key(n, has_mod, mod_param),
code => encode_nav_key(code, has_mod, mod_param).unwrap_or_default(),
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
fn with_cooked_term<T>(paused: &AtomicBool, f: impl FnOnce() -> T) -> T {
paused.store(true, Ordering::SeqCst);
thread::sleep(Duration::from_millis(100));
drop(disable_raw_mode());
let result = f();
drop(enable_raw_mode());
paused.store(false, Ordering::SeqCst);
result
}
fn stdin_reader_loop(kb_tx: &Sender<Vec<u8>>, paused: &AtomicBool) {
use crossterm::event::{Event, poll, read};
loop {
if paused.load(Ordering::Relaxed) {
thread::sleep(Duration::from_millis(50));
continue;
}
match poll(Duration::from_millis(50)) {
Ok(true) => {
if let Ok(Event::Key(ke)) = read() {
let bytes = key_event_to_bytes(ke);
if !bytes.is_empty() && kb_tx.blocking_send(bytes).is_err() {
break;
}
}
}
Ok(false) => {}
Err(_) => break,
}
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
async fn countdown_with_escape(
stdout_tx: &Sender<Vec<u8>>,
backoff_secs: u64,
attempt: u32,
max_backoff_secs: u64,
exit_token: &CancellationToken,
kb_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
) -> bool {
let escape_done = CancellationToken::new();
let escape_handle = spawn(run_escape_listener(
kb_rx,
exit_token.clone(),
escape_done.clone(),
));
let exiting = countdown_reconnect_banner(
stdout_tx,
backoff_secs,
attempt,
max_backoff_secs,
exit_token,
)
.await;
escape_done.cancel();
drop(escape_handle.await);
exiting
}
#[cfg_attr(nightly, allow(clippy::too_many_lines))]
#[cfg_attr(coverage_nightly, coverage(off))]
async fn run_session_loop(
config: Config,
socket_addr: SocketAddr,
server_ip: String,
server_port: u16,
) -> Result<()> {
let max_backoff = Duration::from_secs(config.max_reconnect_backoff_secs().clamp(2, 86_400));
let (stdout_tx, mut stdout_rx) = channel::<Vec<u8>>(256);
let _stdout_thread = thread::spawn(move || {
let mut out = stdout();
while let Some(msg) = stdout_rx.blocking_recv() {
drop(out.write_all(&msg));
drop(out.flush());
}
});
let pass_cache: Arc<std::sync::Mutex<PassCache>> =
Arc::new(std::sync::Mutex::new(PassCache::Uncached));
let mut config = config;
let mut backoff = Duration::from_secs(2);
let mut reconnect_attempt: u32 = 0;
let exit_token = CancellationToken::new();
let stdin_paused = Arc::new(AtomicBool::new(false));
enable_raw_mode()?;
let (kb_tx, kb_rx) = channel::<Vec<u8>>(64);
let paused_for_reader = stdin_paused.clone();
let _stdin_thread = thread::spawn(move || stdin_reader_loop(&kb_tx, &paused_for_reader));
let kb_rx_shared = Arc::new(Mutex::new(kb_rx));
let mut had_successful_kex = false;
loop {
match connect_and_kex(
&mut config,
socket_addr,
&server_ip,
server_port,
&pass_cache,
stdin_paused.clone(),
)
.await
{
Ok((kex, udp_arc, nak_timeout)) => {
backoff = Duration::from_secs(2);
clear_reconnect_banner(&stdout_tx).await;
had_successful_kex = true;
let session_result = run_udp_session(
kex,
udp_arc,
nak_timeout,
kb_rx_shared.clone(),
config.nat_warmup(),
config.nat_warmup_count(),
stdout_tx.clone(),
config.predict(),
config.diff_mode(),
exit_token.clone(),
)
.await;
if let Err(e) = session_result {
drop(disable_raw_mode());
return Err(e);
}
if exit_token.is_cancelled() {
drop(disable_raw_mode());
time::sleep(Duration::from_millis(100)).await;
std::process::exit(0);
}
show_reconnect_banner(&stdout_tx).await;
time::sleep(Duration::from_millis(500)).await;
}
Err(e) => {
if let Some(fatal) = e.downcast_ref::<FatalKexError>() {
eprintln!("mp: fatal key error: {fatal}");
eprintln!(
"mp: run `mp-keygen` to regenerate your keypair at {}",
fatal.key_path.display()
);
drop(disable_raw_mode());
return Err(e);
}
if e.downcast_ref::<MoshpitError>()
.is_some_and(|e| *e == MoshpitError::HostKeyRejected)
{
drop(disable_raw_mode());
return Err(e);
}
if let Some(&err) = e.downcast_ref::<MoshpitError>() {
match err {
MoshpitError::KeyNotEstablished => {
eprintln!("mp: server rejected the key exchange");
eprintln!(
"mp: ensure your public key is listed in \
~/.mp/authorized_keys on the server"
);
drop(disable_raw_mode());
return Err(e);
}
MoshpitError::NoCommonAlgorithm => {
eprintln!("mp: no common algorithm found during key exchange");
eprintln!(
"mp: check --kex-algos, --aead-algos, --mac-algos, \
and --kdf-algos settings on both client and server"
);
drop(disable_raw_mode());
return Err(e);
}
_ => {}
}
}
reconnect_attempt = reconnect_attempt.saturating_add(1);
error!("Failed to connect to {socket_addr}: {e}, retrying in {backoff:?}");
if !had_successful_kex {
*pass_cache
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = PassCache::Uncached;
}
if countdown_with_escape(
&stdout_tx,
backoff.as_secs(),
reconnect_attempt,
max_backoff.as_secs(),
&exit_token,
kb_rx_shared.clone(),
)
.await
{
clear_reconnect_banner(&stdout_tx).await;
let msg = b"\r\n\x1b[0m[moshpit] Disconnected.\r\n";
drop(stdout_tx.send(msg.to_vec()).await);
drop(disable_raw_mode());
time::sleep(Duration::from_millis(100)).await;
std::process::exit(0);
}
backoff = (backoff * 2).min(max_backoff);
}
}
}
}
#[cfg_attr(nightly, allow(clippy::too_many_lines))]
async fn connect_and_kex(
config: &mut Config,
socket_addr: SocketAddr,
server_ip: &str,
server_port: u16,
pass_cache: &Arc<std::sync::Mutex<PassCache>>,
stdin_paused: Arc<AtomicBool>,
) -> Result<(Kex, Arc<UdpSocket>, Duration)> {
let _ = config.set_resume_session_uuid(read_session_uuid(server_ip, server_port));
let socket = TcpStream::connect(socket_addr).await?;
info!("Connected to {}", socket.peer_addr()?);
let cache = pass_cache.clone();
let paused_pass = stdin_paused.clone();
let pass_fn = move || -> Result<Option<String>> {
let guard = cache
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
if guard.is_cached() {
info!(
"passphrase: returning cached value (has_passphrase={})",
guard.passphrase().is_some()
);
return Ok(guard.passphrase());
}
drop(guard);
info!("passphrase: prompting user");
let result =
tokio::task::block_in_place(|| with_cooked_term(&paused_pass, read_passpharase));
match &result {
Ok(Some(_)) => info!("passphrase: prompt returned a passphrase"),
Ok(None) => info!("passphrase: prompt returned None (key may be unencrypted)"),
Err(e) => error!("passphrase: prompt failed: {e}"),
}
if let Ok(ref pass) = result {
*cache
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = match pass {
Some(s) => PassCache::Passphrase(s.clone()),
None => PassCache::NoPassphrase,
};
}
result
};
let (sock_read, sock_write) = socket.into_split();
let paused_tofu = stdin_paused.clone();
let tofu_fn: libmoshpit::TofuFn =
Arc::new(move |host: &str, fingerprint: &str| -> Result<bool> {
tokio::task::block_in_place(|| {
with_cooked_term(&paused_tofu, || {
let prompt = format!(
"The authenticity of host '{host}' can't be established.\n\
Fingerprint is SHA256:{fingerprint}.\n\
Are you sure you want to continue connecting? (yes/no)"
);
let input: String = dialoguer::Input::new()
.with_prompt(prompt)
.interact_text()?;
Ok(input.eq_ignore_ascii_case("yes"))
})
})
});
let paused_mismatch = stdin_paused;
let mismatch_fn: libmoshpit::HostKeyMismatchFn = Arc::new(
move |host: &str, old_fingerprint: &str, new_fingerprint: &str| -> Result<bool> {
tokio::task::block_in_place(|| {
with_cooked_term(&paused_mismatch, || {
eprintln!("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@");
eprintln!("@ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! @");
eprintln!("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@");
eprintln!("Potential DNS spoofing or machine-in-the-middle detected.");
eprintln!("Host: {host}");
eprintln!("Offending key fingerprint: SHA256:{old_fingerprint}");
eprintln!("Presented key fingerprint: SHA256:{new_fingerprint}");
Confirm::new()
.with_prompt(
"Update ~/.mp/known_hosts with the newly presented key for this host?",
)
.default(false)
.wait_for_newline(true)
.interact()
.map_err(Into::into)
})
})
},
);
let kex_start = Instant::now();
let kex_result = time::timeout(
KEX_TIMEOUT,
run_key_exchange(
config.clone(),
sock_read,
sock_write,
pass_fn,
Some(tofu_fn),
Some(mismatch_fn),
),
)
.await;
let (kex, udp_arc, _) = match kex_result {
Err(_elapsed) => {
return Err(anyhow::anyhow!(
"key exchange timed out after {KEX_TIMEOUT:?} — \
server accepted TCP connection but sent no data"
));
}
Ok(inner) => inner,
}
.map_err(|e| {
if let Some(&moshpit_err) = e.downcast_ref::<MoshpitError>() {
match moshpit_err {
MoshpitError::KeyFileMissing
| MoshpitError::KeyCorrupt
| MoshpitError::KeyPairMismatch
| MoshpitError::DecryptionFailed
| MoshpitError::InvalidPublicKeyFormat
| MoshpitError::InvalidKeyHeader => {
let key_path = config
.key_pair_paths()
.ok()
.map(|(p, _)| p)
.unwrap_or_default();
return anyhow::anyhow!(FatalKexError {
inner: moshpit_err,
key_path,
});
}
_ => {}
}
}
e
})?;
if let Some(session_uuid) = kex.session_uuid() {
if let Err(e) = write_session_uuid(server_ip, server_port, session_uuid) {
trace!("Failed to write session file: {e}");
}
if kex.is_resume() {
info!("Session {session_uuid} resumed");
} else {
info!("New session {session_uuid} started");
}
}
let nak_timeout = kex_start
.elapsed()
.clamp(Duration::from_millis(20), Duration::from_millis(500));
info!("nak_timeout set to {:?} from kex elapsed time", nak_timeout);
Ok((kex, udp_arc, nak_timeout))
}
#[cfg_attr(nightly, allow(clippy::too_many_lines))]
#[cfg_attr(nightly, allow(clippy::too_many_arguments))]
#[cfg_attr(coverage_nightly, coverage(off))]
async fn run_udp_session(
kex: Kex,
udp_arc: Arc<UdpSocket>,
nak_timeout: Duration,
kb_rx: Arc<Mutex<Receiver<Vec<u8>>>>,
nat_warmup: bool,
nat_warmup_count: u32,
stdout_tx: Sender<Vec<u8>>,
display_preference: DisplayPreference,
diff_mode: DiffMode,
exit_token: CancellationToken,
) -> Result<()> {
let (reconnect_tx, mut reconnect_rx) = channel::<()>(1);
let token = CancellationToken::new();
let (tx, rx) = channel::<EncryptedFrame>(256);
let (_control_tx, control_rx) = channel::<EncryptedFrame>(16);
let (retransmit_tx, retransmit_rx) = channel::<Vec<u64>>(512);
let silence_timeout = (nak_timeout * 30).max(Duration::from_secs(9));
let mac_tag_len = kex.mac_tag_len();
let mut udp_reader = UdpReader::builder()
.socket(udp_arc.clone())
.id(kex.uuid())
.hmac(kex.build_hmac())
.rnk(kex.build_aead_key()?)
.mac_tag_len(mac_tag_len)
.nak_out_tx(tx.clone())
.retransmit_tx(retransmit_tx)
.silence_timeout(silence_timeout)
.nak_timeout(nak_timeout)
.reconnect_tx(reconnect_tx)
.query_response_tx(tx.clone())
.diff_mode(diff_mode)
.build();
let mut udp_sender = UdpSender::builder()
.socket(udp_arc)
.control_rx(control_rx)
.rx(rx)
.retransmit_rx(retransmit_rx)
.id(kex.uuid())
.hmac(kex.build_hmac())
.rnk(kex.build_aead_key()?)
.diff_mode(diff_mode)
.build();
let sender_token = token.clone();
let _sender = spawn(async move { udp_sender.frame_loop(sender_token).await });
let (cols, rows) = terminal_size().map_or((80, 24), |(w, h)| (w.0, h.0));
tx.send(EncryptedFrame::Resize((kex.uuid_wrapper(), cols, rows)))
.await?;
if nat_warmup {
info!(
"NAT warmup: sending {} keepalive frame(s)",
nat_warmup_count
);
for _ in 0..nat_warmup_count {
let ts = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros(),
)
.unwrap_or(0);
tx.send(EncryptedFrame::Keepalive(ts)).await?;
}
}
let emulator = Arc::new(std::sync::Mutex::new(Emulator::new(rows, cols)));
let prediction = Arc::new(std::sync::Mutex::new(PredictionEngine::new(
display_preference,
)));
let renderer = Arc::new(std::sync::Mutex::new(Renderer::new(rows, cols)));
let reader_token = token.clone();
let emu_reader = emulator.clone();
let pred_reader = prediction.clone();
let rend_reader = renderer.clone();
let stdout_tx_reader = stdout_tx.clone();
let exit_token_reader = exit_token.clone();
let _reader = spawn(async move {
udp_reader
.client_frame_loop(
reader_token,
exit_token_reader,
stdout_tx_reader,
emu_reader,
pred_reader,
rend_reader,
)
.await;
});
spawn_resize_handler(
tx.clone(),
kex.uuid_wrapper(),
token.clone(),
emulator.clone(),
renderer.clone(),
);
let fwd_token = token.clone();
let exit_token_fwd = exit_token.clone();
let session_tx = tx;
let uuid_wrapper = kex.uuid_wrapper();
let emu_fwd = emulator.clone();
let pred_fwd = prediction.clone();
let stdout_tx_fwd = stdout_tx;
let _forwarder = spawn(async move {
let mut rx = kb_rx.lock().await;
let mut escape_state = EscapeState::Normal;
loop {
select! {
() = fwd_token.cancelled() => break,
data = rx.recv() => match data {
Some(data) => {
let mut to_forward: Vec<u8> = Vec::new();
let mut exit_requested = false;
for &byte in &data {
escape_state = match escape_state {
EscapeState::Normal => {
if byte == 0x1E {
EscapeState::PendingDot
} else {
to_forward.push(byte);
EscapeState::Normal
}
}
EscapeState::PendingDot => {
if byte == 0x2E {
exit_requested = true;
break;
} else if byte == 0x1E {
EscapeState::PendingDot
} else {
to_forward.push(0x1E);
to_forward.push(byte);
EscapeState::Normal
}
}
};
}
if !to_forward.is_empty() {
if session_tx
.send(EncryptedFrame::Bytes((uuid_wrapper, to_forward.clone())))
.await
.is_err()
{
break;
}
let (overlays, cursor) = {
let emu = emu_fwd.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
let screen = emu.screen();
let mut pred = pred_fwd.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
for byte in &to_forward {
pred.new_user_byte(*byte, screen);
}
pred.apply(screen)
};
let preview = paint_overlays_to_ansi(&overlays, cursor);
if !preview.is_empty() {
drop(stdout_tx_fwd.send(preview).await);
}
}
if exit_requested {
let msg = b"\r\n\x1b[0m[moshpit] Disconnected.\r\n";
drop(stdout_tx_fwd.send(msg.to_vec()).await);
exit_token_fwd.cancel();
fwd_token.cancel();
break;
}
}
None => break,
},
}
}
});
select! {
_ = reconnect_rx.recv() => {}
() = exit_token.cancelled() => {}
}
token.cancel();
time::sleep(Duration::from_millis(150)).await;
Ok(())
}
#[cfg(unix)]
fn spawn_resize_handler(
resize_tx: Sender<EncryptedFrame>,
resize_uuid: UuidWrapper,
resize_token: CancellationToken,
emulator: Arc<std::sync::Mutex<Emulator>>,
renderer: Arc<std::sync::Mutex<Renderer>>,
) {
let _resize_handle = spawn(async move {
match signal(SignalKind::window_change()) {
Ok(mut sigwinch) => loop {
tokio::select! {
() = resize_token.cancelled() => break,
_ = sigwinch.recv() => {
let (columns, rows) = terminal_size()
.map_or((80, 24), |(width, height)| (width.0, height.0));
emulator.lock().unwrap_or_else(std::sync::PoisonError::into_inner).set_size(rows, columns);
renderer.lock().unwrap_or_else(std::sync::PoisonError::into_inner).set_size(rows, columns);
if let Err(e) =
resize_tx.send(EncryptedFrame::Resize((resize_uuid, columns, rows))).await
{
error!("Failed to send resize frame: {e}");
break;
}
}
}
},
Err(e) => error!("Failed to register SIGWINCH handler: {e}"),
}
});
}
#[cfg(windows)]
fn spawn_resize_handler(
resize_tx: Sender<EncryptedFrame>,
resize_uuid: UuidWrapper,
resize_token: CancellationToken,
emulator: Arc<std::sync::Mutex<Emulator>>,
renderer: Arc<std::sync::Mutex<Renderer>>,
) {
let _resize_handle = thread::spawn(move || {
let mut last_size = terminal_size().map_or((80, 24), |(w, h)| (w.0, h.0));
loop {
if resize_token.is_cancelled() {
break;
}
thread::sleep(Duration::from_millis(250));
let current_size = terminal_size().map_or(last_size, |(w, h)| (w.0, h.0));
if current_size != last_size {
last_size = current_size;
let (columns, rows) = current_size;
emulator
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.set_size(rows, columns);
renderer
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.set_size(rows, columns);
if let Err(e) =
resize_tx.blocking_send(EncryptedFrame::Resize((resize_uuid, columns, rows)))
{
error!("Failed to send resize frame: {e}");
break;
}
}
}
});
}
fn maybe_generate_keypair(config: &Config) -> Result<()> {
let (priv_key_path, pub_key_path) = config.key_pair_paths()?;
if priv_key_path.try_exists()? && pub_key_path.try_exists()? {
return Ok(());
}
println!("No keypair found at the configured location.");
println!(" Private key: {}", priv_key_path.display());
println!(" Public key: {}", pub_key_path.display());
let generate = Confirm::new()
.with_prompt("Generate a new keypair now?")
.default(true)
.wait_for_newline(true)
.interact()?;
if !generate {
return Ok(());
}
if let Some(parent) = priv_key_path.parent() {
create_key_dir(parent)?;
}
let passphrase: String = Password::new()
.with_prompt(format!(
"Enter passphrase for \"{}\"",
priv_key_path.display()
))
.with_confirmation(
"Enter same passphrase again",
"Passphrases do not match. Try again.",
)
.allow_empty_password(false)
.report(false)
.interact()?;
let passphrase_opt = Some(passphrase);
let keypair = KeyPair::generate_key_pair(
passphrase_opt.as_ref(),
KexMode::Client,
KEY_ALGORITHM_X25519,
)?;
let mut priv_key_file = {
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&priv_key_path)?
}
#[cfg(not(unix))]
{
OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&priv_key_path)?
}
};
keypair.write_private_key(&mut priv_key_file)?;
let mut pub_key_file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&pub_key_path)?;
keypair.write_public_key(&mut pub_key_file)?;
println!(
"Your identification has been saved in {}",
priv_key_path.display()
);
println!(
"Your public key has been saved in {}",
pub_key_path.display()
);
println!("The key fingerprint is:");
println!("{}", keypair.fingerprint()?);
println!("The key's randomart image is:");
print!("{}", keypair.randomart());
Ok(())
}
#[cfg(target_family = "unix")]
fn create_key_dir(path: &Path) -> Result<()> {
DirBuilder::new().mode(0o700).recursive(true).create(path)?;
Ok(())
}
#[cfg(not(target_family = "unix"))]
fn create_key_dir(path: &Path) -> Result<()> {
DirBuilder::new().recursive(true).create(path)?;
Ok(())
}
fn read_passpharase() -> Result<Option<String>> {
Password::new()
.with_prompt("Please enter your private key passphrase")
.report(false)
.interact()
.map(Some)
.map_err(Into::into)
}
#[cfg(unix)]
fn tty_id() -> Option<String> {
use std::io::IsTerminal as _;
if !stdin().is_terminal() {
return None;
}
#[cfg(target_os = "linux")]
let link = std::fs::read_link("/proc/self/fd/0").ok()?;
#[cfg(not(target_os = "linux"))]
let link = std::fs::read_link("/dev/fd/0").ok()?;
let raw = link.to_string_lossy();
let sanitized: String = raw
.trim_start_matches('/')
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '-' {
c
} else {
'_'
}
})
.collect();
if sanitized.is_empty() {
None
} else {
Some(sanitized)
}
}
#[cfg(windows)]
#[allow(unsafe_code)]
fn tty_id() -> Option<String> {
use std::io::IsTerminal as _;
unsafe extern "system" {
fn GetConsoleWindow() -> *mut std::ffi::c_void;
}
if !stdin().is_terminal() {
return None;
}
let hwnd = unsafe { GetConsoleWindow() };
if hwnd.is_null() {
None
} else {
Some(format!("{:x}", hwnd.addr()))
}
}
#[cfg(not(any(unix, windows)))]
fn tty_id() -> Option<String> {
None
}
fn client_id_path(home: &Path) -> PathBuf {
home.join(".mp").join("client_id")
}
fn client_id() -> Option<Uuid> {
client_id_in_home(&dirs2::home_dir()?)
}
#[allow(clippy::unnecessary_wraps)]
fn client_id_in_home(home: &Path) -> Option<Uuid> {
let path = client_id_path(home);
if let Ok(mut f) = std::fs::File::open(&path) {
let mut buf = String::new();
drop(f.read_to_string(&mut buf));
if let Ok(uuid) = buf.trim().parse::<Uuid>() {
return Some(uuid);
}
}
let id = Uuid::new_v4();
if let Some(parent) = path.parent() {
drop(std::fs::create_dir_all(parent));
}
if let Ok(mut f) = std::fs::File::create(&path) {
drop(write!(f, "{id}"));
}
Some(id)
}
fn session_file_path(host: &str, port: u16) -> Option<PathBuf> {
let home = dirs2::home_dir()?;
let safe_host: String = host
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '-' || c == '.' {
c
} else {
'_'
}
})
.collect();
let cid = client_id()?;
let name = match tty_id() {
Some(tty) => format!("{cid}_{safe_host}_{port}_{tty}"),
None => format!("{cid}_{safe_host}_{port}"),
};
Some(home.join(".mp").join("sessions").join(name))
}
#[cfg(test)]
fn session_file_path_in_home(home: &Path, host: &str, port: u16) -> Option<PathBuf> {
let safe_host: String = host
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '-' || c == '.' {
c
} else {
'_'
}
})
.collect();
let cid = client_id_in_home(home)?;
let name = match tty_id() {
Some(tty) => format!("{cid}_{safe_host}_{port}_{tty}"),
None => format!("{cid}_{safe_host}_{port}"),
};
Some(home.join(".mp").join("sessions").join(name))
}
fn read_uuid_from_path(path: &Path) -> Option<Uuid> {
let mut file = std::fs::File::open(path).ok()?;
let mut buf = String::new();
let _ = file.read_to_string(&mut buf).ok();
buf.trim().parse::<Uuid>().ok()
}
fn write_uuid_to_path(path: &Path, uuid: Uuid) -> Result<()> {
if let Some(parent) = path.parent() {
#[cfg(unix)]
{
DirBuilder::new()
.mode(0o700)
.recursive(true)
.create(parent)?;
}
#[cfg(not(unix))]
{
DirBuilder::new().recursive(true).create(parent)?;
}
}
let mut file = {
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(path)?
}
#[cfg(not(unix))]
{
std::fs::File::create(path)?
}
};
write!(file, "{uuid}")?;
Ok(())
}
fn read_session_uuid(host: &str, port: u16) -> Option<Uuid> {
read_uuid_from_path(&session_file_path(host, port)?)
}
fn write_session_uuid(host: &str, port: u16, session_uuid: Uuid) -> Result<()> {
let path = session_file_path(host, port).ok_or_else(|| anyhow::anyhow!("no home dir"))?;
write_uuid_to_path(&path, session_uuid)
}
#[cfg(test)]
mod tests {
use super::*;
use clap::Parser;
struct TestHome {
path: PathBuf,
}
impl TestHome {
fn new() -> Self {
let path = std::env::temp_dir().join(Uuid::new_v4().to_string());
std::fs::create_dir_all(&path).expect("failed to create temp dir");
Self { path }
}
fn path(&self) -> &Path {
&self.path
}
}
impl Drop for TestHome {
fn drop(&mut self) {
drop(std::fs::remove_dir_all(&self.path));
}
}
#[test]
fn test_pass_cache() {
let mut cache = PassCache::Uncached;
assert!(!cache.is_cached());
cache = PassCache::NoPassphrase;
assert!(cache.is_cached());
assert_eq!(cache.passphrase(), None);
cache = PassCache::Passphrase("secret".to_string());
assert!(cache.is_cached());
assert_eq!(cache.passphrase(), Some("secret".to_string()));
}
#[test]
#[should_panic(expected = "passphrase() called before caching")]
fn test_pass_cache_panic() {
let cache = PassCache::Uncached;
drop(cache.passphrase());
}
#[tokio::test]
async fn test_banners() -> Result<()> {
let (tx, mut rx) = channel(10);
show_reconnect_banner(&tx).await;
let msg = rx
.recv()
.await
.ok_or_else(|| anyhow::anyhow!("channel closed"))?;
assert!(
String::from_utf8_lossy(&msg).contains("[moshpit] server unreachable, reconnecting...")
);
clear_reconnect_banner(&tx).await;
let msg = rx
.recv()
.await
.ok_or_else(|| anyhow::anyhow!("channel closed"))?;
assert!(String::from_utf8_lossy(&msg).ends_with("\x1b[0m\x1b[K\x1b[u"));
let token = CancellationToken::new();
let _ = countdown_reconnect_banner(&tx, 0, 1, 10, &token).await;
let msg = rx
.recv()
.await
.ok_or_else(|| anyhow::anyhow!("channel closed"))?;
assert!(String::from_utf8_lossy(&msg).contains("attempt #1"));
Ok(())
}
#[tokio::test]
async fn countdown_banner_pre_cancelled_returns_true() {
let (tx, mut _rx) = channel(10);
let token = CancellationToken::new();
token.cancel();
let result = countdown_reconnect_banner(&tx, 0, 1, 10, &token).await;
assert!(result);
}
#[test]
fn read_uuid_from_path_missing_file_returns_none() {
let dir = std::env::temp_dir().join(Uuid::new_v4().to_string());
let path = dir.join("session");
assert!(read_uuid_from_path(&path).is_none());
}
#[test]
fn read_uuid_from_path_garbage_returns_none() {
let dir = std::env::temp_dir().join(Uuid::new_v4().to_string());
std::fs::create_dir_all(&dir).expect("failed to create temp dir");
let path = dir.join("session");
std::fs::write(&path, "not-a-uuid").expect("failed to write test file");
assert!(read_uuid_from_path(&path).is_none());
}
#[test]
fn write_and_read_uuid_roundtrip() -> Result<()> {
let dir = std::env::temp_dir().join(Uuid::new_v4().to_string());
let path = dir.join("sub").join("session");
let uuid = Uuid::new_v4();
write_uuid_to_path(&path, uuid)?;
assert_eq!(read_uuid_from_path(&path), Some(uuid));
Ok(())
}
#[test]
fn write_uuid_creates_parent_directories() -> Result<()> {
let dir = std::env::temp_dir().join(Uuid::new_v4().to_string());
let nested = dir.join("a").join("b").join("c").join("session");
let uuid = Uuid::new_v4();
write_uuid_to_path(&nested, uuid)?;
assert!(nested.exists());
Ok(())
}
#[test]
fn client_id_path_is_under_dot_mp() {
let home = TestHome::new();
let path = client_id_path(home.path());
assert!(path.starts_with(home.path().join(".mp")));
assert_eq!(path.file_name().expect("path has a file name"), "client_id");
}
#[test]
fn test_client_id() {
let home = TestHome::new();
let id1 = client_id_in_home(home.path());
assert!(id1.is_some());
let id2 = client_id_in_home(home.path());
assert_eq!(id1, id2); }
#[test]
fn test_session_uuid_persistence() -> Result<()> {
let home = TestHome::new();
let host = "test.host";
let port = 12345;
let uuid = Uuid::new_v4();
let path = session_file_path_in_home(home.path(), host, port)
.ok_or_else(|| anyhow::anyhow!("no session file path"))?;
if let Some(parent) = path.parent() {
DirBuilder::new().recursive(true).create(parent)?;
}
std::fs::write(&path, uuid.to_string())?;
let read_uuid = {
let mut file = std::fs::File::open(&path)?;
let mut buf = String::new();
let _ = file.read_to_string(&mut buf)?;
buf.trim().parse::<Uuid>()?
};
assert_eq!(uuid, read_uuid);
Ok(())
}
#[test]
fn test_session_file_path() -> Result<()> {
let home = TestHome::new();
let host = "some_host.com";
let port = 2222;
let path = session_file_path_in_home(home.path(), host, port)
.ok_or_else(|| anyhow::anyhow!("no session file path"))?;
assert!(path.to_string_lossy().contains("some_host.com"));
assert!(path.to_string_lossy().contains("2222"));
Ok(())
}
#[test]
fn test_create_key_dir() -> Result<()> {
let dir = std::env::temp_dir().join(Uuid::new_v4().to_string());
let key_dir = dir.join("keys");
create_key_dir(&key_dir)?;
assert!(key_dir.exists());
assert!(key_dir.is_dir());
Ok(())
}
#[test]
fn test_maybe_generate_keypair_existing() -> Result<()> {
let dir = std::env::temp_dir().join(Uuid::new_v4().to_string());
std::fs::create_dir_all(&dir)?;
let priv_path = dir.join("id_ed25519");
let pub_path = dir.join("id_ed25519.pub");
let config_path = dir.join("config.toml");
std::fs::write(&priv_path, "fake private key")?;
std::fs::write(&pub_path, "fake public key")?;
std::fs::write(
&config_path,
"[tracing.stdout]\n\
with_target = false\n\
with_thread_ids = false\n\
with_thread_names = false\n\
with_line_number = false\n\
with_level = false\n\
[tracing.file]\n\
quiet = 0\n\
verbose = 0\n\
[tracing.file.layer]\n\
with_target = false\n\
with_thread_ids = false\n\
with_thread_names = false\n\
with_line_number = false\n\
with_level = false\n",
)?;
let cli = Cli::try_parse_from([
"moshpit",
"-c",
config_path.to_str().expect("path is valid UTF-8"),
"-p",
priv_path.to_str().expect("path is valid UTF-8"),
"-k",
pub_path.to_str().expect("path is valid UTF-8"),
"user@host",
])?;
let config = load::<Cli, Config, Cli>(&cli, &cli)?;
let result = maybe_generate_keypair(&config);
assert!(result.is_ok());
Ok(())
}
#[tokio::test]
async fn test_connect_and_kex_tcp_failure() -> Result<()> {
let mut config = Config::default();
let pass_cache = Arc::new(std::sync::Mutex::new(PassCache::Uncached));
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let addr = format!("127.0.0.1:{port}").parse()?;
let result = connect_and_kex(
&mut config,
addr,
"127.0.0.1",
port,
&pass_cache,
Arc::new(AtomicBool::new(false)),
)
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.to_lowercase()
.contains("refused")
);
Ok(())
}
#[tokio::test]
async fn test_connect_and_kex_kex_failure() -> Result<()> {
let dir = std::env::temp_dir().join(Uuid::new_v4().to_string());
std::fs::create_dir_all(&dir)?;
let config_path = dir.join("config.toml");
let empty_priv_key_path = dir.join("empty_priv_key");
let empty_pub_key_path = dir.join("empty_pub_key");
std::fs::write(&empty_priv_key_path, b"")?;
std::fs::write(&empty_pub_key_path, b"")?;
std::fs::write(
&config_path,
"[tracing.stdout]\n\
with_target = false\n\
with_thread_ids = false\n\
with_thread_names = false\n\
with_line_number = false\n\
with_level = false\n\
[tracing.file]\n\
quiet = 0\n\
verbose = 0\n\
[tracing.file.layer]\n\
with_target = false\n\
with_thread_ids = false\n\
with_thread_names = false\n\
with_line_number = false\n\
with_level = false\n",
)?;
let cli = Cli::try_parse_from([
"moshpit",
"-c",
config_path.to_str().expect("test path is valid UTF-8"),
"-p",
empty_priv_key_path
.to_str()
.expect("test path is valid UTF-8"),
"-k",
empty_pub_key_path
.to_str()
.expect("test path is valid UTF-8"),
"user@host",
])?;
let mut config = load::<Cli, Config, Cli>(&cli, &cli)?;
let pass_cache = Arc::new(std::sync::Mutex::new(PassCache::Uncached));
let listener = match tokio::net::TcpListener::bind("127.0.0.1:0").await {
Ok(listener) => listener,
Err(e) if e.kind() == std::io::ErrorKind::PermissionDenied => return Ok(()),
Err(e) => return Err(e.into()),
};
let port = listener.local_addr()?.port();
drop(spawn(async move {
use tokio::io::AsyncWriteExt;
if let Ok((mut socket, _)) = listener.accept().await {
drop(socket.write_all(b"SSH-2.0-Moshpit\r\n").await);
}
}));
let addr = format!("127.0.0.1:{port}").parse()?;
let result = connect_and_kex(
&mut config,
addr,
"127.0.0.1",
port,
&pass_cache,
Arc::new(AtomicBool::new(false)),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.downcast_ref::<FatalKexError>().is_some(),
"empty key files should produce FatalKexError, got: {err}"
);
Ok(())
}
#[test]
fn fatal_kex_error_display_includes_error_and_path() {
use libmoshpit::MoshpitError;
let key_path = PathBuf::from("/home/user/.mp/id_ed25519");
let fatal = FatalKexError {
inner: MoshpitError::KeyFileMissing,
key_path: key_path.clone(),
};
let display = format!("{fatal}");
assert!(
display.contains("Key file not found"),
"display should contain error message, got: {display}"
);
assert!(
display.contains("/home/user/.mp/id_ed25519"),
"display should contain key path, got: {display}"
);
}
#[tokio::test]
async fn connect_and_kex_missing_key_file_wrapped_as_fatal_error() -> Result<()> {
use clap::Parser as _;
let home = TestHome::new();
let config_path = home.path().join("config.toml");
let priv_path = home.path().join("nonexistent_id_ed25519");
let pub_path = home.path().join("nonexistent_id_ed25519.pub");
std::fs::write(
&config_path,
"[tracing.stdout]\n\
with_target = false\n\
with_thread_ids = false\n\
with_thread_names = false\n\
with_line_number = false\n\
with_level = false\n\
[tracing.file]\n\
quiet = 0\n\
verbose = 0\n\
[tracing.file.layer]\n\
with_target = false\n\
with_thread_ids = false\n\
with_thread_names = false\n\
with_line_number = false\n\
with_level = false\n",
)?;
let cli = Cli::try_parse_from([
"moshpit",
"-c",
config_path.to_str().expect("path is valid UTF-8"),
"-p",
priv_path.to_str().expect("path is valid UTF-8"),
"-k",
pub_path.to_str().expect("path is valid UTF-8"),
"user@host",
])?;
let mut config = load::<Cli, Config, Cli>(&cli, &cli)?;
let pass_cache = Arc::new(std::sync::Mutex::new(PassCache::Uncached));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let port = listener.local_addr()?.port();
drop(spawn(async move {
if let Ok((_, _)) = listener.accept().await {}
}));
let addr = format!("127.0.0.1:{port}").parse()?;
let result = connect_and_kex(
&mut config,
addr,
"127.0.0.1",
port,
&pass_cache,
Arc::new(AtomicBool::new(false)),
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.downcast_ref::<FatalKexError>().is_some(),
"missing key file should produce FatalKexError, got: {err}"
);
Ok(())
}
#[test]
fn ctrl_digit_aliases_produce_correct_control_codes() {
use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, KeyEventState, KeyModifiers};
fn ctrl_char(c: char) -> KeyEvent {
KeyEvent {
code: KeyCode::Char(c),
modifiers: KeyModifiers::CONTROL,
kind: KeyEventKind::Press,
state: KeyEventState::empty(),
}
}
assert_eq!(key_event_to_bytes(ctrl_char('6')), b"\x1e");
assert_eq!(key_event_to_bytes(ctrl_char('4')), b"\x1c");
assert_eq!(key_event_to_bytes(ctrl_char('5')), b"\x1d");
assert_eq!(key_event_to_bytes(ctrl_char('7')), b"\x1f");
}
mod escape_listener {
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::mpsc::channel;
use tokio_util::sync::CancellationToken;
use super::super::run_escape_listener;
#[tokio::test]
async fn done_token_cancels_listener_without_triggering_exit() {
let (tx, rx) = channel::<Vec<u8>>(8);
let kb_rx = Arc::new(Mutex::new(rx));
let exit_token = CancellationToken::new();
let done_token = CancellationToken::new();
done_token.cancel();
run_escape_listener(kb_rx, exit_token.clone(), done_token).await;
assert!(!exit_token.is_cancelled());
drop(tx);
}
#[tokio::test]
async fn sender_drop_stops_listener_without_triggering_exit() {
let (tx, rx) = channel::<Vec<u8>>(8);
let kb_rx = Arc::new(Mutex::new(rx));
let exit_token = CancellationToken::new();
let done_token = CancellationToken::new();
drop(tx);
run_escape_listener(kb_rx, exit_token.clone(), done_token).await;
assert!(!exit_token.is_cancelled());
}
#[tokio::test]
async fn normal_bytes_do_not_trigger_exit() -> anyhow::Result<()> {
let (tx, rx) = channel::<Vec<u8>>(8);
let kb_rx = Arc::new(Mutex::new(rx));
let exit_token = CancellationToken::new();
let done_token = CancellationToken::new();
tx.send(b"hello".to_vec()).await?;
drop(tx);
run_escape_listener(kb_rx, exit_token.clone(), done_token).await;
assert!(!exit_token.is_cancelled());
Ok(())
}
#[tokio::test]
async fn escape_prefix_then_non_dot_does_not_trigger_exit() -> anyhow::Result<()> {
let (tx, rx) = channel::<Vec<u8>>(8);
let kb_rx = Arc::new(Mutex::new(rx));
let exit_token = CancellationToken::new();
let done_token = CancellationToken::new();
tx.send(vec![0x1E, b'x']).await?;
drop(tx);
run_escape_listener(kb_rx, exit_token.clone(), done_token).await;
assert!(!exit_token.is_cancelled());
Ok(())
}
#[tokio::test]
async fn repeated_escape_prefix_stays_pending_without_triggering_exit() -> anyhow::Result<()>
{
let (tx, rx) = channel::<Vec<u8>>(8);
let kb_rx = Arc::new(Mutex::new(rx));
let exit_token = CancellationToken::new();
let done_token = CancellationToken::new();
tx.send(vec![0x1E, 0x1E, 0x1E]).await?;
drop(tx);
run_escape_listener(kb_rx, exit_token.clone(), done_token).await;
assert!(!exit_token.is_cancelled());
Ok(())
}
#[tokio::test]
async fn full_sequence_in_one_chunk_triggers_exit() -> anyhow::Result<()> {
let (tx, rx) = channel::<Vec<u8>>(8);
let kb_rx = Arc::new(Mutex::new(rx));
let exit_token = CancellationToken::new();
let done_token = CancellationToken::new();
tx.send(vec![0x1E, 0x2E]).await?;
run_escape_listener(kb_rx, exit_token.clone(), done_token).await;
assert!(exit_token.is_cancelled());
Ok(())
}
#[tokio::test]
async fn sequence_split_across_sends_triggers_exit() -> anyhow::Result<()> {
let (tx, rx) = channel::<Vec<u8>>(8);
let kb_rx = Arc::new(Mutex::new(rx));
let exit_token = CancellationToken::new();
let done_token = CancellationToken::new();
tx.send(vec![0x1E]).await?;
tx.send(vec![0x2E]).await?;
run_escape_listener(kb_rx, exit_token.clone(), done_token).await;
assert!(exit_token.is_cancelled());
Ok(())
}
}
mod key_encoding {
use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, KeyEventState, KeyModifiers};
use super::super::key_event_to_bytes;
fn press(code: KeyCode) -> KeyEvent {
KeyEvent {
code,
modifiers: KeyModifiers::NONE,
kind: KeyEventKind::Press,
state: KeyEventState::empty(),
}
}
fn press_mod(code: KeyCode, mods: KeyModifiers) -> KeyEvent {
KeyEvent {
code,
modifiers: mods,
kind: KeyEventKind::Press,
state: KeyEventState::empty(),
}
}
fn release(code: KeyCode) -> KeyEvent {
KeyEvent {
code,
modifiers: KeyModifiers::NONE,
kind: KeyEventKind::Release,
state: KeyEventState::empty(),
}
}
#[test]
fn release_events_produce_no_bytes() {
assert!(key_event_to_bytes(release(KeyCode::Char('a'))).is_empty());
assert!(key_event_to_bytes(release(KeyCode::Up)).is_empty());
}
#[test]
fn arrow_keys_produce_csi_sequences() {
assert_eq!(key_event_to_bytes(press(KeyCode::Up)), b"\x1b[A");
assert_eq!(key_event_to_bytes(press(KeyCode::Down)), b"\x1b[B");
assert_eq!(key_event_to_bytes(press(KeyCode::Right)), b"\x1b[C");
assert_eq!(key_event_to_bytes(press(KeyCode::Left)), b"\x1b[D");
}
#[test]
fn arrow_keys_with_shift_use_modifier_param() {
let shift = KeyModifiers::SHIFT;
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Up, shift)),
b"\x1b[1;2A"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Down, shift)),
b"\x1b[1;2B"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Right, shift)),
b"\x1b[1;2C"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Left, shift)),
b"\x1b[1;2D"
);
}
#[test]
fn arrow_keys_with_ctrl_use_modifier_param() {
let ctrl = KeyModifiers::CONTROL;
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Up, ctrl)),
b"\x1b[1;5A"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Left, ctrl)),
b"\x1b[1;5D"
);
}
#[test]
fn navigation_keys() {
assert_eq!(key_event_to_bytes(press(KeyCode::Home)), b"\x1b[H");
assert_eq!(key_event_to_bytes(press(KeyCode::End)), b"\x1b[F");
assert_eq!(key_event_to_bytes(press(KeyCode::Insert)), b"\x1b[2~");
assert_eq!(key_event_to_bytes(press(KeyCode::Delete)), b"\x1b[3~");
assert_eq!(key_event_to_bytes(press(KeyCode::PageUp)), b"\x1b[5~");
assert_eq!(key_event_to_bytes(press(KeyCode::PageDown)), b"\x1b[6~");
}
#[test]
fn navigation_keys_with_modifier() {
let ctrl = KeyModifiers::CONTROL;
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Home, ctrl)),
b"\x1b[1;5H"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::End, ctrl)),
b"\x1b[1;5F"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Insert, ctrl)),
b"\x1b[2;5~"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Delete, ctrl)),
b"\x1b[3;5~"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::PageUp, ctrl)),
b"\x1b[5;5~"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::PageDown, ctrl)),
b"\x1b[6;5~"
);
}
#[test]
fn function_keys() {
assert_eq!(key_event_to_bytes(press(KeyCode::F(1))), b"\x1bOP");
assert_eq!(key_event_to_bytes(press(KeyCode::F(2))), b"\x1bOQ");
assert_eq!(key_event_to_bytes(press(KeyCode::F(3))), b"\x1bOR");
assert_eq!(key_event_to_bytes(press(KeyCode::F(4))), b"\x1bOS");
assert_eq!(key_event_to_bytes(press(KeyCode::F(5))), b"\x1b[15~");
assert_eq!(key_event_to_bytes(press(KeyCode::F(6))), b"\x1b[17~");
assert_eq!(key_event_to_bytes(press(KeyCode::F(7))), b"\x1b[18~");
assert_eq!(key_event_to_bytes(press(KeyCode::F(8))), b"\x1b[19~");
assert_eq!(key_event_to_bytes(press(KeyCode::F(9))), b"\x1b[20~");
assert_eq!(key_event_to_bytes(press(KeyCode::F(10))), b"\x1b[21~");
assert_eq!(key_event_to_bytes(press(KeyCode::F(11))), b"\x1b[23~");
assert_eq!(key_event_to_bytes(press(KeyCode::F(12))), b"\x1b[24~");
}
#[test]
fn function_keys_out_of_range_produce_no_bytes() {
assert!(key_event_to_bytes(press(KeyCode::F(0))).is_empty());
assert!(key_event_to_bytes(press(KeyCode::F(13))).is_empty());
}
#[test]
fn function_keys_with_modifier() {
let shift = KeyModifiers::SHIFT;
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::F(1), shift)),
b"\x1b[1;2P"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::F(5), shift)),
b"\x1b[15;2~"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::F(12), shift)),
b"\x1b[24;2~"
);
}
#[test]
fn simple_keys() {
assert_eq!(key_event_to_bytes(press(KeyCode::Backspace)), b"\x7f");
assert_eq!(key_event_to_bytes(press(KeyCode::Enter)), b"\r");
assert_eq!(key_event_to_bytes(press(KeyCode::Tab)), b"\t");
assert_eq!(key_event_to_bytes(press(KeyCode::BackTab)), b"\x1b[Z");
assert_eq!(key_event_to_bytes(press(KeyCode::Esc)), b"\x1b");
assert_eq!(key_event_to_bytes(press(KeyCode::Null)), b"\x00");
}
#[test]
fn printable_chars() {
assert_eq!(key_event_to_bytes(press(KeyCode::Char('a'))), b"a");
assert_eq!(key_event_to_bytes(press(KeyCode::Char('Z'))), b"Z");
assert_eq!(key_event_to_bytes(press(KeyCode::Char('!'))), b"!");
}
#[test]
fn non_ascii_char_encodes_utf8() {
assert_eq!(
key_event_to_bytes(press(KeyCode::Char('\u{00e9}'))), "\u{00e9}".as_bytes()
);
}
#[test]
fn ctrl_chars_produce_control_codes() {
let ctrl = KeyModifiers::CONTROL;
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Char('a'), ctrl)),
b"\x01"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Char('c'), ctrl)),
b"\x03"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Char('z'), ctrl)),
b"\x1a"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Char('@'), ctrl)),
b"\x00"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Char('['), ctrl)),
b"\x1b"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Char('^'), ctrl)),
b"\x1e"
);
}
#[test]
fn ctrl_non_ascii_encodes_utf8_fallback() {
let ctrl = KeyModifiers::CONTROL;
let result = key_event_to_bytes(press_mod(KeyCode::Char('\u{00e9}'), ctrl));
assert_eq!(result, "\u{00e9}".as_bytes());
}
#[test]
fn alt_chars_prefix_with_escape() {
let alt = KeyModifiers::ALT;
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Char('a'), alt)),
b"\x1ba"
);
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Char('z'), alt)),
b"\x1bz"
);
}
#[test]
fn ctrl_alt_chars_prefix_with_escape_and_control_code() {
let ctrl_alt = KeyModifiers::CONTROL | KeyModifiers::ALT;
assert_eq!(
key_event_to_bytes(press_mod(KeyCode::Char('a'), ctrl_alt)),
b"\x1b\x01"
);
}
#[test]
fn ctrl_alt_non_ascii_utf8_fallback() {
let ctrl_alt = KeyModifiers::CONTROL | KeyModifiers::ALT;
let result = key_event_to_bytes(press_mod(KeyCode::Char('\u{00e9}'), ctrl_alt));
let mut expected = b"\x1b".to_vec();
expected.extend_from_slice("\u{00e9}".as_bytes());
assert_eq!(result, expected);
}
}
}