use anyhow::{Context, Result};
use std::io::{Read, Write};
use std::mem::{size_of, zeroed};
use std::ptr::null_mut;
use std::thread;
use tokio::sync::broadcast;
use windows::core::PWSTR;
use windows::Win32::Foundation::{
CloseHandle, SetHandleInformation, HANDLE, HANDLE_FLAGS, HANDLE_FLAG_INHERIT,
INVALID_HANDLE_VALUE,
};
use windows::Win32::Security::SECURITY_ATTRIBUTES;
use windows::Win32::Storage::FileSystem::{ReadFile, WriteFile};
use windows::Win32::System::Console::{
ClosePseudoConsole, CreatePseudoConsole, GetConsoleScreenBufferInfo, GetStdHandle,
ResizePseudoConsole, CONSOLE_SCREEN_BUFFER_INFO, COORD, HPCON, PSEUDOCONSOLE_INHERIT_CURSOR,
STD_OUTPUT_HANDLE,
};
use windows::Win32::System::Pipes::CreatePipe;
use windows::Win32::System::Threading::{
CreateProcessW, InitializeProcThreadAttributeList, UpdateProcThreadAttribute,
WaitForSingleObject, EXTENDED_STARTUPINFO_PRESENT, LPPROC_THREAD_ATTRIBUTE_LIST,
PROCESS_CREATION_FLAGS, PROCESS_INFORMATION, PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE,
STARTUPINFOEXW,
};
use crate::core::event_router::{CaptureEvent, Event, EventRouter, LifecycleEvent};
fn parent_console_size() -> COORD {
const FALLBACK: COORD = COORD { X: 120, Y: 30 };
unsafe {
let handle = match GetStdHandle(STD_OUTPUT_HANDLE) {
Ok(h) => h,
Err(_) => return FALLBACK,
};
let mut info: CONSOLE_SCREEN_BUFFER_INFO = zeroed();
if GetConsoleScreenBufferInfo(handle, &mut info).is_err() {
return FALLBACK;
}
let cols = info.srWindow.Right - info.srWindow.Left + 1;
let rows = info.srWindow.Bottom - info.srWindow.Top + 1;
if cols <= 0 || rows <= 0 {
return FALLBACK;
}
COORD { X: cols, Y: rows }
}
}
pub struct PtyShell {
hpc: HPCON,
process_info: PROCESS_INFORMATION,
pipe_pty_out: HANDLE,
pipe_pty_in: HANDLE,
_attr_list: Vec<u8>,
}
unsafe impl Send for PtyShell {}
struct OwnedHandle(HANDLE);
impl OwnedHandle {
fn new(handle: HANDLE) -> Self {
Self(handle)
}
fn as_raw(&self) -> HANDLE {
self.0
}
fn take(mut self) -> HANDLE {
let handle = self.0;
self.0 = INVALID_HANDLE_VALUE;
handle
}
}
impl Drop for OwnedHandle {
fn drop(&mut self) {
if self.0 != INVALID_HANDLE_VALUE {
unsafe {
let _ = CloseHandle(self.0);
}
}
}
}
impl Drop for PtyShell {
fn drop(&mut self) {
unsafe {
let _ = CloseHandle(self.pipe_pty_in);
let _ = CloseHandle(self.pipe_pty_out);
ClosePseudoConsole(self.hpc);
WaitForSingleObject(self.process_info.hProcess, 100);
let _ = CloseHandle(self.process_info.hProcess);
let _ = CloseHandle(self.process_info.hThread);
}
}
}
impl PtyShell {
pub fn spawn(program: &str) -> Result<Self> {
unsafe { Self::spawn_impl(program) }
}
unsafe fn spawn_impl(program: &str) -> Result<Self> {
let mut pipe_in_read = INVALID_HANDLE_VALUE;
let mut pipe_in_write = INVALID_HANDLE_VALUE;
let mut pipe_out_read = INVALID_HANDLE_VALUE;
let mut pipe_out_write = INVALID_HANDLE_VALUE;
let sa = SECURITY_ATTRIBUTES {
nLength: size_of::<SECURITY_ATTRIBUTES>() as u32,
bInheritHandle: true.into(),
lpSecurityDescriptor: null_mut(),
};
CreatePipe(&mut pipe_in_read, &mut pipe_in_write, Some(&sa), 0)
.context("Failed to create input pipe")?;
let pipe_in_read = OwnedHandle::new(pipe_in_read);
let pipe_in_write = OwnedHandle::new(pipe_in_write);
CreatePipe(&mut pipe_out_read, &mut pipe_out_write, Some(&sa), 0)
.context("Failed to create output pipe")?;
let pipe_out_read = OwnedHandle::new(pipe_out_read);
let pipe_out_write = OwnedHandle::new(pipe_out_write);
SetHandleInformation(
pipe_in_write.as_raw(),
HANDLE_FLAG_INHERIT.0,
HANDLE_FLAGS(0),
)
.context("Failed to clear inherit on pipe_in_write")?;
SetHandleInformation(
pipe_out_read.as_raw(),
HANDLE_FLAG_INHERIT.0,
HANDLE_FLAGS(0),
)
.context("Failed to clear inherit on pipe_out_read")?;
let size = parent_console_size();
let hpc = CreatePseudoConsole(
size,
pipe_in_read.as_raw(), pipe_out_write.as_raw(), PSEUDOCONSOLE_INHERIT_CURSOR,
)
.context("Failed to create pseudo-console")?;
drop(pipe_in_read);
drop(pipe_out_write);
let mut attr_list_size: usize = 0;
let _ = InitializeProcThreadAttributeList(None, 1, None, &mut attr_list_size);
let mut attr_list: Vec<u8> = vec![0; attr_list_size];
let attr_list_ptr = LPPROC_THREAD_ATTRIBUTE_LIST(attr_list.as_mut_ptr() as *mut _);
InitializeProcThreadAttributeList(Some(attr_list_ptr), 1, None, &mut attr_list_size)
.context("Failed to initialize proc thread attribute list")?;
UpdateProcThreadAttribute(
attr_list_ptr,
0,
PROC_THREAD_ATTRIBUTE_PSEUDOCONSOLE as usize,
Some(hpc.0 as *const std::ffi::c_void),
size_of::<HPCON>(),
None,
None,
)
.context("Failed to update proc thread attribute")?;
use windows::Win32::System::Threading::STARTF_USESTDHANDLES;
let mut startup_info: STARTUPINFOEXW = zeroed();
startup_info.StartupInfo.cb = size_of::<STARTUPINFOEXW>() as u32;
startup_info.StartupInfo.dwFlags = STARTF_USESTDHANDLES;
startup_info.StartupInfo.hStdInput = INVALID_HANDLE_VALUE;
startup_info.StartupInfo.hStdOutput = INVALID_HANDLE_VALUE;
startup_info.StartupInfo.hStdError = INVALID_HANDLE_VALUE;
startup_info.lpAttributeList = attr_list_ptr;
let mut process_info: PROCESS_INFORMATION = zeroed();
let mut cmd_wide: Vec<u16> = program.encode_utf16().chain(std::iter::once(0)).collect();
let creation_flags = PROCESS_CREATION_FLAGS(EXTENDED_STARTUPINFO_PRESENT.0);
CreateProcessW(
None,
Some(PWSTR(cmd_wide.as_mut_ptr())),
None,
None,
false, creation_flags,
None,
None,
&startup_info.StartupInfo,
&mut process_info,
)
.context(format!("Failed to create process: {}", program))?;
log::debug!(
"ConPTY spawned: hpc={:?}, pid={}, pipe_in={:?}, pipe_out={:?}",
hpc.0,
process_info.dwProcessId,
pipe_in_write.as_raw().0,
pipe_out_read.as_raw().0
);
Ok(Self {
hpc,
process_info,
pipe_pty_out: pipe_out_read.take(),
pipe_pty_in: pipe_in_write.take(),
_attr_list: attr_list,
})
}
pub fn get_writer(&self) -> Result<PtyWriter> {
Ok(PtyWriter {
handle: self.pipe_pty_in,
})
}
pub fn get_reader(&self) -> Result<PtyReader> {
Ok(PtyReader {
handle: self.pipe_pty_out,
})
}
pub fn get_resizer(&self) -> Result<PtyResizer> {
Ok(PtyResizer { hpc: self.hpc })
}
pub fn forward_output(
&mut self,
mut event_rx: broadcast::Receiver<Event>,
router: EventRouter,
) -> Result<()> {
let reader = self.get_reader()?;
let mut writer = self.get_writer()?;
let mut stdout = std::io::stdout();
let mut buf = [0u8; 4096];
let mut pending_data = Vec::new();
let mut responded_to_cpr = false;
let mut shutdown_received = false;
loop {
match event_rx.try_recv() {
Ok(Event::Lifecycle(LifecycleEvent::Shutdown)) => {
log::debug!("Shell forwarder received shutdown signal");
shutdown_received = true;
break;
}
Ok(_) => {} Err(broadcast::error::TryRecvError::Empty) => {}
Err(broadcast::error::TryRecvError::Closed) => break,
Err(broadcast::error::TryRecvError::Lagged(_)) => {}
}
match reader.read_timeout(&mut buf, 10) {
Ok(0) => {
unsafe {
if WaitForSingleObject(self.process_info.hProcess, 0).0 == 0 {
break; }
}
thread::sleep(std::time::Duration::from_millis(10));
}
Ok(n) => {
let data = &buf[..n];
if !responded_to_cpr {
pending_data.extend_from_slice(data);
if pending_data
.windows(4)
.any(|w| w == b"\x1b[6n" || w == [0x1b, b'[', b'6', b'n'])
{
log::debug!("Responding to cursor position request");
let _ = writer.write_all(b"\x1b[1;1R");
let _ = writer.flush();
responded_to_cpr = true;
}
if pending_data.len() > 500 {
responded_to_cpr = true;
}
}
stdout.write_all(data)?;
stdout.flush()?;
}
Err(_) => {
unsafe {
if WaitForSingleObject(self.process_info.hProcess, 0).0 == 0 {
break; }
}
thread::sleep(std::time::Duration::from_millis(10));
}
}
}
if !shutdown_received {
router.send(Event::Capture(CaptureEvent::Stop));
}
Ok(())
}
}
pub struct PtyResizer {
hpc: HPCON,
}
unsafe impl Send for PtyResizer {}
impl PtyResizer {
pub fn resize(&self, rows: u16, cols: u16) -> Result<()> {
let size = COORD {
X: cols as i16,
Y: rows as i16,
};
unsafe { ResizePseudoConsole(self.hpc, size) }
.map_err(|e| anyhow::anyhow!("ResizePseudoConsole failed: {}", e))?;
Ok(())
}
}
pub struct PtyWriter {
handle: HANDLE,
}
unsafe impl Send for PtyWriter {}
unsafe impl Sync for PtyWriter {}
impl Write for PtyWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut bytes_written: u32 = 0;
unsafe {
WriteFile(self.handle, Some(buf), Some(&mut bytes_written), None)
.map_err(|e: windows::core::Error| std::io::Error::other(e.to_string()))?;
}
log::debug!(
"PtyWriter::write: requested {} bytes, wrote {} bytes",
buf.len(),
bytes_written
);
Ok(bytes_written as usize)
}
fn flush(&mut self) -> std::io::Result<()> {
use windows::Win32::Storage::FileSystem::FlushFileBuffers;
unsafe {
let _ = FlushFileBuffers(self.handle);
}
Ok(())
}
}
pub struct PtyReader {
handle: HANDLE,
}
unsafe impl Send for PtyReader {}
unsafe impl Sync for PtyReader {}
impl PtyReader {
fn read_timeout(&self, buf: &mut [u8], _timeout_ms: u32) -> std::io::Result<usize> {
use windows::Win32::System::Pipes::PeekNamedPipe;
let mut bytes_available: u32 = 0;
let mut total_bytes_avail: u32 = 0;
unsafe {
match PeekNamedPipe(
self.handle,
None,
0,
None,
Some(&mut bytes_available),
Some(&mut total_bytes_avail),
) {
Ok(_) => {}
Err(e) => {
log::debug!("PeekNamedPipe error: {}", e);
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
e.to_string(),
));
}
}
}
if bytes_available == 0 {
return Ok(0);
}
let mut bytes_read: u32 = 0;
let to_read = std::cmp::min(buf.len() as u32, bytes_available);
unsafe {
ReadFile(
self.handle,
Some(&mut buf[..to_read as usize]),
Some(&mut bytes_read),
None,
)
.map_err(|e: windows::core::Error| {
log::debug!("ReadFile error: {}", e);
std::io::Error::other(e.to_string())
})?;
}
Ok(bytes_read as usize)
}
}
impl Read for PtyReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.read_timeout(buf, 0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pty_shell_spawn() {
let shell = PtyShell::spawn("cmd.exe");
assert!(
shell.is_ok(),
"Failed to spawn PTY shell: {:?}",
shell.err()
);
let shell = shell.unwrap();
assert!(shell.get_reader().is_ok(), "Failed to get reader");
assert!(shell.get_writer().is_ok(), "Failed to get writer");
drop(shell);
}
#[test]
fn test_pty_interactive() {
let _ = env_logger::builder().is_test(true).try_init();
println!("\n=== ConPTY Interactive Test ===");
let shell = PtyShell::spawn("cmd.exe").expect("Failed to spawn shell");
let reader = shell.get_reader().expect("Failed to get reader");
let mut writer = shell.get_writer().expect("Failed to get writer");
let mut buf = [0u8; 4096];
let mut all_output = Vec::new();
let mut responded_to_cpr = false;
println!("Reading initial output...");
for i in 0..50 {
match reader.read_timeout(&mut buf, 100) {
Ok(n) if n > 0 => {
all_output.extend_from_slice(&buf[..n]);
print!("[{}:{}b]", i, n);
std::io::Write::flush(&mut std::io::stdout()).ok();
if !responded_to_cpr {
let output_str = String::from_utf8_lossy(&all_output);
if output_str.contains("\x1b[6n") {
println!("\nResponding to cursor position query...");
let written = writer
.write(b"\x1b[1;1R")
.expect("Failed to write cursor pos");
println!("Wrote {} bytes for cursor position response", written);
writer.flush().expect("Failed to flush");
responded_to_cpr = true;
}
}
}
_ => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
}
if i > 20 {
let output_str = String::from_utf8_lossy(&all_output);
if output_str.contains(">") {
println!("\nFound prompt, shell ready");
break;
}
}
}
println!(
"\nAll output so far ({} bytes):\n{:?}",
all_output.len(),
String::from_utf8_lossy(&all_output)
);
println!("\nSending: echo TEST_OUTPUT_12345");
let written = writer
.write(b"echo TEST_OUTPUT_12345\r\n")
.expect("Failed to write");
println!("Wrote {} bytes for echo command", written);
writer.flush().expect("Failed to flush");
println!("Reading response...");
let mut response = Vec::new();
for i in 0..30 {
match reader.read_timeout(&mut buf, 100) {
Ok(n) if n > 0 => {
response.extend_from_slice(&buf[..n]);
print!("[{}:{}b]", i, n);
std::io::Write::flush(&mut std::io::stdout()).ok();
}
_ => {
std::thread::sleep(std::time::Duration::from_millis(50));
}
}
}
let response_str = String::from_utf8_lossy(&response);
println!(
"\nResponse ({} bytes):\n{:?}",
response.len(),
&response_str
);
assert!(
response_str.contains("TEST_OUTPUT_12345"),
"Expected response to contain TEST_OUTPUT_12345"
);
println!("\n=== Test Passed ===\n");
}
}