use std::{
io::{Read, Write},
sync::{
atomic::{AtomicBool, Ordering},
mpsc, Arc,
},
thread,
};
use crate::error::{Error, Result};
use portable_pty::{native_pty_system, CommandBuilder, PtySize};
use tracing::warn;
use windows_sys::Win32::{
Foundation::HANDLE,
System::Console::{
GetConsoleMode, GetStdHandle, SetConsoleMode, ENABLE_ECHO_INPUT, ENABLE_LINE_INPUT,
ENABLE_PROCESSED_INPUT, ENABLE_VIRTUAL_TERMINAL_INPUT, STD_INPUT_HANDLE,
},
};
use crate::{
checks::Check,
config::{Config, Settings},
env::Environment,
prompt::Prompter,
};
use super::common::{
handle_statement, is_control_passthrough, BufferResult, InputBuffer, StatementAction,
WrapperConfig,
};
struct WinRawModeGuard {
handle: HANDLE,
original_mode: u32,
}
impl WinRawModeGuard {
fn enter() -> Result<Self> {
let handle = unsafe { GetStdHandle(STD_INPUT_HANDLE) };
if handle.is_null() || handle == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE {
return Err(Error::Wrap("GetStdHandle failed".into()));
}
let mut original_mode: u32 = 0;
let ok = unsafe { GetConsoleMode(handle, &mut original_mode) };
if ok == 0 {
return Err(Error::Wrap("GetConsoleMode failed".into()));
}
let raw_mode = (original_mode
& !(ENABLE_ECHO_INPUT | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT))
| ENABLE_VIRTUAL_TERMINAL_INPUT;
let ok = unsafe { SetConsoleMode(handle, raw_mode) };
if ok == 0 {
return Err(Error::Wrap("SetConsoleMode (raw) failed".into()));
}
Ok(Self {
handle,
original_mode,
})
}
fn restore_cooked(&self) -> Result<()> {
let ok = unsafe { SetConsoleMode(self.handle, self.original_mode) };
if ok == 0 {
return Err(Error::Wrap("SetConsoleMode (cooked) failed".into()));
}
Ok(())
}
fn re_enter_raw(&self) -> Result<()> {
let raw_mode = (self.original_mode
& !(ENABLE_ECHO_INPUT | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT))
| ENABLE_VIRTUAL_TERMINAL_INPUT;
let ok = unsafe { SetConsoleMode(self.handle, raw_mode) };
if ok == 0 {
return Err(Error::Wrap("SetConsoleMode (re-raw) failed".into()));
}
Ok(())
}
}
unsafe impl Send for WinRawModeGuard {}
impl Drop for WinRawModeGuard {
fn drop(&mut self) {
unsafe {
SetConsoleMode(self.handle, self.original_mode);
}
}
}
enum OutputMsg {
ChildExited(u32),
ReadEof,
}
pub struct PtyProxy<'a> {
pub wrapper_config: WrapperConfig,
pub settings: &'a Settings,
pub checks: &'a [Check],
pub env: &'a dyn Environment,
pub prompter: &'a dyn Prompter,
pub config: &'a Config,
}
impl PtyProxy<'_> {
#[allow(clippy::too_many_lines)]
pub fn run(&self, program: &str, args: &[String]) -> Result<i32> {
let pty_system = native_pty_system();
let size = PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
};
let pair = pty_system
.openpty(size)
.map_err(|e| Error::Wrap(format!("failed to open ConPTY: {e}")))?;
let mut cmd = CommandBuilder::new(program);
for arg in args {
cmd.arg(arg);
}
let mut child = pair
.slave
.spawn_command(cmd)
.map_err(|e| Error::Wrap(format!("failed to spawn child: {e}")))?;
let mut pty_reader = pair
.master
.try_clone_reader()
.map_err(|e| Error::Wrap(format!("failed to clone PTY reader: {e}")))?;
let mut pty_writer = pair
.master
.take_writer()
.map_err(|e| Error::Wrap(format!("failed to take PTY writer: {e}")))?;
let guard = WinRawModeGuard::enter()
.map_err(|e| Error::Wrap(format!("failed to enter raw mode: {e}")))?;
let output_paused = Arc::new(AtomicBool::new(false));
let output_paused_clone = Arc::clone(&output_paused);
let (tx, rx) = mpsc::channel::<OutputMsg>();
let output_thread = thread::spawn(move || {
let mut stdout = std::io::stdout();
let mut buf = [0u8; 4096];
loop {
match pty_reader.read(&mut buf) {
Ok(0) => {
let _ = tx.send(OutputMsg::ReadEof);
break;
}
Ok(n) => {
if !output_paused_clone.load(Ordering::Acquire) {
let _ = stdout.write_all(&buf[..n]);
let _ = stdout.flush();
}
}
Err(_) => {
let _ = tx.send(OutputMsg::ReadEof);
break;
}
}
}
});
let mut stdin = std::io::stdin();
let mut input_buffer = InputBuffer::new(self.wrapper_config.delimiter);
let mut buf = [0u8; 4096];
let exit_code = loop {
match rx.try_recv() {
Ok(OutputMsg::ChildExited(code)) => {
break i32::try_from(code).unwrap_or(1);
}
Ok(OutputMsg::ReadEof) => {
match child.wait() {
Ok(status) => {
break status.exit_code().try_into().unwrap_or(1);
}
Err(_) => break 1,
}
}
Err(mpsc::TryRecvError::Empty) => {}
Err(mpsc::TryRecvError::Disconnected) => match child.wait() {
Ok(status) => {
break status.exit_code().try_into().unwrap_or(1);
}
Err(_) => break 1,
},
}
match child.try_wait() {
Ok(Some(status)) => {
break status.exit_code().try_into().unwrap_or(1);
}
Ok(None) => {} Err(_) => break 1,
}
match stdin.read(&mut buf) {
Ok(0) => break 0, Ok(n) => {
for &byte in &buf[..n] {
if is_control_passthrough(byte) {
let _ = pty_writer.write_all(&[byte]);
let _ = pty_writer.flush();
if byte == 0x03 || byte == 0x04 {
input_buffer.reset();
}
continue;
}
match input_buffer.feed(byte) {
BufferResult::Buffered => {
let _ = pty_writer.write_all(&[byte]);
let _ = pty_writer.flush();
}
BufferResult::Statement(stmt) => {
tracing::debug!(
"[wrap] statement detected ({} bytes): {:?}",
stmt.len(),
stmt
);
output_paused.store(true, Ordering::Release);
if let Err(e) = guard.restore_cooked() {
warn!("[wrap] failed to restore cooked mode: {e}");
}
let action = handle_statement(
&stmt,
self.settings,
self.checks,
self.env,
self.prompter,
self.config,
&self.wrapper_config.display_name,
);
if let Err(e) = guard.re_enter_raw() {
warn!("[wrap] failed to re-enter raw mode: {e}");
}
output_paused.store(false, Ordering::Release);
match action {
StatementAction::Forward => {
let delim = self.wrapper_config.delimiter.trigger_byte();
let _ = pty_writer.write_all(&[delim]);
let _ = pty_writer.flush();
}
StatementAction::Block => {
let _ = pty_writer.write_all(&[0x03]);
let _ = pty_writer.flush();
}
}
}
}
}
}
Err(e) => {
warn!("[wrap] read stdin error: {e}");
break 1;
}
}
};
drop(guard);
drop(pty_writer);
let _ = output_thread.join();
Ok(exit_code)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn raw_mode_calculation() {
let original: u32 = ENABLE_ECHO_INPUT | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT;
let raw = (original & !(ENABLE_ECHO_INPUT | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT))
| ENABLE_VIRTUAL_TERMINAL_INPUT;
assert_eq!(raw & ENABLE_ECHO_INPUT, 0);
assert_eq!(raw & ENABLE_LINE_INPUT, 0);
assert_eq!(raw & ENABLE_PROCESSED_INPUT, 0);
assert_ne!(raw & ENABLE_VIRTUAL_TERMINAL_INPUT, 0);
}
}