Skip to main content

atuin_pty_proxy/
lib.rs

1pub mod osc133;
2
3use clap::{Args, Subcommand, ValueEnum};
4
5#[derive(Subcommand, Debug)]
6pub enum Cmd {
7    /// Print shell code to initialize atuin pty-proxy on shell startup
8    Init(Init),
9}
10
11#[derive(Args, Debug)]
12pub struct Init {
13    /// Shell to generate init for. If omitted, attempt auto-detection
14    #[arg(value_enum)]
15    shell: Option<Shell>,
16}
17
18#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)]
19#[value(rename_all = "lower")]
20#[allow(clippy::enum_variant_names, clippy::doc_markdown)]
21enum Shell {
22    /// Zsh setup
23    Zsh,
24    /// Bash setup
25    Bash,
26    /// Fish setup
27    Fish,
28    /// Nu setup
29    Nu,
30}
31
32impl Init {
33    fn run(self) -> Result<(), String> {
34        let shell = detect_shell(self.shell)?;
35        let script = render_init(shell);
36        print!("{script}");
37        Ok(())
38    }
39}
40
41pub fn run(cmd: Option<Cmd>) {
42    match cmd {
43        Some(Cmd::Init(init)) => {
44            if let Err(err) = init.run() {
45                eprintln!("atuin pty-proxy: {err}");
46                std::process::exit(1);
47            }
48        }
49        None => app::main(),
50    }
51}
52
53fn detect_shell(cli_shell: Option<Shell>) -> Result<Shell, String> {
54    if let Some(shell) = cli_shell {
55        return Ok(shell);
56    }
57
58    if let Ok(shell) = std::env::var("ATUIN_SHELL")
59        && let Some(shell) = shell_from_name(&shell)
60    {
61        return Ok(shell);
62    }
63
64    if let Ok(shell) = std::env::var("SHELL")
65        && let Some(shell) = shell_from_name(&shell)
66    {
67        return Ok(shell);
68    }
69
70    Err(
71        "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu"
72            .to_string(),
73    )
74}
75
76fn shell_from_name(name: &str) -> Option<Shell> {
77    let shell = name
78        .trim()
79        .rsplit('/')
80        .next()
81        .unwrap_or(name)
82        .trim_start_matches('-')
83        .to_ascii_lowercase();
84
85    match shell.as_str() {
86        "bash" => Some(Shell::Bash),
87        "zsh" => Some(Shell::Zsh),
88        "fish" => Some(Shell::Fish),
89        "nu" => Some(Shell::Nu),
90        _ => None,
91    }
92}
93
94fn render_init(shell: Shell) -> &'static str {
95    match shell {
96        Shell::Bash | Shell::Zsh => {
97            r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then
98  _atuin_pty_proxy_tmux_current="${TMUX:-}"
99  _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-${ATUIN_HEX_TMUX:-}}"
100
101  if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-${ATUIN_HEX_ACTIVE:-}}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then
102    export ATUIN_PTY_PROXY_ACTIVE=1
103    export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current"
104    exec atuin pty-proxy
105  fi
106
107  unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous
108fi
109"#
110        }
111        Shell::Fish => {
112            r#"if status is-interactive; and test -t 0; and test -t 1
113    set -l _atuin_pty_proxy_tmux_current ""
114    if set -q TMUX
115        set _atuin_pty_proxy_tmux_current "$TMUX"
116    end
117
118    set -l _atuin_pty_proxy_tmux_previous ""
119    if set -q ATUIN_PTY_PROXY_TMUX
120        set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX"
121    else if set -q ATUIN_HEX_TMUX
122        set _atuin_pty_proxy_tmux_previous "$ATUIN_HEX_TMUX"
123    end
124
125    if not set -q ATUIN_PTY_PROXY_ACTIVE; and not set -q ATUIN_HEX_ACTIVE
126        set -gx ATUIN_PTY_PROXY_ACTIVE 1
127        set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current"
128        exec atuin pty-proxy
129    else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous"
130        set -gx ATUIN_PTY_PROXY_ACTIVE 1
131        set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current"
132        exec atuin pty-proxy
133    end
134end
135"#
136        }
137        // Nushell cannot dynamically source the output of `atuin init nu`,
138        // so we only output the pty-proxy preamble here. Users must also set up
139        // `atuin init nu` separately.
140        Shell::Nu => {
141            r#"if (is-terminal --stdin) and (is-terminal --stdout) {
142    let tmux_current = ($env.TMUX? | default "")
143    let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default ($env.ATUIN_HEX_TMUX? | default ""))
144
145    if (($env.ATUIN_PTY_PROXY_ACTIVE? | default ($env.ATUIN_HEX_ACTIVE? | default "")) | is-empty) or ($tmux_current != $tmux_previous) {
146        $env.ATUIN_PTY_PROXY_ACTIVE = "1"
147        $env.ATUIN_PTY_PROXY_TMUX = $tmux_current
148        exec atuin pty-proxy
149    }
150}
151"#
152        }
153    }
154}
155
156#[cfg(not(unix))]
157mod app {
158    pub(crate) fn main() {
159        eprintln!("atuin pty-proxy currently supports unix platforms");
160        std::process::exit(1);
161    }
162}
163
164#[cfg(unix)]
165mod app {
166    use std::io::{Read, Write};
167    use std::os::unix::net::UnixListener;
168    use std::sync::mpsc;
169
170    use crossterm::terminal;
171    use portable_pty::{CommandBuilder, PtySize, native_pty_system};
172
173    enum ParserMsg {
174        Data(Vec<u8>),
175        Resize { rows: u16, cols: u16 },
176        ScreenRequest(mpsc::Sender<Vec<u8>>),
177    }
178
179    pub(crate) fn main() {
180        if let Err(e) = run() {
181            let _ = terminal::disable_raw_mode();
182            eprintln!("atuin pty-proxy: {e:#}");
183            std::process::exit(1);
184        }
185    }
186
187    fn socket_path() -> std::path::PathBuf {
188        let dir = std::env::temp_dir();
189        dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id()))
190    }
191
192    /// Wire format written to the Unix socket:
193    ///
194    /// ```text
195    /// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE]
196    /// [row_0_len: u32 BE][row_0_bytes...]
197    /// [row_1_len: u32 BE][row_1_bytes...]
198    /// ...
199    /// ```
200    ///
201    /// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain
202    /// pre-built ANSI escape sequences.  The client can write them directly to
203    /// stdout without needing its own vt100 parser.
204    fn encode_screen(parser: &vt100::Parser) -> Vec<u8> {
205        let screen = parser.screen();
206        let (rows, cols) = screen.size();
207        let (cursor_row, cursor_col) = screen.cursor_position();
208
209        let mut buf: Vec<u8> = Vec::with_capacity(256 + (rows as usize * cols as usize));
210        buf.extend_from_slice(&rows.to_be_bytes());
211        buf.extend_from_slice(&cols.to_be_bytes());
212        buf.extend_from_slice(&cursor_row.to_be_bytes());
213        buf.extend_from_slice(&cursor_col.to_be_bytes());
214
215        for row_bytes in screen.rows_formatted(0, cols) {
216            let len = row_bytes.len() as u32;
217            buf.extend_from_slice(&len.to_be_bytes());
218            buf.extend_from_slice(&row_bytes);
219        }
220
221        buf
222    }
223
224    fn handle_parser_msg(parser: &mut vt100::Parser, msg: ParserMsg) {
225        match msg {
226            ParserMsg::Data(data) => parser.process(&data),
227            ParserMsg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols),
228            ParserMsg::ScreenRequest(reply_tx) => {
229                let _ = reply_tx.send(encode_screen(parser));
230            }
231        }
232    }
233
234    fn run() -> eyre::Result<()> {
235        let (cols, rows) = terminal::size()?;
236
237        let pty_system = native_pty_system();
238        let pair = pty_system
239            .openpty(PtySize {
240                rows,
241                cols,
242                pixel_width: 0,
243                pixel_height: 0,
244            })
245            .map_err(|e| eyre::eyre!("{e:#}"))?;
246
247        // Set up socket path and expose it to child processes
248        let sock_path = socket_path();
249        // Clean up any stale socket from a previous crash
250        let _ = std::fs::remove_file(&sock_path);
251
252        let mut cmd = CommandBuilder::new_default_prog();
253        cmd.cwd(std::env::current_dir()?);
254        cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str());
255        cmd.env("ATUIN_HEX_SOCKET", sock_path.as_os_str());
256
257        let mut child = pair
258            .slave
259            .spawn_command(cmd)
260            .map_err(|e| eyre::eyre!("{e:#}"))?;
261
262        // Close slave side in parent process
263        drop(pair.slave);
264
265        let mut pty_reader = pair
266            .master
267            .try_clone_reader()
268            .map_err(|e| eyre::eyre!("{e:#}"))?;
269        let mut pty_writer = pair
270            .master
271            .take_writer()
272            .map_err(|e| eyre::eyre!("{e:#}"))?;
273
274        // Channel: stdout/sigwinch/socket threads -> parser thread (bounded, non-blocking send)
275        let (msg_tx, msg_rx) = mpsc::sync_channel::<ParserMsg>(64);
276
277        // --- Parser thread ---
278        // Maintains a persistent vt100::Parser fed bytes as they arrive.
279        // On screen request: reads current state directly (no replay).
280        std::thread::spawn(move || {
281            let mut parser = vt100::Parser::new(rows, cols, 0);
282
283            loop {
284                // Block until at least one message arrives
285                let first = match msg_rx.recv() {
286                    Ok(msg) => msg,
287                    Err(_) => break,
288                };
289
290                handle_parser_msg(&mut parser, first);
291
292                // Drain all remaining pending messages so the parser stays
293                // caught up during high-throughput bursts (e.g. `cat bigfile`).
294                // The channel holds at most 64 items, so this is bounded.
295                while let Ok(msg) = msg_rx.try_recv() {
296                    handle_parser_msg(&mut parser, msg);
297                }
298            }
299        });
300
301        // --- Socket server thread ---
302        // Listens on Unix socket; on connection, requests screen state from parser thread.
303        {
304            let sock_path_clone = sock_path.clone();
305            let screen_tx = msg_tx.clone();
306            std::thread::spawn(move || {
307                let listener = match UnixListener::bind(&sock_path_clone) {
308                    Ok(l) => l,
309                    Err(e) => {
310                        eprintln!("atuin pty-proxy: failed to bind socket: {e}");
311                        return;
312                    }
313                };
314
315                for stream in listener.incoming() {
316                    let mut stream = match stream {
317                        Ok(s) => s,
318                        Err(_) => break,
319                    };
320
321                    let (reply_tx, reply_rx) = mpsc::channel();
322                    if screen_tx.send(ParserMsg::ScreenRequest(reply_tx)).is_err() {
323                        break;
324                    }
325                    if let Ok(data) = reply_rx.recv() {
326                        let _ = stream.write_all(&data);
327                        let _ = stream.flush();
328                    }
329                }
330            });
331        }
332
333        // Handle terminal resize via SIGWINCH
334        {
335            use signal_hook::consts::SIGWINCH;
336            use signal_hook::iterator::Signals;
337
338            let master = pair.master;
339            let resize_tx = msg_tx.clone();
340            let mut signals = Signals::new([SIGWINCH])?;
341
342            std::thread::spawn(move || {
343                for _ in signals.forever() {
344                    if let Ok((cols, rows)) = terminal::size() {
345                        let _ = master.resize(PtySize {
346                            rows,
347                            cols,
348                            pixel_width: 0,
349                            pixel_height: 0,
350                        });
351                        let _ = resize_tx.try_send(ParserMsg::Resize { rows, cols });
352                    }
353                }
354            });
355        }
356
357        terminal::enable_raw_mode()?;
358
359        // PTY -> stdout (with OSC 133 parsing + buffer feed)
360        let stdout_thread = std::thread::spawn(move || {
361            let mut stdout = std::io::stdout();
362            let mut parser = crate::osc133::Parser::new();
363            let mut buf = [0u8; 8192];
364            loop {
365                match pty_reader.read(&mut buf) {
366                    Ok(0) | Err(_) => break,
367                    Ok(n) => {
368                        parser.push(&buf[..n], |_event| {
369                            // Zone transitions are tracked inside the parser.
370                            // Callers can query parser.zone() after push.
371                        });
372
373                        // Feed bytes to the shadow parser. Drops on backpressure —
374                        // the screen snapshot may be stale during bursts, but
375                        // self-corrects once output settles.
376                        let _ = msg_tx.try_send(ParserMsg::Data(buf[..n].to_vec()));
377
378                        if stdout.write_all(&buf[..n]).is_err() {
379                            break;
380                        }
381                        let _ = stdout.flush();
382                    }
383                }
384            }
385        });
386
387        // stdin -> PTY
388        std::thread::spawn(move || {
389            let mut stdin = std::io::stdin();
390            let mut buf = [0u8; 8192];
391            loop {
392                match stdin.read(&mut buf) {
393                    Ok(0) | Err(_) => break,
394                    Ok(n) => {
395                        if pty_writer.write_all(&buf[..n]).is_err() {
396                            break;
397                        }
398                    }
399                }
400            }
401        });
402
403        let status = child.wait()?;
404        let _ = stdout_thread.join();
405
406        let _ = terminal::disable_raw_mode();
407
408        // Clean up socket file
409        let _ = std::fs::remove_file(&sock_path);
410
411        std::process::exit(process_exit_code(status.exit_code()));
412    }
413
414    fn process_exit_code(code: u32) -> i32 {
415        i32::try_from(code).unwrap_or(1)
416    }
417
418    #[cfg(test)]
419    mod tests {
420        use super::process_exit_code;
421
422        #[test]
423        fn process_exit_code_preserves_valid_values() {
424            assert_eq!(process_exit_code(0), 0);
425            assert_eq!(process_exit_code(127), 127);
426            assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX);
427        }
428
429        #[test]
430        fn process_exit_code_defaults_when_out_of_range() {
431            assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1);
432        }
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::{Shell, render_init, shell_from_name};
439
440    #[test]
441    fn shell_from_name_handles_paths() {
442        assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh));
443        assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash));
444        assert_eq!(shell_from_name("fish"), Some(Shell::Fish));
445        assert_eq!(shell_from_name("nu"), Some(Shell::Nu));
446    }
447
448    #[test]
449    fn posix_init_uses_exec_and_tmux_guard() {
450        let script = render_init(Shell::Bash);
451        assert!(script.contains("exec atuin pty-proxy"));
452        assert!(script.contains("ATUIN_PTY_PROXY_TMUX"));
453        assert!(!script.contains("eval \"$(atuin init bash)\""));
454    }
455
456    #[test]
457    fn posix_init_has_no_double_braces() {
458        let script = render_init(Shell::Bash);
459        assert!(!script.contains("${{"), "double braces in bash init script");
460    }
461
462    #[test]
463    fn fish_init_uses_source() {
464        let script = render_init(Shell::Fish);
465        assert!(script.contains("exec atuin pty-proxy"));
466        assert!(!script.contains("atuin init fish | source"));
467    }
468
469    #[test]
470    fn nu_init_uses_exec_and_tty_guard() {
471        let script = render_init(Shell::Nu);
472        assert!(script.contains("exec atuin pty-proxy"));
473        assert!(script.contains("ATUIN_PTY_PROXY_TMUX"));
474        assert!(script.contains("is-terminal --stdin"));
475        assert!(script.contains("is-terminal --stdout"));
476        assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE"));
477    }
478}