1use std::collections::HashMap;
22use std::fmt;
23use std::io::{self, Read, Write};
24use std::path::PathBuf;
25use std::sync::mpsc;
26use std::thread;
27use std::time::{Duration, Instant};
28
29use portable_pty::{CommandBuilder, ExitStatus, PtySize};
30
31#[derive(Debug, Clone)]
33pub struct ShellConfig {
34 pub shell: Option<PathBuf>,
37
38 pub args: Vec<String>,
40
41 pub env: HashMap<String, String>,
43
44 pub cwd: Option<PathBuf>,
46
47 pub cols: u16,
49
50 pub rows: u16,
52
53 pub term: String,
55
56 pub log_events: bool,
58}
59
60impl Default for ShellConfig {
61 fn default() -> Self {
62 Self {
63 shell: None,
64 args: Vec::new(),
65 env: HashMap::new(),
66 cwd: None,
67 cols: 80,
68 rows: 24,
69 term: "xterm-256color".to_string(),
70 log_events: false,
71 }
72 }
73}
74
75impl ShellConfig {
76 #[must_use]
78 pub fn with_shell(shell: impl Into<PathBuf>) -> Self {
79 Self {
80 shell: Some(shell.into()),
81 ..Default::default()
82 }
83 }
84
85 #[must_use]
87 pub fn size(mut self, cols: u16, rows: u16) -> Self {
88 self.cols = cols;
89 self.rows = rows;
90 self
91 }
92
93 #[must_use]
95 pub fn arg(mut self, arg: impl Into<String>) -> Self {
96 self.args.push(arg.into());
97 self
98 }
99
100 #[must_use]
102 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
103 self.env.insert(key.into(), value.into());
104 self
105 }
106
107 #[must_use]
109 pub fn inherit_env(mut self) -> Self {
110 for (key, value) in std::env::vars() {
111 self.env.entry(key).or_insert(value);
112 }
113 self
114 }
115
116 #[must_use]
118 pub fn cwd(mut self, path: impl Into<PathBuf>) -> Self {
119 self.cwd = Some(path.into());
120 self
121 }
122
123 #[must_use]
125 pub fn term(mut self, term: impl Into<String>) -> Self {
126 self.term = term.into();
127 self
128 }
129
130 #[must_use]
132 pub fn logging(mut self, enabled: bool) -> Self {
133 self.log_events = enabled;
134 self
135 }
136
137 fn resolve_shell(&self) -> PathBuf {
139 if let Some(ref shell) = self.shell {
140 return shell.clone();
141 }
142
143 if let Ok(shell) = std::env::var("SHELL") {
145 return PathBuf::from(shell);
146 }
147
148 PathBuf::from("/bin/sh")
150 }
151}
152
153#[derive(Debug)]
155enum ReaderMsg {
156 Data(Vec<u8>),
157 Eof,
158 Err(io::Error),
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq)]
163pub enum ProcessState {
164 Running,
166 Exited(i32),
168 Signaled(i32),
170 Unknown,
172}
173
174impl ProcessState {
175 #[must_use]
177 pub const fn is_alive(self) -> bool {
178 matches!(self, ProcessState::Running)
179 }
180
181 #[must_use]
183 pub const fn exit_code(self) -> Option<i32> {
184 match self {
185 ProcessState::Exited(code) => Some(code),
186 _ => None,
187 }
188 }
189}
190
191pub struct PtyProcess {
218 child: Box<dyn portable_pty::Child + Send + Sync>,
219 writer: Box<dyn Write + Send>,
220 rx: mpsc::Receiver<ReaderMsg>,
221 reader_thread: Option<thread::JoinHandle<()>>,
222 captured: Vec<u8>,
223 eof: bool,
224 state: ProcessState,
225 config: ShellConfig,
226}
227
228impl fmt::Debug for PtyProcess {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 f.debug_struct("PtyProcess")
231 .field("pid", &self.child.process_id())
232 .field("state", &self.state)
233 .field("captured_len", &self.captured.len())
234 .field("eof", &self.eof)
235 .finish()
236 }
237}
238
239impl PtyProcess {
240 pub fn spawn(config: ShellConfig) -> io::Result<Self> {
249 let shell_path = config.resolve_shell();
250
251 if config.log_events {
252 log_event(
253 "PTY_PROCESS_SPAWN",
254 format!("shell={}", shell_path.display()),
255 );
256 }
257
258 let mut cmd = CommandBuilder::new(&shell_path);
260
261 for arg in &config.args {
263 cmd.arg(arg);
264 }
265
266 cmd.env("TERM", &config.term);
268 for (key, value) in &config.env {
269 cmd.env(key, value);
270 }
271
272 if let Some(ref cwd) = config.cwd {
274 cmd.cwd(cwd);
275 }
276
277 let pty_system = portable_pty::native_pty_system();
279 let pair = pty_system
280 .openpty(PtySize {
281 rows: config.rows,
282 cols: config.cols,
283 pixel_width: 0,
284 pixel_height: 0,
285 })
286 .map_err(|e| io::Error::other(e.to_string()))?;
287
288 let child = pair
290 .slave
291 .spawn_command(cmd)
292 .map_err(|e| io::Error::other(e.to_string()))?;
293
294 let mut reader = pair
296 .master
297 .try_clone_reader()
298 .map_err(|e| io::Error::other(e.to_string()))?;
299 let writer = pair
300 .master
301 .take_writer()
302 .map_err(|e| io::Error::other(e.to_string()))?;
303
304 let (tx, rx) = mpsc::channel::<ReaderMsg>();
306 let reader_thread = thread::spawn(move || {
307 let mut buf = [0u8; 8192];
308 loop {
309 match reader.read(&mut buf) {
310 Ok(0) => {
311 let _ = tx.send(ReaderMsg::Eof);
312 break;
313 }
314 Ok(n) => {
315 let _ = tx.send(ReaderMsg::Data(buf[..n].to_vec()));
316 }
317 Err(err) => {
318 let _ = tx.send(ReaderMsg::Err(err));
319 break;
320 }
321 }
322 }
323 });
324
325 if config.log_events {
326 log_event(
327 "PTY_PROCESS_STARTED",
328 format!("pid={:?}", child.process_id()),
329 );
330 }
331
332 Ok(Self {
333 child,
334 writer,
335 rx,
336 reader_thread: Some(reader_thread),
337 captured: Vec::new(),
338 eof: false,
339 state: ProcessState::Running,
340 config,
341 })
342 }
343
344 #[must_use]
348 pub fn is_alive(&mut self) -> bool {
349 self.poll_state();
350 self.state.is_alive()
351 }
352
353 #[must_use]
355 pub fn state(&mut self) -> ProcessState {
356 self.poll_state();
357 self.state
358 }
359
360 #[must_use]
362 pub fn pid(&self) -> Option<u32> {
363 self.child.process_id()
364 }
365
366 pub fn kill(&mut self) -> io::Result<()> {
374 if !self.state.is_alive() {
375 return Ok(());
376 }
377
378 if self.config.log_events {
379 log_event(
380 "PTY_PROCESS_KILL",
381 format!("pid={:?}", self.child.process_id()),
382 );
383 }
384
385 self.child.kill()?;
387 self.state = ProcessState::Unknown;
388
389 match self.wait_timeout(Duration::from_millis(100)) {
391 Ok(status) => {
392 self.update_state_from_exit(&status);
393 }
394 Err(_) => {
395 self.state = ProcessState::Unknown;
397 }
398 }
399
400 Ok(())
401 }
402
403 pub fn wait(&mut self) -> io::Result<ExitStatus> {
411 let status = self.child.wait()?;
412 self.update_state_from_exit(&status);
413 Ok(status)
414 }
415
416 pub fn wait_timeout(&mut self, timeout: Duration) -> io::Result<ExitStatus> {
422 let deadline = Instant::now() + timeout;
423
424 loop {
425 match self.child.try_wait()? {
427 Some(status) => {
428 self.update_state_from_exit(&status);
429 return Ok(status);
430 }
431 None => {
432 if Instant::now() >= deadline {
433 return Err(io::Error::new(
434 io::ErrorKind::TimedOut,
435 "wait_timeout: process did not exit in time",
436 ));
437 }
438 thread::sleep(Duration::from_millis(10));
439 }
440 }
441 }
442 }
443
444 pub fn write_all(&mut self, data: &[u8]) -> io::Result<()> {
450 self.writer.write_all(data)?;
451 self.writer.flush()?;
452
453 if self.config.log_events {
454 log_event("PTY_PROCESS_INPUT", format!("bytes={}", data.len()));
455 }
456
457 Ok(())
458 }
459
460 pub fn read_available(&mut self) -> io::Result<Vec<u8>> {
462 self.drain_channel(Duration::ZERO)?;
463 Ok(self.captured.clone())
464 }
465
466 pub fn read_until(&mut self, pattern: &[u8], timeout: Duration) -> io::Result<Vec<u8>> {
472 if pattern.is_empty() {
473 return Ok(self.captured.clone());
474 }
475
476 let deadline = Instant::now() + timeout;
477
478 loop {
479 if find_subsequence(&self.captured, pattern).is_some() {
481 return Ok(self.captured.clone());
482 }
483
484 if self.eof || Instant::now() >= deadline {
485 break;
486 }
487
488 let remaining = deadline.saturating_duration_since(Instant::now());
489 self.drain_channel(remaining)?;
490 }
491
492 Err(io::Error::new(
493 io::ErrorKind::TimedOut,
494 format!(
495 "read_until: pattern not found (captured {} bytes)",
496 self.captured.len()
497 ),
498 ))
499 }
500
501 pub fn drain(&mut self, timeout: Duration) -> io::Result<usize> {
503 if self.eof {
504 return Ok(0);
505 }
506
507 let start_len = self.captured.len();
508 let deadline = Instant::now() + timeout;
509
510 while !self.eof && Instant::now() < deadline {
511 let remaining = deadline.saturating_duration_since(Instant::now());
512 match self.drain_channel(remaining) {
513 Ok(0) if self.eof => break,
514 Ok(_) => continue,
515 Err(e) if e.kind() == io::ErrorKind::TimedOut => break,
516 Err(e) => return Err(e),
517 }
518 }
519
520 Ok(self.captured.len() - start_len)
521 }
522
523 #[must_use]
525 pub fn output(&self) -> &[u8] {
526 &self.captured
527 }
528
529 pub fn clear_output(&mut self) {
531 self.captured.clear();
532 }
533
534 pub fn resize(&mut self, cols: u16, rows: u16) -> io::Result<()> {
538 if self.config.log_events {
542 log_event("PTY_PROCESS_RESIZE", format!("cols={} rows={}", cols, rows));
543 }
544 Ok(())
545 }
546
547 fn poll_state(&mut self) {
550 if !self.state.is_alive() {
551 return;
552 }
553
554 match self.child.try_wait() {
555 Ok(Some(status)) => {
556 self.update_state_from_exit(&status);
557 }
558 Ok(None) => {
559 }
561 Err(_) => {
562 self.state = ProcessState::Unknown;
563 }
564 }
565 }
566
567 fn update_state_from_exit(&mut self, status: &ExitStatus) {
568 if status.success() {
569 self.state = ProcessState::Exited(0);
570 } else {
571 let code = 1; self.state = ProcessState::Exited(code);
575 }
576 }
577
578 fn drain_channel(&mut self, timeout: Duration) -> io::Result<usize> {
579 if self.eof {
580 return Ok(0);
581 }
582
583 let mut total = 0usize;
584
585 let first = if timeout.is_zero() {
587 match self.rx.try_recv() {
588 Ok(msg) => Some(msg),
589 Err(mpsc::TryRecvError::Empty) => return Ok(0),
590 Err(mpsc::TryRecvError::Disconnected) => {
591 self.eof = true;
592 return Ok(0);
593 }
594 }
595 } else {
596 match self.rx.recv_timeout(timeout) {
597 Ok(msg) => Some(msg),
598 Err(mpsc::RecvTimeoutError::Timeout) => return Ok(0),
599 Err(mpsc::RecvTimeoutError::Disconnected) => {
600 self.eof = true;
601 return Ok(0);
602 }
603 }
604 };
605
606 let mut msg = match first {
607 Some(m) => m,
608 None => return Ok(0),
609 };
610
611 loop {
612 match msg {
613 ReaderMsg::Data(bytes) => {
614 total = total.saturating_add(bytes.len());
615 self.captured.extend_from_slice(&bytes);
616 }
617 ReaderMsg::Eof => {
618 self.eof = true;
619 break;
620 }
621 ReaderMsg::Err(err) => return Err(err),
622 }
623
624 match self.rx.try_recv() {
625 Ok(next) => msg = next,
626 Err(mpsc::TryRecvError::Empty) => break,
627 Err(mpsc::TryRecvError::Disconnected) => {
628 self.eof = true;
629 break;
630 }
631 }
632 }
633
634 if total > 0 && self.config.log_events {
635 log_event("PTY_PROCESS_OUTPUT", format!("bytes={}", total));
636 }
637
638 Ok(total)
639 }
640}
641
642impl Drop for PtyProcess {
643 fn drop(&mut self) {
644 let _ = self.writer.flush();
646 let _ = self.child.kill();
647
648 if let Some(handle) = self.reader_thread.take() {
649 let _ = handle.join();
650 }
651
652 if self.config.log_events {
653 log_event(
654 "PTY_PROCESS_DROP",
655 format!("pid={:?}", self.child.process_id()),
656 );
657 }
658 }
659}
660
661fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
664 if needle.is_empty() {
665 return Some(0);
666 }
667 haystack
668 .windows(needle.len())
669 .position(|window| window == needle)
670}
671
672fn log_event(event: &str, detail: impl fmt::Display) {
673 let timestamp = time::OffsetDateTime::now_utc()
674 .format(&time::format_description::well_known::Rfc3339)
675 .unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string());
676 eprintln!("[{}] {}: {}", timestamp, event, detail);
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682
683 #[test]
686 fn shell_config_defaults() {
687 let config = ShellConfig::default();
688 assert!(config.shell.is_none());
689 assert!(config.args.is_empty());
690 assert!(config.env.is_empty());
691 assert!(config.cwd.is_none());
692 assert_eq!(config.cols, 80);
693 assert_eq!(config.rows, 24);
694 assert_eq!(config.term, "xterm-256color");
695 assert!(!config.log_events);
696 }
697
698 #[test]
699 fn shell_config_with_shell() {
700 let config = ShellConfig::with_shell("/bin/bash");
701 assert_eq!(config.shell, Some(PathBuf::from("/bin/bash")));
702 }
703
704 #[test]
705 fn shell_config_builder_chain() {
706 let config = ShellConfig::default()
707 .size(120, 40)
708 .arg("-l")
709 .env("FOO", "bar")
710 .cwd("/tmp")
711 .term("dumb")
712 .logging(true);
713
714 assert_eq!(config.cols, 120);
715 assert_eq!(config.rows, 40);
716 assert_eq!(config.args, vec!["-l"]);
717 assert_eq!(config.env.get("FOO"), Some(&"bar".to_string()));
718 assert_eq!(config.cwd, Some(PathBuf::from("/tmp")));
719 assert_eq!(config.term, "dumb");
720 assert!(config.log_events);
721 }
722
723 #[test]
724 fn shell_config_resolve_shell_explicit() {
725 let config = ShellConfig::with_shell("/bin/zsh");
726 assert_eq!(config.resolve_shell(), PathBuf::from("/bin/zsh"));
727 }
728
729 #[test]
730 fn shell_config_resolve_shell_env() {
731 let config = ShellConfig::default();
733 let shell = config.resolve_shell();
734 assert!(shell.to_str().unwrap().contains("sh") || shell.to_str().unwrap().contains("zsh"));
736 }
737
738 #[test]
741 fn process_state_is_alive() {
742 assert!(ProcessState::Running.is_alive());
743 assert!(!ProcessState::Exited(0).is_alive());
744 assert!(!ProcessState::Signaled(9).is_alive());
745 assert!(!ProcessState::Unknown.is_alive());
746 }
747
748 #[test]
749 fn process_state_exit_code() {
750 assert_eq!(ProcessState::Running.exit_code(), None);
751 assert_eq!(ProcessState::Exited(0).exit_code(), Some(0));
752 assert_eq!(ProcessState::Exited(1).exit_code(), Some(1));
753 assert_eq!(ProcessState::Signaled(9).exit_code(), None);
754 assert_eq!(ProcessState::Unknown.exit_code(), None);
755 }
756
757 #[test]
760 fn find_subsequence_empty_needle() {
761 assert_eq!(find_subsequence(b"anything", b""), Some(0));
762 }
763
764 #[test]
765 fn find_subsequence_found() {
766 assert_eq!(find_subsequence(b"hello world", b"world"), Some(6));
767 }
768
769 #[test]
770 fn find_subsequence_not_found() {
771 assert_eq!(find_subsequence(b"hello world", b"xyz"), None);
772 }
773
774 #[cfg(unix)]
777 #[test]
778 fn spawn_and_basic_io() {
779 let config = ShellConfig::default().logging(false);
780 let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
781
782 assert!(proc.is_alive());
784 assert!(proc.pid().is_some());
785
786 proc.write_all(b"echo hello-pty-process\n")
788 .expect("write should succeed");
789
790 let output = proc
792 .read_until(b"hello-pty-process", Duration::from_secs(5))
793 .expect("should find output");
794
795 assert!(
796 output
797 .windows(b"hello-pty-process".len())
798 .any(|w| w == b"hello-pty-process"),
799 "expected to find 'hello-pty-process' in output"
800 );
801
802 proc.kill().expect("kill should succeed");
804 assert!(!proc.is_alive());
805 }
806
807 #[cfg(unix)]
808 #[test]
809 fn spawn_with_env() {
810 let config = ShellConfig::default()
811 .logging(false)
812 .env("TEST_VAR", "test_value_123");
813
814 let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
815
816 proc.write_all(b"echo $TEST_VAR\n")
817 .expect("write should succeed");
818
819 let output = proc
820 .read_until(b"test_value_123", Duration::from_secs(5))
821 .expect("should find env var in output");
822
823 assert!(
824 output
825 .windows(b"test_value_123".len())
826 .any(|w| w == b"test_value_123"),
827 "expected to find env var value in output"
828 );
829
830 proc.kill().expect("kill should succeed");
831 }
832
833 #[cfg(unix)]
834 #[test]
835 fn exit_command_terminates() {
836 let config = ShellConfig::default().logging(false);
837 let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
838
839 proc.write_all(b"exit 0\n").expect("write should succeed");
840
841 let status = proc
843 .wait_timeout(Duration::from_secs(5))
844 .expect("wait should succeed");
845 assert!(status.success());
846 assert!(!proc.is_alive());
847 }
848
849 #[cfg(unix)]
850 #[test]
851 fn kill_is_idempotent() {
852 let config = ShellConfig::default().logging(false);
853 let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
854
855 proc.kill().expect("first kill should succeed");
856 proc.kill().expect("second kill should succeed");
857 proc.kill().expect("third kill should succeed");
858
859 assert!(!proc.is_alive());
860 }
861
862 #[cfg(unix)]
863 #[test]
864 fn drain_captures_all_output() {
865 let config = ShellConfig::default().logging(false);
866 let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
867
868 proc.write_all(b"for i in 1 2 3 4 5; do echo line$i; done; exit 0\n")
870 .expect("write should succeed");
871
872 let _ = proc.wait_timeout(Duration::from_secs(5));
874
875 let _ = proc.drain(Duration::from_secs(2));
877
878 let output = String::from_utf8_lossy(proc.output());
879 for i in 1..=5 {
880 assert!(
881 output.contains(&format!("line{i}")),
882 "missing line{i} in output: {output:?}"
883 );
884 }
885 }
886
887 #[cfg(unix)]
888 #[test]
889 fn clear_output_works() {
890 let config = ShellConfig::default().logging(false);
891 let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
892
893 proc.write_all(b"echo test\n")
894 .expect("write should succeed");
895 thread::sleep(Duration::from_millis(100));
896 let _ = proc.read_available();
897
898 assert!(!proc.output().is_empty());
899
900 proc.clear_output();
901 assert!(proc.output().is_empty());
902
903 proc.kill().expect("kill should succeed");
904 }
905
906 #[cfg(unix)]
907 #[test]
908 fn specific_shell_path() {
909 let config = ShellConfig::with_shell("/bin/sh").logging(false);
910 let mut proc = PtyProcess::spawn(config).expect("spawn should succeed");
911
912 assert!(proc.is_alive());
913 proc.kill().expect("kill should succeed");
914 }
915
916 #[cfg(unix)]
917 #[test]
918 fn invalid_shell_fails() {
919 let config = ShellConfig::with_shell("/nonexistent/shell").logging(false);
920 let result = PtyProcess::spawn(config);
921
922 assert!(result.is_err());
923 }
924}