use anyhow::{Context, Result};
use crossterm::terminal::{self, disable_raw_mode, enable_raw_mode};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TerminalSize {
pub cols: u16,
pub rows: u16,
}
impl TerminalSize {
pub fn current() -> Result<Self> {
let (cols, rows) = terminal::size().context("failed to get terminal size")?;
Ok(Self { cols, rows })
}
}
pub struct RawModeGuard {
_private: (),
}
impl RawModeGuard {
pub fn new() -> Result<Self> {
enable_raw_mode().context("failed to enable raw mode")?;
Ok(Self { _private: () })
}
}
impl Drop for RawModeGuard {
fn drop(&mut self) {
let _ = disable_raw_mode();
}
}
pub struct ResizeWatcher {
rx: mpsc::UnboundedReceiver<TerminalSize>,
_shutdown: Arc<AtomicBool>,
}
impl ResizeWatcher {
#[cfg(unix)]
pub fn new() -> Result<Self> {
use tokio::signal::unix::{SignalKind, signal};
let (tx, rx) = mpsc::unbounded_channel();
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
tokio::spawn(async move {
let mut sigwinch = match signal(SignalKind::window_change()) {
Ok(s) => s,
Err(e) => {
tracing::warn!("Failed to register SIGWINCH handler: {}", e);
return;
}
};
while !shutdown_clone.load(Ordering::Relaxed) {
if sigwinch.recv().await.is_some() {
if let Ok(size) = TerminalSize::current() {
if tx.send(size).is_err() {
break;
}
}
}
}
});
Ok(Self {
rx,
_shutdown: shutdown,
})
}
pub async fn recv(&mut self) -> Option<TerminalSize> {
self.rx.recv().await
}
}
pub struct InteractiveSession<R, W> {
reader: R,
writer: W,
tty: bool,
shutdown: Arc<AtomicBool>,
}
impl<R, W> InteractiveSession<R, W>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
pub fn new(reader: R, writer: W, tty: bool) -> Self {
Self {
reader,
writer,
tty,
shutdown: Arc::new(AtomicBool::new(false)),
}
}
pub async fn run(self) -> Result<()> {
let _raw_guard = if self.tty {
Some(RawModeGuard::new()?)
} else {
None
};
let initial_size = if self.tty {
TerminalSize::current().ok()
} else {
None
};
if let Some(size) = initial_size {
tracing::debug!("Initial terminal size: {}x{}", size.cols, size.rows);
}
let shutdown = self.shutdown.clone();
let shutdown_stdin = self.shutdown.clone();
let mut writer = self.writer;
let stdin_task = tokio::spawn(async move {
let mut stdin = tokio::io::stdin();
let mut buf = [0u8; 1024];
loop {
if shutdown_stdin.load(Ordering::Relaxed) {
break;
}
tokio::select! {
result = stdin.read(&mut buf) => {
match result {
Ok(0) => break, Ok(n) => {
if n >= 2 && buf[0] == 0x10 && buf[1] == 0x11 {
tracing::debug!("Detach sequence detected");
break;
}
if let Err(e) = writer.write_all(&buf[..n]).await {
tracing::debug!("Failed to write to remote: {}", e);
break;
}
if let Err(e) = writer.flush().await {
tracing::debug!("Failed to flush remote: {}", e);
break;
}
}
Err(e) => {
tracing::debug!("Failed to read stdin: {}", e);
break;
}
}
}
}
}
});
let mut reader = self.reader;
let stdout_task = tokio::spawn(async move {
let mut stdout = tokio::io::stdout();
let mut buf = [0u8; 4096];
loop {
if shutdown.load(Ordering::Relaxed) {
break;
}
match reader.read(&mut buf).await {
Ok(0) => break, Ok(n) => {
if let Err(e) = stdout.write_all(&buf[..n]).await {
tracing::debug!("Failed to write to stdout: {}", e);
break;
}
if let Err(e) = stdout.flush().await {
tracing::debug!("Failed to flush stdout: {}", e);
break;
}
}
Err(e) => {
tracing::debug!("Failed to read from remote: {}", e);
break;
}
}
}
});
tokio::select! {
_ = stdin_task => {
tracing::debug!("stdin task completed");
}
_ = stdout_task => {
tracing::debug!("stdout task completed");
}
}
self.shutdown.store(true, Ordering::Relaxed);
Ok(())
}
}
pub struct SyncTerminalIO {
shutdown: Arc<AtomicBool>,
}
impl SyncTerminalIO {
pub fn new() -> Self {
Self {
shutdown: Arc::new(AtomicBool::new(false)),
}
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
}
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(Ordering::Relaxed)
}
}
impl Default for SyncTerminalIO {
fn default() -> Self {
Self::new()
}
}
pub type ResizeCallback = Box<dyn Fn(TerminalSize) + Send + 'static>;
#[derive(Default)]
pub struct AttachConfig {
pub tty: bool,
pub stdin: bool,
pub stdout: bool,
pub stderr: bool,
pub on_resize: Option<ResizeCallback>,
}
impl AttachConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_tty(mut self, tty: bool) -> Self {
self.tty = tty;
self
}
pub fn with_stdin(mut self, stdin: bool) -> Self {
self.stdin = stdin;
self
}
pub fn with_stdout(mut self, stdout: bool) -> Self {
self.stdout = stdout;
self
}
pub fn with_stderr(mut self, stderr: bool) -> Self {
self.stderr = stderr;
self
}
pub fn with_resize_callback<F>(mut self, callback: F) -> Self
where
F: Fn(TerminalSize) + Send + 'static,
{
self.on_resize = Some(Box::new(callback));
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_terminal_size_current() {
let _ = TerminalSize::current();
}
#[test]
fn test_attach_config_builder() {
let config = AttachConfig::new()
.with_tty(true)
.with_stdin(true)
.with_stdout(true)
.with_stderr(false);
assert!(config.tty);
assert!(config.stdin);
assert!(config.stdout);
assert!(!config.stderr);
}
#[test]
fn test_sync_terminal_io_shutdown() {
let io = SyncTerminalIO::new();
assert!(!io.is_shutdown());
io.shutdown();
assert!(io.is_shutdown());
}
}