1use std::{
2 collections::VecDeque,
3 io::Write,
4 path::PathBuf,
5 process::Command,
6 sync::{Arc, Mutex},
7 thread,
8 time::Duration,
9};
10
11use completer::SqlCompleter;
12use miette::{miette, IntoDiagnostic, Result};
13use portable_pty::{CommandBuilder, NativePtySystem, PtySize, PtySystem};
14use psql_writer::PsqlWriter;
15use rustyline::{
16 history::{History as _, SearchDirection},
17 Config, Editor,
18};
19use schema_cache::SchemaCacheManager;
20use tempfile::NamedTempFile;
21use thiserror::Error;
22use tracing::{debug, info, trace, warn};
23
24pub use find::find_postgres_bin;
25pub use ots::prompt_for_ots;
26
27mod completer;
28mod find;
29pub mod highlighter;
30pub mod history;
31mod ots;
32mod prompt;
33mod psql_writer;
34mod reader;
35mod schema_cache;
36mod terminal;
37
38#[cfg(windows)]
43pub fn set_console_codepage(codepage: u32) {
44 unsafe {
45 use windows_sys::Win32::System::Console::{SetConsoleCP, SetConsoleOutputCP};
46 SetConsoleCP(codepage);
47 SetConsoleOutputCP(codepage);
48 }
49}
50
51#[cfg(not(windows))]
53pub fn set_console_codepage(_codepage: u32) {
54 }
56
57#[derive(Debug, Error)]
58pub enum PsqlError {
59 #[error("psql process terminated unexpectedly")]
60 ProcessTerminated,
61 #[error("failed to read from psql")]
62 ReadError,
63 #[error("failed to write to psql")]
64 WriteError,
65}
66
67#[derive(Debug, Clone)]
69pub struct PsqlConfig {
70 pub program: String,
72
73 pub write: bool,
75
76 pub args: Vec<String>,
78
79 pub psqlrc: String,
81
82 pub history_path: PathBuf,
84
85 pub user: Option<String>,
87
88 pub ots: Option<String>,
90
91 pub passthrough: bool,
93
94 pub disable_schema_cache: bool,
96
97 pub theme: highlighter::Theme,
99}
100
101impl PsqlConfig {
102 fn psqlrc(&self, boundary: Option<&str>, disable_pager: bool) -> Result<NamedTempFile> {
103 let prompts = if let Some(boundary) = boundary {
104 format!(
105 "\\set PROMPT1 '<<<{boundary}|||1|||%/|||%n|||%#|||%R|||%x>>>'\n\
106 \\set PROMPT2 '<<<{boundary}|||2|||%/|||%n|||%#|||%R|||%x>>>'\n\
107 \\set PROMPT3 '<<<{boundary}|||3|||%/|||%n|||%#|||%R|||%x>>>'\n"
108 )
109 } else {
110 String::new()
111 };
112
113 let pager_setting = if disable_pager {
114 "\\pset pager off\n"
115 } else {
116 ""
117 };
118
119 let mut rc = tempfile::Builder::new()
120 .prefix("bestool-psql-")
121 .suffix(".psqlrc")
122 .tempfile()
123 .into_diagnostic()?;
124
125 write!(
126 rc.as_file_mut(),
127 "\\encoding UTF8\n\
128 \\timing\n\
129 {pager_setting}\
130 {existing}\n\
131 {ro}\n\
132 {prompts}",
133 existing = self.psqlrc,
134 ro = if self.write {
135 ""
136 } else {
137 "SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY;"
138 },
139 )
140 .into_diagnostic()?;
141
142 Ok(rc)
143 }
144
145 fn pty_command(self, boundary: Option<&str>) -> Result<(CommandBuilder, NamedTempFile)> {
146 let mut cmd = CommandBuilder::new(crate::find_postgres_bin(&self.program)?);
147
148 if self.write {
149 cmd.arg("--set=AUTOCOMMIT=OFF");
150 }
151
152 let rc = self.psqlrc(boundary, false)?;
153 cmd.env("PSQLRC", rc.path());
154 for arg in &self.args {
157 cmd.arg(arg);
158 }
159
160 Ok((cmd, rc))
161 }
162
163 fn std_command(
164 self,
165 boundary: Option<&str>,
166 disable_pager: bool,
167 ) -> Result<(Command, NamedTempFile)> {
168 let mut cmd = Command::new(crate::find_postgres_bin(&self.program)?);
169
170 if self.write {
171 cmd.arg("--set=AUTOCOMMIT=OFF");
172 }
173
174 let rc = self.psqlrc(boundary, disable_pager)?;
175 cmd.env("PSQLRC", rc.path());
176 if disable_pager {
177 cmd.env("PAGER", "cat");
178 }
179
180 for arg in &self.args {
181 cmd.arg(arg);
182 }
183
184 Ok((cmd, rc))
185 }
186}
187
188#[cfg(unix)]
190struct RawMode {
191 term_fd: i32,
192 original_termios: libc::termios,
193 stdin_fd: i32,
194 original_flags: i32,
195}
196
197#[cfg(unix)]
198impl RawMode {
199 fn enable() -> Option<Self> {
200 use std::os::unix::io::AsRawFd;
201
202 let stdin_fd = std::io::stdin().as_raw_fd();
203
204 let tty_fd = unsafe { libc::open(c"/dev/tty".as_ptr(), libc::O_RDWR) };
206 let term_fd = if tty_fd >= 0 {
207 tty_fd
208 } else {
209 libc::STDOUT_FILENO
210 };
211
212 let mut original_termios: libc::termios = unsafe { std::mem::zeroed() };
214 if unsafe { libc::tcgetattr(term_fd, &mut original_termios) } != 0 {
215 if tty_fd >= 0 {
216 unsafe { libc::close(tty_fd) };
217 }
218 return None;
219 }
220
221 let original_flags = unsafe { libc::fcntl(stdin_fd, libc::F_GETFL) };
223 if original_flags < 0 {
224 if tty_fd >= 0 {
225 unsafe { libc::close(tty_fd) };
226 }
227 return None;
228 }
229
230 let mut raw_termios = original_termios;
232 unsafe {
233 libc::cfmakeraw(&mut raw_termios);
234 raw_termios.c_lflag &= !libc::ECHO;
236 raw_termios.c_lflag &= !libc::ECHONL;
237 libc::tcsetattr(term_fd, libc::TCSANOW, &raw_termios);
238
239 libc::fcntl(stdin_fd, libc::F_SETFL, original_flags | libc::O_NONBLOCK);
241 }
242
243 Some(RawMode {
244 term_fd,
245 original_termios,
246 stdin_fd,
247 original_flags,
248 })
249 }
250}
251
252#[cfg(unix)]
253impl Drop for RawMode {
254 fn drop(&mut self) {
255 unsafe {
257 libc::tcsetattr(self.term_fd, libc::TCSANOW, &self.original_termios);
258 libc::fcntl(self.stdin_fd, libc::F_SETFL, self.original_flags);
259 if self.term_fd != libc::STDOUT_FILENO {
260 libc::close(self.term_fd);
261 }
262 }
263 }
264}
265
266#[cfg(unix)]
268fn forward_stdin_to_pty(psql_writer: &PsqlWriter) {
269 use std::io::Read;
270
271 let stdin_handle = std::io::stdin();
272 let mut stdin_lock = stdin_handle.lock();
273
274 let mut buf = [0u8; 1024];
276 match stdin_lock.read(&mut buf) {
277 Ok(n) if n > 0 => {
278 if std::env::var("DEBUG_PTY").is_ok() {
279 use std::io::Write;
280 let data = String::from_utf8_lossy(&buf[..n]);
281 eprintln!("\x1b[33m[FWD]\x1b[0m forwarding {} bytes: {:?}", n, data);
282 std::io::stderr().flush().ok();
283 }
284 if let Err(e) = psql_writer.write_bytes(&buf[..n]) {
285 warn!("failed to forward stdin to pty: {}", e);
286 }
287 }
288 _ => {}
289 }
290}
291
292#[cfg(windows)]
293fn forward_stdin_to_pty(psql_writer: &PsqlWriter) {
294 use windows_sys::Win32::System::Console::{
295 GetStdHandle, PeekConsoleInputW, ReadConsoleInputW, INPUT_RECORD, STD_INPUT_HANDLE,
296 };
297
298 unsafe {
299 let stdin_handle = GetStdHandle(STD_INPUT_HANDLE);
300 if !stdin_handle.is_null() && stdin_handle as i32 != -1 {
301 let mut num_events: u32 = 0;
302 let mut buffer: [INPUT_RECORD; 1] = std::mem::zeroed();
303
304 if PeekConsoleInputW(stdin_handle, buffer.as_mut_ptr(), 1, &mut num_events) != 0
306 && num_events > 0
307 {
308 let mut num_read: u32 = 0;
310 if ReadConsoleInputW(stdin_handle, buffer.as_mut_ptr(), 1, &mut num_read) != 0
311 && num_read > 0
312 {
313 let record = &buffer[0];
315 if record.EventType == 1 {
317 let key_event = record.Event.KeyEvent;
318 if key_event.bKeyDown != 0 {
320 let ch = key_event.uChar.UnicodeChar;
321 if ch != 0 {
322 let mut utf8_buf = [0u8; 4];
324 if let Some(c) = char::from_u32(ch as u32) {
325 let utf8_str = c.encode_utf8(&mut utf8_buf);
326 if std::env::var("DEBUG_PTY").is_ok() {
327 use std::io::Write;
328 eprint!(
329 "\x1b[33m[FWD]\x1b[0m forwarding char: {:?}\n",
330 utf8_str
331 );
332 std::io::stderr().flush().ok();
333 }
334 if let Err(e) = psql_writer.write_bytes(utf8_str.as_bytes()) {
335 warn!("failed to forward stdin to pty: {}", e);
336 }
337 }
338 }
339 }
340 }
341 }
342 }
343 }
344 }
345}
346
347pub fn run(config: PsqlConfig) -> Result<i32> {
348 if config.passthrough {
350 if config.write {
351 return Err(miette!(
352 "passthrough mode is only available in read-only mode"
353 ));
354 }
355 info!("launching psql in passthrough mode");
356 return run_passthrough(config);
357 }
358
359 let theme = config.theme;
361
362 let boundary = prompt::generate_boundary();
363 debug!(boundary = %boundary, "generated prompt boundary marker");
364
365 let pty_system = NativePtySystem::default();
366
367 let (cols, rows) = terminal::get_terminal_size();
368
369 let pty_pair = pty_system
370 .openpty(PtySize {
371 rows,
372 cols,
373 pixel_width: 0,
374 pixel_height: 0,
375 })
376 .map_err(|e| miette!("failed to create pty: {}", e))?;
377
378 let pty_master = Arc::new(Mutex::new(pty_pair.master));
379
380 terminal::spawn_resize_handler(pty_master.clone());
381
382 let history_path = config.history_path.clone();
383 let db_user = config.user.clone();
384 let boundary_clone = boundary.clone();
385
386 let write_mode = Arc::new(Mutex::new(config.write));
388 let ots = Arc::new(Mutex::new(config.ots.clone()));
389 let write_mode_clone = write_mode.clone();
390 let ots_clone = ots.clone();
391
392 let disable_schema_cache = config.disable_schema_cache;
393
394 let (cmd, _rc_guard) = config.pty_command(Some(&boundary))?;
395 let mut child = pty_pair
396 .slave
397 .spawn_command(cmd)
398 .map_err(|e| miette!("failed to spawn psql: {}", e))?;
399
400 drop(pty_pair.slave);
401
402 let reader = {
403 let master = pty_master.lock().unwrap();
404 master
405 .try_clone_reader()
406 .map_err(|e| miette!("failed to clone pty reader: {}", e))?
407 };
408
409 let writer = Arc::new(Mutex::new({
410 let master = pty_master.lock().unwrap();
411 master
412 .take_writer()
413 .map_err(|e| miette!("failed to get pty writer: {}", e))?
414 }));
415
416 let running = Arc::new(Mutex::new(true));
418 let running_clone = running.clone();
419
420 let output_buffer = Arc::new(Mutex::new(VecDeque::with_capacity(1024)));
422 let output_buffer_clone = output_buffer.clone();
423
424 let psql_writer = PsqlWriter::new(writer.clone(), output_buffer.clone());
425
426 let current_prompt = Arc::new(Mutex::new(String::new()));
427 let current_prompt_clone = current_prompt.clone();
428
429 let current_prompt_info = Arc::new(Mutex::new(None));
431 let current_prompt_info_clone = current_prompt_info.clone();
432
433 let last_input = Arc::new(Mutex::new(String::new()));
435
436 let print_enabled = Arc::new(Mutex::new(true));
438 let print_enabled_clone = print_enabled.clone();
439
440 let reader_thread = reader::spawn_reader_thread(reader::ReaderThreadParams {
441 reader,
442 boundary: boundary_clone,
443 output_buffer: output_buffer_clone,
444 current_prompt: current_prompt_clone,
445 current_prompt_info: current_prompt_info_clone,
446 last_input: last_input.clone(),
447 running: running_clone,
448 print_enabled: print_enabled_clone,
449 });
450
451 let history = history::History::setup(
452 history_path.clone(),
453 db_user,
454 *write_mode.lock().unwrap(),
455 ots.lock().unwrap().clone(),
456 );
457
458 let schema_cache_manager = if !disable_schema_cache {
459 debug!("initializing schema cache");
460 let manager =
461 SchemaCacheManager::new(writer.clone(), print_enabled.clone(), write_mode.clone());
462
463 if let Err(e) = manager.refresh() {
464 warn!("failed to populate schema cache: {}", e);
465 }
466
467 Some(manager)
468 } else {
469 debug!("schema cache disabled by config");
470 None
471 };
472
473 let mut completer =
474 SqlCompleter::with_pty_and_theme(writer.clone(), output_buffer.clone(), theme);
475 if let Some(ref cache_manager) = schema_cache_manager {
476 completer = completer.with_schema_cache(cache_manager.cache_arc());
477 }
478
479 let mut rl: Editor<SqlCompleter, history::History> = Editor::with_history(
480 Config::builder()
481 .auto_add_history(false)
482 .history_ignore_dups(false)
483 .unwrap()
484 .build(),
485 history,
486 )
487 .into_diagnostic()?;
488
489 rl.set_helper(Some(completer));
490
491 let mut last_reload = std::time::Instant::now();
492
493 debug!("entering main event loop");
494
495 #[cfg(unix)]
496 let mut raw_mode: Option<RawMode> = None;
497
498 loop {
499 if last_reload.elapsed() >= Duration::from_secs(60) {
500 debug!("reloading history timestamps");
501 if let Err(e) = rl.history_mut().reload_timestamps() {
502 warn!("failed to reload history timestamps: {}", e);
503 }
504 last_reload = std::time::Instant::now();
505 }
506 match child.try_wait().into_diagnostic()? {
507 Some(status) => {
508 debug!(exit_code = status.exit_code(), "psql process exited");
510 reader_thread.join().ok();
511 return Ok(status.exit_code() as i32);
512 }
513 None => {
514 }
516 }
517
518 if !*running.lock().unwrap() {
520 thread::sleep(Duration::from_millis(50));
522 if let Some(status) = child.try_wait().into_diagnostic()? {
523 return Ok(status.exit_code() as i32);
524 }
525 }
526
527 thread::sleep(Duration::from_millis(50));
529
530 let at_prompt = psql_writer.buffer_contains(&format!("<<<{boundary}|||"));
531 if !at_prompt {
532 #[cfg(unix)]
535 if raw_mode.is_none() {
536 raw_mode = RawMode::enable();
537 }
538
539 forward_stdin_to_pty(&psql_writer);
541 thread::sleep(Duration::from_millis(50));
542 continue;
543 }
544
545 #[cfg(unix)]
547 if raw_mode.is_some() {
548 raw_mode = None; }
550
551 let prompt_text = current_prompt.lock().unwrap().clone();
553 let readline_prompt = if prompt_text.is_empty() {
554 "psql> ".to_string()
555 } else {
556 prompt_text
557 };
558
559 match rl.readline(&readline_prompt) {
560 Ok(line) => {
561 trace!("received input line");
562 let trimmed = line.trim();
563 if trimmed == "\\e" || trimmed.starts_with("\\e ") {
564 debug!("editor command intercepted");
565
566 let initial_content = if trimmed == "\\e" {
568 let hist_len = rl.history().len();
570 if hist_len > 0 {
571 match rl.history().get(hist_len - 1, SearchDirection::Forward) {
572 Ok(Some(result)) => result.entry.to_string(),
573 _ => String::new(),
574 }
575 } else {
576 String::new()
577 }
578 } else {
579 trimmed
581 .strip_prefix("\\e ")
582 .unwrap_or("")
583 .trim()
584 .to_string()
585 };
586
587 match edit::edit(&initial_content) {
589 Ok(edited_content) => {
590 let edited_trimmed = edited_content.trim();
591
592 if !edited_trimmed.is_empty() {
594 info!("sending edited content to psql");
595
596 if let Err(e) = rl.history_mut().add(&edited_content) {
598 warn!("failed to add history entry: {}", e);
599 } else {
600 debug!("wrote history entry before sending to psql");
601 }
602
603 *last_input.lock().unwrap() = format!("{}\n", edited_content);
605
606 if let Err(e) = psql_writer.write_line(&edited_content) {
608 warn!("failed to write to psql: {}", e);
609 return Err(PsqlError::WriteError).into_diagnostic();
610 }
611 } else {
612 debug!("editor returned empty content, skipping");
613 }
614 }
615 Err(e) => {
616 warn!("editor failed: {}", e);
617 eprintln!("Editor failed: {}", e);
618 }
619 }
620 continue;
621 }
622
623 if trimmed == "\\refresh" {
624 let prompt_info = current_prompt_info.lock().unwrap().clone();
625 if let Some(ref info) = prompt_info {
626 if info.in_transaction() {
627 eprintln!("Cannot refresh schema cache while in a transaction. Please COMMIT or ROLLBACK first.");
628 continue;
629 }
630 }
631
632 if let Some(ref cache_manager) = schema_cache_manager {
633 info!("refreshing schema cache...");
634 eprintln!("Refreshing schema cache...");
635 match cache_manager.refresh() {
636 Ok(()) => {
637 eprintln!("Schema cache refreshed successfully");
638 }
639 Err(e) => {
640 warn!("failed to refresh schema cache: {}", e);
641 eprintln!("Failed to refresh schema cache: {}", e);
642 }
643 }
644 } else {
645 eprintln!("Schema cache is not enabled");
646 }
647 continue;
648 }
649
650 if trimmed == "\\W" {
651 let prompt_info = current_prompt_info.lock().unwrap().clone();
652 if let Some(ref info) = prompt_info {
653 if info.in_transaction() && info.transaction == "*" {
654 warn!("Pending transaction! Please COMMIT or ROLLBACK first");
655 continue;
656 }
657 }
658
659 let mut current_write_mode = write_mode_clone.lock().unwrap();
660 let mut current_ots = ots_clone.lock().unwrap();
661
662 if *current_write_mode {
663 *current_write_mode = false;
664 *current_ots = None;
665
666 let cmd = "SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY;\n\\set AUTOCOMMIT on\nROLLBACK;\n";
667 if let Err(e) = psql_writer.write_str(cmd) {
668 warn!("failed to write to psql: {}", e);
669 continue;
670 }
671
672 thread::sleep(Duration::from_millis(50));
673 info!("Write mode disabled");
674 thread::sleep(Duration::from_millis(5));
675 eprintln!("SESSION IS NOW READ ONLY");
676
677 let db_user = rl.history().db_user.clone();
678 let sys_user = rl.history().sys_user.clone();
679 rl.history_mut().set_context(db_user, sys_user, false, None);
680 } else {
681 drop(current_write_mode);
682 drop(current_ots);
683
684 let db_handle = rl.history().clone_db();
685 match ots::prompt_for_ots_with_db(Some(db_handle), Some(&history_path)) {
686 Ok(new_ots) => {
687 let mut current_write_mode = write_mode_clone.lock().unwrap();
688 let mut current_ots = ots_clone.lock().unwrap();
689
690 *current_write_mode = true;
691 *current_ots = Some(new_ots.clone());
692
693 let cmd = "SET SESSION CHARACTERISTICS AS TRANSACTION READ WRITE;\n\\set AUTOCOMMIT off\nROLLBACK;\n";
694 if let Err(e) = psql_writer.write_str(cmd) {
695 warn!("failed to write to psql: {}", e);
696 continue;
697 }
698
699 thread::sleep(Duration::from_millis(50));
700 info!("Write mode enabled");
701 thread::sleep(Duration::from_millis(5));
702 eprintln!("AUTOCOMMIT IS OFF -- REMEMBER TO `COMMIT;` YOUR WRITES");
703
704 let db_user = rl.history().db_user.clone();
705 let sys_user = rl.history().sys_user.clone();
706 rl.history_mut().set_context(
707 db_user,
708 sys_user,
709 true,
710 Some(new_ots),
711 );
712 }
713 Err(e) => {
714 eprintln!("Failed to enable write mode: {}", e);
715 }
716 }
717 }
718 continue;
719 }
720
721 if !line.trim().is_empty() {
722 if let Err(e) = rl.history_mut().add(&line) {
723 warn!("failed to add history entry: {}", e);
724 } else {
725 debug!("wrote history entry before sending to psql");
726 }
727 }
728
729 *last_input.lock().unwrap() = format!("{}\n", line);
731
732 if let Err(e) = psql_writer.write_line(&line) {
733 warn!("failed to write to psql: {}", e);
734 return Err(PsqlError::WriteError).into_diagnostic();
735 }
736 }
737 Err(rustyline::error::ReadlineError::Interrupted) => {
738 debug!("received Ctrl-C");
739 psql_writer.send_control(3).ok(); }
741 Err(rustyline::error::ReadlineError::Eof) => {
742 debug!("received Ctrl-D (EOF)");
743 psql_writer.send_control(4).ok(); break;
745 }
746 Err(err) => {
747 return Err(err).into_diagnostic();
748 }
749 }
750 }
751
752 reader_thread.join().ok();
753
754 let status = child.wait().into_diagnostic()?;
755
756 debug!("compacting history database");
757 if let Err(e) = rl.history_mut().compact() {
758 warn!("failed to compact history database: {}", e);
759 }
760
761 debug!(exit_code = status.exit_code(), "exiting");
762 Ok(status.exit_code() as i32)
763}
764
765fn run_passthrough(mut config: PsqlConfig) -> Result<i32> {
769 config.write = false;
771
772 let (mut cmd, _guard) = config.std_command(None, false)?;
773 let status = cmd.status().into_diagnostic()?;
774
775 Ok(status.code().unwrap_or(1))
776}