use std::path::Path;
use std::sync::Arc;
use async_trait::async_trait;
use parking_lot::RwLock;
use reqwest::Method;
use tokio::sync::mpsc;
use boxlite_shared::errors::{BoxliteError, BoxliteResult};
use crate::BoxInfo;
use crate::litebox::copy::CopyOptions;
use crate::litebox::snapshot_mgr::SnapshotInfo;
use crate::litebox::{BoxCommand, ExecResult, ExecStderr, ExecStdin, ExecStdout, Execution};
use crate::metrics::BoxMetrics;
use crate::runtime::backend::{BoxBackend, SnapshotBackend};
use crate::runtime::id::BoxID;
use crate::runtime::options::{CloneOptions, ExportOptions, SnapshotOptions};
use super::client::ApiClient;
use super::exec::RestExecControl;
use super::types::{
BoxMetricsResponse, BoxResponse, CloneBoxRequest, CreateSnapshotRequest, ExecRequest,
ExecResponse, ExecutionStatusResponse, ExportBoxRequest, ListSnapshotsResponse,
SnapshotResponse,
};
pub(crate) struct RestBox {
client: ApiClient,
cached_info: RwLock<BoxInfo>,
}
impl RestBox {
pub fn new(client: ApiClient, info: BoxInfo) -> Self {
Self {
client,
cached_info: RwLock::new(info),
}
}
fn box_id_str(&self) -> String {
self.cached_info.read().id.to_string()
}
}
#[async_trait]
impl BoxBackend for RestBox {
fn id(&self) -> &BoxID {
unsafe {
let info = self.cached_info.data_ptr();
&(*info).id
}
}
fn name(&self) -> Option<&str> {
unsafe {
let info = self.cached_info.data_ptr();
(*info).name.as_deref()
}
}
fn info(&self) -> BoxInfo {
self.cached_info.read().clone()
}
async fn start(&self) -> BoxliteResult<()> {
let box_id = self.box_id_str();
let path = format!("/boxes/{}/start", box_id);
let resp: BoxResponse = self.client.post_empty(&path).await?;
let new_info = resp.to_box_info()?;
let mut info = self.cached_info.write();
*info = new_info;
Ok(())
}
async fn exec(&self, command: BoxCommand) -> BoxliteResult<Execution> {
let box_id = self.box_id_str();
let path = format!("/boxes/{}/exec", box_id);
let req = ExecRequest::from_command(&command);
let resp: ExecResponse = self.client.post(&path, &req).await?;
let execution_id = resp.execution_id;
let (stdout_tx, stdout_rx) = mpsc::unbounded_channel::<String>();
let (stderr_tx, stderr_rx) = mpsc::unbounded_channel::<String>();
let (stdin_tx, stdin_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (result_tx, result_rx) = mpsc::unbounded_channel::<ExecResult>();
let ws_client = self.client.clone();
let ws_box_id = box_id.clone();
let ws_exec_id = execution_id.clone();
tokio::spawn(async move {
attach_ws(
&ws_client,
&ws_box_id,
&ws_exec_id,
stdin_rx,
stdout_tx,
stderr_tx,
result_tx,
)
.await;
});
let control = RestExecControl::new(self.client.clone(), box_id);
let stdout = ExecStdout::new(stdout_rx);
let stderr = ExecStderr::new(stderr_rx);
let stdin = ExecStdin::new(stdin_tx);
Ok(Execution::new(
execution_id,
Box::new(control),
result_rx,
Some(stdin),
Some(stdout),
Some(stderr),
))
}
async fn attach(&self, execution_id: &str) -> BoxliteResult<Execution> {
let box_id = self.box_id_str();
let path = format!("/boxes/{}/executions/{}/attach", box_id, execution_id);
let stream = self.client.connect_ws(&path).await.map_err(|e| match e {
BoxliteError::NotFound(msg) => BoxliteError::SessionReaped(format!(
"session {} not found — likely reaped after disconnect timeout: {}",
execution_id, msg
)),
BoxliteError::AlreadyExists(msg) => BoxliteError::AlreadyExists(format!(
"session {} has another client attached: {}",
execution_id, msg
)),
other => other,
})?;
let (stdout_tx, stdout_rx) = mpsc::unbounded_channel::<String>();
let (stderr_tx, stderr_rx) = mpsc::unbounded_channel::<String>();
let (stdin_tx, stdin_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (result_tx, result_rx) = mpsc::unbounded_channel::<ExecResult>();
let ws_client = self.client.clone();
let ws_box_id = box_id.clone();
let ws_exec_id = execution_id.to_string();
tokio::spawn(async move {
attach_ws_pump(
&ws_client,
&ws_box_id,
&ws_exec_id,
stream,
stdin_rx,
stdout_tx,
stderr_tx,
result_tx,
)
.await;
});
let control = RestExecControl::new(self.client.clone(), box_id);
let stdout = ExecStdout::new(stdout_rx);
let stderr = ExecStderr::new(stderr_rx);
let stdin = ExecStdin::new(stdin_tx);
Ok(Execution::new(
execution_id.to_string(),
Box::new(control),
result_rx,
Some(stdin),
Some(stdout),
Some(stderr),
))
}
async fn metrics(&self) -> BoxliteResult<BoxMetrics> {
let box_id = self.box_id_str();
let path = format!("/boxes/{}/metrics", box_id);
let resp: BoxMetricsResponse = self.client.get(&path).await?;
Ok(box_metrics_from_response(&resp))
}
async fn stop(&self) -> BoxliteResult<()> {
let box_id = self.box_id_str();
let path = format!("/boxes/{}/stop", box_id);
let resp: BoxResponse = self.client.post_empty(&path).await?;
let new_info = resp.to_box_info()?;
let mut info = self.cached_info.write();
*info = new_info;
Ok(())
}
async fn copy_into(
&self,
host_src: &Path,
container_dst: &str,
_opts: CopyOptions,
) -> BoxliteResult<()> {
let box_id = self.box_id_str();
let tar_bytes = create_tar_from_path(host_src)?;
let encoded_dst = urlencoding::encode(container_dst);
let path = format!("/boxes/{}/files?path={}", box_id, encoded_dst);
let builder = self
.client
.authorized_request(Method::PUT, &path)
.await?
.header("Content-Type", "application/x-tar")
.body(tar_bytes);
let resp = builder
.send()
.await
.map_err(|e| BoxliteError::Internal(format!("copy_into upload failed: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
return Err(BoxliteError::Internal(format!(
"copy_into failed (HTTP {}): {}",
status, text
)));
}
Ok(())
}
async fn copy_out(
&self,
container_src: &str,
host_dst: &Path,
_opts: CopyOptions,
) -> BoxliteResult<()> {
let box_id = self.box_id_str();
let encoded_src = urlencoding::encode(container_src);
let path = format!("/boxes/{}/files?path={}", box_id, encoded_src);
let builder = self
.client
.authorized_request(Method::GET, &path)
.await?
.header("Accept", "application/x-tar");
let resp = builder
.send()
.await
.map_err(|e| BoxliteError::Internal(format!("copy_out download failed: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
return Err(BoxliteError::Internal(format!(
"copy_out failed (HTTP {}): {}",
status, text
)));
}
let tar_bytes = resp
.bytes()
.await
.map_err(|e| BoxliteError::Internal(format!("copy_out read body failed: {}", e)))?;
extract_tar_to_path(&tar_bytes, host_dst)
}
async fn clone_box(
&self,
options: CloneOptions,
name: Option<String>,
) -> BoxliteResult<crate::LiteBox> {
self.client.require_clone_enabled().await?;
let box_id = self.box_id_str();
let path = format!("/boxes/{}/clone", box_id);
let req = CloneBoxRequest::from_options(&options, name.as_deref());
let resp: BoxResponse = self.client.post(&path, &req).await?;
let info = resp.to_box_info()?;
let rest_box = Arc::new(RestBox::new(self.client.clone(), info));
let box_backend: Arc<dyn BoxBackend> = rest_box.clone();
let snapshot_backend: Arc<dyn SnapshotBackend> = rest_box;
Ok(crate::LiteBox::new(box_backend, snapshot_backend))
}
async fn clone_boxes(
&self,
options: CloneOptions,
count: usize,
names: Vec<String>,
) -> BoxliteResult<Vec<crate::LiteBox>> {
let mut results = Vec::with_capacity(count);
for i in 0..count {
let name = names.get(i).cloned();
let litebox = self.clone_box(options.clone(), name).await?;
results.push(litebox);
}
Ok(results)
}
async fn export_box(
&self,
options: ExportOptions,
dest: &Path,
) -> BoxliteResult<crate::runtime::options::BoxArchive> {
self.client.require_export_enabled().await?;
let box_id = self.box_id_str();
let path = format!("/boxes/{}/export", box_id);
let req = ExportBoxRequest::from_options(&options);
let archive_bytes = self.client.post_for_bytes(&path, &req).await?;
let output_path = if dest.is_dir() {
let name = self.name().unwrap_or("box");
dest.join(format!("{}.boxlite", name))
} else {
dest.to_path_buf()
};
if let Some(parent) = output_path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
BoxliteError::Storage(format!(
"Failed to create export destination directory {}: {}",
parent.display(),
e
))
})?;
}
std::fs::write(&output_path, &archive_bytes).map_err(|e| {
BoxliteError::Storage(format!(
"Failed to write export archive {}: {}",
output_path.display(),
e
))
})?;
Ok(crate::runtime::options::BoxArchive::new(output_path))
}
}
#[async_trait]
impl SnapshotBackend for RestBox {
async fn create(&self, options: SnapshotOptions, name: &str) -> BoxliteResult<SnapshotInfo> {
self.client.require_snapshots_enabled().await?;
let box_id = self.box_id_str();
let path = format!("/boxes/{}/snapshots", box_id);
let req = CreateSnapshotRequest::from_options(&options, name);
let resp: SnapshotResponse = self.client.post(&path, &req).await?;
Ok(resp.to_snapshot_info())
}
async fn list(&self) -> BoxliteResult<Vec<SnapshotInfo>> {
self.client.require_snapshots_enabled().await?;
let box_id = self.box_id_str();
let path = format!("/boxes/{}/snapshots", box_id);
let resp: ListSnapshotsResponse = self.client.get(&path).await?;
Ok(resp
.snapshots
.iter()
.map(SnapshotResponse::to_snapshot_info)
.collect())
}
async fn get(&self, name: &str) -> BoxliteResult<Option<SnapshotInfo>> {
self.client.require_snapshots_enabled().await?;
let box_id = self.box_id_str();
let encoded_name = urlencoding::encode(name);
let path = format!("/boxes/{}/snapshots/{}", box_id, encoded_name);
match self.client.get::<SnapshotResponse>(&path).await {
Ok(resp) => Ok(Some(resp.to_snapshot_info())),
Err(BoxliteError::NotFound(_)) => Ok(None),
Err(e) => Err(e),
}
}
async fn remove(&self, name: &str) -> BoxliteResult<()> {
self.client.require_snapshots_enabled().await?;
let box_id = self.box_id_str();
let encoded_name = urlencoding::encode(name);
let path = format!("/boxes/{}/snapshots/{}", box_id, encoded_name);
self.client.delete(&path).await
}
async fn restore(&self, name: &str) -> BoxliteResult<()> {
self.client.require_snapshots_enabled().await?;
let box_id = self.box_id_str();
let encoded_name = urlencoding::encode(name);
let path = format!("/boxes/{}/snapshots/{}/restore", box_id, encoded_name);
self.client.post_empty_no_content(&path).await
}
}
#[cfg(not(test))]
const WS_WATCHDOG: std::time::Duration = std::time::Duration::from_secs(45);
#[cfg(test)]
const WS_WATCHDOG: std::time::Duration = std::time::Duration::from_millis(300);
async fn attach_ws(
client: &ApiClient,
box_id: &str,
execution_id: &str,
stdin_rx: mpsc::UnboundedReceiver<Vec<u8>>,
stdout_tx: mpsc::UnboundedSender<String>,
stderr_tx: mpsc::UnboundedSender<String>,
result_tx: mpsc::UnboundedSender<ExecResult>,
) {
let path = format!("/boxes/{}/executions/{}/attach", box_id, execution_id);
let stream = match client.connect_ws(&path).await {
Ok(s) => s,
Err(e) => {
emit_or_fallback(
client,
box_id,
execution_id,
&result_tx,
format!("WS connect failed: {}", e),
)
.await;
return;
}
};
attach_ws_pump(
client,
box_id,
execution_id,
stream,
stdin_rx,
stdout_tx,
stderr_tx,
result_tx,
)
.await;
}
#[allow(clippy::too_many_arguments)]
async fn attach_ws_pump(
client: &ApiClient,
box_id: &str,
execution_id: &str,
stream: tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
mut stdin_rx: mpsc::UnboundedReceiver<Vec<u8>>,
stdout_tx: mpsc::UnboundedSender<String>,
stderr_tx: mpsc::UnboundedSender<String>,
result_tx: mpsc::UnboundedSender<ExecResult>,
) {
use futures::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message;
let (mut sink, mut stream) = stream.split();
let stdin_task = tokio::spawn(async move {
while let Some(bytes) = stdin_rx.recv().await {
if sink.send(Message::Binary(bytes)).await.is_err() {
break;
}
}
let _ = sink
.send(Message::Text(r#"{"type":"stdin_eof"}"#.to_string()))
.await;
});
let mut last_error_message: Option<String> = None;
loop {
let next = tokio::time::timeout(WS_WATCHDOG, stream.next()).await;
let frame = match next {
Err(_) => {
emit_or_fallback(
client,
box_id,
execution_id,
&result_tx,
"no WS traffic for watchdog interval (likely connection idle timeout or proxy cut)"
.to_string(),
)
.await;
break;
}
Ok(None) => {
let cause = last_error_message.take().unwrap_or_else(|| {
"WS stream ended before exit frame (likely connection idle timeout or proxy cut)"
.to_string()
});
emit_or_fallback(client, box_id, execution_id, &result_tx, cause).await;
break;
}
Ok(Some(Err(e))) => {
emit_or_fallback(
client,
box_id,
execution_id,
&result_tx,
format!("WS stream read error: {}", e),
)
.await;
break;
}
Ok(Some(Ok(msg))) => msg,
};
match frame {
Message::Binary(bytes) => {
if let Some((channel, payload)) = bytes.split_first() {
let text = String::from_utf8_lossy(payload).into_owned();
match *channel {
0x01 => {
let _ = stdout_tx.send(text);
}
0x02 => {
let _ = stderr_tx.send(text);
}
other => {
tracing::warn!(channel = other, "WS attach: unknown channel prefix");
}
}
}
}
Message::Text(text) => match parse_control_frame(&text) {
ControlFrame::Exit { exit_code } => {
let _ = result_tx.send(ExecResult {
exit_code,
error_message: None,
});
break;
}
ControlFrame::Error { message } => {
tracing::warn!(message = %message, "WS attach: server-reported error");
last_error_message = Some(message);
}
ControlFrame::Unknown => {
tracing::warn!(text = %text, "WS attach: unrecognized text frame");
}
},
Message::Close(_) => {
let cause = last_error_message.take().unwrap_or_else(|| {
"WS closed before exit frame (likely connection idle timeout or proxy cut)"
.to_string()
});
emit_or_fallback(client, box_id, execution_id, &result_tx, cause).await;
break;
}
Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => {}
}
}
stdin_task.abort();
}
enum ControlFrame {
Exit { exit_code: i32 },
Error { message: String },
Unknown,
}
fn parse_control_frame(text: &str) -> ControlFrame {
let Ok(value) = serde_json::from_str::<serde_json::Value>(text) else {
return ControlFrame::Unknown;
};
match value.get("type").and_then(|v| v.as_str()) {
Some("exit") => {
let exit_code = value
.get("exit_code")
.and_then(|v| v.as_i64())
.unwrap_or(-1) as i32;
ControlFrame::Exit { exit_code }
}
Some("error") => {
let message = value
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("server reported error without message")
.to_string();
ControlFrame::Error { message }
}
_ => ControlFrame::Unknown,
}
}
async fn emit_or_fallback(
client: &ApiClient,
box_id: &str,
execution_id: &str,
result_tx: &mpsc::UnboundedSender<ExecResult>,
cause: String,
) {
let status_path = format!("/boxes/{}/executions/{}", box_id, execution_id);
let status_probe = tokio::time::timeout(
std::time::Duration::from_secs(5),
client.get::<ExecutionStatusResponse>(&status_path),
);
if let Ok(Ok(info)) = status_probe.await {
match info.status.as_str() {
"completed" | "killed" | "timed_out" => {
let _ = result_tx.send(ExecResult {
exit_code: info.exit_code.unwrap_or(-1),
error_message: None,
});
return;
}
_ => {
}
}
}
let _ = result_tx.send(ExecResult {
exit_code: -1,
error_message: Some(cause),
});
}
fn create_tar_from_path(host_src: &Path) -> BoxliteResult<Vec<u8>> {
let mut archive = tar::Builder::new(Vec::new());
if host_src.is_dir() {
archive.append_dir_all(".", host_src).map_err(|e| {
BoxliteError::Internal(format!(
"failed to create tar from {}: {}",
host_src.display(),
e
))
})?;
} else {
let file_name = host_src
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_else(|| "file".to_string());
let mut file = std::fs::File::open(host_src).map_err(|e| {
BoxliteError::Internal(format!("failed to open {}: {}", host_src.display(), e))
})?;
archive.append_file(&file_name, &mut file).map_err(|e| {
BoxliteError::Internal(format!(
"failed to add {} to tar: {}",
host_src.display(),
e
))
})?;
}
archive
.into_inner()
.map_err(|e| BoxliteError::Internal(format!("failed to finalize tar archive: {}", e)))
}
fn extract_tar_to_path(tar_bytes: &[u8], host_dst: &Path) -> BoxliteResult<()> {
if let Some(parent) = host_dst.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
BoxliteError::Internal(format!(
"failed to create directory {}: {}",
parent.display(),
e
))
})?;
}
let mut archive = tar::Archive::new(tar_bytes);
archive.unpack(host_dst).map_err(|e| {
BoxliteError::Internal(format!(
"failed to extract tar to {}: {}",
host_dst.display(),
e
))
})
}
fn box_metrics_from_response(resp: &BoxMetricsResponse) -> BoxMetrics {
let (
total_create_ms,
guest_boot_ms,
fs_setup_ms,
img_prepare_ms,
guest_rootfs_ms,
box_config_ms,
box_spawn_ms,
container_init_ms,
) = if let Some(ref timing) = resp.boot_timing {
(
timing.total_create_ms.map(|v| v as u128),
timing.guest_boot_ms.map(|v| v as u128),
timing.filesystem_setup_ms.map(|v| v as u128),
timing.image_prepare_ms.map(|v| v as u128),
timing.guest_rootfs_ms.map(|v| v as u128),
timing.box_config_ms.map(|v| v as u128),
timing.box_spawn_ms.map(|v| v as u128),
timing.container_init_ms.map(|v| v as u128),
)
} else {
(None, None, None, None, None, None, None, None)
};
BoxMetrics {
commands_executed_total: resp.commands_executed_total,
exec_errors_total: resp.exec_errors_total,
bytes_sent_total: resp.bytes_sent_total,
bytes_received_total: resp.bytes_received_total,
total_create_duration_ms: total_create_ms,
guest_boot_duration_ms: guest_boot_ms,
cpu_percent: resp.cpu_percent,
memory_bytes: resp.memory_bytes,
network_bytes_sent: resp.network_bytes_sent,
network_bytes_received: resp.network_bytes_received,
network_tcp_connections: resp.network_tcp_connections,
network_tcp_errors: resp.network_tcp_errors,
stage_filesystem_setup_ms: fs_setup_ms,
stage_image_prepare_ms: img_prepare_ms,
stage_guest_rootfs_ms: guest_rootfs_ms,
stage_box_config_ms: box_config_ms,
stage_box_spawn_ms: box_spawn_ms,
stage_container_init_ms: container_init_ms,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rest::client::ApiClient;
use crate::rest::options::BoxliteRestOptions;
use futures::{SinkExt, StreamExt};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;
#[derive(Default)]
struct ServerState {
received_stdin: Vec<Vec<u8>>,
status_calls: u32,
}
type SharedState = Arc<Mutex<ServerState>>;
async fn read_request_head(stream: &mut TcpStream) -> Vec<u8> {
let mut buf = Vec::with_capacity(1024);
let mut tmp = [0u8; 512];
loop {
let n = match stream.read(&mut tmp).await {
Ok(0) => break,
Ok(n) => n,
Err(_) => break,
};
buf.extend_from_slice(&tmp[..n]);
if buf.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
if buf.len() > 16 * 1024 {
break;
}
}
buf
}
fn client_for(port: u16) -> ApiClient {
let opts = BoxliteRestOptions::new(format!("http://127.0.0.1:{}", port));
ApiClient::new(&opts).expect("ApiClient::new")
}
async fn write_status_response(stream: &mut TcpStream, body: &str) {
let resp = format!(
"HTTP/1.1 200 OK\r\n\
Content-Type: application/json\r\n\
Content-Length: {}\r\n\
Connection: close\r\n\r\n{}",
body.len(),
body
);
let _ = stream.write_all(resp.as_bytes()).await;
let _ = stream.shutdown().await;
}
struct ChainedStream {
head: Vec<u8>,
head_pos: usize,
inner: TcpStream,
}
impl tokio::io::AsyncRead for ChainedStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
if self.head_pos < self.head.len() {
let remaining = &self.head[self.head_pos..];
let take = remaining.len().min(buf.remaining());
buf.put_slice(&remaining[..take]);
self.head_pos += take;
return std::task::Poll::Ready(Ok(()));
}
std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl tokio::io::AsyncWrite for ChainedStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::pin::Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
async fn run_server<F, Fut>(
listener: TcpListener,
state: SharedState,
status_body: Option<String>,
ws_handler: F,
) where
F: FnOnce(tokio_tungstenite::WebSocketStream<ChainedStream>, SharedState) -> Fut
+ Send
+ 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let mut ws_handler = Some(ws_handler);
loop {
let (mut stream, _) = match listener.accept().await {
Ok(p) => p,
Err(_) => return,
};
let head = read_request_head(&mut stream).await;
let head_str = String::from_utf8_lossy(&head);
let is_upgrade = head_str.to_ascii_lowercase().contains("upgrade: websocket");
if is_upgrade {
if let Some(handler) = ws_handler.take() {
let chained = ChainedStream {
head,
head_pos: 0,
inner: stream,
};
match tokio_tungstenite::accept_async(chained).await {
Ok(ws) => handler(ws, state.clone()).await,
Err(_) => continue,
}
}
} else if let Some(ref body) = status_body {
let mut s = state.lock().await;
s.status_calls += 1;
drop(s);
write_status_response(&mut stream, body).await;
} else {
let _ = stream.shutdown().await;
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ws_clean_exit_emits_result() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let state: SharedState = Arc::new(Mutex::new(ServerState::default()));
let state_clone = state.clone();
let server = tokio::spawn(async move {
run_server(listener, state_clone, None, |mut ws, state| async move {
if let Some(Ok(Message::Binary(b))) = ws.next().await {
let mut s = state.lock().await;
s.received_stdin.push(b);
}
ws.send(Message::Binary(vec![0x01, b'h', b'i']))
.await
.unwrap();
ws.send(Message::Text(r#"{"type":"exit","exit_code":7}"#.into()))
.await
.unwrap();
let _ = ws.close(None).await;
})
.await;
});
let client = client_for(port);
let (stdout_tx, mut stdout_rx) = mpsc::unbounded_channel::<String>();
let (stderr_tx, _stderr_rx) = mpsc::unbounded_channel::<String>();
let (stdin_tx, stdin_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (result_tx, mut result_rx) = mpsc::unbounded_channel::<ExecResult>();
stdin_tx.send(b"hello".to_vec()).unwrap();
let attach = tokio::spawn(async move {
attach_ws(
&client, "box1", "exec1", stdin_rx, stdout_tx, stderr_tx, result_tx,
)
.await;
});
let res = tokio::time::timeout(Duration::from_secs(3), result_rx.recv())
.await
.expect("result channel timed out")
.expect("result channel closed without value");
assert_eq!(res.exit_code, 7);
assert!(res.error_message.is_none());
let out = tokio::time::timeout(Duration::from_secs(1), stdout_rx.recv())
.await
.expect("stdout timed out")
.expect("stdout channel closed");
assert_eq!(out, "hi");
attach.await.unwrap();
let s = state.lock().await;
assert!(
s.received_stdin.iter().any(|b| b == b"hello"),
"server never observed stdin payload; got {:?}",
s.received_stdin
);
drop(s);
server.abort();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ws_close_without_exit_falls_back_to_status() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let state: SharedState = Arc::new(Mutex::new(ServerState::default()));
let status_body =
r#"{"execution_id":"exec1","status":"completed","exit_code":42}"#.to_string();
let state_clone = state.clone();
let server = tokio::spawn(async move {
run_server(
listener,
state_clone,
Some(status_body),
|mut ws, _state| async move {
ws.send(Message::Binary(vec![0x01, b'x'])).await.unwrap();
let _ = ws.close(None).await;
},
)
.await;
});
let client = client_for(port);
let (stdout_tx, _stdout_rx) = mpsc::unbounded_channel::<String>();
let (stderr_tx, _stderr_rx) = mpsc::unbounded_channel::<String>();
let (_stdin_tx, stdin_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (result_tx, mut result_rx) = mpsc::unbounded_channel::<ExecResult>();
let attach = tokio::spawn(async move {
attach_ws(
&client, "box1", "exec1", stdin_rx, stdout_tx, stderr_tx, result_tx,
)
.await;
});
let res = tokio::time::timeout(Duration::from_secs(3), result_rx.recv())
.await
.expect("result channel timed out")
.expect("result channel closed without value");
assert_eq!(
res.exit_code, 42,
"expected status fallback to surface real exit code"
);
assert!(res.error_message.is_none());
attach.await.unwrap();
let s = state.lock().await;
assert!(
s.status_calls >= 1,
"status fallback endpoint was never called"
);
drop(s);
server.abort();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ws_watchdog_fires_when_idle() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let state: SharedState = Arc::new(Mutex::new(ServerState::default()));
let state_clone = state.clone();
let server = tokio::spawn(async move {
run_server(listener, state_clone, None, |ws, _state| async move {
let _kept_alive = ws;
tokio::time::sleep(Duration::from_secs(2)).await;
})
.await;
});
let client = client_for(port);
let (stdout_tx, _stdout_rx) = mpsc::unbounded_channel::<String>();
let (stderr_tx, _stderr_rx) = mpsc::unbounded_channel::<String>();
let (_stdin_tx, stdin_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (result_tx, mut result_rx) = mpsc::unbounded_channel::<ExecResult>();
let attach = tokio::spawn(async move {
attach_ws(
&client, "box1", "exec1", stdin_rx, stdout_tx, stderr_tx, result_tx,
)
.await;
});
let res = tokio::time::timeout(Duration::from_secs(3), result_rx.recv())
.await
.expect("watchdog never fired")
.expect("result channel closed without value");
assert_eq!(res.exit_code, -1);
let msg = res.error_message.expect("expected diagnostic message");
assert!(msg.contains("watchdog"), "unexpected diagnostic: {:?}", msg);
attach.await.unwrap();
server.abort();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ws_text_error_frame_logs_but_continues() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let state: SharedState = Arc::new(Mutex::new(ServerState::default()));
let state_clone = state.clone();
let server = tokio::spawn(async move {
run_server(listener, state_clone, None, |mut ws, _state| async move {
ws.send(Message::Text(
r#"{"type":"error","message":"signal not allowed"}"#.into(),
))
.await
.unwrap();
ws.send(Message::Text(r#"{"type":"exit","exit_code":0}"#.into()))
.await
.unwrap();
let _ = ws.close(None).await;
})
.await;
});
let client = client_for(port);
let (stdout_tx, _stdout_rx) = mpsc::unbounded_channel::<String>();
let (stderr_tx, _stderr_rx) = mpsc::unbounded_channel::<String>();
let (_stdin_tx, stdin_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (result_tx, mut result_rx) = mpsc::unbounded_channel::<ExecResult>();
let attach = tokio::spawn(async move {
attach_ws(
&client, "box1", "exec1", stdin_rx, stdout_tx, stderr_tx, result_tx,
)
.await;
});
let res = tokio::time::timeout(Duration::from_secs(3), result_rx.recv())
.await
.expect("result channel timed out")
.expect("result channel closed without value");
assert_eq!(
res.exit_code, 0,
"informational error frame must not terminate the attach"
);
assert!(res.error_message.is_none());
attach.await.unwrap();
server.abort();
}
}