use portable_pty::{CommandBuilder, MasterPty};
use std::io::{Read, Write};
use std::path::Path;
pub fn read_pty_output(
reader: Box<dyn Read + Send>,
writer: Box<dyn Write + Send>,
master: Box<dyn MasterPty + Send>,
child: &mut Box<dyn portable_pty::Child + Send + Sync>,
) -> (String, i32) {
#[cfg(unix)]
{
let _ = master; drop(writer);
let mut reader = reader;
let mut buf = String::new();
reader.read_to_string(&mut buf).unwrap();
let exit_status = child.wait().unwrap();
(buf, exit_status.exit_code() as i32)
}
#[cfg(windows)]
{
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, mpsc};
use std::thread;
use std::time::Duration;
let should_stop = Arc::new(AtomicBool::new(false));
let should_stop_reader = should_stop.clone();
let (tx, rx) = mpsc::channel();
let read_thread = thread::spawn(move || {
let mut reader = reader;
let mut writer = writer;
let mut output = Vec::new();
let mut temp_buf = [0u8; 4096];
loop {
if should_stop_reader.load(Ordering::Relaxed) {
break;
}
match reader.read(&mut temp_buf) {
Ok(0) => {
break;
}
Ok(n) => {
let chunk = &temp_buf[..n];
output.extend_from_slice(chunk);
if let Some(pos) = find_cursor_request(chunk) {
let response = b"\x1b[1;1R";
let _ = writer.write_all(response);
let _ = writer.flush();
eprintln!(
"ConPTY: Responded to cursor position request at byte {}",
pos
);
}
}
Err(e) => {
if e.kind() == std::io::ErrorKind::WouldBlock {
thread::sleep(Duration::from_millis(10));
continue;
}
eprintln!("ConPTY: Read error: {}", e);
break;
}
}
}
let _ = tx.send(output);
});
let exit_status = child.wait().unwrap();
let exit_code = exit_status.exit_code() as i32;
should_stop.store(true, Ordering::Relaxed);
let close_thread = thread::spawn(move || {
drop(master);
});
let output = match rx.recv_timeout(Duration::from_secs(10)) {
Ok(data) => data,
Err(_) => {
eprintln!("ConPTY: Read thread timed out after child exit");
Vec::new()
}
};
drop(close_thread);
drop(read_thread);
let buf = String::from_utf8_lossy(&output).to_string();
(buf, exit_code)
}
}
#[cfg(windows)]
fn find_cursor_request(data: &[u8]) -> Option<usize> {
let pattern = b"\x1b[6n";
data.windows(pattern.len())
.position(|window| window == pattern)
}
pub fn build_pty_command(
command: &str,
args: &[&str],
working_dir: &Path,
env_vars: &[(String, String)],
home_dir: Option<&Path>,
) -> CommandBuilder {
let mut cmd = CommandBuilder::new(command);
for arg in args {
cmd.arg(*arg);
}
cmd.cwd(working_dir);
super::configure_pty_command(&mut cmd);
for (key, value) in env_vars {
cmd.env(key, value);
}
if let Some(home) = home_dir {
cmd.env("HOME", home.to_string_lossy().to_string());
cmd.env(
"XDG_CONFIG_HOME",
home.join(".config").to_string_lossy().to_string(),
);
#[cfg(windows)]
cmd.env("USERPROFILE", home.to_string_lossy().to_string());
cmd.env("WORKTRUNK_TEST_NUSHELL_ENV", "0");
}
cmd
}
pub fn exec_cmd_in_pty(cmd: CommandBuilder, input: &str) -> (String, i32) {
let pair = super::open_pty();
let mut child = pair.slave.spawn_command(cmd).unwrap();
drop(pair.slave);
let reader = pair.master.try_clone_reader().unwrap();
let mut writer = pair.master.take_writer().unwrap();
if !input.is_empty() {
writer.write_all(input.as_bytes()).unwrap();
writer.flush().unwrap();
}
let (buf, exit_code) = read_pty_output(reader, writer, pair.master, &mut child);
let normalized = buf.replace("\r\n", "\n");
(normalized, exit_code)
}
pub fn exec_cmd_in_pty_prompted(
cmd: CommandBuilder,
inputs: &[&str],
prompt_marker: &str,
) -> (String, i32) {
let pair = super::open_pty();
let mut child = pair.slave.spawn_command(cmd).unwrap();
drop(pair.slave);
let reader = pair.master.try_clone_reader().unwrap();
let writer = pair.master.take_writer().unwrap();
prompted_pty_interaction(reader, writer, &mut child, inputs, prompt_marker)
}
fn prompted_pty_interaction(
reader: Box<dyn std::io::Read + Send>,
writer: Box<dyn std::io::Write + Send>,
child: &mut Box<dyn portable_pty::Child + Send + Sync>,
inputs: &[&str],
prompt_marker: &str,
) -> (String, i32) {
use std::sync::mpsc;
use std::time::{Duration, Instant};
let (tx, rx) = mpsc::channel::<Vec<u8>>();
let reader_thread = std::thread::spawn(move || {
let mut reader = reader;
let mut buf = [0u8; 4096];
loop {
match std::io::Read::read(&mut reader, &mut buf) {
Ok(0) => break,
Ok(n) => {
if tx.send(buf[..n].to_vec()).is_err() {
break;
}
}
Err(_) => break,
}
}
});
let mut accumulated = Vec::new();
let mut writer = writer;
let timeout = Duration::from_secs(10);
let poll = Duration::from_millis(10);
let marker = prompt_marker.as_bytes();
let mut markers_seen: usize = 0;
for input in inputs {
let target = markers_seen + 1;
let start = Instant::now();
loop {
while let Ok(chunk) = rx.try_recv() {
accumulated.extend_from_slice(&chunk);
}
if count_marker_occurrences(&accumulated, marker) >= target {
markers_seen = target;
break;
}
if start.elapsed() > timeout {
panic!(
"Timed out waiting for prompt marker {:?} (occurrence {}). Output so far:\n{}",
prompt_marker,
target,
String::from_utf8_lossy(&accumulated)
);
}
std::thread::sleep(poll);
}
let quiescence = Duration::from_millis(20);
let drain_ceiling = Duration::from_millis(500);
let drain_start = Instant::now();
let mut last_data = Instant::now();
loop {
while let Ok(chunk) = rx.try_recv() {
accumulated.extend_from_slice(&chunk);
last_data = Instant::now();
}
if last_data.elapsed() >= quiescence {
break;
}
if drain_start.elapsed() >= drain_ceiling {
break;
}
std::thread::sleep(poll);
}
writer.write_all(input.as_bytes()).unwrap();
writer.flush().unwrap();
}
let exit_status = child.wait().unwrap();
let exit_code = exit_status.exit_code() as i32;
drop(writer);
let _ = reader_thread.join();
while let Ok(chunk) = rx.try_recv() {
accumulated.extend_from_slice(&chunk);
}
let buf = String::from_utf8_lossy(&accumulated).to_string();
let normalized = buf.replace("\r\n", "\n");
(normalized, exit_code)
}
fn count_marker_occurrences(haystack: &[u8], needle: &[u8]) -> usize {
if needle.is_empty() || needle.len() > haystack.len() {
return 0;
}
haystack
.windows(needle.len())
.filter(|w| *w == needle)
.count()
}