Skip to main content

heyo_sdk/
shell.rs

1//! Persistent shell session over a WebSocket. Mirrors the protocol documented
2//! in `sdk-ts/src/shell.ts` and implemented by
3//! `mvm-ctrl/src/api.rs::shell_stream_ws`.
4//!
5//! Wire protocol:
6//!
7//! - Client → server, first frame: JSON
8//!   `{type:"init", cols, rows, env?, cwd?, sessionId?}`
9//! - Server → client: JSON `{type:"ready", sessionId, lastSeq?}`
10//! - stdin: binary `[0x01, ...bytes]`
11//! - stdout: binary `[0x02, seq:u64-be, ...bytes]` (PTY merges stdout/stderr)
12//! - resize: JSON `{type:"resize", cols, rows}`
13//! - ack: JSON `{type:"ack", seq}`
14//! - close: JSON `{type:"close"}`
15//! - exit: JSON `{type:"exit", code}`
16//! - error: JSON `{type:"error", code?, message}`
17
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::Duration;
21
22use futures_util::stream::Stream;
23use futures_util::{SinkExt, StreamExt};
24use http::Request;
25use serde::Deserialize;
26use serde_json::json;
27use tokio::sync::{mpsc, oneshot, watch, Mutex};
28use tokio::time::sleep;
29use tokio_tungstenite::tungstenite::Message;
30
31use crate::client::HeyoClient;
32use crate::commands::encode_path;
33use crate::errors::HeyoError;
34
35const FRAME_STDIN: u8 = 0x01;
36const FRAME_STDOUT: u8 = 0x02;
37const ACK_INTERVAL: Duration = Duration::from_millis(100);
38const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60);
39
40/// Reconnect tuning.
41#[derive(Debug, Clone)]
42pub struct ShellReconnectOptions {
43    pub max_retries: u32,
44    pub base_delay: Duration,
45    pub max_delay: Duration,
46}
47impl Default for ShellReconnectOptions {
48    fn default() -> Self {
49        Self {
50            max_retries: 5,
51            base_delay: Duration::from_millis(100),
52            max_delay: Duration::from_secs(30),
53        }
54    }
55}
56
57/// Options for [`crate::Sandbox::shell`].
58#[derive(Debug, Clone)]
59pub struct ShellOptions {
60    pub cwd: Option<String>,
61    pub env: Option<HashMap<String, String>>,
62    pub cols: u16,
63    pub rows: u16,
64    /// `None` disables auto-reconnect; `Some(_)` uses the supplied tuning.
65    pub reconnect: Option<ShellReconnectOptions>,
66}
67
68impl Default for ShellOptions {
69    fn default() -> Self {
70        Self {
71            cwd: None,
72            env: None,
73            cols: 80,
74            rows: 24,
75            reconnect: Some(ShellReconnectOptions::default()),
76        }
77    }
78}
79
80/// Lifecycle events emitted on the session's event stream.
81#[derive(Debug, Clone)]
82pub enum ShellEvent {
83    Reconnecting { attempt: u32, delay: Duration },
84    Reconnected,
85    Closed { exit_code: Option<i32> },
86    Error(String),
87}
88
89/// A live shell session. Use `output()` for stdout chunks, `events()` for
90/// lifecycle events, and `write()` / `resize()` / `close()` to drive it.
91pub struct ShellSession {
92    inner: Arc<SessionInner>,
93    output_rx: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
94    events_rx: Arc<Mutex<mpsc::Receiver<ShellEvent>>>,
95}
96
97struct SessionInner {
98    write_tx: mpsc::UnboundedSender<OutboundMessage>,
99    session_id: Mutex<Option<String>>,
100    closed_tx: watch::Sender<bool>,
101    closed_rx: watch::Receiver<bool>,
102    exit_code: Mutex<Option<i32>>,
103}
104
105enum OutboundMessage {
106    Stdin(Vec<u8>),
107    Json(serde_json::Value),
108    Close,
109}
110
111#[derive(Deserialize)]
112#[serde(tag = "type", rename_all = "lowercase")]
113enum ServerControl {
114    Ready {
115        #[serde(rename = "sessionId")]
116        session_id: String,
117        #[serde(default, rename = "lastSeq")]
118        #[allow(dead_code)]
119        last_seq: Option<u64>,
120    },
121    Exit {
122        code: i32,
123    },
124    Error {
125        #[serde(default)]
126        code: Option<String>,
127        #[serde(default)]
128        message: Option<String>,
129    },
130}
131
132impl ShellSession {
133    /// Open and wait for the first `ready` frame.
134    pub(crate) async fn open(
135        client: HeyoClient,
136        sandbox_id: String,
137        options: ShellOptions,
138    ) -> Result<Self, HeyoError> {
139        let path = format!("/deployed-sandboxes/{}/shell-stream", encode_path(&sandbox_id));
140        let url = client.ws_url(&path)?;
141        let auth = client.ws_authorization();
142        let (output_tx, output_rx) = mpsc::channel::<Vec<u8>>(256);
143        let (events_tx, events_rx) = mpsc::channel::<ShellEvent>(64);
144        let (write_tx, write_rx) = mpsc::unbounded_channel::<OutboundMessage>();
145        let (closed_tx, closed_rx) = watch::channel(false);
146        let (ready_tx, ready_rx) = oneshot::channel::<Result<String, HeyoError>>();
147        let inner = Arc::new(SessionInner {
148            write_tx,
149            session_id: Mutex::new(None),
150            closed_tx,
151            closed_rx,
152            exit_code: Mutex::new(None),
153        });
154
155        let reconnect = options.reconnect.clone();
156        let init_opts = options;
157        let url_clone = url.clone();
158        let auth_clone = auth.clone();
159        let inner_for_task = inner.clone();
160        let output_tx_clone = output_tx.clone();
161        let events_tx_clone = events_tx.clone();
162        tokio::spawn(run_session(
163            url_clone,
164            auth_clone,
165            init_opts,
166            reconnect,
167            inner_for_task,
168            write_rx,
169            output_tx_clone,
170            events_tx_clone,
171            Some(ready_tx),
172        ));
173        drop(output_tx);
174        drop(events_tx);
175
176        match ready_rx.await {
177            Ok(Ok(_session_id)) => Ok(ShellSession {
178                inner,
179                output_rx: Arc::new(Mutex::new(output_rx)),
180                events_rx: Arc::new(Mutex::new(events_rx)),
181            }),
182            Ok(Err(e)) => Err(e),
183            Err(_) => Err(HeyoError::Connection("shell session task dropped before ready".into())),
184        }
185    }
186
187    /// Latest server-assigned session id (set after `open()` resolves; may
188    /// change across reconnects only if the server reissues one — usually it
189    /// does not).
190    pub async fn session_id(&self) -> Option<String> {
191        self.inner.session_id.lock().await.clone()
192    }
193
194    pub async fn write(&self, bytes: &[u8]) -> Result<(), HeyoError> {
195        if *self.inner.closed_rx.borrow() {
196            return Err(HeyoError::Connection("session is closed".into()));
197        }
198        self.inner
199            .write_tx
200            .send(OutboundMessage::Stdin(bytes.to_vec()))
201            .map_err(|_| HeyoError::Connection("session writer is closed".into()))
202    }
203
204    pub async fn resize(&self, cols: u16, rows: u16) -> Result<(), HeyoError> {
205        if *self.inner.closed_rx.borrow() {
206            return Err(HeyoError::Connection("session is closed".into()));
207        }
208        self.inner
209            .write_tx
210            .send(OutboundMessage::Json(json!({
211                "type": "resize",
212                "cols": cols,
213                "rows": rows,
214            })))
215            .map_err(|_| HeyoError::Connection("session writer is closed".into()))
216    }
217
218    /// Send the close frame, then tear down the socket once the server
219    /// acknowledges (or after 2s).
220    pub async fn close(&self) -> Result<(), HeyoError> {
221        if *self.inner.closed_rx.borrow() {
222            return Ok(());
223        }
224        let _ = self.inner.write_tx.send(OutboundMessage::Close);
225        // Wait until either the closed flag flips, or we time out.
226        let mut rx = self.inner.closed_rx.clone();
227        let _ = tokio::time::timeout(Duration::from_secs(2), async {
228            while !*rx.borrow_and_update() {
229                if rx.changed().await.is_err() {
230                    break;
231                }
232            }
233        })
234        .await;
235        Ok(())
236    }
237
238    pub async fn exit_code(&self) -> Option<i32> {
239        *self.inner.exit_code.lock().await
240    }
241
242    pub fn is_closed(&self) -> bool {
243        *self.inner.closed_rx.borrow()
244    }
245
246    /// Stream of stdout chunks. Yields until the session closes.
247    pub fn output(&self) -> impl Stream<Item = Vec<u8>> + Send + Unpin {
248        let rx = self.output_rx.clone();
249        Box::pin(async_stream::stream! {
250            loop {
251                let mut guard = rx.lock().await;
252                match guard.recv().await {
253                    Some(chunk) => yield chunk,
254                    None => break,
255                }
256            }
257        })
258    }
259
260    /// Stream of lifecycle events. Yields until the session closes.
261    pub fn events(&self) -> impl Stream<Item = ShellEvent> + Send + Unpin {
262        let rx = self.events_rx.clone();
263        Box::pin(async_stream::stream! {
264            loop {
265                let mut guard = rx.lock().await;
266                match guard.recv().await {
267                    Some(event) => yield event,
268                    None => break,
269                }
270            }
271        })
272    }
273}
274
275#[allow(clippy::too_many_arguments)]
276async fn run_session(
277    url: String,
278    auth: String,
279    options: ShellOptions,
280    reconnect: Option<ShellReconnectOptions>,
281    inner: Arc<SessionInner>,
282    mut write_rx: mpsc::UnboundedReceiver<OutboundMessage>,
283    output_tx: mpsc::Sender<Vec<u8>>,
284    events_tx: mpsc::Sender<ShellEvent>,
285    mut ready_tx: Option<oneshot::Sender<Result<String, HeyoError>>>,
286) {
287    let mut attempt: u32 = 0;
288    let mut last_seq_received: u64 = 0;
289    let mut last_seq_acked: u64 = 0;
290    let cols = options.cols;
291    let rows = options.rows;
292
293    'outer: loop {
294        let session_id_opt = inner.session_id.lock().await.clone();
295        let reconnecting = session_id_opt.is_some();
296
297        // Build the WS upgrade request with bearer auth.
298        let request = match Request::builder()
299            .method("GET")
300            .uri(&url)
301            .header("Authorization", &auth)
302            .header("Sec-WebSocket-Version", "13")
303            .header("Sec-WebSocket-Key", tokio_tungstenite::tungstenite::handshake::client::generate_key())
304            .header("Connection", "Upgrade")
305            .header("Upgrade", "websocket")
306            .header("Host", host_from_url(&url))
307            .body(())
308        {
309            Ok(r) => r,
310            Err(e) => {
311                let err = HeyoError::Connection(format!("build ws request: {}", e));
312                if let Some(tx) = ready_tx.take() {
313                    let _ = tx.send(Err(err));
314                    return;
315                }
316                let _ = events_tx.send(ShellEvent::Error(format!("{}", e))).await;
317                break;
318            }
319        };
320
321        let connect_result = tokio_tungstenite::connect_async(request).await;
322        let (ws_stream, _) = match connect_result {
323            Ok(s) => s,
324            Err(e) => {
325                if reconnecting {
326                    if let Some(r) = reconnect.as_ref() {
327                        attempt += 1;
328                        if attempt > r.max_retries {
329                            let err = HeyoError::Connection(format!(
330                                "gave up after {} reconnect attempts: {}",
331                                r.max_retries, e
332                            ));
333                            let _ = events_tx.send(ShellEvent::Error(err.to_string())).await;
334                            break;
335                        }
336                        let delay = backoff(r, attempt);
337                        let _ = events_tx
338                            .send(ShellEvent::Reconnecting { attempt, delay })
339                            .await;
340                        sleep(delay).await;
341                        continue;
342                    }
343                }
344                let err = HeyoError::Connection(format!("ws connect: {}", e));
345                if let Some(tx) = ready_tx.take() {
346                    let _ = tx.send(Err(err));
347                    return;
348                }
349                let _ = events_tx.send(ShellEvent::Error(format!("{}", e))).await;
350                break;
351            }
352        };
353
354        let (mut ws_tx, mut ws_rx) = ws_stream.split();
355
356        // Send init frame.
357        let mut init = serde_json::Map::new();
358        init.insert("type".into(), json!("init"));
359        init.insert("cols".into(), json!(cols));
360        init.insert("rows".into(), json!(rows));
361        if let Some(env) = &options.env {
362            init.insert("env".into(), json!(env));
363        }
364        if let Some(cwd) = &options.cwd {
365            init.insert("cwd".into(), json!(cwd));
366        }
367        if let Some(sid) = &session_id_opt {
368            init.insert("sessionId".into(), json!(sid));
369        }
370        if let Err(e) = ws_tx
371            .send(Message::Text(serde_json::Value::Object(init).to_string()))
372            .await
373        {
374            let msg = format!("send init: {}", e);
375            if let Some(tx) = ready_tx.take() {
376                let _ = tx.send(Err(HeyoError::Connection(msg.clone())));
377                return;
378            }
379            let _ = events_tx.send(ShellEvent::Error(msg)).await;
380            continue;
381        }
382
383        let mut got_ready = false;
384        let mut ack_due = false;
385        let mut graceful_close = false;
386        // ack timer
387        let mut ack_timer = tokio::time::interval(ACK_INTERVAL);
388        ack_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
389        // heartbeat timer — bumped on every received frame
390        let mut heartbeat = Box::pin(sleep(HEARTBEAT_TIMEOUT));
391
392        loop {
393            tokio::select! {
394                _ = ack_timer.tick() => {
395                    if ack_due && last_seq_received > last_seq_acked {
396                        last_seq_acked = last_seq_received;
397                        if ws_tx
398                            .send(Message::Text(json!({
399                                "type": "ack",
400                                "seq": last_seq_acked,
401                            }).to_string()))
402                            .await
403                            .is_err()
404                        {
405                            // Socket dropped; the read side will surface it.
406                        }
407                        ack_due = false;
408                    }
409                }
410                _ = &mut heartbeat => {
411                    let _ = ws_tx.send(Message::Close(None)).await;
412                    break;
413                }
414                Some(msg) = write_rx.recv() => {
415                    match msg {
416                        OutboundMessage::Stdin(bytes) => {
417                            let mut frame = Vec::with_capacity(bytes.len() + 1);
418                            frame.push(FRAME_STDIN);
419                            frame.extend_from_slice(&bytes);
420                            if ws_tx.send(Message::Binary(frame)).await.is_err() {
421                                break;
422                            }
423                        }
424                        OutboundMessage::Json(v) => {
425                            if ws_tx.send(Message::Text(v.to_string())).await.is_err() {
426                                break;
427                            }
428                        }
429                        OutboundMessage::Close => {
430                            graceful_close = true;
431                            let _ = ws_tx
432                                .send(Message::Text(json!({"type":"close"}).to_string()))
433                                .await;
434                            // Give the server 2s grace to send `exit`.
435                            let _ = tokio::time::timeout(Duration::from_secs(2), async {
436                                while let Some(msg) = ws_rx.next().await {
437                                    match msg {
438                                        Ok(Message::Text(t)) => {
439                                            if handle_control(
440                                                &t,
441                                                &inner,
442                                                &output_tx,
443                                                &events_tx,
444                                                &mut ready_tx,
445                                                &mut got_ready,
446                                            )
447                                            .await
448                                            .is_break()
449                                            {
450                                                break;
451                                            }
452                                        }
453                                        _ => {}
454                                    }
455                                }
456                            })
457                            .await;
458                            break;
459                        }
460                    }
461                }
462                Some(msg) = ws_rx.next() => {
463                    heartbeat = Box::pin(sleep(HEARTBEAT_TIMEOUT));
464                    match msg {
465                        Ok(Message::Text(t)) => {
466                            if handle_control(
467                                &t,
468                                &inner,
469                                &output_tx,
470                                &events_tx,
471                                &mut ready_tx,
472                                &mut got_ready,
473                            )
474                            .await
475                            .is_break()
476                            {
477                                break;
478                            }
479                            if got_ready {
480                                attempt = 0;
481                            }
482                        }
483                        Ok(Message::Binary(bytes)) => {
484                            if let Some(seq) = handle_binary(
485                                &bytes,
486                                last_seq_received,
487                                &output_tx,
488                            )
489                            .await
490                            {
491                                last_seq_received = seq;
492                                ack_due = true;
493                            }
494                        }
495                        Ok(Message::Close(_)) => break,
496                        Ok(_) => {}
497                        Err(_) => break,
498                    }
499                }
500                else => break,
501            }
502        }
503
504        // Socket closed. Decide if we should reconnect.
505        let exit_set = inner.exit_code.lock().await.is_some();
506        if graceful_close || exit_set {
507            break;
508        }
509        let sid = inner.session_id.lock().await.clone();
510        match (sid, reconnect.as_ref()) {
511            (Some(_), Some(r)) => {
512                attempt += 1;
513                if attempt > r.max_retries {
514                    let err = HeyoError::Connection(format!(
515                        "gave up after {} reconnect attempts",
516                        r.max_retries
517                    ));
518                    let _ = events_tx.send(ShellEvent::Error(err.to_string())).await;
519                    break 'outer;
520                }
521                let delay = backoff(r, attempt);
522                let _ = events_tx
523                    .send(ShellEvent::Reconnecting { attempt, delay })
524                    .await;
525                sleep(delay).await;
526            }
527            _ => {
528                // First-connect failure that closed without `ready`: surface
529                // to the open() waiter if it's still around.
530                if let Some(tx) = ready_tx.take() {
531                    let _ = tx.send(Err(HeyoError::Connection(
532                        "shell-stream socket closed before ready".into(),
533                    )));
534                }
535                break;
536            }
537        }
538    }
539
540    let _ = inner.closed_tx.send(true);
541    let exit = *inner.exit_code.lock().await;
542    let _ = events_tx.send(ShellEvent::Closed { exit_code: exit }).await;
543}
544
545async fn handle_control(
546    text: &str,
547    inner: &Arc<SessionInner>,
548    _output_tx: &mpsc::Sender<Vec<u8>>,
549    events_tx: &mpsc::Sender<ShellEvent>,
550    ready_tx: &mut Option<oneshot::Sender<Result<String, HeyoError>>>,
551    got_ready: &mut bool,
552) -> std::ops::ControlFlow<()> {
553    let msg: ServerControl = match serde_json::from_str(text) {
554        Ok(m) => m,
555        Err(_) => return std::ops::ControlFlow::Continue(()),
556    };
557    match msg {
558        ServerControl::Ready { session_id, last_seq: _ } => {
559            let was_reconnect;
560            {
561                let mut guard = inner.session_id.lock().await;
562                was_reconnect = guard.is_some();
563                *guard = Some(session_id.clone());
564            }
565            if let Some(tx) = ready_tx.take() {
566                let _ = tx.send(Ok(session_id));
567            }
568            *got_ready = true;
569            if was_reconnect {
570                let _ = events_tx.send(ShellEvent::Reconnected).await;
571            }
572            std::ops::ControlFlow::Continue(())
573        }
574        ServerControl::Exit { code } => {
575            *inner.exit_code.lock().await = Some(code);
576            std::ops::ControlFlow::Break(())
577        }
578        ServerControl::Error { code, message } => {
579            let err_msg = message.unwrap_or_else(|| "shell-stream server error".to_string());
580            let _ = events_tx.send(ShellEvent::Error(err_msg)).await;
581            if code.as_deref() == Some("session_expired") {
582                let _ = inner.closed_tx.send(true);
583                std::ops::ControlFlow::Break(())
584            } else {
585                std::ops::ControlFlow::Continue(())
586            }
587        }
588    }
589}
590
591async fn handle_binary(
592    bytes: &[u8],
593    last_seq_received: u64,
594    output_tx: &mpsc::Sender<Vec<u8>>,
595) -> Option<u64> {
596    if bytes.is_empty() || bytes[0] != FRAME_STDOUT || bytes.len() < 9 {
597        return None;
598    }
599    let mut seq_bytes = [0u8; 8];
600    seq_bytes.copy_from_slice(&bytes[1..9]);
601    let seq = u64::from_be_bytes(seq_bytes);
602    if seq <= last_seq_received {
603        return None;
604    }
605    let _ = output_tx.send(bytes[9..].to_vec()).await;
606    Some(seq)
607}
608
609fn backoff(r: &ShellReconnectOptions, attempt: u32) -> Duration {
610    let pow = (attempt.saturating_sub(1)).min(30);
611    let factor: u64 = 1u64.checked_shl(pow).unwrap_or(u64::MAX);
612    let raw_ms = (r.base_delay.as_millis() as u64).saturating_mul(factor);
613    let cap_ms = r.max_delay.as_millis() as u64;
614    Duration::from_millis(raw_ms.min(cap_ms))
615}
616
617fn host_from_url(url: &str) -> String {
618    url::Url::parse(url)
619        .ok()
620        .and_then(|u| {
621            let host = u.host_str()?.to_string();
622            if let Some(port) = u.port() {
623                Some(format!("{}:{}", host, port))
624            } else {
625                Some(host)
626            }
627        })
628        .unwrap_or_default()
629}