1use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::Duration;
21
22use futures_util::stream::Stream;
23use futures_util::{SinkExt, StreamExt};
24use http::Request;
25use serde::Deserialize;
26use serde_json::json;
27use tokio::sync::{mpsc, oneshot, watch, Mutex};
28use tokio::time::sleep;
29use tokio_tungstenite::tungstenite::Message;
30
31use crate::client::HeyoClient;
32use crate::commands::encode_path;
33use crate::errors::HeyoError;
34
35const FRAME_STDIN: u8 = 0x01;
36const FRAME_STDOUT: u8 = 0x02;
37const ACK_INTERVAL: Duration = Duration::from_millis(100);
38const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(60);
39
40#[derive(Debug, Clone)]
42pub struct ShellReconnectOptions {
43 pub max_retries: u32,
44 pub base_delay: Duration,
45 pub max_delay: Duration,
46}
47impl Default for ShellReconnectOptions {
48 fn default() -> Self {
49 Self {
50 max_retries: 5,
51 base_delay: Duration::from_millis(100),
52 max_delay: Duration::from_secs(30),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct ShellOptions {
60 pub cwd: Option<String>,
61 pub env: Option<HashMap<String, String>>,
62 pub cols: u16,
63 pub rows: u16,
64 pub reconnect: Option<ShellReconnectOptions>,
66}
67
68impl Default for ShellOptions {
69 fn default() -> Self {
70 Self {
71 cwd: None,
72 env: None,
73 cols: 80,
74 rows: 24,
75 reconnect: Some(ShellReconnectOptions::default()),
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub enum ShellEvent {
83 Reconnecting { attempt: u32, delay: Duration },
84 Reconnected,
85 Closed { exit_code: Option<i32> },
86 Error(String),
87}
88
89pub struct ShellSession {
92 inner: Arc<SessionInner>,
93 output_rx: Arc<Mutex<mpsc::Receiver<Vec<u8>>>>,
94 events_rx: Arc<Mutex<mpsc::Receiver<ShellEvent>>>,
95}
96
97struct SessionInner {
98 write_tx: mpsc::UnboundedSender<OutboundMessage>,
99 session_id: Mutex<Option<String>>,
100 closed_tx: watch::Sender<bool>,
101 closed_rx: watch::Receiver<bool>,
102 exit_code: Mutex<Option<i32>>,
103}
104
105enum OutboundMessage {
106 Stdin(Vec<u8>),
107 Json(serde_json::Value),
108 Close,
109}
110
111#[derive(Deserialize)]
112#[serde(tag = "type", rename_all = "lowercase")]
113enum ServerControl {
114 Ready {
115 #[serde(rename = "sessionId")]
116 session_id: String,
117 #[serde(default, rename = "lastSeq")]
118 #[allow(dead_code)]
119 last_seq: Option<u64>,
120 },
121 Exit {
122 code: i32,
123 },
124 Error {
125 #[serde(default)]
126 code: Option<String>,
127 #[serde(default)]
128 message: Option<String>,
129 },
130}
131
132impl ShellSession {
133 pub(crate) async fn open(
135 client: HeyoClient,
136 sandbox_id: String,
137 options: ShellOptions,
138 ) -> Result<Self, HeyoError> {
139 let path = format!("/deployed-sandboxes/{}/shell-stream", encode_path(&sandbox_id));
140 let url = client.ws_url(&path)?;
141 let auth = client.ws_authorization();
142 let (output_tx, output_rx) = mpsc::channel::<Vec<u8>>(256);
143 let (events_tx, events_rx) = mpsc::channel::<ShellEvent>(64);
144 let (write_tx, write_rx) = mpsc::unbounded_channel::<OutboundMessage>();
145 let (closed_tx, closed_rx) = watch::channel(false);
146 let (ready_tx, ready_rx) = oneshot::channel::<Result<String, HeyoError>>();
147 let inner = Arc::new(SessionInner {
148 write_tx,
149 session_id: Mutex::new(None),
150 closed_tx,
151 closed_rx,
152 exit_code: Mutex::new(None),
153 });
154
155 let reconnect = options.reconnect.clone();
156 let init_opts = options;
157 let url_clone = url.clone();
158 let auth_clone = auth.clone();
159 let inner_for_task = inner.clone();
160 let output_tx_clone = output_tx.clone();
161 let events_tx_clone = events_tx.clone();
162 tokio::spawn(run_session(
163 url_clone,
164 auth_clone,
165 init_opts,
166 reconnect,
167 inner_for_task,
168 write_rx,
169 output_tx_clone,
170 events_tx_clone,
171 Some(ready_tx),
172 ));
173 drop(output_tx);
174 drop(events_tx);
175
176 match ready_rx.await {
177 Ok(Ok(_session_id)) => Ok(ShellSession {
178 inner,
179 output_rx: Arc::new(Mutex::new(output_rx)),
180 events_rx: Arc::new(Mutex::new(events_rx)),
181 }),
182 Ok(Err(e)) => Err(e),
183 Err(_) => Err(HeyoError::Connection("shell session task dropped before ready".into())),
184 }
185 }
186
187 pub async fn session_id(&self) -> Option<String> {
191 self.inner.session_id.lock().await.clone()
192 }
193
194 pub async fn write(&self, bytes: &[u8]) -> Result<(), HeyoError> {
195 if *self.inner.closed_rx.borrow() {
196 return Err(HeyoError::Connection("session is closed".into()));
197 }
198 self.inner
199 .write_tx
200 .send(OutboundMessage::Stdin(bytes.to_vec()))
201 .map_err(|_| HeyoError::Connection("session writer is closed".into()))
202 }
203
204 pub async fn resize(&self, cols: u16, rows: u16) -> Result<(), HeyoError> {
205 if *self.inner.closed_rx.borrow() {
206 return Err(HeyoError::Connection("session is closed".into()));
207 }
208 self.inner
209 .write_tx
210 .send(OutboundMessage::Json(json!({
211 "type": "resize",
212 "cols": cols,
213 "rows": rows,
214 })))
215 .map_err(|_| HeyoError::Connection("session writer is closed".into()))
216 }
217
218 pub async fn close(&self) -> Result<(), HeyoError> {
221 if *self.inner.closed_rx.borrow() {
222 return Ok(());
223 }
224 let _ = self.inner.write_tx.send(OutboundMessage::Close);
225 let mut rx = self.inner.closed_rx.clone();
227 let _ = tokio::time::timeout(Duration::from_secs(2), async {
228 while !*rx.borrow_and_update() {
229 if rx.changed().await.is_err() {
230 break;
231 }
232 }
233 })
234 .await;
235 Ok(())
236 }
237
238 pub async fn exit_code(&self) -> Option<i32> {
239 *self.inner.exit_code.lock().await
240 }
241
242 pub fn is_closed(&self) -> bool {
243 *self.inner.closed_rx.borrow()
244 }
245
246 pub fn output(&self) -> impl Stream<Item = Vec<u8>> + Send + Unpin {
248 let rx = self.output_rx.clone();
249 Box::pin(async_stream::stream! {
250 loop {
251 let mut guard = rx.lock().await;
252 match guard.recv().await {
253 Some(chunk) => yield chunk,
254 None => break,
255 }
256 }
257 })
258 }
259
260 pub fn events(&self) -> impl Stream<Item = ShellEvent> + Send + Unpin {
262 let rx = self.events_rx.clone();
263 Box::pin(async_stream::stream! {
264 loop {
265 let mut guard = rx.lock().await;
266 match guard.recv().await {
267 Some(event) => yield event,
268 None => break,
269 }
270 }
271 })
272 }
273}
274
275#[allow(clippy::too_many_arguments)]
276async fn run_session(
277 url: String,
278 auth: String,
279 options: ShellOptions,
280 reconnect: Option<ShellReconnectOptions>,
281 inner: Arc<SessionInner>,
282 mut write_rx: mpsc::UnboundedReceiver<OutboundMessage>,
283 output_tx: mpsc::Sender<Vec<u8>>,
284 events_tx: mpsc::Sender<ShellEvent>,
285 mut ready_tx: Option<oneshot::Sender<Result<String, HeyoError>>>,
286) {
287 let mut attempt: u32 = 0;
288 let mut last_seq_received: u64 = 0;
289 let mut last_seq_acked: u64 = 0;
290 let cols = options.cols;
291 let rows = options.rows;
292
293 'outer: loop {
294 let session_id_opt = inner.session_id.lock().await.clone();
295 let reconnecting = session_id_opt.is_some();
296
297 let request = match Request::builder()
299 .method("GET")
300 .uri(&url)
301 .header("Authorization", &auth)
302 .header("Sec-WebSocket-Version", "13")
303 .header("Sec-WebSocket-Key", tokio_tungstenite::tungstenite::handshake::client::generate_key())
304 .header("Connection", "Upgrade")
305 .header("Upgrade", "websocket")
306 .header("Host", host_from_url(&url))
307 .body(())
308 {
309 Ok(r) => r,
310 Err(e) => {
311 let err = HeyoError::Connection(format!("build ws request: {}", e));
312 if let Some(tx) = ready_tx.take() {
313 let _ = tx.send(Err(err));
314 return;
315 }
316 let _ = events_tx.send(ShellEvent::Error(format!("{}", e))).await;
317 break;
318 }
319 };
320
321 let connect_result = tokio_tungstenite::connect_async(request).await;
322 let (ws_stream, _) = match connect_result {
323 Ok(s) => s,
324 Err(e) => {
325 if reconnecting {
326 if let Some(r) = reconnect.as_ref() {
327 attempt += 1;
328 if attempt > r.max_retries {
329 let err = HeyoError::Connection(format!(
330 "gave up after {} reconnect attempts: {}",
331 r.max_retries, e
332 ));
333 let _ = events_tx.send(ShellEvent::Error(err.to_string())).await;
334 break;
335 }
336 let delay = backoff(r, attempt);
337 let _ = events_tx
338 .send(ShellEvent::Reconnecting { attempt, delay })
339 .await;
340 sleep(delay).await;
341 continue;
342 }
343 }
344 let err = HeyoError::Connection(format!("ws connect: {}", e));
345 if let Some(tx) = ready_tx.take() {
346 let _ = tx.send(Err(err));
347 return;
348 }
349 let _ = events_tx.send(ShellEvent::Error(format!("{}", e))).await;
350 break;
351 }
352 };
353
354 let (mut ws_tx, mut ws_rx) = ws_stream.split();
355
356 let mut init = serde_json::Map::new();
358 init.insert("type".into(), json!("init"));
359 init.insert("cols".into(), json!(cols));
360 init.insert("rows".into(), json!(rows));
361 if let Some(env) = &options.env {
362 init.insert("env".into(), json!(env));
363 }
364 if let Some(cwd) = &options.cwd {
365 init.insert("cwd".into(), json!(cwd));
366 }
367 if let Some(sid) = &session_id_opt {
368 init.insert("sessionId".into(), json!(sid));
369 }
370 if let Err(e) = ws_tx
371 .send(Message::Text(serde_json::Value::Object(init).to_string()))
372 .await
373 {
374 let msg = format!("send init: {}", e);
375 if let Some(tx) = ready_tx.take() {
376 let _ = tx.send(Err(HeyoError::Connection(msg.clone())));
377 return;
378 }
379 let _ = events_tx.send(ShellEvent::Error(msg)).await;
380 continue;
381 }
382
383 let mut got_ready = false;
384 let mut ack_due = false;
385 let mut graceful_close = false;
386 let mut ack_timer = tokio::time::interval(ACK_INTERVAL);
388 ack_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
389 let mut heartbeat = Box::pin(sleep(HEARTBEAT_TIMEOUT));
391
392 loop {
393 tokio::select! {
394 _ = ack_timer.tick() => {
395 if ack_due && last_seq_received > last_seq_acked {
396 last_seq_acked = last_seq_received;
397 if ws_tx
398 .send(Message::Text(json!({
399 "type": "ack",
400 "seq": last_seq_acked,
401 }).to_string()))
402 .await
403 .is_err()
404 {
405 }
407 ack_due = false;
408 }
409 }
410 _ = &mut heartbeat => {
411 let _ = ws_tx.send(Message::Close(None)).await;
412 break;
413 }
414 Some(msg) = write_rx.recv() => {
415 match msg {
416 OutboundMessage::Stdin(bytes) => {
417 let mut frame = Vec::with_capacity(bytes.len() + 1);
418 frame.push(FRAME_STDIN);
419 frame.extend_from_slice(&bytes);
420 if ws_tx.send(Message::Binary(frame)).await.is_err() {
421 break;
422 }
423 }
424 OutboundMessage::Json(v) => {
425 if ws_tx.send(Message::Text(v.to_string())).await.is_err() {
426 break;
427 }
428 }
429 OutboundMessage::Close => {
430 graceful_close = true;
431 let _ = ws_tx
432 .send(Message::Text(json!({"type":"close"}).to_string()))
433 .await;
434 let _ = tokio::time::timeout(Duration::from_secs(2), async {
436 while let Some(msg) = ws_rx.next().await {
437 match msg {
438 Ok(Message::Text(t)) => {
439 if handle_control(
440 &t,
441 &inner,
442 &output_tx,
443 &events_tx,
444 &mut ready_tx,
445 &mut got_ready,
446 )
447 .await
448 .is_break()
449 {
450 break;
451 }
452 }
453 _ => {}
454 }
455 }
456 })
457 .await;
458 break;
459 }
460 }
461 }
462 Some(msg) = ws_rx.next() => {
463 heartbeat = Box::pin(sleep(HEARTBEAT_TIMEOUT));
464 match msg {
465 Ok(Message::Text(t)) => {
466 if handle_control(
467 &t,
468 &inner,
469 &output_tx,
470 &events_tx,
471 &mut ready_tx,
472 &mut got_ready,
473 )
474 .await
475 .is_break()
476 {
477 break;
478 }
479 if got_ready {
480 attempt = 0;
481 }
482 }
483 Ok(Message::Binary(bytes)) => {
484 if let Some(seq) = handle_binary(
485 &bytes,
486 last_seq_received,
487 &output_tx,
488 )
489 .await
490 {
491 last_seq_received = seq;
492 ack_due = true;
493 }
494 }
495 Ok(Message::Close(_)) => break,
496 Ok(_) => {}
497 Err(_) => break,
498 }
499 }
500 else => break,
501 }
502 }
503
504 let exit_set = inner.exit_code.lock().await.is_some();
506 if graceful_close || exit_set {
507 break;
508 }
509 let sid = inner.session_id.lock().await.clone();
510 match (sid, reconnect.as_ref()) {
511 (Some(_), Some(r)) => {
512 attempt += 1;
513 if attempt > r.max_retries {
514 let err = HeyoError::Connection(format!(
515 "gave up after {} reconnect attempts",
516 r.max_retries
517 ));
518 let _ = events_tx.send(ShellEvent::Error(err.to_string())).await;
519 break 'outer;
520 }
521 let delay = backoff(r, attempt);
522 let _ = events_tx
523 .send(ShellEvent::Reconnecting { attempt, delay })
524 .await;
525 sleep(delay).await;
526 }
527 _ => {
528 if let Some(tx) = ready_tx.take() {
531 let _ = tx.send(Err(HeyoError::Connection(
532 "shell-stream socket closed before ready".into(),
533 )));
534 }
535 break;
536 }
537 }
538 }
539
540 let _ = inner.closed_tx.send(true);
541 let exit = *inner.exit_code.lock().await;
542 let _ = events_tx.send(ShellEvent::Closed { exit_code: exit }).await;
543}
544
545async fn handle_control(
546 text: &str,
547 inner: &Arc<SessionInner>,
548 _output_tx: &mpsc::Sender<Vec<u8>>,
549 events_tx: &mpsc::Sender<ShellEvent>,
550 ready_tx: &mut Option<oneshot::Sender<Result<String, HeyoError>>>,
551 got_ready: &mut bool,
552) -> std::ops::ControlFlow<()> {
553 let msg: ServerControl = match serde_json::from_str(text) {
554 Ok(m) => m,
555 Err(_) => return std::ops::ControlFlow::Continue(()),
556 };
557 match msg {
558 ServerControl::Ready { session_id, last_seq: _ } => {
559 let was_reconnect;
560 {
561 let mut guard = inner.session_id.lock().await;
562 was_reconnect = guard.is_some();
563 *guard = Some(session_id.clone());
564 }
565 if let Some(tx) = ready_tx.take() {
566 let _ = tx.send(Ok(session_id));
567 }
568 *got_ready = true;
569 if was_reconnect {
570 let _ = events_tx.send(ShellEvent::Reconnected).await;
571 }
572 std::ops::ControlFlow::Continue(())
573 }
574 ServerControl::Exit { code } => {
575 *inner.exit_code.lock().await = Some(code);
576 std::ops::ControlFlow::Break(())
577 }
578 ServerControl::Error { code, message } => {
579 let err_msg = message.unwrap_or_else(|| "shell-stream server error".to_string());
580 let _ = events_tx.send(ShellEvent::Error(err_msg)).await;
581 if code.as_deref() == Some("session_expired") {
582 let _ = inner.closed_tx.send(true);
583 std::ops::ControlFlow::Break(())
584 } else {
585 std::ops::ControlFlow::Continue(())
586 }
587 }
588 }
589}
590
591async fn handle_binary(
592 bytes: &[u8],
593 last_seq_received: u64,
594 output_tx: &mpsc::Sender<Vec<u8>>,
595) -> Option<u64> {
596 if bytes.is_empty() || bytes[0] != FRAME_STDOUT || bytes.len() < 9 {
597 return None;
598 }
599 let mut seq_bytes = [0u8; 8];
600 seq_bytes.copy_from_slice(&bytes[1..9]);
601 let seq = u64::from_be_bytes(seq_bytes);
602 if seq <= last_seq_received {
603 return None;
604 }
605 let _ = output_tx.send(bytes[9..].to_vec()).await;
606 Some(seq)
607}
608
609fn backoff(r: &ShellReconnectOptions, attempt: u32) -> Duration {
610 let pow = (attempt.saturating_sub(1)).min(30);
611 let factor: u64 = 1u64.checked_shl(pow).unwrap_or(u64::MAX);
612 let raw_ms = (r.base_delay.as_millis() as u64).saturating_mul(factor);
613 let cap_ms = r.max_delay.as_millis() as u64;
614 Duration::from_millis(raw_ms.min(cap_ms))
615}
616
617fn host_from_url(url: &str) -> String {
618 url::Url::parse(url)
619 .ok()
620 .and_then(|u| {
621 let host = u.host_str()?.to_string();
622 if let Some(port) = u.port() {
623 Some(format!("{}:{}", host, port))
624 } else {
625 Some(host)
626 }
627 })
628 .unwrap_or_default()
629}