use crate::protocol::{self, ClientMsg, FrameReader, ServerMsg};
use crate::session::SessionHandles;
use std::io::Write;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tracing::debug;
use super::session_setup::resize_pty;
use super::shared::{
lock_mutex, prepend_passthrough, render_and_send, store_dims, RENDER_THROTTLE,
};
fn session_exit_code(h: &SessionHandles) -> Option<i32> {
h.exit_code.lock().ok().and_then(|c| *c)
}
pub(super) async fn screen_to_client(
h: SessionHandles,
mut renderer: retach::screen::AnsiRenderer,
refresh_notify: Arc<tokio::sync::Notify>,
mut evict_rx: tokio::sync::watch::Receiver<bool>,
mut writer: tokio::net::unix::OwnedWriteHalf,
) -> anyhow::Result<()> {
use std::pin::pin;
use std::time::Duration;
use tokio::time::Instant;
if !h.reader_alive.load(Ordering::Acquire) {
render_and_send(&h.screen, &mut renderer, &mut writer).await?;
let msg = protocol::encode(&ServerMsg::SessionEnded {
exit_code: session_exit_code(&h),
})?;
writer.write_all(&msg).await?;
return Ok(());
}
let mut throttle_sleep = pin!(tokio::time::sleep(Duration::ZERO));
let mut pending_render = false;
loop {
tokio::select! {
_ = h.screen_notify.notified() => {
if !h.reader_alive.load(Ordering::Acquire) {
let (render_data, passthrough) = {
let mut screen = lock_mutex(&h.screen, "screen")?;
renderer.take_and_render(&mut *screen)
};
let update = prepend_passthrough(passthrough, render_data);
let msg = protocol::encode(&ServerMsg::ScreenUpdate(update))?;
writer.write_all(&msg).await?;
let msg = protocol::encode(&ServerMsg::SessionEnded {
exit_code: session_exit_code(&h),
})?;
writer.write_all(&msg).await?;
break;
}
pending_render = true;
throttle_sleep.as_mut().reset(Instant::now() + RENDER_THROTTLE);
}
_ = &mut throttle_sleep, if pending_render => {
let (render_data, passthrough) = {
let mut screen = lock_mutex(&h.screen, "screen")?;
renderer.take_and_render(&mut *screen)
};
let update = prepend_passthrough(passthrough, render_data);
if !update.is_empty() {
let msg = protocol::encode(&ServerMsg::ScreenUpdate(update))?;
writer.write_all(&msg).await?;
}
pending_render = false;
}
_ = refresh_notify.notified() => {
render_and_send(&h.screen, &mut renderer, &mut writer).await?;
}
result = evict_rx.changed() => {
match result {
Ok(()) => {
debug!(session = %h.name, "client evicted by new connection");
let msg = protocol::encode(&ServerMsg::Error("evicted by new client".into()))?;
if let Err(e) = writer.write_all(&msg).await {
debug!(session = %h.name, error = %e, "failed to send eviction notice to client");
}
}
Err(_) => {
debug!(session = %h.name, "session killed while client connected");
let msg = protocol::encode(&ServerMsg::SessionEnded { exit_code: None })?;
if let Err(e) = writer.write_all(&msg).await {
debug!(session = %h.name, error = %e, "failed to send session-ended to killed client");
}
}
}
break;
}
}
}
Ok(())
}
pub(super) async fn client_to_pty(
h: SessionHandles,
mut sock_reader: tokio::net::unix::OwnedReadHalf,
refresh_notify: Arc<tokio::sync::Notify>,
leftover: Vec<u8>,
) -> anyhow::Result<()> {
let mut frames = FrameReader::with_leftover(leftover);
loop {
if !frames.fill_from(&mut sock_reader).await? {
debug!(session = %h.name, "client socket closed");
break;
}
while let Some(msg) = frames.decode_next::<ClientMsg>()? {
match msg {
ClientMsg::Input(input) => {
let pw = h.pty_writer.clone();
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let mut w = lock_mutex(&pw, "pty_writer")?;
w.write_all(&input)?;
w.flush()?;
Ok(())
})
.await??;
}
ClientMsg::Resize { cols, rows } => {
let master_clone = h.master.clone();
let screen_clone = h.screen.clone();
let dims_clone = h.dims.clone();
let name_clone = h.name.clone();
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
resize_pty(&master_clone, &screen_clone, cols, rows)?;
store_dims(&dims_clone, cols, rows, &name_clone);
Ok(())
})
.await??;
}
ClientMsg::RefreshScreen => {
refresh_notify.notify_one();
}
ClientMsg::Detach => {
debug!(session = %h.name, "client detached");
return Ok(());
}
ClientMsg::Connect { .. }
| ClientMsg::ListSessions { .. }
| ClientMsg::KillSession { .. } => {
tracing::debug!("ignoring unexpected client message in session relay");
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pty::Pty;
use retach::screen::{AnsiRenderer, Screen};
use std::sync::atomic::AtomicBool;
use std::sync::Mutex as StdMutex;
use std::time::Duration;
fn test_handles(
reader_alive: bool,
exit_code: Option<i32>,
) -> (
Pty,
SessionHandles,
Arc<tokio::sync::Notify>,
tokio::sync::watch::Sender<bool>,
tokio::sync::watch::Receiver<bool>,
) {
let pty = Pty::spawn(80, 24).unwrap();
let screen = Arc::new(StdMutex::new(Screen::new(80, 24, 1000)));
let screen_notify = Arc::new(tokio::sync::Notify::new());
let (evict_tx, evict_rx) = tokio::sync::watch::channel(true);
let handles = SessionHandles {
screen,
pty_writer: pty.writer_arc(),
master: pty.master_arc(),
dims: Arc::new(StdMutex::new(retach::screen::TerminalSize {
cols: 80,
rows: 24,
})),
screen_notify: screen_notify.clone(),
reader_alive: Arc::new(AtomicBool::new(reader_alive)),
exit_code: Arc::new(StdMutex::new(exit_code)),
name: "relay-test".into(),
};
(pty, handles, screen_notify, evict_tx, evict_rx)
}
async fn drain_server_msgs(client: tokio::net::UnixStream) -> Vec<ServerMsg> {
let (mut reader, _w) = client.into_split();
let mut frames = FrameReader::new();
let mut out = Vec::new();
while frames.fill_from(&mut reader).await.unwrap_or(false) {
while let Some(msg) = frames.decode_next::<ServerMsg>().unwrap() {
out.push(msg);
}
}
while let Some(msg) = frames.decode_next::<ServerMsg>().unwrap() {
out.push(msg);
}
out
}
#[tokio::test]
async fn screen_to_client_dead_reader_sends_session_ended_with_exit_code() {
let (_pty, handles, refresh_notify, _evict_tx, evict_rx) = test_handles(false, Some(7));
{
let mut scr = handles.screen.lock().unwrap();
scr.process(b"FINAL OUTPUT");
}
let (client, server) = tokio::net::UnixStream::pair().unwrap();
let (_r, w) = server.into_split();
let task = tokio::spawn(screen_to_client(
handles,
AnsiRenderer::new(),
refresh_notify,
evict_rx,
w,
));
let msgs = drain_server_msgs(client).await;
task.await.unwrap().unwrap();
assert!(
matches!(
msgs.last(),
Some(ServerMsg::SessionEnded { exit_code: Some(7) })
),
"last message must be SessionEnded with the captured exit code: {:?}",
msgs
);
assert!(
msgs.iter().any(|m| matches!(m, ServerMsg::ScreenUpdate(_))),
"a final ScreenUpdate must precede SessionEnded: {:?}",
msgs
);
}
#[tokio::test]
async fn screen_to_client_reader_dies_while_attached() {
let (_pty, handles, refresh_notify, _evict_tx, evict_rx) = test_handles(true, Some(3));
let reader_alive = handles.reader_alive.clone();
let screen_notify = handles.screen_notify.clone();
let (client, server) = tokio::net::UnixStream::pair().unwrap();
let (_r, w) = server.into_split();
let task = tokio::spawn(screen_to_client(
handles,
AnsiRenderer::new(),
refresh_notify,
evict_rx,
w,
));
reader_alive.store(false, Ordering::Release);
screen_notify.notify_one();
let msgs = drain_server_msgs(client).await;
task.await.unwrap().unwrap();
assert!(
matches!(
msgs.last(),
Some(ServerMsg::SessionEnded { exit_code: Some(3) })
),
"reader death while attached must end with SessionEnded(exit 3): {:?}",
msgs
);
}
#[tokio::test]
async fn screen_to_client_eviction_notifies_old_client() {
let (_pty, handles, refresh_notify, evict_tx, evict_rx) = test_handles(true, None);
let (client, server) = tokio::net::UnixStream::pair().unwrap();
let (_r, w) = server.into_split();
let task = tokio::spawn(screen_to_client(
handles,
AnsiRenderer::new(),
refresh_notify,
evict_rx,
w,
));
evict_tx.send(false).unwrap();
let msgs = drain_server_msgs(client).await;
task.await.unwrap().unwrap();
assert!(
msgs.iter()
.any(|m| matches!(m, ServerMsg::Error(s) if s.contains("evicted"))),
"evicted client must receive an eviction Error: {:?}",
msgs
);
}
#[tokio::test]
async fn client_to_pty_detach_terminates_cleanly() {
let (_pty, handles, refresh_notify, _evict_tx, _evict_rx) = test_handles(true, None);
let (client, server) = tokio::net::UnixStream::pair().unwrap();
let (sock_reader, _w) = server.into_split();
let (_cr, mut cw) = client.into_split();
let task = tokio::spawn(client_to_pty(
handles,
sock_reader,
refresh_notify,
Vec::new(),
));
let msg = protocol::encode(&ClientMsg::Detach).unwrap();
cw.write_all(&msg).await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), task)
.await
.expect("client_to_pty must return promptly on Detach");
result.unwrap().unwrap();
}
#[tokio::test]
async fn screen_to_client_delivers_throttled_render() {
let (_pty, handles, refresh_notify, evict_tx, evict_rx) = test_handles(true, None);
let screen = handles.screen.clone();
let screen_notify = handles.screen_notify.clone();
let (client, server) = tokio::net::UnixStream::pair().unwrap();
let (_r, w) = server.into_split();
let task = tokio::spawn(screen_to_client(
handles,
AnsiRenderer::new(),
refresh_notify,
evict_rx,
w,
));
{
let mut scr = screen.lock().unwrap();
scr.process(b"HELLO_RELAY");
}
screen_notify.notify_one();
let (mut reader, _cw) = client.into_split();
let mut frames = FrameReader::new();
let mut got = false;
let deadline = tokio::time::Instant::now() + Duration::from_secs(3);
'outer: while tokio::time::Instant::now() < deadline {
let fill =
tokio::time::timeout(Duration::from_millis(500), frames.fill_from(&mut reader));
if let Ok(Ok(true)) = fill.await {
while let Some(msg) = frames.decode_next::<ServerMsg>().unwrap() {
if let ServerMsg::ScreenUpdate(data) = msg {
if data
.windows(b"HELLO_RELAY".len())
.any(|w| w == b"HELLO_RELAY")
{
got = true;
break 'outer;
}
}
}
}
}
assert!(
got,
"client must receive a ScreenUpdate carrying the rendered bytes"
);
evict_tx.send(false).unwrap();
let _ = task.await;
}
}