Skip to main content

coding_agent_search/
tui_asciicast.rs

1use anyhow::{Context, Result, anyhow, bail};
2use ftui::runtime::{AsciicastRecorder, AsciicastWriter};
3use portable_pty::{CommandBuilder, PtySize, native_pty_system};
4use std::fs::{File, OpenOptions};
5use std::io::{self, BufWriter, IsTerminal, Read, Write};
6#[cfg(unix)]
7use std::os::fd::{AsRawFd, RawFd};
8use std::path::{Path, PathBuf};
9use std::sync::{
10    Arc,
11    atomic::{AtomicBool, Ordering},
12};
13use std::time::{Duration, SystemTime, UNIX_EPOCH};
14
15// Inline POSIX constants and FFI for fcntl / EIO — avoids a direct `libc` dependency.
16#[cfg(unix)]
17mod posix {
18    use std::ffi::c_int;
19    pub const EIO: c_int = 5;
20    pub const F_GETFL: c_int = 3;
21    pub const F_SETFL: c_int = 4;
22
23    // O_NONBLOCK varies across platforms — must use the right constant.
24    #[cfg(any(
25        target_os = "macos",
26        target_os = "ios",
27        target_os = "freebsd",
28        target_os = "openbsd",
29        target_os = "netbsd",
30        target_os = "dragonfly"
31    ))]
32    pub const O_NONBLOCK: c_int = 0x0004;
33    #[cfg(not(any(
34        target_os = "macos",
35        target_os = "ios",
36        target_os = "freebsd",
37        target_os = "openbsd",
38        target_os = "netbsd",
39        target_os = "dragonfly"
40    )))]
41    pub const O_NONBLOCK: c_int = 0o4000;
42
43    unsafe extern "C" {
44        pub fn fcntl(fd: c_int, cmd: c_int, ...) -> c_int;
45    }
46}
47
48/// Run the current `cass tui` invocation inside a PTY and mirror output to an
49/// asciicast v2 file.
50///
51/// This records terminal output only by default. Input bytes are intentionally
52/// not captured to reduce accidental secret leakage (passwords/tokens typed in
53/// the terminal are not serialized into the recording stream).
54pub fn run_tui_with_asciicast(recording_path: &Path, interactive: bool) -> Result<()> {
55    ensure_asciicast_output_available(recording_path)?;
56
57    let (child_args, removed_flag) = strip_asciicast_args(std::env::args().skip(1));
58    if !removed_flag {
59        return Err(anyhow!(
60            "internal error: --asciicast flag was not found in process arguments"
61        ));
62    }
63
64    let exe_path = std::env::current_exe().context("resolve current executable path")?;
65    let exe_str = exe_path
66        .to_str()
67        .ok_or_else(|| anyhow!("executable path is not valid UTF-8"))?;
68
69    let (cols, rows) = detect_terminal_size();
70    let pty_system = native_pty_system();
71    let pair = pty_system
72        .openpty(PtySize {
73            rows,
74            cols,
75            pixel_width: 0,
76            pixel_height: 0,
77        })
78        .context("open PTY for asciicast recording")?;
79
80    let mut cmd = CommandBuilder::new(exe_str);
81    for arg in child_args {
82        cmd.arg(arg);
83    }
84    // Parent already handled update prompt check; avoid duplicate prompt in child.
85    cmd.env("CODING_AGENT_SEARCH_NO_UPDATE_PROMPT", "1");
86
87    let mut child = pair
88        .slave
89        .spawn_command(cmd)
90        .context("spawn TUI child process for asciicast recording")?;
91    drop(pair.slave);
92
93    let mut reader = pair
94        .master
95        .try_clone_reader()
96        .context("clone PTY reader for asciicast capture")?;
97
98    let mut writer_keepalive = Some(
99        pair.master
100            .take_writer()
101            .context("take PTY writer for input forwarding")?,
102    );
103    let mut stdin_forwarder: Option<std::thread::JoinHandle<()>> = None;
104    let mut stdin_stop_requested: Option<Arc<AtomicBool>> = None;
105    #[cfg(unix)]
106    let mut _stdin_nonblocking_guard: Option<StdinNonBlockingGuard> = None;
107
108    let allow_input = interactive
109        && io::stdin().is_terminal()
110        && io::stdout().is_terminal()
111        && dotenvy::var("TUI_HEADLESS").is_err();
112
113    let _raw_mode = RawModeGuard::new(allow_input)?;
114
115    if allow_input && let Some(writer) = writer_keepalive.take() {
116        let stop_requested = Arc::new(AtomicBool::new(false));
117        #[cfg(unix)]
118        {
119            _stdin_nonblocking_guard = StdinNonBlockingGuard::new(io::stdin().as_raw_fd()).ok();
120        }
121        let stop_for_thread = Arc::clone(&stop_requested);
122        stdin_forwarder = Some(std::thread::spawn(move || {
123            forward_stdin(writer, stop_for_thread)
124        }));
125        stdin_stop_requested = Some(stop_requested);
126    }
127
128    let run_result: Result<_> = (|| {
129        let recorder = open_asciicast_recorder_no_overwrite(recording_path, cols, rows)
130            .with_context(|| format!("create asciicast file at {}", recording_path.display()))?;
131        let mut mirror = AsciicastWriter::new(io::stdout(), recorder);
132
133        let mut buf = [0_u8; 8192];
134        loop {
135            match reader.read(&mut buf) {
136                Ok(0) => break,
137                Ok(n) => {
138                    mirror
139                        .write_all(&buf[..n])
140                        .context("write PTY output to terminal/asciicast mirror")?;
141                }
142                Err(err) if err.kind() == io::ErrorKind::Interrupted => {}
143                Err(err) if is_pty_eof_error(&err) => break,
144                Err(err) => return Err(err).context("read PTY output"),
145            }
146        }
147
148        let _ = mirror.finish().context("finalize asciicast recording")?;
149        let _ = writer_keepalive.take();
150
151        child
152            .wait()
153            .context("wait for TUI child process to exit after recording")
154    })();
155
156    if let Some(stop_requested) = stdin_stop_requested.take() {
157        stop_requested.store(true, Ordering::Relaxed);
158    }
159
160    if let Some(handle) = stdin_forwarder.take() {
161        #[cfg(unix)]
162        {
163            if _stdin_nonblocking_guard.is_some() || handle.is_finished() {
164                let _ = handle.join();
165            }
166        }
167        #[cfg(not(unix))]
168        {
169            if handle.is_finished() {
170                let _ = handle.join();
171            }
172        }
173        // If stdin could not be switched to nonblocking and the reader is still
174        // blocked, dropping the handle intentionally detaches.
175    }
176
177    let status = match run_result {
178        Ok(status) => status,
179        Err(err) => {
180            let _ = writer_keepalive.take();
181            let _ = child.kill();
182            let _ = child.wait();
183            return Err(err);
184        }
185    };
186
187    if !status.success() {
188        bail!("TUI exited with non-zero status while recording: {status}");
189    }
190    Ok(())
191}
192
193fn open_asciicast_recorder_no_overwrite(
194    recording_path: &Path,
195    cols: u16,
196    rows: u16,
197) -> Result<AsciicastRecorder<BufWriter<File>>> {
198    ensure_asciicast_output_available(recording_path)?;
199    let mut options = OpenOptions::new();
200    options.write(true).create_new(true);
201    #[cfg(unix)]
202    {
203        use std::os::unix::fs::OpenOptionsExt;
204        options.mode(0o600);
205    }
206    let file = options.open(recording_path).with_context(|| {
207        format!(
208            "create asciicast output file without overwrite at {}",
209            recording_path.display()
210        )
211    })?;
212    let timestamp = SystemTime::now()
213        .duration_since(UNIX_EPOCH)
214        .map_err(|err| anyhow!("system clock is before Unix epoch: {err}"))?
215        .as_secs()
216        .try_into()
217        .context("asciicast timestamp exceeds i64 range")?;
218    AsciicastRecorder::with_writer(BufWriter::new(file), cols, rows, timestamp)
219        .with_context(|| format!("write asciicast header to {}", recording_path.display()))
220}
221
222fn ensure_asciicast_output_available(path: &Path) -> Result<()> {
223    if path.file_name().filter(|name| !name.is_empty()).is_none() {
224        bail!(
225            "asciicast output path must include a filename: {}",
226            path.display()
227        );
228    }
229
230    let parent = path
231        .parent()
232        .filter(|parent| !parent.as_os_str().is_empty())
233        .unwrap_or_else(|| Path::new("."));
234    ensure_asciicast_parent(parent)?;
235
236    match std::fs::symlink_metadata(path) {
237        Ok(_) => bail!("asciicast output already exists: {}", path.display()),
238        Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(()),
239        Err(err) => {
240            Err(err).with_context(|| format!("inspect asciicast output {}", path.display()))
241        }
242    }
243}
244
245fn ensure_asciicast_parent(parent: &Path) -> Result<()> {
246    ensure_parent_chain_has_no_symlinks(parent)?;
247    std::fs::create_dir_all(parent)
248        .with_context(|| format!("create asciicast parent directory {}", parent.display()))?;
249    ensure_parent_chain_has_no_symlinks(parent)
250}
251
252fn ensure_parent_chain_has_no_symlinks(path: &Path) -> Result<()> {
253    let mut ancestors: Vec<PathBuf> = path
254        .ancestors()
255        .filter(|ancestor| !ancestor.as_os_str().is_empty())
256        .map(Path::to_path_buf)
257        .collect();
258    ancestors.reverse();
259
260    for ancestor in ancestors {
261        match std::fs::symlink_metadata(&ancestor) {
262            Ok(metadata) => {
263                let file_type = metadata.file_type();
264                if file_type.is_symlink() {
265                    if is_allowed_system_symlink_ancestor(&ancestor) {
266                        continue;
267                    }
268                    bail!(
269                        "asciicast output parent must not contain symlinks: {}",
270                        ancestor.display()
271                    );
272                }
273                if !file_type.is_dir() {
274                    bail!(
275                        "asciicast output parent is not a directory: {}",
276                        ancestor.display()
277                    );
278                }
279            }
280            Err(err) if err.kind() == io::ErrorKind::NotFound => {}
281            Err(err) => {
282                return Err(err).with_context(|| {
283                    format!("inspect asciicast parent directory {}", ancestor.display())
284                });
285            }
286        }
287    }
288
289    Ok(())
290}
291
292#[cfg(target_os = "macos")]
293fn is_allowed_system_symlink_ancestor(path: &Path) -> bool {
294    path == Path::new("/var") || path == Path::new("/tmp")
295}
296
297#[cfg(not(target_os = "macos"))]
298fn is_allowed_system_symlink_ancestor(_path: &Path) -> bool {
299    false
300}
301
302fn detect_terminal_size() -> (u16, u16) {
303    fn env_dim(key: &str) -> Option<u16> {
304        dotenvy::var(key)
305            .ok()
306            .and_then(|raw| raw.trim().parse::<u16>().ok())
307            .filter(|value| *value > 0)
308    }
309
310    let env_cols = env_dim("COLUMNS");
311    let env_rows = env_dim("LINES");
312    if let (Some(cols), Some(rows)) = (env_cols, env_rows) {
313        return (cols, rows);
314    }
315
316    #[cfg(unix)]
317    {
318        if io::stdin().is_terminal() {
319            let output = std::process::Command::new("stty").arg("size").output().ok();
320            if let Some(output) = output
321                && output.status.success()
322                && let Ok(text) = String::from_utf8(output.stdout)
323            {
324                let mut parts = text.split_whitespace();
325                if let (Some(rows), Some(cols)) = (parts.next(), parts.next())
326                    && let (Ok(rows), Ok(cols)) = (rows.parse::<u16>(), cols.parse::<u16>())
327                    && rows > 0
328                    && cols > 0
329                {
330                    return (cols, rows);
331                }
332            }
333        }
334    }
335
336    (120, 40)
337}
338
339fn forward_stdin(mut child_writer: Box<dyn Write + Send>, stop_requested: Arc<AtomicBool>) {
340    let stdin = io::stdin();
341    let mut stdin_lock = stdin.lock();
342    let mut buf = [0_u8; 256];
343    loop {
344        if stop_requested.load(Ordering::Relaxed) {
345            break;
346        }
347        match stdin_lock.read(&mut buf) {
348            Ok(0) => break,
349            Ok(n) => {
350                if child_writer.write_all(&buf[..n]).is_err() {
351                    break;
352                }
353                if child_writer.flush().is_err() {
354                    break;
355                }
356            }
357            Err(err) if err.kind() == io::ErrorKind::Interrupted => {}
358            Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
359                std::thread::sleep(Duration::from_millis(10));
360            }
361            Err(_) => break,
362        }
363    }
364}
365
366fn strip_asciicast_args<I>(args: I) -> (Vec<String>, bool)
367where
368    I: IntoIterator<Item = String>,
369{
370    let mut out = Vec::new();
371    let mut removed = false;
372    let mut iter = args.into_iter();
373    while let Some(arg) = iter.next() {
374        if arg == "--asciicast" {
375            removed = true;
376            let _ = iter.next();
377            continue;
378        }
379        if arg.starts_with("--asciicast=") {
380            removed = true;
381            continue;
382        }
383        out.push(arg);
384    }
385    (out, removed)
386}
387
388fn is_pty_eof_error(err: &io::Error) -> bool {
389    if matches!(
390        err.kind(),
391        io::ErrorKind::UnexpectedEof | io::ErrorKind::BrokenPipe
392    ) {
393        return true;
394    }
395    #[cfg(unix)]
396    {
397        err.raw_os_error() == Some(posix::EIO)
398    }
399    #[cfg(not(unix))]
400    {
401        false
402    }
403}
404
405struct RawModeGuard {
406    #[cfg(unix)]
407    inner: Option<ftui_tty::RawModeGuard>,
408}
409
410impl RawModeGuard {
411    fn new(enabled: bool) -> Result<Self> {
412        #[cfg(unix)]
413        {
414            let inner = if enabled {
415                Some(
416                    ftui_tty::RawModeGuard::enter()
417                        .context("enable raw mode for input passthrough")?,
418                )
419            } else {
420                None
421            };
422            Ok(Self { inner })
423        }
424        #[cfg(not(unix))]
425        {
426            let _ = enabled;
427            Ok(Self {})
428        }
429    }
430}
431
432#[cfg(unix)]
433impl Drop for RawModeGuard {
434    fn drop(&mut self) {
435        let _ = self.inner.take();
436    }
437}
438
439#[cfg(unix)]
440struct StdinNonBlockingGuard {
441    fd: RawFd,
442    old_flags: std::ffi::c_int,
443}
444
445#[cfg(unix)]
446impl StdinNonBlockingGuard {
447    fn new(fd: RawFd) -> io::Result<Self> {
448        // SAFETY: fcntl does not outlive `fd` and is called with valid command
449        // constants; errors are surfaced via last_os_error.
450        let old_flags = unsafe { posix::fcntl(fd, posix::F_GETFL) };
451        if old_flags < 0 {
452            return Err(io::Error::last_os_error());
453        }
454
455        // SAFETY: same as above; we preserve and later restore original flags.
456        let set_result = unsafe { posix::fcntl(fd, posix::F_SETFL, old_flags | posix::O_NONBLOCK) };
457        if set_result < 0 {
458            return Err(io::Error::last_os_error());
459        }
460
461        Ok(Self { fd, old_flags })
462    }
463}
464
465#[cfg(unix)]
466impl Drop for StdinNonBlockingGuard {
467    fn drop(&mut self) {
468        // SAFETY: best-effort restoration of original descriptor flags.
469        unsafe {
470            let _ = posix::fcntl(self.fd, posix::F_SETFL, self.old_flags);
471        }
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::{
478        ensure_asciicast_output_available, is_pty_eof_error, open_asciicast_recorder_no_overwrite,
479        strip_asciicast_args,
480    };
481    use std::io;
482    use std::io::Write as _;
483
484    #[test]
485    fn strips_split_asciicast_flag_and_value() {
486        let input = vec![
487            "tui".to_string(),
488            "--asciicast".to_string(),
489            "demo.cast".to_string(),
490            "--once".to_string(),
491        ];
492        let (args, removed) = strip_asciicast_args(input);
493        assert!(removed);
494        assert_eq!(args, vec!["tui", "--once"]);
495    }
496
497    #[test]
498    fn strips_inline_asciicast_flag() {
499        let input = vec![
500            "tui".to_string(),
501            "--asciicast=demo.cast".to_string(),
502            "--data-dir".to_string(),
503            "/tmp/cass".to_string(),
504        ];
505        let (args, removed) = strip_asciicast_args(input);
506        assert!(removed);
507        assert_eq!(args, vec!["tui", "--data-dir", "/tmp/cass"]);
508    }
509
510    #[test]
511    fn leaves_unrelated_args_untouched() {
512        let input = vec!["tui".to_string(), "--once".to_string()];
513        let (args, removed) = strip_asciicast_args(input.clone());
514        assert!(!removed);
515        assert_eq!(args, input);
516    }
517
518    #[test]
519    fn recognizes_common_pty_eof_errors() {
520        let eof = io::Error::new(io::ErrorKind::UnexpectedEof, "eof");
521        assert!(is_pty_eof_error(&eof));
522
523        let pipe = io::Error::new(io::ErrorKind::BrokenPipe, "broken");
524        assert!(is_pty_eof_error(&pipe));
525    }
526
527    #[test]
528    fn creates_asciicast_parent_and_new_output() {
529        let tmp = tempfile::TempDir::new().expect("tempdir");
530        let output_path = tmp.path().join("nested").join("demo.cast");
531
532        let recorder =
533            open_asciicast_recorder_no_overwrite(&output_path, 80, 24).expect("open recorder");
534        let mut writer = recorder.finish().expect("finish recorder");
535        writer.flush().expect("flush recorder");
536
537        let contents = std::fs::read_to_string(&output_path).expect("read asciicast");
538        assert!(
539            contents.starts_with("{\"version\":2"),
540            "unexpected asciicast header: {contents:?}"
541        );
542    }
543
544    #[test]
545    #[cfg(unix)]
546    fn creates_asciicast_output_with_private_permissions() {
547        use std::os::unix::fs::PermissionsExt;
548
549        let tmp = tempfile::TempDir::new().expect("tempdir");
550        let output_path = tmp.path().join("demo.cast");
551
552        let recorder =
553            open_asciicast_recorder_no_overwrite(&output_path, 80, 24).expect("open recorder");
554        let mut writer = recorder.finish().expect("finish recorder");
555        writer.flush().expect("flush recorder");
556
557        let mode = std::fs::metadata(&output_path)
558            .expect("metadata")
559            .permissions()
560            .mode()
561            & 0o777;
562        assert_eq!(
563            mode, 0o600,
564            "asciicast recordings should not gain group/other permissions"
565        );
566    }
567
568    #[test]
569    fn rejects_existing_asciicast_output_without_clobbering() {
570        let tmp = tempfile::TempDir::new().expect("tempdir");
571        let output_path = tmp.path().join("demo.cast");
572        std::fs::write(&output_path, "existing cast").expect("seed existing output");
573
574        let err = open_asciicast_recorder_no_overwrite(&output_path, 80, 24)
575            .expect_err("existing output should be rejected");
576
577        assert!(
578            err.to_string().contains("already exists"),
579            "unexpected error: {err:#}"
580        );
581        assert_eq!(
582            std::fs::read_to_string(&output_path).expect("read existing output"),
583            "existing cast"
584        );
585    }
586
587    #[test]
588    #[cfg(unix)]
589    fn rejects_existing_asciicast_output_symlink_without_following() {
590        use std::os::unix::fs::symlink;
591
592        let tmp = tempfile::TempDir::new().expect("tempdir");
593        let protected_target = tmp.path().join("protected.cast");
594        let output_path = tmp.path().join("demo.cast");
595        std::fs::write(&protected_target, "protected").expect("seed protected target");
596        symlink(&protected_target, &output_path).expect("create output symlink");
597
598        let err = open_asciicast_recorder_no_overwrite(&output_path, 80, 24)
599            .expect_err("symlink output should be rejected");
600
601        assert!(
602            err.to_string().contains("already exists"),
603            "unexpected error: {err:#}"
604        );
605        assert_eq!(
606            std::fs::read_to_string(&protected_target).expect("read protected target"),
607            "protected"
608        );
609        assert_eq!(
610            std::fs::read_link(&output_path).expect("output path remains symlink"),
611            protected_target
612        );
613    }
614
615    #[test]
616    #[cfg(unix)]
617    fn rejects_symlinked_asciicast_parent_before_creating_output() {
618        use std::os::unix::fs::symlink;
619
620        let tmp = tempfile::TempDir::new().expect("tempdir");
621        let outside_dir = tmp.path().join("outside");
622        let linked_dir = tmp.path().join("linked");
623        std::fs::create_dir_all(&outside_dir).expect("create outside dir");
624        symlink(&outside_dir, &linked_dir).expect("create parent symlink");
625        let output_path = linked_dir.join("demo.cast");
626
627        let err = ensure_asciicast_output_available(&output_path)
628            .expect_err("symlinked parent should be rejected");
629
630        assert!(
631            err.to_string().contains("must not contain symlinks"),
632            "unexpected error: {err:#}"
633        );
634        assert!(
635            !outside_dir.join("demo.cast").exists(),
636            "preflight should not write through symlinked parent"
637        );
638    }
639}