use super::constants::*;
use super::input::handle_input_event;
use super::terminal_modes::configure_terminal_modes;
use crate::pty::{
terminal::{TerminalOps, TerminalStateGuard},
PtyConfig, PtyMessage, PtyState,
};
use anyhow::{Context, Result};
use russh::{client::Msg, Channel, ChannelMsg};
use std::io::{self, Write};
use tokio::sync::{mpsc, watch};
use tokio::time::Duration;
pub struct PtySession {
pub session_id: usize,
channel: Channel<Msg>,
config: PtyConfig,
state: PtyState,
terminal_guard: Option<TerminalStateGuard>,
cancel_tx: watch::Sender<bool>,
cancel_rx: watch::Receiver<bool>,
msg_tx: Option<mpsc::Sender<PtyMessage>>,
msg_rx: Option<mpsc::Receiver<PtyMessage>>,
}
impl PtySession {
pub async fn new(session_id: usize, channel: Channel<Msg>, config: PtyConfig) -> Result<Self> {
let (msg_tx, msg_rx) = mpsc::channel(PTY_MESSAGE_CHANNEL_SIZE);
let (cancel_tx, cancel_rx) = watch::channel(false);
Ok(Self {
session_id,
channel,
config,
state: PtyState::Inactive,
terminal_guard: None,
cancel_tx,
cancel_rx,
msg_tx: Some(msg_tx),
msg_rx: Some(msg_rx),
})
}
pub fn state(&self) -> PtyState {
self.state
}
pub async fn initialize(&mut self) -> Result<()> {
self.state = PtyState::Initializing;
let (width, height) = crate::pty::utils::get_terminal_size()?;
let terminal_modes = configure_terminal_modes();
self.channel
.request_pty(
false,
&self.config.term_type,
width,
height,
0, 0, &terminal_modes, )
.await
.with_context(|| "Failed to request PTY on SSH channel")?;
self.channel
.request_shell(false)
.await
.with_context(|| "Failed to request shell on SSH channel")?;
self.state = PtyState::Active;
tracing::debug!("PTY session {} initialized", self.session_id);
Ok(())
}
pub async fn run(&mut self) -> Result<()> {
if self.state == PtyState::Inactive {
self.initialize().await?;
}
if self.state != PtyState::Active {
anyhow::bail!("PTY session is not in active state");
}
self.terminal_guard = Some(TerminalStateGuard::new()?);
if self.config.enable_mouse {
TerminalOps::enable_mouse()?;
}
let mut msg_rx = self
.msg_rx
.take()
.ok_or_else(|| anyhow::anyhow!("Message receiver already taken"))?;
let mut resize_signals = crate::pty::utils::setup_resize_handler()?;
let cancel_for_resize = self.cancel_rx.clone();
let resize_tx = self
.msg_tx
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Message sender not available"))?
.clone();
let resize_task = tokio::spawn(async move {
let mut cancel_for_resize = cancel_for_resize;
loop {
tokio::select! {
signal = async {
for signal in resize_signals.forever() {
if signal == signal_hook::consts::SIGWINCH {
return signal;
}
}
signal_hook::consts::SIGWINCH } => {
if signal == signal_hook::consts::SIGWINCH {
if let Ok((width, height)) = crate::pty::utils::get_terminal_size() {
if resize_tx.try_send(PtyMessage::Resize { width, height }).is_err() {
break;
}
}
}
}
_ = cancel_for_resize.changed() => {
if *cancel_for_resize.borrow() {
break;
}
}
}
}
});
let input_tx = self
.msg_tx
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Message sender not available"))?
.clone();
let cancel_for_input = self.cancel_rx.clone();
let input_task = tokio::task::spawn_blocking(move || {
loop {
if *cancel_for_input.borrow() {
break;
}
let poll_timeout = Duration::from_millis(INPUT_POLL_TIMEOUT_MS);
if crossterm::event::poll(poll_timeout).unwrap_or(false) {
match crossterm::event::read() {
Ok(event) => {
if let Some(data) = handle_input_event(event) {
if input_tx.try_send(PtyMessage::LocalInput(data)).is_err() {
break;
}
}
}
Err(e) => {
let _ =
input_tx.try_send(PtyMessage::Error(format!("Input error: {e}")));
break;
}
}
}
}
});
let mut should_terminate = false;
let mut cancel_rx = self.cancel_rx.clone();
while !should_terminate {
tokio::select! {
msg = self.channel.wait() => {
match msg {
Some(ChannelMsg::Data { ref data }) => {
if let Err(e) = io::stdout().write_all(data) {
tracing::error!("Failed to write to stdout: {e}");
should_terminate = true;
} else {
let _ = io::stdout().flush();
}
}
Some(ChannelMsg::ExtendedData { ref data, ext }) => {
if ext == 1 {
if let Err(e) = io::stdout().write_all(data) {
tracing::error!("Failed to write stderr to stdout: {e}");
should_terminate = true;
} else {
let _ = io::stdout().flush();
}
}
}
Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) => {
tracing::debug!("SSH channel closed");
let _ = self.cancel_tx.send(true);
should_terminate = true;
}
Some(_) => {
}
None => {
should_terminate = true;
}
}
}
message = msg_rx.recv() => {
match message {
Some(PtyMessage::LocalInput(data)) => {
if let Err(e) = self.channel.data(data.as_slice()).await {
tracing::error!("Failed to send data to SSH channel: {e}");
should_terminate = true;
}
}
Some(PtyMessage::RemoteOutput(data)) => {
if let Err(e) = io::stdout().write_all(&data) {
tracing::error!("Failed to write to stdout: {e}");
should_terminate = true;
} else {
let _ = io::stdout().flush();
}
}
Some(PtyMessage::Resize { width, height }) => {
if let Err(e) = self.channel.window_change(width, height, 0, 0).await {
tracing::warn!("Failed to send window resize to remote: {e}");
} else {
tracing::debug!("Terminal resized to {width}x{height}");
}
}
Some(PtyMessage::Terminate) => {
tracing::debug!("PTY session {} terminating", self.session_id);
should_terminate = true;
}
Some(PtyMessage::Error(error)) => {
tracing::error!("PTY error: {error}");
should_terminate = true;
}
None => {
should_terminate = true;
}
}
}
_ = cancel_rx.changed() => {
if *cancel_rx.borrow() {
tracing::debug!("PTY session {} received cancellation signal", self.session_id);
should_terminate = true;
}
}
}
}
let _ = self.cancel_tx.send(true);
let _ = tokio::time::timeout(Duration::from_millis(TASK_CLEANUP_TIMEOUT_MS), async {
tokio::select! {
_ = resize_task => {},
_ = input_task => {},
_ = tokio::time::sleep(Duration::from_millis(TASK_CLEANUP_TIMEOUT_MS)) => {
}
}
})
.await;
if self.config.enable_mouse {
let _ = TerminalOps::disable_mouse();
}
self.terminal_guard = None;
let _ = io::stdout().flush();
self.state = PtyState::Closed;
Ok(())
}
pub async fn shutdown(&mut self) -> Result<()> {
self.state = PtyState::ShuttingDown;
let _ = self.cancel_tx.send(true);
if let Err(e) = self.channel.eof().await {
tracing::warn!("Failed to send EOF to SSH channel: {e}");
}
self.terminal_guard = None;
self.state = PtyState::Closed;
Ok(())
}
}
impl Drop for PtySession {
fn drop(&mut self) {
let _ = self.cancel_tx.send(true);
}
}