use std::os::fd::AsFd;
use std::path::{Path, PathBuf};
use libc::{kill, SIGWINCH};
use nix::pty::{forkpty, ForkptyResult, Winsize};
use nix::sys::termios::{cfmakeraw, tcgetattr, tcsetattr, LocalFlags, SetArg, Termios};
use nix::sys::wait::waitpid;
use nix::unistd::{execvp, Pid};
use signal_hook::iterator::Signals;
use std::env;
use std::ffi::CString;
use std::fs::{File, OpenOptions};
use std::io::{self, BufReader, Read, Write};
use std::os::fd::{AsRawFd, FromRawFd, RawFd};
use std::sync::{mpsc, Arc, Mutex};
use crate::cli::{Model, NEW_COMMAND_MSG};
use crate::errors::WtgError;
use crate::openai::query_chatgpt;
struct RawModeGuard<F: AsFd> {
fd: F,
orig_termios: Termios,
}
impl<F: AsFd> RawModeGuard<F> {
fn new(fd: F) -> nix::Result<Self> {
let orig_termios = tcgetattr(fd.as_fd())?;
Ok(RawModeGuard { fd, orig_termios })
}
fn enable_raw_mode(&self) -> nix::Result<()> {
let mut raw = self.orig_termios.clone();
cfmakeraw(&mut raw);
raw.local_flags.remove(LocalFlags::ECHO);
tcsetattr(self.fd.as_fd(), SetArg::TCSANOW, &raw)
}
}
impl<F: AsFd> Drop for RawModeGuard<F> {
fn drop(&mut self) {
let _ = tcsetattr(self.fd.as_fd(), SetArg::TCSANOW, &self.orig_termios);
}
}
fn get_parent_winsize() -> Winsize {
use libc::{ioctl, winsize, TIOCGWINSZ};
let fd = io::stdin().as_raw_fd();
let mut ws: winsize = winsize {
ws_row: 24,
ws_col: 80,
ws_xpixel: 0,
ws_ypixel: 0,
};
if unsafe { ioctl(fd, TIOCGWINSZ, &mut ws) } == -1 {
panic!("Failed to get window size using ioctl");
}
Winsize {
ws_row: ws.ws_row,
ws_col: ws.ws_col,
ws_xpixel: ws.ws_xpixel,
ws_ypixel: ws.ws_ypixel,
}
}
fn update_pty_winsize(master_fd: RawFd) -> Result<(), WtgError> {
let window_size = get_parent_winsize();
let ret = unsafe { libc::ioctl(master_fd, libc::TIOCSWINSZ, &window_size) };
if ret == -1 {
Err(WtgError::NixError(nix::Error::last()))
} else {
Ok(())
}
}
fn listen_pty_resize(child_pid_for_resize: Pid, master_fd: RawFd) -> Result<(), WtgError> {
std::thread::spawn(move || {
let mut signals =
Signals::new([SIGWINCH]).expect("Unable to register SIGWINCH signal handler");
for _ in signals.forever() {
if let Err(e) = update_pty_winsize(master_fd) {
eprintln!("Failed to update pty window size: {:?}", e);
}
let _ = unsafe { kill(child_pid_for_resize.into(), SIGWINCH) };
}
});
Ok(())
}
pub fn run_session(logfile: &str) -> Result<(), WtgError> {
let path = PathBuf::from(logfile);
let log = OpenOptions::new()
.append(true)
.create(true)
.open(path.clone())
.expect("Failed to open log file");
initialize_env_vars(path)?;
println!("Starting wtg session. Type 'exit' to quit.");
let window_size = get_parent_winsize();
let fork_result = unsafe { forkpty(Some(&window_size), None).expect("forkpty failed") };
match fork_result {
ForkptyResult::Parent { child, master } => {
let stdin = std::io::stdin();
let guard =
RawModeGuard::new(stdin).expect("Failed to get terminal attributes for raw mode");
guard.enable_raw_mode().expect("Failed to enable raw mode");
let master_fd = master.as_raw_fd();
let master_file = unsafe { File::from_raw_fd(master_fd) };
let mut master_reader = master_file
.try_clone()
.expect("Failed to clone master file");
let mut master_writer = master_file;
listen_pty_resize(child, master_fd).expect("Failed to listen for pty resize");
let (enter_tx, enter_rx) = mpsc::channel::<()>();
let (truncated_tx, truncated_rx) = mpsc::channel::<()>();
std::thread::spawn(move || {
let stdin = io::stdin();
let mut input = stdin.lock();
let mut buf = [0u8; 1024];
loop {
match input.read(&mut buf) {
Ok(0) => {
break;
}
Ok(n) => {
if buf[..n].iter().any(|&b| b == b'\n' || b == b'\r') {
let _ = enter_tx.send(());
let _ = truncated_rx.recv();
}
if master_writer.write_all(&buf[..n]).is_err() {
break;
}
}
Err(_) => break,
}
}
});
let log = Arc::new(Mutex::new(log));
{
let log = Arc::clone(&log);
std::thread::spawn(move || {
while let Ok(()) = enter_rx.recv() {
let mut log = log.lock().unwrap();
log.write_all(NEW_COMMAND_MSG.as_bytes()).unwrap();
let _ = truncated_tx.send(());
}
});
}
let mut buf = [0u8; 1024];
loop {
let n = master_reader
.read(&mut buf)
.expect("Error reading from PTY");
if n == 0 {
break;
}
{
let stdout = io::stdout();
let mut out = stdout.lock();
out.write_all(&buf[..n]).expect("Failed to write to stdout");
out.flush().unwrap();
}
{
let mut log = log.lock().unwrap();
log.write_all(&buf[..n]).expect("Failed to write to log");
log.flush().unwrap();
}
}
waitpid(child, None).expect("Failed to wait on child");
}
ForkptyResult::Child => {
let shell = env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string());
let shell_c = CString::new(shell).expect("CString failed");
let args = [shell_c.clone()];
execvp(&shell_c, &args).expect("execvp failed");
}
}
Ok(())
}
fn initialize_env_vars<P: AsRef<Path>>(path: P) -> Result<(), WtgError> {
env::set_var("WTG_LOG", path.as_ref().canonicalize()?);
Ok(())
}
fn get_log_content(logfile: String) -> Result<String, WtgError> {
let file = File::open(&logfile).map_err(|_| WtgError::LogFileOpenError { logfile })?;
let mut reader = BufReader::new(file);
let mut log_vec = Vec::new();
reader
.read_to_end(&mut log_vec)
.expect("Failed to read log file");
Ok(String::from_utf8_lossy(&log_vec).to_string())
}
fn extract_context_from_log(logfile: &str) -> Result<String, WtgError> {
let log_content = get_log_content(logfile.to_string())?;
let last_idx = log_content
.rfind(NEW_COMMAND_MSG)
.ok_or(WtgError::NoCommandRun {
logfile: logfile.to_string(),
})?;
let last_line_start = log_content[..last_idx]
.rfind('\n')
.map(|i| i + 1)
.unwrap_or(0);
let second_to_last_idx =
log_content[..last_idx]
.rfind(NEW_COMMAND_MSG)
.ok_or(WtgError::NoCommandRun {
logfile: logfile.to_string(),
})?;
let second_to_last_line_start = log_content[..second_to_last_idx]
.rfind('\n')
.map(|i| i + 1)
.unwrap_or(0);
let mut context = log_content[second_to_last_line_start..last_line_start].to_string();
context = context.replace(NEW_COMMAND_MSG, "");
Ok(context)
}
pub fn run_query(
logfile: Option<String>,
prompt: Option<String>,
model: Option<Model>,
) -> Result<(), WtgError> {
let stdin_fileno = io::stdin().as_raw_fd();
let context = if !nix::unistd::isatty(stdin_fileno).unwrap_or(false) {
let mut piped_input = String::new();
io::stdin().read_to_string(&mut piped_input).unwrap();
piped_input
} else {
let logfile = logfile.unwrap_or_else(|| env::var("WTG_LOG").expect("WTG_LOG not set"));
extract_context_from_log(&logfile)?
};
let _ = query_chatgpt(&context, prompt.as_deref(), model).unwrap_or_else(|e| {
eprintln!("Error querying ChatGPT: {}", e);
String::new()
});
Ok(())
}
pub fn run_chat(logfile: Option<String>, model: Option<Model>) -> Result<(), WtgError> {
let stdin_fileno = io::stdin().as_raw_fd();
if !nix::unistd::isatty(stdin_fileno).unwrap_or(false) {
return Err(WtgError::ChatNotTty);
}
let logfile = logfile.unwrap_or_else(|| env::var("WTG_LOG").expect("WTG_LOG not set"));
let mut chat_context = extract_context_from_log(&logfile)?;
println!("(type 'exit' ('e') or 'quit' ('q') to end chat)");
loop {
let prompt_text = {
print!("user> ");
io::stdout().flush().unwrap();
let mut input = String::new();
if io::stdin().read_line(&mut input).is_err() {
eprintln!("Error reading from stdin.");
continue;
}
let trimmed = input.trim().to_string();
if trimmed.to_lowercase() == "exit"
|| trimmed.to_lowercase() == "e"
|| trimmed.to_lowercase() == "q"
|| trimmed.to_lowercase() == "quit"
{
break;
}
trimmed
};
let response =
query_chatgpt(&chat_context, Some(&prompt_text), model).unwrap_or_else(|e| {
eprintln!("Error querying ChatGPT: {}", e);
String::new()
});
chat_context.push_str(&format!("\nuser: {}\nassistant: {}", prompt_text, response));
}
Ok(())
}