use anyhow::{Context, Result};
use crossterm::event::{Event, KeyCode, KeyEvent, KeyEventKind, KeyModifiers, MouseEvent};
use russh::{client::Msg, Channel, ChannelMsg};
use smallvec::SmallVec;
use std::io::{self, Write};
use tokio::sync::{mpsc, watch};
use tokio::time::Duration;
use super::{
terminal::{TerminalOps, TerminalStateGuard},
PtyConfig, PtyMessage, PtyState,
};
#[allow(dead_code)]
const MAX_KEY_SEQUENCE_SIZE: usize = 8;
#[allow(dead_code)]
const SSH_IO_BUFFER_SIZE: usize = 4096;
#[allow(dead_code)]
const TERMINAL_OUTPUT_CHUNK_SIZE: usize = 1024;
const CTRL_C_SEQUENCE: &[u8] = &[0x03]; const CTRL_D_SEQUENCE: &[u8] = &[0x04]; const CTRL_Z_SEQUENCE: &[u8] = &[0x1a]; const CTRL_A_SEQUENCE: &[u8] = &[0x01]; const CTRL_E_SEQUENCE: &[u8] = &[0x05]; const CTRL_U_SEQUENCE: &[u8] = &[0x15]; const CTRL_K_SEQUENCE: &[u8] = &[0x0b]; const CTRL_W_SEQUENCE: &[u8] = &[0x17]; const CTRL_L_SEQUENCE: &[u8] = &[0x0c]; const CTRL_R_SEQUENCE: &[u8] = &[0x12];
const ENTER_SEQUENCE: &[u8] = &[0x0d]; const TAB_SEQUENCE: &[u8] = &[0x09]; const BACKSPACE_SEQUENCE: &[u8] = &[0x7f]; const ESC_SEQUENCE: &[u8] = &[0x1b];
const UP_ARROW_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x41]; const DOWN_ARROW_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x42]; const RIGHT_ARROW_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x43]; const LEFT_ARROW_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x44];
const F1_SEQUENCE: &[u8] = &[0x1b, 0x4f, 0x50]; const F2_SEQUENCE: &[u8] = &[0x1b, 0x4f, 0x51]; const F3_SEQUENCE: &[u8] = &[0x1b, 0x4f, 0x52]; const F4_SEQUENCE: &[u8] = &[0x1b, 0x4f, 0x53]; const F5_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x31, 0x35, 0x7e]; const F6_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x31, 0x37, 0x7e]; const F7_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x31, 0x38, 0x7e]; const F8_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x31, 0x39, 0x7e]; const F9_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x32, 0x30, 0x7e]; const F10_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x32, 0x31, 0x7e]; const F11_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x32, 0x33, 0x7e]; const F12_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x32, 0x34, 0x7e];
const HOME_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x48]; const END_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x46]; const PAGE_UP_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x35, 0x7e]; const PAGE_DOWN_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x36, 0x7e]; const INSERT_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x32, 0x7e]; const DELETE_SEQUENCE: &[u8] = &[0x1b, 0x5b, 0x33, 0x7e];
pub struct PtySession {
pub session_id: usize,
channel: Channel<Msg>,
config: PtyConfig,
state: PtyState,
terminal_guard: Option<TerminalStateGuard>,
cancel_tx: watch::Sender<bool>,
cancel_rx: watch::Receiver<bool>,
msg_tx: Option<mpsc::Sender<PtyMessage>>,
msg_rx: Option<mpsc::Receiver<PtyMessage>>,
}
impl PtySession {
pub async fn new(session_id: usize, channel: Channel<Msg>, config: PtyConfig) -> Result<Self> {
const PTY_MESSAGE_CHANNEL_SIZE: usize = 256;
let (msg_tx, msg_rx) = mpsc::channel(PTY_MESSAGE_CHANNEL_SIZE);
let (cancel_tx, cancel_rx) = watch::channel(false);
Ok(Self {
session_id,
channel,
config,
state: PtyState::Inactive,
terminal_guard: None,
cancel_tx,
cancel_rx,
msg_tx: Some(msg_tx),
msg_rx: Some(msg_rx),
})
}
pub fn state(&self) -> PtyState {
self.state
}
pub async fn initialize(&mut self) -> Result<()> {
self.state = PtyState::Initializing;
let (width, height) = super::utils::get_terminal_size()?;
self.channel
.request_pty(
false,
&self.config.term_type,
width,
height,
0, 0, &[], )
.await
.with_context(|| "Failed to request PTY on SSH channel")?;
self.channel
.request_shell(false)
.await
.with_context(|| "Failed to request shell on SSH channel")?;
self.state = PtyState::Active;
tracing::debug!("PTY session {} initialized", self.session_id);
Ok(())
}
pub async fn run(&mut self) -> Result<()> {
if self.state == PtyState::Inactive {
self.initialize().await?;
}
if self.state != PtyState::Active {
anyhow::bail!("PTY session is not in active state");
}
self.terminal_guard = Some(TerminalStateGuard::new()?);
if self.config.enable_mouse {
TerminalOps::enable_mouse()?;
}
let mut msg_rx = self
.msg_rx
.take()
.ok_or_else(|| anyhow::anyhow!("Message receiver already taken"))?;
let mut resize_signals = super::utils::setup_resize_handler()?;
let cancel_for_resize = self.cancel_rx.clone();
let resize_tx = self
.msg_tx
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Message sender not available"))?
.clone();
let resize_task = tokio::spawn(async move {
let mut cancel_for_resize = cancel_for_resize;
loop {
tokio::select! {
signal = async {
for signal in resize_signals.forever() {
if signal == signal_hook::consts::SIGWINCH {
return signal;
}
}
signal_hook::consts::SIGWINCH } => {
if signal == signal_hook::consts::SIGWINCH {
if let Ok((width, height)) = super::utils::get_terminal_size() {
if resize_tx.try_send(PtyMessage::Resize { width, height }).is_err() {
break;
}
}
}
}
_ = cancel_for_resize.changed() => {
if *cancel_for_resize.borrow() {
break;
}
}
}
}
});
let input_tx = self
.msg_tx
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Message sender not available"))?
.clone();
let cancel_for_input = self.cancel_rx.clone();
let input_task = tokio::task::spawn_blocking(move || {
loop {
if *cancel_for_input.borrow() {
break;
}
const INPUT_POLL_TIMEOUT_MS: u64 = 500;
let poll_timeout = Duration::from_millis(INPUT_POLL_TIMEOUT_MS);
if crossterm::event::poll(poll_timeout).unwrap_or(false) {
match crossterm::event::read() {
Ok(event) => {
if let Some(data) = Self::handle_input_event(event) {
if input_tx.try_send(PtyMessage::LocalInput(data)).is_err() {
break;
}
}
}
Err(e) => {
let _ =
input_tx.try_send(PtyMessage::Error(format!("Input error: {e}")));
break;
}
}
}
}
});
let mut should_terminate = false;
let mut cancel_rx = self.cancel_rx.clone();
while !should_terminate {
tokio::select! {
msg = self.channel.wait() => {
match msg {
Some(ChannelMsg::Data { ref data }) => {
if let Err(e) = io::stdout().write_all(data) {
tracing::error!("Failed to write to stdout: {e}");
should_terminate = true;
} else {
let _ = io::stdout().flush();
}
}
Some(ChannelMsg::ExtendedData { ref data, ext }) => {
if ext == 1 {
if let Err(e) = io::stdout().write_all(data) {
tracing::error!("Failed to write stderr to stdout: {e}");
should_terminate = true;
} else {
let _ = io::stdout().flush();
}
}
}
Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) => {
tracing::debug!("SSH channel closed");
let _ = self.cancel_tx.send(true);
should_terminate = true;
}
Some(_) => {
}
None => {
should_terminate = true;
}
}
}
message = msg_rx.recv() => {
match message {
Some(PtyMessage::LocalInput(data)) => {
if let Err(e) = self.channel.data(data.as_slice()).await {
tracing::error!("Failed to send data to SSH channel: {e}");
should_terminate = true;
}
}
Some(PtyMessage::RemoteOutput(data)) => {
if let Err(e) = io::stdout().write_all(&data) {
tracing::error!("Failed to write to stdout: {e}");
should_terminate = true;
} else {
let _ = io::stdout().flush();
}
}
Some(PtyMessage::Resize { width, height }) => {
if let Err(e) = self.channel.window_change(width, height, 0, 0).await {
tracing::warn!("Failed to send window resize to remote: {e}");
} else {
tracing::debug!("Terminal resized to {width}x{height}");
}
}
Some(PtyMessage::Terminate) => {
tracing::debug!("PTY session {} terminating", self.session_id);
should_terminate = true;
}
Some(PtyMessage::Error(error)) => {
tracing::error!("PTY error: {error}");
should_terminate = true;
}
None => {
should_terminate = true;
}
}
}
_ = cancel_rx.changed() => {
if *cancel_rx.borrow() {
tracing::debug!("PTY session {} received cancellation signal", self.session_id);
should_terminate = true;
}
}
}
}
let _ = self.cancel_tx.send(true);
const TASK_CLEANUP_TIMEOUT_MS: u64 = 100;
let _ = tokio::time::timeout(Duration::from_millis(TASK_CLEANUP_TIMEOUT_MS), async {
tokio::select! {
_ = resize_task => {},
_ = input_task => {},
_ = tokio::time::sleep(Duration::from_millis(TASK_CLEANUP_TIMEOUT_MS)) => {
}
}
})
.await;
if self.config.enable_mouse {
let _ = TerminalOps::disable_mouse();
}
self.terminal_guard = None;
let _ = io::stdout().flush();
self.state = PtyState::Closed;
Ok(())
}
fn handle_input_event(event: Event) -> Option<SmallVec<[u8; 8]>> {
match event {
Event::Key(key_event) => {
if key_event.kind != KeyEventKind::Press {
return None;
}
Self::key_event_to_bytes(key_event)
}
Event::Mouse(mouse_event) => {
Self::mouse_event_to_bytes(mouse_event)
}
Event::Resize(_width, _height) => {
None
}
_ => None,
}
}
fn key_event_to_bytes(key_event: KeyEvent) -> Option<SmallVec<[u8; 8]>> {
match key_event {
KeyEvent {
code: KeyCode::Char(c),
modifiers: KeyModifiers::CONTROL,
..
} => {
match c {
'c' | 'C' => Some(SmallVec::from_slice(CTRL_C_SEQUENCE)), 'd' | 'D' => Some(SmallVec::from_slice(CTRL_D_SEQUENCE)), 'z' | 'Z' => Some(SmallVec::from_slice(CTRL_Z_SEQUENCE)), 'a' | 'A' => Some(SmallVec::from_slice(CTRL_A_SEQUENCE)), 'e' | 'E' => Some(SmallVec::from_slice(CTRL_E_SEQUENCE)), 'u' | 'U' => Some(SmallVec::from_slice(CTRL_U_SEQUENCE)), 'k' | 'K' => Some(SmallVec::from_slice(CTRL_K_SEQUENCE)), 'w' | 'W' => Some(SmallVec::from_slice(CTRL_W_SEQUENCE)), 'l' | 'L' => Some(SmallVec::from_slice(CTRL_L_SEQUENCE)), 'r' | 'R' => Some(SmallVec::from_slice(CTRL_R_SEQUENCE)), _ => {
let byte = (c.to_ascii_lowercase() as u8).saturating_sub(b'a' - 1);
if byte <= 26 {
Some(SmallVec::from_slice(&[byte]))
} else {
None
}
}
}
}
KeyEvent {
code: KeyCode::Char(c),
modifiers: KeyModifiers::NONE,
..
} => {
let bytes = c.to_string().into_bytes();
Some(SmallVec::from_slice(&bytes))
}
KeyEvent {
code: KeyCode::Enter,
..
} => Some(SmallVec::from_slice(ENTER_SEQUENCE)),
KeyEvent {
code: KeyCode::Tab, ..
} => Some(SmallVec::from_slice(TAB_SEQUENCE)),
KeyEvent {
code: KeyCode::Backspace,
..
} => Some(SmallVec::from_slice(BACKSPACE_SEQUENCE)),
KeyEvent {
code: KeyCode::Esc, ..
} => Some(SmallVec::from_slice(ESC_SEQUENCE)),
KeyEvent {
code: KeyCode::Up, ..
} => Some(SmallVec::from_slice(UP_ARROW_SEQUENCE)),
KeyEvent {
code: KeyCode::Down,
..
} => Some(SmallVec::from_slice(DOWN_ARROW_SEQUENCE)),
KeyEvent {
code: KeyCode::Right,
..
} => Some(SmallVec::from_slice(RIGHT_ARROW_SEQUENCE)),
KeyEvent {
code: KeyCode::Left,
..
} => Some(SmallVec::from_slice(LEFT_ARROW_SEQUENCE)),
KeyEvent {
code: KeyCode::F(n),
..
} => {
match n {
1 => Some(SmallVec::from_slice(F1_SEQUENCE)), 2 => Some(SmallVec::from_slice(F2_SEQUENCE)), 3 => Some(SmallVec::from_slice(F3_SEQUENCE)), 4 => Some(SmallVec::from_slice(F4_SEQUENCE)), 5 => Some(SmallVec::from_slice(F5_SEQUENCE)), 6 => Some(SmallVec::from_slice(F6_SEQUENCE)), 7 => Some(SmallVec::from_slice(F7_SEQUENCE)), 8 => Some(SmallVec::from_slice(F8_SEQUENCE)), 9 => Some(SmallVec::from_slice(F9_SEQUENCE)), 10 => Some(SmallVec::from_slice(F10_SEQUENCE)), 11 => Some(SmallVec::from_slice(F11_SEQUENCE)), 12 => Some(SmallVec::from_slice(F12_SEQUENCE)), _ => None, }
}
KeyEvent {
code: KeyCode::Home,
..
} => Some(SmallVec::from_slice(HOME_SEQUENCE)),
KeyEvent {
code: KeyCode::End, ..
} => Some(SmallVec::from_slice(END_SEQUENCE)),
KeyEvent {
code: KeyCode::PageUp,
..
} => Some(SmallVec::from_slice(PAGE_UP_SEQUENCE)),
KeyEvent {
code: KeyCode::PageDown,
..
} => Some(SmallVec::from_slice(PAGE_DOWN_SEQUENCE)),
KeyEvent {
code: KeyCode::Insert,
..
} => Some(SmallVec::from_slice(INSERT_SEQUENCE)),
KeyEvent {
code: KeyCode::Delete,
..
} => Some(SmallVec::from_slice(DELETE_SEQUENCE)),
_ => None,
}
}
fn mouse_event_to_bytes(_mouse_event: MouseEvent) -> Option<SmallVec<[u8; 8]>> {
None
}
pub async fn shutdown(&mut self) -> Result<()> {
self.state = PtyState::ShuttingDown;
let _ = self.cancel_tx.send(true);
if let Err(e) = self.channel.eof().await {
tracing::warn!("Failed to send EOF to SSH channel: {e}");
}
self.terminal_guard = None;
self.state = PtyState::Closed;
Ok(())
}
}
impl Drop for PtySession {
fn drop(&mut self) {
let _ = self.cancel_tx.send(true);
}
}