use std::{
collections::{BTreeSet, VecDeque},
ffi::OsString,
io::Read as _,
net::SocketAddr,
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
},
thread::{self, sleep},
time::{Duration, SystemTime, UNIX_EPOCH},
};
#[cfg(unix)]
use std::{
ffi::{CStr, CString},
os::unix::{fs::OpenOptionsExt as _, process::CommandExt},
process::Stdio,
};
use anyhow::{Context as _, Result};
use bytes::{Buf as _, BytesMut};
use clap::Parser as _;
use libmoshpit::{
DiffMode, EncryptedFrame, KexMode, MAX_UDP_PAYLOAD, MoshpitError, SessionRegistry,
TerminalMessage, UdpReader, UdpSender, UuidWrapper, init_tracing, is_exit_title, load,
new_session_registry, run_key_exchange,
};
#[cfg(windows)]
use portable_pty::CommandBuilder;
use portable_pty::{PtySize, native_pty_system};
use tokio::{
net::{TcpListener, TcpStream},
select, spawn,
sync::{
Mutex,
mpsc::{Receiver, Sender, channel},
oneshot,
},
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, trace};
use uuid::Uuid;
use zstd::encode_all;
use crate::{
cli::Cli,
config::Config,
session::{
FullSessionRegistry, SCROLLBACK_CAPACITY, SessionOutputHandle, SessionRecord,
new_full_registry,
},
};
const DEFAULT_PACING_DELAY_US: u64 = 1000;
const PROACTIVE_REPAINT_NAK_THRESHOLD: u64 = 10;
const MTU_TIERS: &[usize] = &[1_200, 1_300, 1_348];
const MTU_POLL_INTERVAL: Duration = Duration::from_millis(200);
const MTU_PROBE_QUIET_TICKS: u32 = 300;
const MTU_PROBE_SUCCESS_TICKS: u32 = 150;
const MTU_PROBE_FAIL_THRESHOLD: u64 = 3;
const SCREEN_SYNC_IDLE_INTERVAL: Duration = Duration::from_millis(50);
const SCREEN_SYNC_BURST_INTERVAL: Duration = Duration::from_millis(10);
const SCREEN_SYNC_BURST_DIRTY_THRESHOLD: u64 = 5;
const DATAGRAM_REPAINT_INTERVAL: Duration = Duration::from_millis(150);
const STATESYNC_INTERVAL: Duration = Duration::from_millis(50);
const STATESYNC_HISTORY_LEN: usize = 32;
const MAX_STATESYNC_DIFF_BYTES: usize = 900;
const STATE_CHUNK_SIZE: usize = 800;
const CLIENT_SILENCE_TIMEOUT_US: u64 = 30_000_000;
fn now_micros() -> u64 {
u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros(),
)
.unwrap_or(0)
}
#[allow(unsafe_code)]
#[cfg_attr(coverage_nightly, coverage(off))]
pub(crate) async fn run<I, T>(args: Option<I>) -> Result<()>
where
I: IntoIterator<Item = T>,
T: Into<OsString> + Clone,
{
let cli = if let Some(args) = args {
Cli::try_parse_from(args)?
} else {
Cli::try_parse()?
};
#[cfg(unix)]
if unsafe { libc::getuid() } == 0 {
info!("Running as root (multi-user mode enabled)");
}
let mut config =
load::<Cli, Config, Cli>(&cli, &cli).with_context(|| MoshpitError::ConfigLoad)?;
let mut file_tracing = config.tracing().file().clone();
let _ = file_tracing.set_verbose(cli.verbose());
let _ = file_tracing.set_quiet(cli.quiet());
init_tracing(&config, &file_tracing, &cli, None).with_context(|| MoshpitError::TracingInit)?;
info!("Configuration loaded");
info!("Tracing initialized");
let socket_addr = SocketAddr::new(
config
.mps()
.ip()
.parse()
.with_context(|| MoshpitError::InvalidIpAddress)?,
config.mps().port(),
);
let _ = config.set_mode(KexMode::Server(socket_addr));
let listener = TcpListener::bind(socket_addr).await?;
let mut port_pool = BTreeSet::new();
for i in 50000..60000 {
let _ = port_pool.insert(i);
}
let port_pool_arc = Arc::new(Mutex::new(port_pool));
let _ = config.set_port_pool(port_pool_arc);
let session_registry = new_session_registry();
let _ = config.set_session_registry(session_registry);
let full_registry = new_full_registry();
let server_token = CancellationToken::new();
loop {
let config_c = config.clone();
let st = server_token.clone();
let fr_c = full_registry.clone();
select! {
_ = tokio::signal::ctrl_c() => {
info!("Received Ctrl-C, shutting down server");
server_token.cancel();
tokio::time::sleep(Duration::from_millis(300)).await;
break;
}
accept_res = listener.accept() => {
match accept_res {
Ok((socket, _addr)) => {
let tcp_local_addr = match socket.local_addr() {
Ok(a) => a,
Err(e) => { error!("local_addr: {e}"); continue; }
};
let mut config_conn = config_c;
let _ = config_conn.set_mode(KexMode::Server(tcp_local_addr));
let _conn = spawn(async move {
if let Err(e) = handle_connection(config_conn, socket, st, fr_c).await {
error!("{e}");
}
});
}
Err(e) => error!("{e}"),
}
}
}
}
Ok(())
}
async fn resolve_session(
kex: &libmoshpit::Kex,
skex: &libmoshpit::ServerKex,
conn_token: &CancellationToken,
udp_port: u16,
data_tx: Sender<EncryptedFrame>,
control_tx: Sender<EncryptedFrame>,
full_registry: &FullSessionRegistry,
) -> Result<(
Sender<TerminalMessage>,
Option<Receiver<TerminalMessage>>,
Arc<Mutex<SessionOutputHandle>>,
Arc<Mutex<VecDeque<u8>>>,
Arc<Mutex<vt100::Parser>>,
Arc<AtomicU64>,
Arc<AtomicBool>,
Arc<AtomicUsize>,
)> {
let session_uuid = skex.session_uuid();
if skex.is_resume() {
let reg = full_registry.lock().await;
if let Some(record) = reg.get(&session_uuid) {
let term_tx = record.term_tx.clone();
let output_handle = record.output_handle.clone();
let scrollback = record.scrollback.clone();
let server_emulator = record.server_emulator.clone();
let dirty_counter = record.dirty_counter.clone();
let diff_in_flight = record.diff_in_flight.clone();
let effective_mtu = record.effective_mtu.clone();
drop(reg);
diff_in_flight.store(false, Ordering::Relaxed);
{
let mut h = output_handle.lock().await;
if let Some(old_token) = h.conn_token.take() {
old_token.cancel();
}
h.kex_uuid = kex.uuid();
h.data_tx = Some(data_tx.clone());
h.control_tx = Some(control_tx.clone());
h.conn_token = Some(conn_token.clone());
h.udp_port = Some(udp_port);
}
let screen_state = {
let emu = server_emulator.lock().await;
emu.screen().contents_formatted()
};
let screen_state_bytes = screen_state.len();
let compressed =
encode_all(screen_state.as_slice(), 3).unwrap_or_else(|_| screen_state.clone());
data_tx
.send(EncryptedFrame::ScreenStateCompressed(compressed))
.await?;
info!(
user = skex.user(),
session = %session_uuid,
screen_state_bytes,
"session resumed"
);
Ok((
term_tx,
None::<Receiver<TerminalMessage>>,
output_handle,
scrollback,
server_emulator,
dirty_counter,
diff_in_flight,
effective_mtu,
))
} else {
drop(reg);
info!(
user = skex.user(),
session = %session_uuid,
"previous session expired, starting new session"
);
new_session(
kex,
conn_token,
udp_port,
session_uuid,
data_tx,
control_tx,
full_registry,
)
.await
}
} else {
let result = new_session(
kex,
conn_token,
udp_port,
session_uuid,
data_tx,
control_tx,
full_registry,
)
.await?;
info!(
user = skex.user(),
session = %session_uuid,
"new session started"
);
Ok(result)
}
}
#[cfg_attr(nightly, allow(clippy::too_many_lines))]
#[cfg_attr(coverage_nightly, coverage(off))]
async fn handle_connection(
config: Config,
socket: TcpStream,
server_token: CancellationToken,
full_registry: FullSessionRegistry,
) -> Result<()> {
let (sock_read, sock_write) = socket.into_split();
let port_pool = config.port_pool();
let session_registry = config.session_registry();
let warmup_delay = config.warmup_delay_ms().map(Duration::from_millis);
let pacing_delay =
Duration::from_micros(config.pacing_delay_us().unwrap_or(DEFAULT_PACING_DELAY_US));
let term_type = config.term_type().clone();
let (kex, udp_arc, skex_opt) =
run_key_exchange(config, sock_read, sock_write, || Ok(None), None, None).await?;
info!("Key exchange completed with moshpit");
let skex = skex_opt.ok_or_else(|| anyhow::anyhow!("missing server kex info"))?;
let session_uuid = skex.session_uuid();
let diff_mode = skex.diff_mode();
let udp_port = udp_arc.local_addr()?.port();
let (data_tx, data_rx) = channel::<EncryptedFrame>(256);
let (control_tx, control_rx) = channel::<EncryptedFrame>(16);
let (retransmit_tx, retransmit_rx) = channel::<Vec<u64>>(512);
let udp_recv = udp_arc.clone();
let udp_send = udp_arc.clone();
let conn_token = CancellationToken::new();
let (peer_discovered_tx, peer_discovered_rx) = oneshot::channel::<SocketAddr>();
let (peer_addr_tx, peer_addr_rx) = channel::<SocketAddr>(4);
let (
term_tx,
maybe_term_rx,
output_handle,
scrollback,
server_emulator,
dirty_counter,
diff_in_flight,
effective_mtu,
) = resolve_session(
&kex,
&skex,
&conn_token,
udp_port,
data_tx.clone(),
control_tx.clone(),
&full_registry,
)
.await?;
let (repaint_tx, mut repaint_rx) = channel::<()>(1);
let (client_ack_tx, mut client_ack_rx) = channel::<u64>(16);
let nak_received_count = Arc::new(AtomicU64::new(0));
let nak_received_count_for_mtu = nak_received_count.clone();
let last_rx_us = Arc::new(AtomicU64::new(now_micros()));
let mac_tag_len = kex.mac_tag_len();
let mut udp_reader = UdpReader::builder()
.socket(udp_recv)
.id(kex.uuid())
.hmac(kex.build_hmac())
.rnk(kex.build_aead_key()?)
.mac_tag_len(mac_tag_len)
.nak_out_tx(data_tx.clone())
.retransmit_tx(retransmit_tx)
.peer_discovered_tx(peer_discovered_tx)
.peer_addr_tx(peer_addr_tx)
.repaint_tx(repaint_tx)
.nak_received_count(nak_received_count.clone())
.diff_mode(diff_mode)
.client_ack_tx(client_ack_tx)
.last_rx_us(last_rx_us.clone())
.build();
let mut udp_sender = UdpSender::builder()
.socket(udp_send)
.control_rx(control_rx)
.rx(data_rx)
.retransmit_rx(retransmit_rx)
.id(kex.uuid())
.hmac(kex.build_hmac())
.rnk(kex.build_aead_key()?)
.peer_discovered_rx(peer_discovered_rx)
.peer_addr_rx(peer_addr_rx)
.maybe_warmup_delay(warmup_delay)
.diff_mode(diff_mode)
.build();
let reader_token = conn_token.clone();
let term_tx_c = term_tx.clone();
let _udp_reader_handle = spawn(async move {
if let Err(e) = udp_reader.server_frame_loop(reader_token, term_tx_c).await {
error!("{e}");
}
});
let sender_token = conn_token.clone();
let _udp_handle = spawn(async move { udp_sender.frame_loop(sender_token).await });
spawn_connection_watchdogs(control_tx.clone(), conn_token.clone(), server_token);
spawn_silence_watchdog(conn_token.clone(), last_rx_us);
if diff_mode == DiffMode::StateSync {
let ss_emu = server_emulator.clone();
let ss_tx = data_tx.clone();
let ss_token = conn_token.clone();
let _state_sync = spawn(async move {
let mut ticker = tokio::time::interval(STATESYNC_INTERVAL);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut sent_states: VecDeque<(u64, Vec<u8>)> = VecDeque::new();
let mut ack_diff_id: u64 = 0;
let mut ack_state: Vec<u8> = Vec::new();
let mut diff_counter: u64 = 0;
let mut last_current: Vec<u8> = Vec::new();
let mut ack_dirty = true;
loop {
select! {
() = ss_token.cancelled() => break,
msg = repaint_rx.recv() => {
if msg.is_none() {
break;
}
while repaint_rx.try_recv().is_ok() {}
let (contents, is_alt) = {
let emu = ss_emu.lock().await;
let screen = emu.screen();
(screen.contents_formatted(), screen.alternate_screen())
};
let compressed = encode_all(contents.as_slice(), 3)
.unwrap_or_else(|_| contents.clone());
if !send_state_chunked(&ss_tx, compressed).await {
break;
}
let mut ack = contents;
if is_alt {
let mut prefixed = b"\x1b[?1049h".to_vec();
prefixed.extend_from_slice(&ack);
ack = prefixed;
}
ack_state = ack;
ack_diff_id = 0;
sent_states.clear();
ack_dirty = true;
}
diff_id = client_ack_rx.recv() => {
let Some(diff_id) = diff_id else { break; };
if let Some(pos) = sent_states.iter().position(|(id, _)| *id == diff_id) {
let snapshot = sent_states[pos].1.clone();
ack_state = snapshot;
ack_diff_id = diff_id;
drop(sent_states.drain(..=pos));
ack_dirty = true;
}
}
_ = ticker.tick() => {
let (current, rows, cols, is_alt) = {
let emu = ss_emu.lock().await;
let screen = emu.screen();
let formatted = screen.contents_formatted();
let (r, c) = screen.size();
let alt = screen.alternate_screen();
(formatted, r, c, alt)
};
if current == last_current && !ack_dirty && ack_diff_id == diff_counter {
continue;
}
let current_for_cache = current.clone();
let mut ack_parser = vt100::Parser::new(rows, cols, 0);
if !ack_state.is_empty() {
ack_parser.process(&ack_state);
}
let ack_is_alt = ack_parser.screen().alternate_screen();
let mut cur_parser = vt100::Parser::new(rows, cols, 0);
cur_parser.process(¤t);
let mut diff = Vec::new();
if is_alt && !ack_is_alt {
diff.extend_from_slice(b"\x1b[?1049h");
} else if !is_alt && ack_is_alt {
diff.extend_from_slice(b"\x1b[?1049l");
}
let content_diff = cur_parser.screen().contents_diff(ack_parser.screen());
if content_diff.is_empty() && diff.is_empty() {
last_current = current_for_cache;
ack_dirty = false;
continue;
}
diff.extend_from_slice(&content_diff);
let compressed = encode_all(diff.as_slice(), 1)
.unwrap_or_else(|_| diff.clone());
if compressed.len() > MAX_STATESYNC_DIFF_BYTES {
let full_compressed = encode_all(current.as_slice(), 3)
.unwrap_or_else(|_| current.clone());
if !send_state_chunked(&ss_tx, full_compressed).await {
break;
}
let mut ack = current;
if is_alt {
let mut prefixed = b"\x1b[?1049h".to_vec();
prefixed.extend_from_slice(&ack);
ack = prefixed;
}
ack_state = ack;
ack_diff_id = 0;
sent_states.clear();
} else {
diff_counter += 1;
if ss_tx
.send(EncryptedFrame::StateSyncDiff((ack_diff_id, diff_counter, compressed)))
.await
.is_err()
{
break;
}
let mut snapshot = current;
if is_alt {
let mut prefixed = b"\x1b[?1049h".to_vec();
prefixed.extend_from_slice(&snapshot);
snapshot = prefixed;
}
sent_states.push_back((diff_counter, snapshot));
if sent_states.len() > STATESYNC_HISTORY_LEN {
drop(sent_states.pop_front());
}
}
last_current = current_for_cache;
ack_dirty = false;
}
}
}
});
} else {
let sync_emu = server_emulator.clone();
let sync_tx = data_tx.clone();
let sync_token = conn_token.clone();
let sync_dirty = dirty_counter.clone();
let sync_diff = diff_in_flight.clone();
let _screen_sync = spawn(async move {
let mut last_dirty: u64 = 0;
let mut interval = SCREEN_SYNC_IDLE_INTERVAL;
loop {
select! {
() = sync_token.cancelled() => break,
() = tokio::time::sleep(interval) => {
let current = sync_dirty.load(Ordering::Relaxed);
let delta = current.wrapping_sub(last_dirty);
interval = if delta >= SCREEN_SYNC_BURST_DIRTY_THRESHOLD {
SCREEN_SYNC_BURST_INTERVAL
} else {
SCREEN_SYNC_IDLE_INTERVAL
};
if delta == 0 {
continue;
}
if sync_diff.swap(false, Ordering::Relaxed) {
last_dirty = current;
continue;
}
last_dirty = current;
let contents = {
let emu = sync_emu.lock().await;
emu.screen().contents_formatted()
};
let compressed = encode_all(contents.as_slice(), 3)
.unwrap_or_else(|_| contents.clone());
if sync_tx.send(EncryptedFrame::ScreenStateCompressed(compressed)).await.is_err() {
break;
}
}
}
}
});
let repaint_emu = server_emulator.clone();
let repaint_tx_out = data_tx.clone();
let repaint_token = conn_token.clone();
let _repaint_on_request = spawn(async move {
loop {
select! {
() = repaint_token.cancelled() => break,
msg = repaint_rx.recv() => {
if msg.is_none() {
break;
}
while repaint_rx.try_recv().is_ok() {}
let contents = {
let emu = repaint_emu.lock().await;
emu.screen().contents_formatted()
};
let compressed = encode_all(contents.as_slice(), 3)
.unwrap_or_else(|_| contents.clone());
if repaint_tx_out.send(EncryptedFrame::ScreenStateCompressed(compressed)).await.is_err() {
break;
}
}
}
}
});
if diff_mode == DiffMode::Datagram {
let datagram_emu = server_emulator.clone();
let datagram_tx = data_tx.clone();
let datagram_token = conn_token.clone();
let _datagram_repaint = spawn(async move {
loop {
select! {
() = datagram_token.cancelled() => break,
() = tokio::time::sleep(DATAGRAM_REPAINT_INTERVAL) => {
let contents = {
let emu = datagram_emu.lock().await;
emu.screen().contents_formatted()
};
let compressed = encode_all(contents.as_slice(), 3)
.unwrap_or_else(|_| contents.clone());
if datagram_tx.send(EncryptedFrame::ScreenStateCompressed(compressed)).await.is_err() {
break;
}
}
}
}
});
}
}
spawn_proactive_repaint_watchdog(
data_tx.clone(),
conn_token.clone(),
nak_received_count,
server_emulator.clone(),
);
spawn_mtu_probe_task(
conn_token.clone(),
nak_received_count_for_mtu,
effective_mtu.clone(),
);
if let Some(term_rx) = maybe_term_rx {
spawn_pty(
session_uuid,
skex.user().to_owned(),
skex.shell().to_owned(),
term_rx,
term_tx.clone(),
output_handle,
scrollback,
server_emulator,
dirty_counter,
diff_in_flight,
pacing_delay,
term_type,
port_pool,
session_registry,
full_registry,
effective_mtu,
diff_mode,
);
}
Ok(())
}
fn spawn_connection_watchdogs(
control_tx: Sender<EncryptedFrame>,
conn_token: CancellationToken,
server_token: CancellationToken,
) {
let watcher_tx = control_tx.clone();
let watcher_conn_token = conn_token.clone();
let _shutdown_watcher = spawn(async move {
server_token.cancelled().await;
drop(watcher_tx.send(EncryptedFrame::Shutdown).await);
tokio::time::sleep(Duration::from_millis(100)).await;
watcher_conn_token.cancel();
});
let _keepalive = spawn(async move {
let mut ticker = tokio::time::interval(Duration::from_secs(3));
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
select! {
() = conn_token.cancelled() => break,
_ = ticker.tick() => {
let ts = u64::try_from(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros(),
)
.unwrap_or(0);
if control_tx.send(EncryptedFrame::Keepalive(ts)).await.is_err() {
break;
}
}
}
}
});
}
async fn send_state_chunked(ss_tx: &Sender<EncryptedFrame>, compressed: Vec<u8>) -> bool {
if compressed.len() <= MAX_STATESYNC_DIFF_BYTES {
ss_tx
.send(EncryptedFrame::ScreenStateCompressed(compressed))
.await
.is_ok()
} else {
let chunks: Vec<Vec<u8>> = compressed
.chunks(STATE_CHUNK_SIZE)
.map(<[u8]>::to_vec)
.collect();
let total = u16::try_from(chunks.len()).unwrap_or(u16::MAX);
for (seq, chunk) in chunks.into_iter().enumerate() {
let seq_u16 = u16::try_from(seq).unwrap_or(u16::MAX);
if ss_tx
.send(EncryptedFrame::StateChunk((seq_u16, total, chunk)))
.await
.is_err()
{
return false;
}
}
true
}
}
fn spawn_silence_watchdog(token: CancellationToken, last_rx_us: Arc<AtomicU64>) {
let _watchdog = spawn(async move {
let mut ticker = tokio::time::interval(Duration::from_secs(5));
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
select! {
() = token.cancelled() => break,
_ = ticker.tick() => {
let elapsed_us = now_micros().saturating_sub(last_rx_us.load(Ordering::Relaxed));
if elapsed_us > CLIENT_SILENCE_TIMEOUT_US {
info!("Client silence timeout (30 s): cancelling connection");
token.cancel();
break;
}
}
}
}
});
}
fn mtu_probe_step(
current_nak: u64,
last_nak: &mut u64,
tier: &mut usize,
quiet_ticks: &mut u32,
probe_ticks: &mut u32,
probing: &mut bool,
) -> Option<usize> {
let delta = current_nak.wrapping_sub(*last_nak);
*last_nak = current_nak;
let prev_tier = *tier;
if *probing {
if delta >= MTU_PROBE_FAIL_THRESHOLD {
*tier = (*tier).saturating_sub(1);
*probing = false;
*quiet_ticks = 0;
*probe_ticks = 0;
} else {
*probe_ticks += 1;
if *probe_ticks >= MTU_PROBE_SUCCESS_TICKS {
*probing = false;
*quiet_ticks = 0;
*probe_ticks = 0;
}
}
} else if delta == 0 {
*quiet_ticks += 1;
if *quiet_ticks >= MTU_PROBE_QUIET_TICKS && *tier + 1 < MTU_TIERS.len() {
*tier += 1;
*probing = true;
*probe_ticks = 0;
*quiet_ticks = 0;
}
} else {
*quiet_ticks = 0;
}
(*tier != prev_tier).then_some(MTU_TIERS[*tier])
}
fn spawn_mtu_probe_task(
token: CancellationToken,
nak_received_count: Arc<AtomicU64>,
effective_mtu: Arc<AtomicUsize>,
) {
let _task = spawn(async move {
let mut ticker = tokio::time::interval(MTU_POLL_INTERVAL);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut tier: usize = 0;
let mut last_nak: u64 = 0;
let mut quiet_ticks: u32 = 0;
let mut probe_ticks: u32 = 0;
let mut probing = false;
loop {
select! {
() = token.cancelled() => break,
_ = ticker.tick() => {
let current = nak_received_count.load(Ordering::Relaxed);
if let Some(new_mtu) = mtu_probe_step(
current,
&mut last_nak,
&mut tier,
&mut quiet_ticks,
&mut probe_ticks,
&mut probing,
) {
effective_mtu.store(new_mtu, Ordering::Relaxed);
}
}
}
}
});
}
fn spawn_proactive_repaint_watchdog(
tx: Sender<EncryptedFrame>,
token: CancellationToken,
nak_received_count: Arc<AtomicU64>,
server_emulator: Arc<Mutex<vt100::Parser>>,
) {
let _watchdog = spawn(async move {
let mut last_count: u64 = 0;
let mut ticker = tokio::time::interval(Duration::from_millis(200));
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
select! {
() = token.cancelled() => break,
_ = ticker.tick() => {
let current = nak_received_count.load(Ordering::Relaxed);
let delta = current.wrapping_sub(last_count);
last_count = current;
if delta >= PROACTIVE_REPAINT_NAK_THRESHOLD {
let contents = {
let emu = server_emulator.lock().await;
emu.screen().contents_formatted()
};
let compressed = encode_all(contents.as_slice(), 3)
.unwrap_or_else(|_| contents.clone());
if tx
.send(EncryptedFrame::ScreenStateCompressed(compressed))
.await
.is_err()
{
break;
}
}
}
}
}
});
}
async fn new_session(
kex: &libmoshpit::Kex,
conn_token: &CancellationToken,
udp_port: u16,
session_uuid: Uuid,
data_tx: Sender<EncryptedFrame>,
control_tx: Sender<EncryptedFrame>,
full_registry: &FullSessionRegistry,
) -> Result<(
Sender<TerminalMessage>,
Option<Receiver<TerminalMessage>>,
Arc<Mutex<SessionOutputHandle>>,
Arc<Mutex<VecDeque<u8>>>,
Arc<Mutex<vt100::Parser>>,
Arc<AtomicU64>,
Arc<AtomicBool>,
Arc<AtomicUsize>,
)> {
let (term_tx, term_rx) = channel::<TerminalMessage>(256);
let output_handle = Arc::new(Mutex::new(SessionOutputHandle {
kex_uuid: kex.uuid(),
data_tx: Some(data_tx),
control_tx: Some(control_tx),
conn_token: Some(conn_token.clone()),
udp_port: Some(udp_port),
}));
let scrollback = Arc::new(Mutex::new(VecDeque::with_capacity(SCROLLBACK_CAPACITY)));
let server_emulator = Arc::new(Mutex::new(vt100::Parser::new(24, 80, 0)));
let dirty_counter = Arc::new(AtomicU64::new(1));
let diff_in_flight = Arc::new(AtomicBool::new(false));
let effective_mtu = Arc::new(AtomicUsize::new(MAX_UDP_PAYLOAD));
{
let mut fr = full_registry.lock().await;
drop(fr.insert(
session_uuid,
SessionRecord {
term_tx: term_tx.clone(),
output_handle: output_handle.clone(),
scrollback: scrollback.clone(),
server_emulator: server_emulator.clone(),
dirty_counter: dirty_counter.clone(),
diff_in_flight: diff_in_flight.clone(),
effective_mtu: effective_mtu.clone(),
},
));
}
Ok((
term_tx,
Some(term_rx),
output_handle,
scrollback,
server_emulator,
dirty_counter,
diff_in_flight,
effective_mtu,
))
}
fn server_intercept_queries(buf: &[u8]) -> Vec<u8> {
if !buf.contains(&0x1b) {
return Vec::new();
}
let mut resp = Vec::new();
let mut i = 0;
while i < buf.len() {
if buf[i] == 0x1b && i + 1 < buf.len() && buf[i + 1] == b'[' {
i += 2;
let marker = if i < buf.len() && matches!(buf[i], b'?' | b'>' | b'=') {
let m = buf[i];
i += 1;
Some(m)
} else {
None
};
let p0 = i;
while i < buf.len() && (buf[i].is_ascii_digit() || buf[i] == b';') {
i += 1;
}
let params = &buf[p0..i];
if i < buf.len() {
let term = buf[i];
i += 1;
match (marker, params, term) {
(None, b"" | b"0", b'c') => resp.extend_from_slice(b"\x1b[?62c"),
(Some(b'>'), b"" | b"0", b'c') => resp.extend_from_slice(b"\x1b[>1;10;0c"),
_ => {}
}
}
} else {
i += 1;
}
}
resp
}
#[cfg_attr(nightly, allow(clippy::too_many_arguments, clippy::too_many_lines))]
#[cfg_attr(coverage_nightly, coverage(off))]
fn spawn_pty_reader(
session_uuid: Uuid,
mut term_out: Box<dyn std::io::Read + Send>,
term_tx: Sender<TerminalMessage>,
output_handle: Arc<Mutex<SessionOutputHandle>>,
scrollback: Arc<Mutex<VecDeque<u8>>>,
server_emulator: Arc<Mutex<vt100::Parser>>,
dirty_counter: Arc<AtomicU64>,
diff_in_flight: Arc<AtomicBool>,
pacing_delay: Duration,
port_pool: Arc<Mutex<BTreeSet<u16>>>,
session_registry: SessionRegistry,
full_registry: FullSessionRegistry,
effective_mtu: Arc<AtomicUsize>,
diff_mode: DiffMode,
) {
let _read_handle = thread::spawn(move || {
loop {
let mut buffer = BytesMut::zeroed(4096);
match term_out.read(&mut buffer) {
Ok(0) => {
trace!("read 0 bytes from terminal, exiting");
break;
}
Ok(n) => {
let buf_slice = &buffer[..n];
let utf8_buf = String::from_utf8_lossy(buf_slice);
{
let mut sb = scrollback.blocking_lock();
let available = SCROLLBACK_CAPACITY.saturating_sub(sb.len());
if buf_slice.len() > available {
for _ in 0..(buf_slice.len() - available) {
let _ = sb.pop_front();
}
}
sb.extend(buf_slice.iter().copied());
}
server_emulator.blocking_lock().process(buf_slice);
let _ = dirty_counter.fetch_add(1, Ordering::Relaxed);
let send_ok = {
let h = output_handle.blocking_lock();
if diff_mode == DiffMode::StateSync {
let resp = server_intercept_queries(buf_slice);
if !resp.is_empty() {
drop(term_tx.try_send(TerminalMessage::Input(resp)));
}
drop(h);
true
} else if let Some(ref sender) = h.data_tx {
let uuid_wrapper = UuidWrapper::new(h.kex_uuid);
let sender_clone = sender.clone();
drop(h);
diff_in_flight.store(true, Ordering::Relaxed);
if let Ok(compressed) = encode_all(buf_slice, 1)
&& compressed.len() < buf_slice.len()
{
let frame =
EncryptedFrame::CompressedBytes((uuid_wrapper, compressed));
sender_clone.blocking_send(frame).is_ok()
} else {
let mut ok = true;
let mtu = effective_mtu.load(Ordering::Relaxed);
let n = buf_slice.len().div_ceil(mtu);
let burst_pacing = pacing_delay * if n > 10 { 3 } else { 1 };
let mut chunks = buf_slice.chunks(mtu).peekable();
while let Some(chunk) = chunks.next() {
let more = chunks.peek().is_some();
let frame =
EncryptedFrame::Bytes((uuid_wrapper, chunk.to_vec()));
ok = sender_clone.blocking_send(frame).is_ok();
if !ok {
break;
}
if more && !burst_pacing.is_zero() {
sleep(burst_pacing);
}
}
ok
}
} else {
drop(h);
true }
};
if !send_ok {
let mut h = output_handle.blocking_lock();
h.data_tx = None;
h.control_tx = None;
}
if is_exit_title(&utf8_buf, true) {
sleep(Duration::from_millis(500));
break;
}
buffer.advance(n);
}
Err(e) => {
error!("error reading from terminal: {e}");
break;
}
}
}
{
let h = output_handle.blocking_lock();
if let Some(ref tx) = h.control_tx {
drop(tx.blocking_send(EncryptedFrame::PtyExit));
}
}
sleep(Duration::from_millis(50));
{
let mut h = output_handle.blocking_lock();
if let Some(token) = h.conn_token.take() {
token.cancel();
}
if let Some(port) = h.udp_port.take() {
let mut pool = port_pool.blocking_lock();
let _ = pool.insert(port);
}
h.data_tx = None;
h.control_tx = None;
}
{
let mut sr = session_registry.blocking_lock();
drop(sr.remove(&session_uuid));
}
{
let mut fr = full_registry.blocking_lock();
drop(fr.remove(&session_uuid));
}
info!(session = %session_uuid, "session ended, client exited cleanly");
});
}
#[allow(unsafe_code)]
#[cfg_attr(
nightly,
allow(
clippy::too_many_arguments,
clippy::needless_pass_by_value,
clippy::too_many_lines
)
)]
#[cfg_attr(not(nightly), allow(clippy::needless_pass_by_value))]
#[cfg_attr(coverage_nightly, coverage(off))]
fn spawn_pty(
session_uuid: Uuid,
#[cfg_attr(not(unix), allow(unused_variables))] user: String,
shell: String,
mut term_rx: Receiver<TerminalMessage>,
term_tx: Sender<TerminalMessage>,
output_handle: Arc<Mutex<SessionOutputHandle>>,
scrollback: Arc<Mutex<VecDeque<u8>>>,
server_emulator: Arc<Mutex<vt100::Parser>>,
dirty_counter: Arc<AtomicU64>,
diff_in_flight: Arc<AtomicBool>,
pacing_delay: Duration,
#[cfg_attr(not(unix), allow(unused_variables))] term_type: String,
port_pool: Arc<Mutex<BTreeSet<u16>>>,
session_registry: SessionRegistry,
full_registry: FullSessionRegistry,
effective_mtu: Arc<AtomicUsize>,
diff_mode: DiffMode,
) {
let _term_handle = thread::spawn(move || {
let pty_system = native_pty_system();
let pair = match pty_system.openpty(PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
}) {
Ok(p) => p,
Err(e) => {
error!("Failed to open PTY: {e}");
return;
}
};
#[cfg(unix)]
{
let daemon_uid = unsafe { libc::getuid() };
if daemon_uid != 0 {
let daemon_user = current_daemon_user();
if daemon_user.as_deref() != Some(user.as_str()) {
error!(
"Daemon user {} cannot spawn shell for user {}",
daemon_user.unwrap_or_else(|| String::from("<unknown>")),
user
);
return;
}
}
let Some(tty_path) = pair.master.tty_name() else {
error!("Unable to determine PTY slave tty path");
return;
};
let slave = match std::fs::OpenOptions::new()
.read(true)
.write(true)
.custom_flags(libc::O_NOCTTY)
.open(&tty_path)
{
Ok(file) => file,
Err(e) => {
error!("Failed to open PTY slave {}: {e}", tty_path.display());
return;
}
};
let stdin_file = match slave.try_clone() {
Ok(file) => file,
Err(e) => {
error!("Failed to clone PTY slave for stdin: {e}");
return;
}
};
let stdout_file = match slave.try_clone() {
Ok(file) => file,
Err(e) => {
error!("Failed to clone PTY slave for stdout: {e}");
return;
}
};
let stderr_file = match slave.try_clone() {
Ok(file) => file,
Err(e) => {
error!("Failed to clone PTY slave for stderr: {e}");
return;
}
};
let mut cmd = std::process::Command::new(&shell);
let _ = cmd.arg("-li");
let mut drop_creds: Option<(CString, libc::uid_t, libc::gid_t)> = None;
if daemon_uid == 0 {
let account = match resolve_user_account(&user, &shell) {
Ok(account) => account,
Err(e) => {
error!("Failed to resolve target account for {user}: {e}");
return;
}
};
let Ok(username_c) = CString::new(account.username.clone()) else {
error!("Target username contains invalid NUL byte");
return;
};
let login_uid = account.uid;
let primary_group_id = account.gid;
let _ = cmd.current_dir(&account.home);
let _ = cmd.env("HOME", &account.home);
let _ = cmd.env("USER", &account.username);
let _ = cmd.env("LOGNAME", &account.username);
let _ = cmd.env("SHELL", &account.shell);
let _ = cmd.env("TERM", &term_type);
drop_creds = Some((username_c, login_uid, primary_group_id));
}
let _ = unsafe {
cmd.pre_exec(move || {
let tiocsctty_request = tiocsctty_ioctl_request();
if libc::setsid() < 0 {
return Err(std::io::Error::last_os_error());
}
if libc::ioctl(0, tiocsctty_request, 0) < 0 {
return Err(std::io::Error::last_os_error());
}
if let Some((username_c, login_uid, primary_group_id)) = drop_creds.as_ref() {
#[cfg(target_os = "linux")]
let initgroups_basegroup = initgroups_base_group(*primary_group_id);
#[cfg(target_os = "macos")]
let initgroups_basegroup = initgroups_base_group(*primary_group_id)?;
if libc::initgroups(username_c.as_ptr(), initgroups_basegroup) < 0 {
return Err(std::io::Error::last_os_error());
}
if libc::setgid(*primary_group_id) < 0 {
return Err(std::io::Error::last_os_error());
}
if libc::setuid(*login_uid) < 0 {
return Err(std::io::Error::last_os_error());
}
}
Ok(())
})
};
let _ = cmd
.stdin(Stdio::from(stdin_file))
.stdout(Stdio::from(stdout_file))
.stderr(Stdio::from(stderr_file));
if let Err(e) = cmd.spawn() {
error!("Failed to spawn shell for user {user}: {e}");
return;
}
drop(pair.slave);
drop(slave);
}
#[cfg(windows)]
{
let cmd = CommandBuilder::new(shell);
if let Err(e) = pair.slave.spawn_command(cmd) {
error!("Failed to spawn shell: {e}");
return;
}
}
let master = pair.master;
let term_out = match master.try_clone_reader() {
Ok(r) => r,
Err(e) => {
error!("Failed to clone PTY reader: {e}");
return;
}
};
let mut term_in = match master.take_writer() {
Ok(w) => w,
Err(e) => {
error!("Failed to take PTY writer: {e}");
return;
}
};
spawn_pty_reader(
session_uuid,
term_out,
term_tx,
output_handle,
scrollback,
server_emulator.clone(),
dirty_counter.clone(),
diff_in_flight,
pacing_delay,
port_pool,
session_registry,
full_registry,
effective_mtu,
diff_mode,
);
while let Some(terminal_message) = term_rx.blocking_recv() {
match terminal_message {
TerminalMessage::Resize { columns, rows } => {
if let Err(e) = master.resize(PtySize {
rows,
cols: columns,
pixel_width: 0,
pixel_height: 0,
}) {
error!("error resizing terminal: {e}");
}
server_emulator
.blocking_lock()
.screen_mut()
.set_size(rows, columns);
let _ = dirty_counter.fetch_add(1, Ordering::Relaxed);
}
TerminalMessage::Input(data) => {
if let Err(e) = term_in.write_all(&data) {
error!("error writing to terminal: {e}");
break;
}
}
}
}
});
}
#[cfg(unix)]
#[allow(unsafe_code)]
fn current_daemon_user() -> Option<String> {
let daemon_uid = unsafe { libc::getuid() };
let pwd = unsafe { libc::getpwuid(daemon_uid) };
if pwd.is_null() {
return None;
}
Some(
unsafe { CStr::from_ptr((*pwd).pw_name) }
.to_string_lossy()
.into_owned(),
)
}
#[cfg(all(unix, target_os = "linux"))]
fn tiocsctty_ioctl_request() -> libc::Ioctl {
libc::TIOCSCTTY
}
#[cfg(all(unix, target_os = "macos"))]
fn tiocsctty_ioctl_request() -> libc::c_ulong {
libc::c_ulong::from(libc::TIOCSCTTY)
}
#[cfg(all(unix, target_os = "linux"))]
fn initgroups_base_group(group_id: libc::gid_t) -> libc::gid_t {
group_id
}
#[cfg(all(unix, target_os = "macos"))]
fn initgroups_base_group(group_id: libc::gid_t) -> std::io::Result<libc::c_int> {
group_id.try_into().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"gid does not fit into c_int for initgroups",
)
})
}
#[cfg(unix)]
struct ResolvedUserAccount {
username: String,
uid: libc::uid_t,
gid: libc::gid_t,
home: String,
shell: String,
}
#[cfg(unix)]
#[allow(unsafe_code)]
fn resolve_user_account(username: &str, fallback_shell: &str) -> Result<ResolvedUserAccount> {
let username_c = CString::new(username)?;
let pwd = unsafe { libc::getpwnam(username_c.as_ptr()) };
if pwd.is_null() {
return Err(anyhow::anyhow!("user '{username}' not found"));
}
let pw = unsafe { *pwd };
let home = unsafe { CStr::from_ptr(pw.pw_dir) }
.to_string_lossy()
.to_string();
let shell_from_db = unsafe { CStr::from_ptr(pw.pw_shell) }
.to_string_lossy()
.to_string();
Ok(ResolvedUserAccount {
username: username.to_string(),
uid: pw.pw_uid,
gid: pw.pw_gid,
home,
shell: if shell_from_db.is_empty() {
fallback_shell.to_string()
} else {
shell_from_db
},
})
}
#[cfg(test)]
#[allow(dead_code, clippy::all)]
mod test {
use libmoshpit::{EncryptedFrame, Kex, ServerKex};
use tokio::sync::mpsc::channel;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
#[cfg(unix)]
use super::{current_daemon_user, resolve_user_account};
use std::{
sync::{
Arc,
atomic::{AtomicU64, AtomicUsize, Ordering},
},
time::Duration,
};
use super::{
MAX_STATESYNC_DIFF_BYTES, MTU_PROBE_FAIL_THRESHOLD, MTU_PROBE_QUIET_TICKS,
MTU_PROBE_SUCCESS_TICKS, MTU_TIERS, PROACTIVE_REPAINT_NAK_THRESHOLD, STATE_CHUNK_SIZE,
mtu_probe_step, new_full_registry, new_session, now_micros, resolve_session,
send_state_chunked, server_intercept_queries, spawn_connection_watchdogs,
spawn_mtu_probe_task, spawn_proactive_repaint_watchdog, spawn_silence_watchdog,
};
#[cfg(unix)]
#[test]
fn current_daemon_user_returns_some() {
let user = current_daemon_user();
assert!(user.is_some());
}
#[cfg(unix)]
#[test]
fn resolve_user_account_unknown_user_errors() {
let result = resolve_user_account("__moshpit_no_such_user__", "/bin/sh");
assert!(result.is_err());
}
#[cfg(unix)]
#[test]
fn resolve_user_account_current_user_roundtrip() {
let Some(username) = current_daemon_user() else {
panic!("expected current daemon user on unix")
};
let account =
resolve_user_account(&username, "/bin/sh").expect("current daemon user should resolve");
assert_eq!(account.username, username);
assert!(!account.home.is_empty());
assert!(!account.shell.is_empty());
}
#[tokio::test]
async fn new_session_registers_in_full_registry() -> anyhow::Result<()> {
let kex = Kex::default();
let conn_token = CancellationToken::new();
let (data_tx, _data_rx) = channel::<EncryptedFrame>(4);
let (control_tx, _control_rx) = channel::<EncryptedFrame>(4);
let session_uuid = Uuid::new_v4();
let registry = new_full_registry();
let _reg_result = new_session(
&kex,
&conn_token,
50_000,
session_uuid,
data_tx,
control_tx,
®istry,
)
.await?;
assert!(registry.lock().await.contains_key(&session_uuid));
Ok(())
}
#[tokio::test]
async fn new_session_returns_some_term_rx() -> anyhow::Result<()> {
let kex = Kex::default();
let conn_token = CancellationToken::new();
let (data_tx, _data_rx) = channel::<EncryptedFrame>(4);
let (control_tx, _control_rx) = channel::<EncryptedFrame>(4);
let session_uuid = Uuid::new_v4();
let registry = new_full_registry();
let (_, maybe_rx, _, _, _, _, _, _) = new_session(
&kex,
&conn_token,
50_000,
session_uuid,
data_tx,
control_tx,
®istry,
)
.await?;
assert!(maybe_rx.is_some());
Ok(())
}
#[tokio::test]
async fn new_session_output_handle_has_correct_kex_uuid() -> anyhow::Result<()> {
let kex = Kex::default();
let conn_token = CancellationToken::new();
let (data_tx, _data_rx) = channel::<EncryptedFrame>(4);
let (control_tx, _control_rx) = channel::<EncryptedFrame>(4);
let session_uuid = Uuid::new_v4();
let registry = new_full_registry();
let (_, _, output_handle, _, _, _, _, _) = new_session(
&kex,
&conn_token,
50_000,
session_uuid,
data_tx,
control_tx,
®istry,
)
.await?;
assert_eq!(output_handle.lock().await.kex_uuid, kex.uuid());
Ok(())
}
#[tokio::test]
async fn new_session_scrollback_initially_empty() -> anyhow::Result<()> {
let kex = Kex::default();
let conn_token = CancellationToken::new();
let (data_tx, _data_rx) = channel::<EncryptedFrame>(4);
let (control_tx, _control_rx) = channel::<EncryptedFrame>(4);
let session_uuid = Uuid::new_v4();
let registry = new_full_registry();
let (_, _, _, scrollback, _, _, _, _) = new_session(
&kex,
&conn_token,
50_000,
session_uuid,
data_tx,
control_tx,
®istry,
)
.await?;
assert!(scrollback.lock().await.is_empty());
Ok(())
}
#[tokio::test]
async fn new_session_emulator_default_size() -> anyhow::Result<()> {
let kex = Kex::default();
let conn_token = CancellationToken::new();
let (data_tx, _data_rx) = channel::<EncryptedFrame>(4);
let (control_tx, _control_rx) = channel::<EncryptedFrame>(4);
let session_uuid = Uuid::new_v4();
let registry = new_full_registry();
let (_, _, _, _, emulator, _, _, _) = new_session(
&kex,
&conn_token,
50_000,
session_uuid,
data_tx,
control_tx,
®istry,
)
.await?;
let emu = emulator.lock().await;
let screen = emu.screen();
assert_eq!(screen.size(), (24, 80));
Ok(())
}
#[tokio::test]
async fn watchdogs_keepalive_sends_frame() {
let (control_tx, mut control_rx) = channel::<EncryptedFrame>(4);
let conn_token = CancellationToken::new();
let server_token = CancellationToken::new();
spawn_connection_watchdogs(control_tx, conn_token.clone(), server_token);
let frame = tokio::time::timeout(Duration::from_millis(200), control_rx.recv()).await;
conn_token.cancel();
let frame = frame
.expect("timeout waiting for keepalive")
.expect("channel closed");
assert!(matches!(frame, EncryptedFrame::Keepalive(_)));
}
#[tokio::test]
async fn watchdogs_server_cancel_sends_shutdown_then_cancels_conn() {
let (control_tx, mut control_rx) = channel::<EncryptedFrame>(4);
let conn_token = CancellationToken::new();
let server_token = CancellationToken::new();
spawn_connection_watchdogs(control_tx, conn_token.clone(), server_token.clone());
server_token.cancel();
tokio::time::sleep(Duration::from_millis(50)).await;
let mut saw_shutdown = false;
while let Ok(frame) = control_rx.try_recv() {
if matches!(frame, EncryptedFrame::Shutdown) {
saw_shutdown = true;
break;
}
}
assert!(
saw_shutdown,
"expected Shutdown frame after server_token cancel"
);
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(conn_token.is_cancelled());
}
#[tokio::test]
async fn watchdogs_conn_cancel_stops_keepalive() {
let (control_tx, mut control_rx) = channel::<EncryptedFrame>(4);
let conn_token = CancellationToken::new();
let server_token = CancellationToken::new();
spawn_connection_watchdogs(control_tx, conn_token.clone(), server_token);
conn_token.cancel();
tokio::time::sleep(Duration::from_millis(50)).await;
while control_rx.try_recv().is_ok() {}
let result = tokio::time::timeout(Duration::from_millis(100), control_rx.recv()).await;
assert!(result.map_or(true, |v| v.is_none()));
}
#[tokio::test]
async fn resolve_session_new_session_path() -> anyhow::Result<()> {
let kex = Kex::default();
let session_uuid = Uuid::new_v4();
let skex = ServerKex::builder()
.user("alice".to_string())
.shell("/usr/bin/fish".to_string())
.session_uuid(session_uuid)
.build();
let conn_token = CancellationToken::new();
let (data_tx, _data_rx) = channel::<EncryptedFrame>(4);
let (control_tx, _control_rx) = channel::<EncryptedFrame>(4);
let registry = new_full_registry();
let (_, maybe_rx, _, _, _, _, _, _) = resolve_session(
&kex,
&skex,
&conn_token,
50_000,
data_tx,
control_tx,
®istry,
)
.await?;
assert!(maybe_rx.is_some());
Ok(())
}
#[tokio::test]
async fn resolve_session_resume_existing() -> anyhow::Result<()> {
let kex = Kex::default();
let session_uuid = Uuid::new_v4();
let conn_token = CancellationToken::new();
let (data_tx, _data_rx) = channel::<EncryptedFrame>(16);
let (control_tx, _control_rx) = channel::<EncryptedFrame>(4);
let registry = new_full_registry();
let _first_session = new_session(
&kex,
&conn_token,
50_000,
session_uuid,
data_tx.clone(),
control_tx.clone(),
®istry,
)
.await?;
let new_kex = Kex::default();
let skex_resume = ServerKex::builder()
.user("alice".to_string())
.shell("/usr/bin/fish".to_string())
.session_uuid(session_uuid)
.is_resume(true)
.build();
let new_conn_token = CancellationToken::new();
let (resume_data_tx, mut resume_data_rx) = channel::<EncryptedFrame>(16);
let (resume_ctrl_tx, _resume_ctrl_rx) = channel::<EncryptedFrame>(4);
let (_, maybe_rx, output_handle, _, _, _, _, _) = resolve_session(
&new_kex,
&skex_resume,
&new_conn_token,
50_001,
resume_data_tx,
resume_ctrl_tx,
®istry,
)
.await?;
assert!(maybe_rx.is_none());
assert_eq!(output_handle.lock().await.kex_uuid, new_kex.uuid());
let mut saw_screen_state = false;
while let Ok(frame) = resume_data_rx.try_recv() {
if matches!(
frame,
EncryptedFrame::ScreenState(_) | EncryptedFrame::ScreenStateCompressed(_)
) {
saw_screen_state = true;
break;
}
}
assert!(saw_screen_state, "expected ScreenState frame on resume");
Ok(())
}
#[tokio::test]
async fn resolve_session_resume_expired() -> anyhow::Result<()> {
let kex = Kex::default();
let session_uuid = Uuid::new_v4();
let skex = ServerKex::builder()
.user("alice".to_string())
.shell("/usr/bin/fish".to_string())
.session_uuid(session_uuid)
.is_resume(true)
.build();
let conn_token = CancellationToken::new();
let (data_tx, _data_rx) = channel::<EncryptedFrame>(4);
let (control_tx, _control_rx) = channel::<EncryptedFrame>(4);
let registry = new_full_registry();
let (_, maybe_rx, _, _, _, _, _, _) = resolve_session(
&kex,
&skex,
&conn_token,
50_000,
data_tx,
control_tx,
®istry,
)
.await?;
assert!(maybe_rx.is_some());
Ok(())
}
#[cfg(all(unix, target_os = "linux"))]
use super::{initgroups_base_group, tiocsctty_ioctl_request};
#[cfg(all(unix, target_os = "linux"))]
#[test]
fn tiocsctty_ioctl_request_is_nonzero() {
assert_ne!(tiocsctty_ioctl_request(), 0);
}
#[cfg(all(unix, target_os = "linux"))]
#[test]
fn initgroups_base_group_roundtrip() {
assert_eq!(initgroups_base_group(42), 42);
assert_eq!(initgroups_base_group(0), 0);
assert_eq!(initgroups_base_group(u32::MAX), u32::MAX);
}
fn make_probe_state() -> (usize, u64, u32, u32, bool) {
(0usize, 0u64, 0u32, 0u32, false)
}
#[test]
fn mtu_probe_step_starts_at_base_mtu() {
let (mut tier, mut last_nak, mut qt, mut pt, mut probing) = make_probe_state();
let result = mtu_probe_step(0, &mut last_nak, &mut tier, &mut qt, &mut pt, &mut probing);
assert!(result.is_none(), "no tier change on first quiet tick");
assert_eq!(tier, 0);
}
#[test]
fn mtu_probe_step_advances_tier_after_quiet_period() {
let (mut tier, mut last_nak, mut qt, mut pt, mut probing) = make_probe_state();
let mut changed = None;
for _ in 0..MTU_PROBE_QUIET_TICKS {
changed = mtu_probe_step(0, &mut last_nak, &mut tier, &mut qt, &mut pt, &mut probing);
}
assert_eq!(
changed,
Some(MTU_TIERS[1]),
"tier upgrades on Nth quiet tick"
);
assert_eq!(tier, 1);
assert!(probing);
}
#[test]
fn mtu_probe_step_reverts_on_nak_spike_during_probe() {
let (mut tier, mut last_nak, mut qt, mut pt, mut probing) = make_probe_state();
for _ in 0..MTU_PROBE_QUIET_TICKS {
let _ = mtu_probe_step(0, &mut last_nak, &mut tier, &mut qt, &mut pt, &mut probing);
}
assert_eq!(tier, 1);
assert!(probing);
let nak = MTU_PROBE_FAIL_THRESHOLD;
let result = mtu_probe_step(
nak,
&mut last_nak,
&mut tier,
&mut qt,
&mut pt,
&mut probing,
);
assert_eq!(result, Some(MTU_TIERS[0]), "tier reverts on NAK spike");
assert_eq!(tier, 0);
assert!(!probing);
}
#[test]
fn mtu_probe_step_confirms_upgrade_after_success_ticks() {
let (mut tier, mut last_nak, mut qt, mut pt, mut probing) = make_probe_state();
for _ in 0..MTU_PROBE_QUIET_TICKS {
let _ = mtu_probe_step(0, &mut last_nak, &mut tier, &mut qt, &mut pt, &mut probing);
}
assert!(probing, "should be probing after quiet period");
for _ in 0..MTU_PROBE_SUCCESS_TICKS {
let _ = mtu_probe_step(0, &mut last_nak, &mut tier, &mut qt, &mut pt, &mut probing);
}
assert!(!probing, "probe confirmed — no longer probing");
assert_eq!(tier, 1, "tier stays at 1 after confirmation");
}
#[test]
fn mtu_probe_step_no_upgrade_below_threshold() {
let (mut tier, mut last_nak, mut qt, mut pt, mut probing) = make_probe_state();
for _ in 0..MTU_PROBE_QUIET_TICKS - 1 {
let _ = mtu_probe_step(0, &mut last_nak, &mut tier, &mut qt, &mut pt, &mut probing);
}
assert_eq!(tier, 0, "no upgrade before quiet threshold");
assert!(!probing);
}
#[test]
fn mtu_probe_step_resets_quiet_on_any_nak() {
let (mut tier, mut last_nak, mut qt, mut pt, mut probing) = make_probe_state();
for _ in 0..MTU_PROBE_QUIET_TICKS - 1 {
let _ = mtu_probe_step(0, &mut last_nak, &mut tier, &mut qt, &mut pt, &mut probing);
}
let _ = mtu_probe_step(1, &mut last_nak, &mut tier, &mut qt, &mut pt, &mut probing);
for _ in 0..MTU_PROBE_QUIET_TICKS - 1 {
let _ = mtu_probe_step(1, &mut last_nak, &mut tier, &mut qt, &mut pt, &mut probing);
}
assert_eq!(tier, 0, "quiet counter reset — no upgrade yet");
}
#[tokio::test]
async fn mtu_probe_task_starts_at_base_mtu() {
let token = CancellationToken::new();
let nak_count = Arc::new(AtomicU64::new(0));
let effective_mtu = Arc::new(AtomicUsize::new(MTU_TIERS[0]));
spawn_mtu_probe_task(token.clone(), nak_count, effective_mtu.clone());
token.cancel();
assert_eq!(effective_mtu.load(Ordering::Relaxed), MTU_TIERS[0]);
}
#[tokio::test]
async fn proactive_repaint_fires_on_nak_saturation() {
let (tx, mut rx) = channel::<EncryptedFrame>(4);
let token = CancellationToken::new();
let nak_count = Arc::new(AtomicU64::new(0));
let emulator = Arc::new(tokio::sync::Mutex::new(vt100::Parser::new(24, 80, 0)));
spawn_proactive_repaint_watchdog(tx, token.clone(), nak_count.clone(), emulator);
nak_count.store(PROACTIVE_REPAINT_NAK_THRESHOLD, Ordering::Relaxed);
let frame = tokio::time::timeout(Duration::from_millis(500), rx.recv()).await;
token.cancel();
let frame = frame
.expect("timeout: proactive repaint did not fire within 500 ms")
.expect("channel closed before proactive repaint");
assert!(matches!(frame, EncryptedFrame::ScreenStateCompressed(_)));
}
#[tokio::test]
async fn proactive_repaint_does_not_fire_below_threshold() {
let (tx, mut rx) = channel::<EncryptedFrame>(4);
let token = CancellationToken::new();
let nak_count = Arc::new(AtomicU64::new(0));
let emulator = Arc::new(tokio::sync::Mutex::new(vt100::Parser::new(24, 80, 0)));
spawn_proactive_repaint_watchdog(tx, token.clone(), nak_count.clone(), emulator);
nak_count.store(PROACTIVE_REPAINT_NAK_THRESHOLD - 1, Ordering::Relaxed);
let result = tokio::time::timeout(Duration::from_millis(300), rx.recv()).await;
token.cancel();
assert!(
result.is_err(),
"expected no proactive repaint below threshold, but got a frame"
);
}
#[tokio::test]
async fn proactive_repaint_stops_on_cancel() {
let (tx, mut rx) = channel::<EncryptedFrame>(4);
let token = CancellationToken::new();
let nak_count = Arc::new(AtomicU64::new(PROACTIVE_REPAINT_NAK_THRESHOLD));
let emulator = Arc::new(tokio::sync::Mutex::new(vt100::Parser::new(24, 80, 0)));
spawn_proactive_repaint_watchdog(tx, token.clone(), nak_count, emulator);
token.cancel();
tokio::time::sleep(Duration::from_millis(250)).await;
while rx.try_recv().is_ok() {}
let result = tokio::time::timeout(Duration::from_millis(100), rx.recv()).await;
assert!(
result.map_or(true, |v| v.is_none()),
"watchdog kept sending after cancellation"
);
}
#[tokio::test]
async fn send_state_chunked_small_payload_sends_screen_state_compressed() {
let (tx, mut rx) = channel::<EncryptedFrame>(8);
let payload = vec![0u8; MAX_STATESYNC_DIFF_BYTES];
let sent = send_state_chunked(&tx, payload.clone()).await;
assert!(sent);
let frame = rx.try_recv().expect("expected a frame");
assert!(
matches!(frame, EncryptedFrame::ScreenStateCompressed(ref d) if *d == payload),
"expected ScreenStateCompressed, got {frame:?}"
);
assert!(rx.try_recv().is_err(), "expected exactly one frame");
}
#[tokio::test]
async fn send_state_chunked_large_payload_sends_state_chunks() -> anyhow::Result<()> {
let (tx, mut rx) = channel::<EncryptedFrame>(128);
let payload = vec![0xABu8; MAX_STATESYNC_DIFF_BYTES + STATE_CHUNK_SIZE + 1];
let sent = send_state_chunked(&tx, payload.clone()).await;
assert!(sent);
let expected_chunks = payload.chunks(STATE_CHUNK_SIZE).count();
let total = u16::try_from(expected_chunks)?;
let mut received = 0usize;
while let Ok(frame) = rx.try_recv() {
let EncryptedFrame::StateChunk((seq, t, data)) = frame else {
panic!("expected StateChunk, got {frame:?}");
};
assert_eq!(seq, u16::try_from(received)?);
assert_eq!(t, total);
let expected_slice = &payload[received * STATE_CHUNK_SIZE
..((received + 1) * STATE_CHUNK_SIZE).min(payload.len())];
assert_eq!(data, expected_slice);
received += 1;
}
assert_eq!(received, expected_chunks);
Ok(())
}
#[tokio::test]
async fn send_state_chunked_closed_channel_returns_false() {
let (tx, rx) = channel::<EncryptedFrame>(8);
drop(rx);
let payload = vec![0u8; MAX_STATESYNC_DIFF_BYTES + 1];
let sent = send_state_chunked(&tx, payload).await;
assert!(!sent);
}
#[test]
fn server_intercept_queries_no_escape_returns_empty() {
assert!(server_intercept_queries(b"hello world").is_empty());
}
#[test]
fn server_intercept_queries_primary_da_returns_vt220() {
assert_eq!(server_intercept_queries(b"\x1b[c"), b"\x1b[?62c");
assert_eq!(server_intercept_queries(b"\x1b[0c"), b"\x1b[?62c");
}
#[test]
fn server_intercept_queries_secondary_da_returns_response() {
assert_eq!(server_intercept_queries(b"\x1b[>c"), b"\x1b[>1;10;0c");
assert_eq!(server_intercept_queries(b"\x1b[>0c"), b"\x1b[>1;10;0c");
}
#[test]
fn server_intercept_queries_unknown_sequence_returns_empty() {
assert!(server_intercept_queries(b"\x1b[?25h").is_empty());
assert!(server_intercept_queries(b"\x1b[6n").is_empty());
}
#[test]
fn server_intercept_queries_multiple_queries_returns_both_responses() {
let input = b"\x1b[c\x1b[>c";
let resp = server_intercept_queries(input);
assert!(
resp.starts_with(b"\x1b[?62c"),
"missing primary DA response"
);
assert!(
resp.ends_with(b"\x1b[>1;10;0c"),
"missing secondary DA response"
);
}
#[tokio::test(start_paused = true)]
async fn silence_watchdog_fires_on_stale_timestamp() {
let token = CancellationToken::new();
let last_rx_us = Arc::new(AtomicU64::new(0));
spawn_silence_watchdog(token.clone(), last_rx_us);
tokio::time::advance(Duration::from_secs(35)).await;
tokio::task::yield_now().await;
tokio::task::yield_now().await;
assert!(
token.is_cancelled(),
"watchdog should have cancelled the token"
);
}
#[tokio::test(start_paused = true)]
async fn silence_watchdog_does_not_fire_when_recently_active() {
let token = CancellationToken::new();
let last_rx_us = Arc::new(AtomicU64::new(now_micros()));
spawn_silence_watchdog(token.clone(), last_rx_us.clone());
tokio::time::advance(Duration::from_secs(5)).await;
tokio::task::yield_now().await;
tokio::task::yield_now().await;
assert!(
!token.is_cancelled(),
"watchdog should not fire within 30s silence threshold"
);
token.cancel();
}
#[tokio::test(start_paused = true)]
async fn silence_watchdog_stops_on_explicit_cancel() {
let token = CancellationToken::new();
let last_rx_us = Arc::new(AtomicU64::new(0));
spawn_silence_watchdog(token.clone(), last_rx_us);
token.cancel();
tokio::time::advance(Duration::from_mins(1)).await;
tokio::task::yield_now().await;
assert!(token.is_cancelled());
}
}