Skip to main content

colab_cli/
shell.rs

1use std::future::Future;
2use std::io::{self, Read, Write};
3use std::pin::Pin;
4use std::sync::Arc;
5
6use crossterm::terminal;
7use futures_util::{SinkExt, StreamExt};
8use tokio_tungstenite::tungstenite;
9
10use crate::client::ColabClient;
11use crate::error::{ColabError, Result};
12use crate::server::storage::StoredServer;
13
14// async refresher used by long-running shells to rotate the proxy token.
15// returns the new StoredServer so reconnect can pick up the rotated value.
16pub type TokenRefresher =
17    Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<StoredServer>> + Send>> + Send + Sync>;
18
19pub async fn run_shell(
20    client: &ColabClient,
21    server: &StoredServer,
22    initial_command: Option<&str>,
23    refresher: Option<TokenRefresher>,
24) -> Result<()> {
25    let term = client
26        .create_terminal(&server.proxy_url, &server.proxy_token)
27        .await?;
28
29    let ws_url = client.terminal_ws_url(&server.proxy_url, &term.name);
30    let request = build_ws_request(&ws_url, &server.proxy_token)?;
31
32    let (ws_stream, _) = tokio_tungstenite::connect_async(request)
33        .await
34        .map_err(|e| ColabError::oauth(format!("WebSocket connect failed: {e}")))?;
35
36    let (mut ws_write, mut ws_read) = ws_stream.split();
37
38    if let Ok((cols, rows)) = terminal::size() {
39        let size_msg = serde_json::json!(["set_size", rows, cols]).to_string();
40        let _ = ws_write
41            .send(tungstenite::Message::Text(size_msg.into()))
42            .await;
43    }
44
45    // PS1 = "<name> /path #". clear wipes the default prompt that flashed first.
46    let label_esc = server.label.replace('\'', "'\\''");
47    let prompt_cmd = format!("export PS1='\\[\\e[36m\\]{label_esc}\\[\\e[0m\\] \\w # ' && clear\n");
48    let _ = ws_write
49        .send(tungstenite::Message::Text(
50            serde_json::json!(["stdin", prompt_cmd]).to_string().into(),
51        ))
52        .await;
53
54    if let Some(cmd) = initial_command {
55        let msg = serde_json::json!(["stdin", format!("{cmd}\n")]).to_string();
56        let _ = ws_write.send(tungstenite::Message::Text(msg.into())).await;
57    }
58
59    // ping every 4min to keep the runtime warm and rotate the proxy token.
60    // the open ws stays pinned to its original token, but any reconnect or
61    // sibling http call via the proxy would 401 once the token expired.
62    let keepalive_client = client.clone();
63    let keepalive_endpoint = server.endpoint.clone();
64    let keepalive_refresher = refresher.clone();
65    let keepalive_handle = tokio::spawn(async move {
66        let mut interval = tokio::time::interval(std::time::Duration::from_secs(4 * 60));
67        interval.tick().await;
68        loop {
69            interval.tick().await;
70            if let Some(refresher) = keepalive_refresher.as_ref() {
71                // we don't need the new StoredServer here — the side-effect
72                // (rotated token in storage) is what matters
73                let _ = (refresher)().await;
74            }
75            let _ = keepalive_client.send_keep_alive(&keepalive_endpoint).await;
76        }
77    });
78    let _keepalive_guard = AbortOnDrop(keepalive_handle);
79
80    terminal::enable_raw_mode().map_err(|e| ColabError::config(format!("raw mode: {e}")))?;
81    let _raw_guard = RawModeGuard;
82
83    let (stdin_tx, mut stdin_rx) = tokio::sync::mpsc::channel::<Vec<u8>>(64);
84
85    std::thread::spawn(move || {
86        let stdin = io::stdin();
87        let mut handle = stdin.lock();
88        let mut buf = [0u8; 4096];
89        loop {
90            match handle.read(&mut buf) {
91                Ok(0) => break,
92                Ok(n) => {
93                    if stdin_tx.blocking_send(buf[..n].to_vec()).is_err() {
94                        break;
95                    }
96                }
97                Err(_) => break,
98            }
99        }
100    });
101
102    loop {
103        tokio::select! {
104            msg = ws_read.next() => {
105                match msg {
106                    Some(Ok(tungstenite::Message::Text(text))) => {
107                        if let Some(data) = parse_stdout_frame(text.as_ref()) {
108                            let mut stdout = io::stdout().lock();
109                            let _ = stdout.write_all(data.as_bytes());
110                            let _ = stdout.flush();
111                        }
112                    }
113                    Some(Ok(tungstenite::Message::Binary(data))) => {
114                        let mut stdout = io::stdout().lock();
115                        let _ = stdout.write_all(&data);
116                        let _ = stdout.flush();
117                    }
118                    Some(Ok(tungstenite::Message::Close(_))) | None => break,
119                    Some(Err(_)) => break,
120                    _ => {}
121                }
122            }
123            data = stdin_rx.recv() => {
124                match data {
125                    Some(bytes) => {
126                        let text = String::from_utf8_lossy(&bytes);
127                        let msg = serde_json::json!(["stdin", text]).to_string();
128                        if ws_write
129                            .send(tungstenite::Message::Text(msg.into()))
130                            .await
131                            .is_err()
132                        {
133                            break;
134                        }
135                    }
136                    None => break,
137                }
138            }
139        }
140    }
141
142    Ok(())
143}
144
145struct AbortOnDrop(tokio::task::JoinHandle<()>);
146
147impl Drop for AbortOnDrop {
148    fn drop(&mut self) {
149        self.0.abort();
150    }
151}
152
153pub async fn capture_remote_command(
154    client: &ColabClient,
155    server: &StoredServer,
156    command: &str,
157) -> Result<String> {
158    let term = client
159        .create_terminal(&server.proxy_url, &server.proxy_token)
160        .await?;
161
162    let ws_url = client.terminal_ws_url(&server.proxy_url, &term.name);
163    let request = build_ws_request(&ws_url, &server.proxy_token)?;
164
165    let (ws_stream, _) = tokio_tungstenite::connect_async(request)
166        .await
167        .map_err(|e| ColabError::oauth(format!("WebSocket connect failed: {e}")))?;
168
169    let (mut ws_write, mut ws_read) = ws_stream.split();
170
171    let start_marker = format!("__colab_start_{}__", uuid::Uuid::new_v4().simple());
172    let end_marker = format!("__colab_end_{}__", uuid::Uuid::new_v4().simple());
173    let full_cmd = format!("printf '{start_marker}\\n'; {command}; printf '\\n{end_marker}\\n'\n");
174    ws_write
175        .send(tungstenite::Message::Text(
176            serde_json::json!(["stdin", full_cmd]).to_string().into(),
177        ))
178        .await
179        .map_err(|e| ColabError::oauth(format!("WebSocket send: {e}")))?;
180
181    let mut buf = String::new();
182    let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(30);
183    loop {
184        let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
185        if remaining.is_zero() {
186            break;
187        }
188        match tokio::time::timeout(remaining, ws_read.next()).await {
189            Ok(Some(Ok(tungstenite::Message::Text(text)))) => {
190                if let Some(data) = parse_stdout_frame(text.as_ref()) {
191                    buf.push_str(&data);
192                    if buf.contains(&end_marker) {
193                        break;
194                    }
195                }
196            }
197            Ok(Some(Ok(tungstenite::Message::Close(_)))) | Ok(None) => break,
198            Err(_) => break,
199            _ => continue,
200        }
201    }
202
203    let start = buf
204        .find(&start_marker)
205        .map(|i| i + start_marker.len())
206        .unwrap_or(0);
207    let end = buf.find(&end_marker).unwrap_or(buf.len());
208    Ok(buf[start..end].trim().to_string())
209}
210
211// long-lived remote process. each ws frame goes to on_chunk; returns when
212// cancel fires, the remote closes, or on_chunk returns false.
213pub async fn stream_remote_output<F>(
214    client: &ColabClient,
215    server: &StoredServer,
216    command: &str,
217    mut on_chunk: F,
218    cancel: impl std::future::Future<Output = ()>,
219) -> Result<()>
220where
221    F: FnMut(&str) -> bool,
222{
223    let term = client
224        .create_terminal(&server.proxy_url, &server.proxy_token)
225        .await?;
226
227    let ws_url = client.terminal_ws_url(&server.proxy_url, &term.name);
228    let request = build_ws_request(&ws_url, &server.proxy_token)?;
229
230    let (ws_stream, _) = tokio_tungstenite::connect_async(request)
231        .await
232        .map_err(|e| ColabError::oauth(format!("WebSocket connect failed: {e}")))?;
233
234    let (mut ws_write, mut ws_read) = ws_stream.split();
235
236    let msg = serde_json::json!(["stdin", format!("{command}\n")]).to_string();
237    ws_write
238        .send(tungstenite::Message::Text(msg.into()))
239        .await
240        .map_err(|e| ColabError::oauth(format!("WebSocket send: {e}")))?;
241
242    tokio::pin!(cancel);
243
244    loop {
245        tokio::select! {
246            _ = &mut cancel => {
247                let interrupt = serde_json::json!(["stdin", "\x03"]).to_string();
248                let _ = ws_write.send(tungstenite::Message::Text(interrupt.into())).await;
249                return Ok(());
250            }
251            msg = ws_read.next() => {
252                match msg {
253                    Some(Ok(tungstenite::Message::Text(text))) => {
254                        if let Some(data) = parse_stdout_frame(text.as_ref())
255                            && !on_chunk(&data)
256                        {
257                            return Ok(());
258                        }
259                    }
260                    Some(Ok(tungstenite::Message::Close(_))) | None => return Ok(()),
261                    Some(Err(e)) => return Err(ColabError::oauth(format!("ws: {e}"))),
262                    _ => {}
263                }
264            }
265        }
266    }
267}
268
269// full-screen remote TUI (bpytop/btop/htop) in alt screen + raw mode.
270// reconnects up to 3 times on a transient ws drop, then gives up.
271pub async fn run_remote_tui(
272    client: &ColabClient,
273    server: &StoredServer,
274    command: &str,
275) -> Result<()> {
276    use crossterm::{cursor, execute, terminal as ct_term};
277
278    let term = client
279        .create_terminal(&server.proxy_url, &server.proxy_token)
280        .await?;
281    let terminal_name = term.name.clone();
282
283    // drop guard so we always reap the remote terminal, even on early return
284    let cleanup_client = client.clone();
285    let cleanup_proxy_url = server.proxy_url.clone();
286    let cleanup_proxy_token = server.proxy_token.clone();
287    let cleanup_name = terminal_name.clone();
288    let _cleanup_guard = TerminalCleanupGuard::new(move || {
289        // fire-and-forget on whatever runtime owns this Drop
290        if let Ok(handle) = tokio::runtime::Handle::try_current() {
291            handle.spawn(async move {
292                let _ = cleanup_client
293                    .delete_terminal(&cleanup_proxy_url, &cleanup_proxy_token, &cleanup_name)
294                    .await;
295            });
296        }
297    });
298
299    // alt screen + raw mode BEFORE first ws connect, otherwise we flicker
300    {
301        let mut out = io::stdout();
302        execute!(out, ct_term::EnterAlternateScreen, cursor::Hide)
303            .map_err(|e| ColabError::config(format!("alt screen: {e}")))?;
304    }
305    struct AltScreenGuard;
306    impl Drop for AltScreenGuard {
307        fn drop(&mut self) {
308            let mut out = io::stdout();
309            let _ = execute!(out, cursor::Show, crossterm::terminal::LeaveAlternateScreen);
310            let _ = out.flush();
311        }
312    }
313    let _alt_guard = AltScreenGuard;
314
315    terminal::enable_raw_mode().map_err(|e| ColabError::config(format!("raw mode: {e}")))?;
316    let _raw_guard = RawModeGuard;
317
318    // shared channel: stdin reader + resize watcher → async ws writer
319    #[derive(Debug)]
320    enum WsOut {
321        Stdin(Vec<u8>),
322        Resize(u16, u16),
323    }
324    let (out_tx, mut out_rx) = tokio::sync::mpsc::channel::<WsOut>(128);
325
326    // raw stdin reader on a blocking thread — keystroke latency lives here
327    let stdin_tx = out_tx.clone();
328    std::thread::spawn(move || {
329        let stdin = io::stdin();
330        let mut handle = stdin.lock();
331        let mut buf = [0u8; 4096];
332        loop {
333            match handle.read(&mut buf) {
334                Ok(0) => break,
335                Ok(n) => {
336                    if stdin_tx
337                        .blocking_send(WsOut::Stdin(buf[..n].to_vec()))
338                        .is_err()
339                    {
340                        break;
341                    }
342                }
343                Err(_) => break,
344            }
345        }
346    });
347
348    // poll terminal size every 250ms; can't use event::read — clashes with
349    // the raw stdin thread on the same fd
350    let resize_tx = out_tx.clone();
351    let resize_handle = tokio::spawn(async move {
352        let mut last = terminal::size().unwrap_or((80, 24));
353        let mut tick = tokio::time::interval(std::time::Duration::from_millis(250));
354        tick.tick().await;
355        loop {
356            tick.tick().await;
357            let cur = terminal::size().unwrap_or(last);
358            if cur != last && resize_tx.send(WsOut::Resize(cur.1, cur.0)).await.is_err() {
359                return;
360            }
361            last = cur;
362        }
363    });
364    let _resize_guard = AbortOnDrop(resize_handle);
365
366    // reconnect loop. clean close → Ok. drop → reattach (3 retries).
367    let mut initial_command: Option<String> = Some(command.to_string());
368    let mut retries_left: u32 = 3;
369
370    loop {
371        let ws_url = client.terminal_ws_url(&server.proxy_url, &terminal_name);
372        let request = build_ws_request(&ws_url, &server.proxy_token)?;
373
374        let connect_result = tokio_tungstenite::connect_async(request).await;
375        let ws_stream = match connect_result {
376            Ok((s, _)) => s,
377            Err(_) if retries_left > 0 => {
378                retries_left -= 1;
379                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
380                continue;
381            }
382            Err(e) => {
383                return Err(ColabError::oauth(format!("WebSocket connect failed: {e}")));
384            }
385        };
386        let (mut ws_write, mut ws_read) = ws_stream.split();
387
388        // send size first so bpytop doesn't start at 80x24 and redraw
389        if let Ok((cols, rows)) = terminal::size() {
390            let size_msg = serde_json::json!(["set_size", rows, cols]).to_string();
391            let _ = ws_write
392                .send(tungstenite::Message::Text(size_msg.into()))
393                .await;
394        }
395
396        // only send the command on first connect; reattach just watches
397        if let Some(cmd) = initial_command.take() {
398            let msg = serde_json::json!(["stdin", format!("{cmd}\n")]).to_string();
399            let _ = ws_write.send(tungstenite::Message::Text(msg.into())).await;
400        }
401
402        let inner = async {
403            loop {
404                tokio::select! {
405                    msg = ws_read.next() => {
406                        match msg {
407                            Some(Ok(tungstenite::Message::Text(text))) => {
408                                if let Some(data) = parse_stdout_frame(text.as_ref()) {
409                                    let mut stdout = io::stdout().lock();
410                                    let _ = stdout.write_all(data.as_bytes());
411                                    let _ = stdout.flush();
412                                }
413                            }
414                            Some(Ok(tungstenite::Message::Binary(bin))) => {
415                                let mut stdout = io::stdout().lock();
416                                let _ = stdout.write_all(&bin);
417                                let _ = stdout.flush();
418                            }
419                            Some(Ok(tungstenite::Message::Close(_))) | None => {
420                                return InnerExit::Closed;
421                            }
422                            Some(Err(_)) => return InnerExit::Dropped,
423                            _ => {}
424                        }
425                    }
426                    out = out_rx.recv() => {
427                        let Some(msg) = out else {
428                            return InnerExit::Closed;
429                        };
430                        let serialized = match msg {
431                            WsOut::Stdin(bytes) => {
432                                let text = String::from_utf8_lossy(&bytes).into_owned();
433                                serde_json::json!(["stdin", text]).to_string()
434                            }
435                            WsOut::Resize(rows, cols) => {
436                                serde_json::json!(["set_size", rows, cols]).to_string()
437                            }
438                        };
439                        if ws_write
440                            .send(tungstenite::Message::Text(serialized.into()))
441                            .await
442                            .is_err()
443                        {
444                            return InnerExit::Dropped;
445                        }
446                    }
447                }
448            }
449        };
450
451        match inner.await {
452            InnerExit::Closed => return Ok(()),
453            InnerExit::Dropped if retries_left > 0 => {
454                retries_left -= 1;
455                // tiny reconnect banner; bpytop's next frame wipes it
456                {
457                    let mut out = io::stdout();
458                    let _ = execute!(
459                        out,
460                        cursor::MoveTo(0, 0),
461                        crossterm::style::Print("  reconnecting\u{2026}  "),
462                    );
463                    let _ = out.flush();
464                }
465                tokio::time::sleep(std::time::Duration::from_millis(400)).await;
466                continue;
467            }
468            InnerExit::Dropped => {
469                return Err(ColabError::oauth(
470                    "WebSocket dropped and could not reattach",
471                ));
472            }
473        }
474    }
475}
476
477enum InnerExit {
478    Closed,
479    Dropped,
480}
481
482struct TerminalCleanupGuard<F: FnOnce()> {
483    cleanup: Option<F>,
484}
485
486impl<F: FnOnce()> TerminalCleanupGuard<F> {
487    fn new(cleanup: F) -> Self {
488        Self {
489            cleanup: Some(cleanup),
490        }
491    }
492}
493
494impl<F: FnOnce()> Drop for TerminalCleanupGuard<F> {
495    fn drop(&mut self) {
496        if let Some(f) = self.cleanup.take() {
497            f();
498        }
499    }
500}
501
502fn build_ws_request(ws_url: &str, proxy_token: &str) -> Result<tungstenite::http::Request<()>> {
503    tungstenite::http::Request::builder()
504        .uri(ws_url)
505        .header("X-Colab-Runtime-Proxy-Token", proxy_token)
506        .header("X-Colab-Client-Agent", "vscode")
507        .header("Host", host_from_url(ws_url))
508        .header("Connection", "Upgrade")
509        .header("Upgrade", "websocket")
510        .header("Sec-WebSocket-Version", "13")
511        .header(
512            "Sec-WebSocket-Key",
513            tungstenite::handshake::client::generate_key(),
514        )
515        .body(())
516        .map_err(|e| ColabError::oauth(format!("failed to build WS request: {e}")))
517}
518
519fn host_from_url(url: &str) -> String {
520    url.replace("wss://", "")
521        .replace("ws://", "")
522        .split('/')
523        .next()
524        .unwrap_or("")
525        .to_string()
526}
527
528// run argv on the remote, stream stdout/stderr through, return its exit code.
529// uses printf-marker tricks to skip shell prompt + command echo without
530// turning echo off; see run_passthrough_inner.
531pub async fn run_passthrough(
532    client: &ColabClient,
533    server: &StoredServer,
534    argv: &[String],
535) -> Result<i32> {
536    let term = client
537        .create_terminal(&server.proxy_url, &server.proxy_token)
538        .await?;
539    let terminal_name = term.name.clone();
540
541    let result = run_passthrough_inner(client, server, &terminal_name, argv).await;
542
543    // always reap the remote terminal, even on error
544    let _ = client
545        .delete_terminal(&server.proxy_url, &server.proxy_token, &terminal_name)
546        .await;
547
548    result
549}
550
551async fn run_passthrough_inner(
552    client: &ColabClient,
553    server: &StoredServer,
554    terminal_name: &str,
555    argv: &[String],
556) -> Result<i32> {
557    let ws_url = client.terminal_ws_url(&server.proxy_url, terminal_name);
558    let request = build_ws_request(&ws_url, &server.proxy_token)?;
559    let (ws_stream, _) = tokio_tungstenite::connect_async(request)
560        .await
561        .map_err(|e| ColabError::oauth(format!("WebSocket connect failed: {e}")))?;
562    let (mut ws_write, mut ws_read) = ws_stream.split();
563
564    if let Ok((cols, rows)) = terminal::size() {
565        let size_msg = serde_json::json!(["set_size", rows, cols]).to_string();
566        let _ = ws_write
567            .send(tungstenite::Message::Text(size_msg.into()))
568            .await;
569    }
570
571    let id = uuid::Uuid::new_v4().simple().to_string();
572    // marker = 0x01 0x02 colab_<phase>_<uuid> 0x03 0x04. unlikely to collide
573    // with user output, and the literal `\001\002...` chars in the wrapper
574    // command (what shows up in the PTY echo) decode to different bytes.
575    let start_marker: Vec<u8> = {
576        let mut v = vec![0x01u8, 0x02];
577        v.extend_from_slice(format!("colab_start_{id}").as_bytes());
578        v.extend_from_slice(&[0x03, 0x04]);
579        v
580    };
581    let end_marker: Vec<u8> = {
582        let mut v = vec![0x01u8, 0x02];
583        v.extend_from_slice(format!("colab_end_{id}").as_bytes());
584        v.extend_from_slice(&[0x03, 0x04]);
585        v
586    };
587
588    let user_cmd = argv
589        .iter()
590        .map(|s| shell_quote(s))
591        .collect::<Vec<_>>()
592        .join(" ");
593
594    // braces isolate $? for the user command. stderr→stdout because the
595    // jupyter terminal only gives us one fd back.
596    let wrapped = format!(
597        "printf '\\001\\002colab_start_{id}\\003\\004\\n'; \
598         {{ {user_cmd}; }} 2>&1; __colab_ec=$?; \
599         printf '\\001\\002colab_end_{id}\\003\\004%d\\n' \"$__colab_ec\"\n"
600    );
601
602    ws_write
603        .send(tungstenite::Message::Text(
604            serde_json::json!(["stdin", wrapped]).to_string().into(),
605        ))
606        .await
607        .map_err(|e| ColabError::oauth(format!("WebSocket send: {e}")))?;
608
609    enum Phase {
610        Pre,
611        Mid,
612        Done,
613    }
614    let mut phase = Phase::Pre;
615    let mut buf: Vec<u8> = Vec::new();
616    let mut tail_after_end: Vec<u8> = Vec::new();
617    let mut exit_code: i32 = 0;
618
619    'outer: loop {
620        let msg = match ws_read.next().await {
621            Some(m) => m,
622            None => break,
623        };
624        let text = match msg {
625            Ok(tungstenite::Message::Text(t)) => t,
626            Ok(tungstenite::Message::Close(_)) => break,
627            Err(_) => break,
628            _ => continue,
629        };
630        let Some(chunk) = parse_stdout_frame(text.as_ref()) else {
631            continue;
632        };
633        let chunk_bytes = chunk.as_bytes();
634
635        // after END, everything is exit-code digits — skip the scanner
636        if matches!(phase, Phase::Done) {
637            tail_after_end.extend_from_slice(chunk_bytes);
638            if parse_exit_code(&tail_after_end).is_some() {
639                break;
640            }
641            continue;
642        }
643
644        buf.extend_from_slice(chunk_bytes);
645
646        loop {
647            match phase {
648                Phase::Pre => {
649                    if let Some(idx) = find_subseq(&buf, &start_marker) {
650                        let after = idx + start_marker.len();
651                        let after = skip_one_newline(&buf, after);
652                        buf.drain(..after);
653                        phase = Phase::Mid;
654                        continue;
655                    }
656                    // hold onto marker_len-1 bytes in case it straddles chunks
657                    let keep = start_marker.len().saturating_sub(1);
658                    if buf.len() > keep {
659                        buf.drain(..buf.len() - keep);
660                    }
661                    break;
662                }
663                Phase::Mid => {
664                    if let Some(idx) = find_subseq(&buf, &end_marker) {
665                        if idx > 0 {
666                            let mut stdout = io::stdout().lock();
667                            let _ = stdout.write_all(&buf[..idx]);
668                            let _ = stdout.flush();
669                        }
670                        let after = idx + end_marker.len();
671                        tail_after_end.extend_from_slice(&buf[after..]);
672                        buf.clear();
673                        phase = Phase::Done;
674                        continue;
675                    }
676                    // flush all but the last marker_len-1 bytes (might be a partial END)
677                    let keep = end_marker.len().saturating_sub(1);
678                    if buf.len() > keep {
679                        let flush_to = buf.len() - keep;
680                        let mut stdout = io::stdout().lock();
681                        let _ = stdout.write_all(&buf[..flush_to]);
682                        let _ = stdout.flush();
683                        buf.drain(..flush_to);
684                    }
685                    break;
686                }
687                Phase::Done => break,
688            }
689        }
690
691        if matches!(phase, Phase::Done) && parse_exit_code(&tail_after_end).is_some() {
692            break 'outer;
693        }
694    }
695
696    // ws closed mid-stream — flush whatever's left so we don't drop output
697    if matches!(phase, Phase::Mid) && !buf.is_empty() {
698        let mut stdout = io::stdout().lock();
699        let _ = stdout.write_all(&buf);
700        let _ = stdout.flush();
701    }
702
703    if let Some(code) = parse_exit_code(&tail_after_end) {
704        exit_code = code;
705    }
706
707    Ok(exit_code)
708}
709
710// byte-for-byte substring search
711fn find_subseq(haystack: &[u8], needle: &[u8]) -> Option<usize> {
712    if needle.is_empty() || haystack.len() < needle.len() {
713        return None;
714    }
715    haystack.windows(needle.len()).position(|w| w == needle)
716}
717
718// eat one \n or \r\n after the START marker so it doesn't show up in output
719fn skip_one_newline(buf: &[u8], idx: usize) -> usize {
720    if idx >= buf.len() {
721        return idx;
722    }
723    if buf[idx] == b'\r' {
724        if idx + 1 < buf.len() && buf[idx + 1] == b'\n' {
725            return idx + 2;
726        }
727        return idx + 1;
728    }
729    if buf[idx] == b'\n' {
730        return idx + 1;
731    }
732    idx
733}
734
735// parse the trailing exit code; needs digit + terminator before returning
736fn parse_exit_code(buf: &[u8]) -> Option<i32> {
737    let mut s = String::new();
738    let mut started = false;
739    for &b in buf {
740        if b.is_ascii_digit() {
741            s.push(b as char);
742            started = true;
743        } else if started {
744            return s.parse::<i32>().ok();
745        } else if b == b'\r' || b == b'\n' || b == b' ' {
746            continue;
747        } else {
748            return None;
749        }
750    }
751    None
752}
753
754// POSIX single-quote for safe embedding in `sh -c`
755pub fn shell_quote(s: &str) -> String {
756    let mut out = String::with_capacity(s.len() + 2);
757    out.push('\'');
758    for c in s.chars() {
759        if c == '\'' {
760            out.push_str("'\"'\"'");
761        } else {
762            out.push(c);
763        }
764    }
765    out.push('\'');
766    out
767}
768
769fn parse_stdout_frame(text: &str) -> Option<String> {
770    let arr: Vec<serde_json::Value> = serde_json::from_str(text).ok()?;
771    if arr.len() >= 2 && arr[0].as_str() == Some("stdout") {
772        arr[1].as_str().map(|s| s.to_string())
773    } else {
774        None
775    }
776}
777
778struct RawModeGuard;
779
780impl Drop for RawModeGuard {
781    fn drop(&mut self) {
782        let _ = terminal::disable_raw_mode();
783    }
784}
785
786#[cfg(test)]
787mod tests {
788    use super::*;
789
790    #[test]
791    fn host_from_url_strips_scheme_and_path() {
792        assert_eq!(
793            host_from_url("wss://abc.proxy.googleusercontent.com/terminals/websocket/1"),
794            "abc.proxy.googleusercontent.com"
795        );
796        assert_eq!(host_from_url("ws://localhost:9000/foo"), "localhost:9000");
797    }
798
799    #[test]
800    fn host_from_url_no_path() {
801        assert_eq!(host_from_url("wss://example.com"), "example.com");
802    }
803
804    #[test]
805    fn shell_quote_plain() {
806        assert_eq!(shell_quote("/content/drive"), "'/content/drive'");
807    }
808
809    #[test]
810    fn shell_quote_with_embedded_single_quote() {
811        assert_eq!(shell_quote("it's/here"), "'it'\"'\"'s/here'");
812    }
813
814    #[test]
815    fn shell_quote_empty() {
816        assert_eq!(shell_quote(""), "''");
817    }
818
819    #[test]
820    fn find_subseq_basic() {
821        assert_eq!(find_subseq(b"hello world", b"world"), Some(6));
822        assert_eq!(find_subseq(b"hello world", b"xyz"), None);
823        assert_eq!(find_subseq(b"", b"x"), None);
824        assert_eq!(find_subseq(b"abc", b""), None);
825    }
826
827    #[test]
828    fn skip_one_newline_handles_lf_and_crlf() {
829        assert_eq!(skip_one_newline(b"\nrest", 0), 1);
830        assert_eq!(skip_one_newline(b"\r\nrest", 0), 2);
831        assert_eq!(skip_one_newline(b"\rrest", 0), 1);
832        assert_eq!(skip_one_newline(b"rest", 0), 0);
833    }
834
835    #[test]
836    fn parse_exit_code_simple() {
837        assert_eq!(parse_exit_code(b"0\n"), Some(0));
838        assert_eq!(parse_exit_code(b"1\n"), Some(1));
839        assert_eq!(parse_exit_code(b"127\n"), Some(127));
840    }
841
842    #[test]
843    fn parse_exit_code_with_whitespace_prefix() {
844        assert_eq!(parse_exit_code(b"\r\n42\n"), Some(42));
845        assert_eq!(parse_exit_code(b"  3 "), Some(3));
846    }
847
848    #[test]
849    fn parse_exit_code_incomplete_returns_none() {
850        // Digits with no terminator yet — the streamer needs more bytes.
851        assert_eq!(parse_exit_code(b"12"), None);
852        assert_eq!(parse_exit_code(b""), None);
853        assert_eq!(parse_exit_code(b"\r\n"), None);
854    }
855
856    #[test]
857    fn parse_exit_code_garbage() {
858        assert_eq!(parse_exit_code(b"abc"), None);
859    }
860}