use crate::protocol::{self, ServerMsg};
use crate::session::{SessionHandles, SessionManager};
use retach::screen::AnsiRenderer;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use tracing::debug;
use super::session_relay::{client_to_pty, screen_to_client};
use super::session_setup::{setup_session, ConnectRequest};
use super::shared::lock_mutex;
const BINCODE_LINE_OVERHEAD: usize = 16;
fn truncate_history_line(line: Vec<u8>, limit: usize) -> Vec<u8> {
let max = limit.saturating_sub(BINCODE_LINE_OVERHEAD);
if line.len() <= max {
return line;
}
let mut end = max;
while end > 0 && (line[end] & 0xC0) == 0x80 {
end -= 1;
}
let mut line = line;
line.truncate(end);
line
}
async fn send_history_chunks(
lines: Vec<Vec<u8>>,
writer: &mut tokio::net::unix::OwnedWriteHalf,
) -> anyhow::Result<()> {
if lines.is_empty() {
return Ok(());
}
let mut chunk = Vec::new();
let mut chunk_size = 0;
let size_limit = protocol::codec::MAX_FRAME_SIZE / 2;
for line in lines {
let line = truncate_history_line(line, size_limit);
let line_size = line.len() + BINCODE_LINE_OVERHEAD;
if chunk_size + line_size > size_limit && !chunk.is_empty() {
let msg = protocol::encode(&ServerMsg::History(std::mem::take(&mut chunk)))?;
writer.write_all(&msg).await?;
chunk_size = 0;
}
chunk_size += line_size;
chunk.push(line);
}
if !chunk.is_empty() {
let msg = protocol::encode(&ServerMsg::History(chunk))?;
writer.write_all(&msg).await?;
}
Ok(())
}
async fn send_initial_state(
handles: &SessionHandles,
is_new_session: bool,
writer: &mut tokio::net::unix::OwnedWriteHalf,
) -> anyhow::Result<AnsiRenderer> {
let connected = protocol::encode(&ServerMsg::Connected {
name: handles.name.clone(),
new_session: is_new_session,
})?;
writer.write_all(&connected).await?;
let mut renderer = AnsiRenderer::new();
let in_alt_screen;
let hist_chunks = {
let mut screen = lock_mutex(&handles.screen, "screen")?;
in_alt_screen = screen.in_alt_screen();
let hist = if in_alt_screen {
Vec::new()
} else {
screen.get_history()
};
screen.discard_pending_scrollback();
hist
};
let had_history = !hist_chunks.is_empty();
send_history_chunks(hist_chunks, writer).await?;
let (pending_chunks, screen_msg) = {
let mut screen = lock_mutex(&handles.screen, "screen")?;
let pending_rows = screen.take_pending_scrollback();
let pending = renderer.render_rows(&*screen, &pending_rows);
let pending = if in_alt_screen { Vec::new() } else { pending };
let any_history = had_history || !pending.is_empty();
let notifications = screen.take_queued_notifications();
screen.take_passthrough();
let mut render_data = Vec::new();
for notif in notifications {
render_data.extend_from_slice(¬if);
}
if any_history {
render_data.extend_from_slice(b"\x1b[");
render_data.extend_from_slice(screen.rows().to_string().as_bytes());
render_data.extend_from_slice(b";1H");
render_data.extend(std::iter::repeat_n(
b'\n',
screen.rows().saturating_sub(1) as usize,
));
}
render_data.extend_from_slice(&renderer.render(&*screen, true));
let screen_msg = protocol::encode(&ServerMsg::ScreenUpdate(render_data))?;
(pending, screen_msg)
};
send_history_chunks(pending_chunks, writer).await?;
writer.write_all(&screen_msg).await?;
Ok(renderer)
}
pub(super) async fn handle_session(
mut stream: tokio::net::UnixStream,
manager: Arc<Mutex<SessionManager>>,
req: ConnectRequest,
) -> anyhow::Result<()> {
let setup = setup_session(
&mut stream,
&manager,
&req.name,
req.history,
req.cols,
req.rows,
req.mode,
)
.await?;
let _client_guard = setup.client_guard;
let (reader, mut writer) = stream.into_split();
let renderer = send_initial_state(&setup.handles, setup.is_new_session, &mut writer).await?;
let refresh_notify = Arc::new(tokio::sync::Notify::new());
setup.handles.screen_notify.notify_one();
let mut screen_to_client_task = tokio::spawn(screen_to_client(
setup.handles.clone(),
renderer,
refresh_notify.clone(),
setup.evict_rx,
writer,
));
let mut client_to_pty_task = tokio::spawn(client_to_pty(
setup.handles,
reader,
refresh_notify,
req.leftover,
));
tokio::select! {
r = &mut screen_to_client_task => {
debug!("screen_to_client finished: {:?}", r.as_ref().map(|r| r.as_ref().map(|_| "ok")));
client_to_pty_task.abort();
r??;
}
r = &mut client_to_pty_task => {
debug!("client_to_pty finished: {:?}", r.as_ref().map(|r| r.as_ref().map(|_| "ok")));
screen_to_client_task.abort();
r??;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use retach::screen::{AnsiRenderer, Screen};
use super::super::shared::prepend_passthrough;
async fn drain_server_msgs(reader: tokio::net::UnixStream) -> Vec<ServerMsg> {
use crate::protocol::FrameReader;
let mut reader = reader;
let mut frames = FrameReader::new();
let mut out = Vec::new();
while frames.fill_from(&mut reader).await.unwrap() {
while let Some(msg) = frames.decode_next::<ServerMsg>().unwrap() {
out.push(msg);
}
}
while let Some(msg) = frames.decode_next::<ServerMsg>().unwrap() {
out.push(msg);
}
out
}
#[tokio::test]
async fn send_initial_state_delivers_scrollback_history() {
use crate::session::Session;
let mut session = Session::new("bridge-hist".into(), 80, 3, 1000).unwrap();
{
let scr = session.screen.clone();
let mut s = scr.lock().unwrap();
for i in 0..10 {
s.process(format!("LINE{}\r\n", i).as_bytes());
}
}
let (_guard, handles, _evict_rx) = session.connect();
let (client, server) = tokio::net::UnixStream::pair().unwrap();
let (_r, mut w) = server.into_split();
send_initial_state(&handles, false, &mut w).await.unwrap();
drop(w);
drop(_r);
let msgs = drain_server_msgs(client).await;
let mut history_lines: Vec<String> = Vec::new();
let mut saw_screen = false;
for m in &msgs {
match m {
ServerMsg::History(lines) => {
for l in lines {
history_lines.push(String::from_utf8_lossy(l).into_owned());
}
}
ServerMsg::ScreenUpdate(_) => saw_screen = true,
_ => {}
}
}
assert!(saw_screen, "a ScreenUpdate must follow the history");
for i in 0..7 {
let needle = format!("LINE{}", i);
assert!(
history_lines.iter().any(|h| h.contains(&needle)),
"scrollback line {} missing from History: {:?}",
needle,
history_lines
);
}
}
#[test]
fn ed3_included_in_screen_update() {
let mut screen = Screen::new(80, 24, 100);
screen.process(b"hello world");
screen.process(b"\x1b[3J");
let passthrough = screen.take_passthrough();
assert_eq!(passthrough.len(), 1);
assert_eq!(passthrough[0], b"\x1b[3J");
let mut renderer = AnsiRenderer::new();
let render_data = renderer.render(&screen, true);
let combined = prepend_passthrough(passthrough, render_data.clone());
assert!(
combined.starts_with(b"\x1b[3J"),
"passthrough should prefix screen data"
);
assert_eq!(&combined[4..], &render_data[..]);
}
#[test]
fn oversized_history_line_truncated_to_valid_utf8() {
let limit = protocol::codec::MAX_FRAME_SIZE / 2;
let line: Vec<u8> = "é".repeat(limit).into_bytes();
assert!(line.len() > limit);
let truncated = truncate_history_line(line, limit);
assert!(truncated.len() + BINCODE_LINE_OVERHEAD <= limit);
assert!(
std::str::from_utf8(&truncated).is_ok(),
"truncation must land on a UTF-8 boundary"
);
let frame = protocol::encode(&ServerMsg::History(vec![truncated])).unwrap();
assert!(frame.len() <= protocol::codec::MAX_FRAME_SIZE);
}
#[test]
fn small_history_line_not_truncated() {
let limit = protocol::codec::MAX_FRAME_SIZE / 2;
let line = b"hello world".to_vec();
assert_eq!(truncate_history_line(line.clone(), limit), line);
}
}