Skip to main content

keep_running/
client.rs

1use crate::protocol::{decode_message, encode_message, ClientMessage, DaemonMessage};
2use crate::session::{self, SessionInfo};
3use crate::terminal::{self, status, status_dim, RawModeGuard};
4use anyhow::{Context, Result};
5use crossterm::tty::IsTty;
6use crossterm::{
7    cursor::{Hide, MoveTo, Show},
8    execute,
9    terminal::{Clear, ClearType},
10};
11use std::io::{self, Read, Write};
12use std::os::fd::AsRawFd;
13use std::os::unix::net::UnixStream;
14use std::sync::atomic::{AtomicBool, Ordering};
15use std::sync::Arc;
16use std::time::Duration;
17
18/// Detach sequence: Ctrl+a followed by 'd'
19const CTRL_A: u8 = 0x01;
20
21/// Drop raw mode (restoring ONLCR etc.) and emit a clean newline so the
22/// next status line lands at column 0. Centralises the previous
23/// `drop(raw_guard); eprint!("\r\n");` repetition — the explicit `\r` was
24/// redundant once cooked mode was restored.
25fn exit_raw(guard: RawModeGuard) {
26    drop(guard);
27    eprintln!();
28}
29
30/// Connect to a session and run the interactive client
31pub fn attach(session: &SessionInfo) -> Result<()> {
32    // Check if we have a terminal
33    if !io::stdin().is_tty() {
34        anyhow::bail!("stdin is not a terminal - cannot attach to session");
35    }
36
37    let socket_path = session::socket_path(&session.name)?;
38    let mut stream =
39        UnixStream::connect(&socket_path).context("Failed to connect to session daemon")?;
40
41    // Set read timeout for handshake
42    stream.set_read_timeout(Some(Duration::from_secs(5)))?;
43    stream.set_nonblocking(false)?;
44
45    // Get terminal size
46    let (cols, rows) = terminal::get_size()?;
47
48    // Send attach message
49    let msg = ClientMessage::Attach { cols, rows };
50    let encoded = encode_message(&msg)?;
51    stream.write_all(&encoded)?;
52
53    // Wait for attached confirmation (blocking with timeout)
54    let mut buf = [0u8; 8192];
55    let n = stream
56        .read(&mut buf)
57        .context("Failed to read attach confirmation")?;
58    if n == 0 {
59        anyhow::bail!("Connection closed while waiting for attach confirmation");
60    }
61    let mut msg_buf = buf[..n].to_vec();
62
63    loop {
64        if let Some((msg, consumed)) = decode_message::<DaemonMessage>(&msg_buf)? {
65            msg_buf.drain(0..consumed);
66            match msg {
67                DaemonMessage::Attached => break,
68                DaemonMessage::Error(e) => anyhow::bail!("Daemon error: {}", e),
69                _ => {}
70            }
71        } else {
72            let n = stream
73                .read(&mut buf)
74                .context("Failed to read from daemon")?;
75            if n == 0 {
76                anyhow::bail!("Connection closed while waiting for attach confirmation");
77            }
78            msg_buf.extend_from_slice(&buf[..n]);
79        }
80    }
81
82    // Clear timeout and switch to non-blocking for main loop
83    stream.set_read_timeout(None)?;
84    stream.set_nonblocking(true)?;
85
86    // Clear screen and move cursor to top-left before showing session content
87    execute!(io::stdout(), Clear(ClearType::All), MoveTo(0, 0))?;
88
89    // Welcome banner (cooked mode — survives at the top of the cleared screen
90    // until program output scrolls over it).
91    status(&format!(
92        "attached to '{}' · pid {}",
93        session.name, session.pid
94    ));
95    status_dim("detach with Ctrl+a d  ·  kill with Ctrl+a k");
96
97    // Enter raw mode
98    let raw_guard = RawModeGuard::enter()?;
99
100    // Set up signal handling for SIGWINCH (terminal resize)
101    let resize_flag = Arc::new(AtomicBool::new(false));
102    let resize_flag_clone = resize_flag.clone();
103
104    // Install SIGWINCH handler
105    unsafe {
106        signal_hook::low_level::register(signal_hook::consts::SIGWINCH, move || {
107            resize_flag_clone.store(true, Ordering::SeqCst);
108        })?;
109    }
110
111    // Main I/O loop
112    let mut input_buf = [0u8; 1024];
113    let mut daemon_buf = [0u8; 8192];
114    let mut daemon_msg_buf = msg_buf; // May have leftover data
115
116    // State for detecting Ctrl+a d
117    let mut saw_ctrl_a = false;
118
119    // Holds a protocol error message captured during decoding so we can
120    // surface it cleanly after dropping raw mode.
121    let mut protocol_error: Option<String> = None;
122
123    // Get file descriptors for polling
124    let stdin_fd = 0i32;
125    let socket_fd = stream.as_raw_fd();
126
127    loop {
128        // Check for resize
129        if resize_flag.swap(false, Ordering::SeqCst) {
130            if let Ok((cols, rows)) = terminal::get_size() {
131                let msg = ClientMessage::Resize { cols, rows };
132                if let Ok(encoded) = encode_message(&msg) {
133                    let _ = stream.write_all(&encoded);
134                }
135            }
136        }
137
138        // Use poll to check for data on stdin and socket
139        let mut poll_fds = [
140            libc::pollfd {
141                fd: stdin_fd,
142                events: libc::POLLIN,
143                revents: 0,
144            },
145            libc::pollfd {
146                fd: socket_fd,
147                events: libc::POLLIN,
148                revents: 0,
149            },
150        ];
151
152        let poll_result = unsafe { libc::poll(poll_fds.as_mut_ptr(), 2, 10) }; // 10ms timeout
153
154        if poll_result < 0 {
155            let err = io::Error::last_os_error();
156            if err.kind() != io::ErrorKind::Interrupted {
157                return Err(err).context("poll failed");
158            }
159            continue;
160        }
161
162        // Check stdin
163        if poll_fds[0].revents & libc::POLLIN != 0 {
164            let n = unsafe {
165                libc::read(
166                    stdin_fd,
167                    input_buf.as_mut_ptr() as *mut libc::c_void,
168                    input_buf.len(),
169                )
170            };
171
172            if n == 0 {
173                // EOF on stdin
174                break;
175            } else if n > 0 {
176                let data = &input_buf[..n as usize];
177
178                // Check for detach/kill sequence (Ctrl+a then d/k)
179                // Batch regular bytes together for efficiency (critical for paste)
180                let mut i = 0;
181                while i < data.len() {
182                    if saw_ctrl_a {
183                        saw_ctrl_a = false;
184                        let byte = data[i];
185                        i += 1;
186
187                        match byte {
188                            b'd' | b'D' => {
189                                // Detach!
190                                let msg = ClientMessage::Detach;
191                                if let Ok(encoded) = encode_message(&msg) {
192                                    let _ = stream.write_all(&encoded);
193                                }
194                                // Drop raw mode before printing so the banner reflows cleanly.
195                                exit_raw(raw_guard);
196                                status(&format!("detached from '{}'", session.name));
197                                status_dim(&format!("reattach: keep-running {}", session.name));
198                                return Ok(());
199                            }
200                            b'k' | b'K' => {
201                                // Kill session!
202                                unsafe {
203                                    libc::kill(session.pid as i32, libc::SIGHUP);
204                                }
205                                exit_raw(raw_guard);
206                                status(&format!("killed '{}'", session.name));
207                                return Ok(());
208                            }
209                            CTRL_A => {
210                                // Double Ctrl+a - send a literal Ctrl+a
211                                let msg = ClientMessage::Input(vec![CTRL_A]);
212                                if let Ok(encoded) = encode_message(&msg) {
213                                    let _ = stream.write_all(&encoded);
214                                }
215                            }
216                            _ => {
217                                // Not a command, send the Ctrl+a we held back, plus this byte
218                                let msg = ClientMessage::Input(vec![CTRL_A, byte]);
219                                if let Ok(encoded) = encode_message(&msg) {
220                                    let _ = stream.write_all(&encoded);
221                                }
222                            }
223                        }
224                    } else {
225                        // Find the next Ctrl+A or end of data
226                        let start = i;
227                        while i < data.len() && data[i] != CTRL_A {
228                            i += 1;
229                        }
230
231                        // Send batch of regular bytes as a single message
232                        if i > start {
233                            let msg = ClientMessage::Input(data[start..i].to_vec());
234                            if let Ok(encoded) = encode_message(&msg) {
235                                let _ = stream.write_all(&encoded);
236                            }
237                        }
238
239                        // If we stopped at Ctrl+A, consume it and set flag
240                        if i < data.len() && data[i] == CTRL_A {
241                            saw_ctrl_a = true;
242                            i += 1;
243                        }
244                    }
245                }
246            }
247        }
248
249        // Check socket
250        let mut socket_eof = false;
251        if poll_fds[1].revents & libc::POLLIN != 0 {
252            match stream.read(&mut daemon_buf) {
253                Ok(0) => {
254                    // Daemon disconnected - but process buffered messages first
255                    socket_eof = true;
256                }
257                Ok(n) => {
258                    daemon_msg_buf.extend_from_slice(&daemon_buf[..n]);
259                }
260                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
261                Err(e) => {
262                    return Err(e).context("Error reading from daemon");
263                }
264            }
265        } else if poll_fds[1].revents & (libc::POLLHUP | libc::POLLERR) != 0 {
266            // Only treat hangup as fatal when there's no data to read (POLLIN not set).
267            // On macOS, POLLHUP can be returned alongside POLLIN during normal
268            // socket state transitions; we should drain data before exiting.
269            socket_eof = true;
270        }
271
272        // Process any messages in the buffer
273        loop {
274            match decode_message::<DaemonMessage>(&daemon_msg_buf) {
275                Ok(Some((msg, consumed))) => {
276                    daemon_msg_buf.drain(0..consumed);
277
278                    match msg {
279                        DaemonMessage::Output(data) => {
280                            terminal::write_stdout(&data)?;
281                        }
282                        DaemonMessage::ReplayStart => {
283                            // Hide cursor during replay to reduce visual noise
284                            let _ = execute!(io::stdout(), Hide);
285                        }
286                        DaemonMessage::ReplayEnd => {
287                            // Restore cursor after replay
288                            let _ = execute!(io::stdout(), Show);
289                        }
290                        DaemonMessage::ChildExited { code } => {
291                            exit_raw(raw_guard);
292                            match code {
293                                Some(c) => status(&format!("process exited with code {}", c)),
294                                None => status("process terminated by signal"),
295                            }
296                            return Ok(());
297                        }
298                        DaemonMessage::Error(e) => {
299                            exit_raw(raw_guard);
300                            status(&format!("daemon error: {}", e));
301                            return Ok(());
302                        }
303                        DaemonMessage::Attached => {
304                            // Already handled
305                        }
306                    }
307                }
308                Ok(None) => break,
309                Err(e) => {
310                    protocol_error = Some(e.to_string());
311                    break;
312                }
313            }
314        }
315
316        // Now that we've processed all buffered messages, handle EOF/hangup
317        if let Some(e) = protocol_error.take() {
318            exit_raw(raw_guard);
319            status(&format!("protocol error: {}", e));
320            return Ok(());
321        }
322        if socket_eof {
323            exit_raw(raw_guard);
324            status("session ended");
325            break;
326        }
327    }
328
329    Ok(())
330}
331
332/// Start a new session and immediately attach to it
333pub fn run_and_attach(name: &str, command: &[String]) -> Result<()> {
334    // Start the daemon
335    crate::daemon::start_daemon(name.to_string(), command.to_vec())?;
336
337    // Poll for the daemon to register itself + bind its socket. Replaces a
338    // fixed 100ms sleep that could race on slower machines.
339    let deadline = std::time::Instant::now() + Duration::from_secs(2);
340    let session = loop {
341        if let Some(s) = session::load_session(name)? {
342            if std::path::Path::new(&s.socket_path).exists() {
343                break s;
344            }
345        }
346        if std::time::Instant::now() >= deadline {
347            anyhow::bail!("session daemon failed to start within 2s");
348        }
349        std::thread::sleep(Duration::from_millis(20));
350    };
351
352    // Don't print a banner here — `attach()` clears the screen before printing
353    // its own welcome banner, so anything we print would be wallpapered over.
354    attach(&session)
355}