use bytes::Bytes;
use crossterm::{
cursor,
event::{self, Event, KeyCode, KeyEvent, KeyModifiers},
execute,
terminal::{self, ClearType},
};
use std::io::{self, Write};
use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tracing::error;
use crate::errors::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct TerminalSize {
#[serde(rename = "cols")]
pub cols: u16,
#[serde(rename = "rows")]
pub rows: u16,
}
impl Default for TerminalSize {
fn default() -> Self {
Self { cols: 80, rows: 24 }
}
}
impl TerminalSize {
pub fn current() -> Self {
terminal::size()
.map(|(cols, rows)| Self { cols, rows })
.unwrap_or_default()
}
pub fn to_json(&self) -> Result<Bytes> {
serde_json::to_vec(self)
.map(Bytes::from)
.map_err(|e| Error::Config(format!("Failed to serialize size: {}", e)))
}
}
#[derive(Debug, Clone)]
pub struct TerminalConfig {
pub input_buffer_size: usize,
pub input_batch_timeout_us: u64,
pub local_echo: bool,
pub resize_poll_interval: Duration,
}
impl Default for TerminalConfig {
fn default() -> Self {
Self {
input_buffer_size: 1024, input_batch_timeout_us: 1000, local_echo: false,
resize_poll_interval: Duration::from_millis(100), }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ControlSignal {
Interrupt,
Suspend,
Quit,
EndOfFile,
}
impl ControlSignal {
pub fn as_byte(&self) -> u8 {
match self {
ControlSignal::Interrupt => 0x03, ControlSignal::Suspend => 0x1A, ControlSignal::Quit => 0x1C, ControlSignal::EndOfFile => 0x04, }
}
}
#[derive(Debug, Clone)]
pub enum TerminalInput {
Data(Bytes),
Signal(ControlSignal),
Resize(TerminalSize),
Eof,
}
pub struct RawModeGuard {
was_raw: bool,
}
impl RawModeGuard {
fn new() -> io::Result<Self> {
let was_raw = terminal::is_raw_mode_enabled()?;
if !was_raw {
terminal::enable_raw_mode()?;
}
Ok(Self { was_raw })
}
}
impl Drop for RawModeGuard {
fn drop(&mut self) {
if !self.was_raw {
if let Err(e) = terminal::disable_raw_mode() {
eprintln!("Warning: Failed to restore terminal: {}", e);
}
}
}
}
pub struct Terminal {
config: TerminalConfig,
size: Arc<(AtomicU16, AtomicU16)>,
running: Arc<AtomicBool>,
}
impl Terminal {
pub fn new(config: TerminalConfig) -> Result<Self> {
let size = TerminalSize::current();
Ok(Self {
config,
size: Arc::new((AtomicU16::new(size.cols), AtomicU16::new(size.rows))),
running: Arc::new(AtomicBool::new(false)),
})
}
pub fn enable_raw_mode(&self) -> Result<RawModeGuard> {
RawModeGuard::new().map_err(|e| Error::Config(format!("Failed to enable raw mode: {}", e)))
}
pub fn size(&self) -> TerminalSize {
TerminalSize {
cols: self.size.0.load(Ordering::Relaxed),
rows: self.size.1.load(Ordering::Relaxed),
}
}
pub fn update_size(&self) -> Option<TerminalSize> {
let current = TerminalSize::current();
let old_cols = self.size.0.swap(current.cols, Ordering::Relaxed);
let old_rows = self.size.1.swap(current.rows, Ordering::Relaxed);
if old_cols != current.cols || old_rows != current.rows {
Some(current)
} else {
None
}
}
pub fn start_input_reader(&self) -> mpsc::Receiver<TerminalInput> {
let (tx, rx) = mpsc::channel(256);
let running = self.running.clone();
let size = self.size.clone();
let config = self.config.clone();
running.store(true, Ordering::Relaxed);
std::thread::spawn(move || {
Self::input_loop(tx, running, size, config);
});
rx
}
pub fn stop(&self) {
self.running.store(false, Ordering::Relaxed);
}
fn input_loop(
tx: mpsc::Sender<TerminalInput>,
running: Arc<AtomicBool>,
size: Arc<(AtomicU16, AtomicU16)>,
config: TerminalConfig,
) {
let mut buffer = Vec::with_capacity(config.input_buffer_size);
let batch_timeout = Duration::from_micros(config.input_batch_timeout_us);
while running.load(Ordering::Relaxed) {
match event::poll(batch_timeout) {
Ok(true) => {
match event::read() {
Ok(Event::Key(key)) => {
if let Some(input) = Self::process_key_event(key) {
match input {
TerminalInput::Data(data) => {
buffer.extend_from_slice(&data);
if buffer.len() >= config.input_buffer_size {
let _ = tx.blocking_send(TerminalInput::Data(
Bytes::from(std::mem::take(&mut buffer)),
));
}
}
other => {
if !buffer.is_empty() {
let _ = tx.blocking_send(TerminalInput::Data(
Bytes::from(std::mem::take(&mut buffer)),
));
}
let _ = tx.blocking_send(other);
}
}
}
}
Ok(Event::Resize(cols, rows)) => {
size.0.store(cols, Ordering::Relaxed);
size.1.store(rows, Ordering::Relaxed);
if !buffer.is_empty() {
let _ = tx.blocking_send(TerminalInput::Data(Bytes::from(
std::mem::take(&mut buffer),
)));
}
let _ = tx
.blocking_send(TerminalInput::Resize(TerminalSize { cols, rows }));
}
Ok(_) => {} Err(e) => {
error!("Terminal read error: {}", e);
break;
}
}
}
Ok(false) => {
if !buffer.is_empty() {
let _ = tx.blocking_send(TerminalInput::Data(Bytes::from(std::mem::take(
&mut buffer,
))));
}
}
Err(e) => {
error!("Terminal poll error: {}", e);
break;
}
}
}
let _ = tx.blocking_send(TerminalInput::Eof);
}
fn process_key_event(key: KeyEvent) -> Option<TerminalInput> {
if key.modifiers.contains(KeyModifiers::CONTROL) {
return match key.code {
KeyCode::Char('c') => Some(TerminalInput::Signal(ControlSignal::Interrupt)),
KeyCode::Char('z') => Some(TerminalInput::Signal(ControlSignal::Suspend)),
KeyCode::Char('\\') => Some(TerminalInput::Signal(ControlSignal::Quit)),
KeyCode::Char('d') => Some(TerminalInput::Signal(ControlSignal::EndOfFile)),
KeyCode::Char(c) => {
let ctrl_code = (c as u8) & 0x1F;
Some(TerminalInput::Data(Bytes::from(vec![ctrl_code])))
}
_ => None,
};
}
let bytes: Vec<u8> = match key.code {
KeyCode::Char(c) => c.to_string().into_bytes(),
KeyCode::Enter => vec![13], KeyCode::Backspace => vec![127], KeyCode::Tab => vec![9],
KeyCode::Esc => vec![27],
KeyCode::Up => vec![27, 91, 65], KeyCode::Down => vec![27, 91, 66], KeyCode::Right => vec![27, 91, 67], KeyCode::Left => vec![27, 91, 68],
KeyCode::Home => vec![27, 91, 72], KeyCode::End => vec![27, 91, 70], KeyCode::PageUp => vec![27, 91, 53, 126], KeyCode::PageDown => vec![27, 91, 54, 126], KeyCode::Insert => vec![27, 91, 50, 126], KeyCode::Delete => vec![27, 91, 51, 126],
KeyCode::F(1) => vec![27, 79, 80], KeyCode::F(2) => vec![27, 79, 81], KeyCode::F(3) => vec![27, 79, 82], KeyCode::F(4) => vec![27, 79, 83], KeyCode::F(5) => vec![27, 91, 49, 53, 126], KeyCode::F(6) => vec![27, 91, 49, 55, 126], KeyCode::F(7) => vec![27, 91, 49, 56, 126], KeyCode::F(8) => vec![27, 91, 49, 57, 126], KeyCode::F(9) => vec![27, 91, 50, 48, 126], KeyCode::F(10) => vec![27, 91, 50, 49, 126], KeyCode::F(11) => vec![27, 91, 50, 51, 126], KeyCode::F(12) => vec![27, 91, 50, 52, 126],
_ => return None,
};
Some(TerminalInput::Data(Bytes::from(bytes)))
}
pub fn write_output(data: &[u8]) -> io::Result<()> {
let mut stdout = io::stdout().lock();
stdout.write_all(data)?;
stdout.flush()
}
pub fn clear_screen() -> io::Result<()> {
execute!(
io::stdout(),
terminal::Clear(ClearType::All),
cursor::MoveTo(0, 0)
)
}
pub fn set_title(title: &str) -> io::Result<()> {
execute!(io::stdout(), terminal::SetTitle(title))
}
}
impl Drop for Terminal {
fn drop(&mut self) {
self.stop();
}
}
pub struct ShellSession {
terminal: Terminal,
last_size: TerminalSize,
}
impl ShellSession {
pub fn new(config: TerminalConfig) -> Result<Self> {
let terminal = Terminal::new(config)?;
let last_size = terminal.size();
Ok(Self {
terminal,
last_size,
})
}
pub fn terminal(&self) -> &Terminal {
&self.terminal
}
pub fn check_resize(&mut self) -> Option<TerminalSize> {
let current = self.terminal.size();
if current != self.last_size {
self.last_size = current;
Some(current)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_terminal_size_default() {
let size = TerminalSize::default();
assert_eq!(size.cols, 80);
assert_eq!(size.rows, 24);
}
#[test]
fn test_terminal_size_json() {
let size = TerminalSize {
cols: 120,
rows: 40,
};
let json = size.to_json().unwrap();
let parsed: TerminalSize = serde_json::from_slice(&json).unwrap();
assert_eq!(parsed, size);
}
#[test]
fn test_control_signal_bytes() {
assert_eq!(ControlSignal::Interrupt.as_byte(), 0x03);
assert_eq!(ControlSignal::Suspend.as_byte(), 0x1A);
assert_eq!(ControlSignal::Quit.as_byte(), 0x1C);
assert_eq!(ControlSignal::EndOfFile.as_byte(), 0x04);
}
#[test]
fn test_terminal_config_default() {
let config = TerminalConfig::default();
assert_eq!(config.input_buffer_size, 1024);
assert!(!config.local_echo);
}
#[test]
fn test_key_event_arrow_keys() {
let up = Terminal::process_key_event(KeyEvent::new(KeyCode::Up, KeyModifiers::NONE));
assert!(matches!(up, Some(TerminalInput::Data(d)) if d.as_ref() == [27, 91, 65]));
let down = Terminal::process_key_event(KeyEvent::new(KeyCode::Down, KeyModifiers::NONE));
assert!(matches!(down, Some(TerminalInput::Data(d)) if d.as_ref() == [27, 91, 66]));
}
#[test]
fn test_key_event_enter() {
let enter = Terminal::process_key_event(KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE));
assert!(matches!(enter, Some(TerminalInput::Data(d)) if d.as_ref() == [13]));
}
#[test]
fn test_key_event_ctrl_c() {
let ctrl_c =
Terminal::process_key_event(KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL));
assert!(matches!(
ctrl_c,
Some(TerminalInput::Signal(ControlSignal::Interrupt))
));
}
}