use axum::{
extract::{
ws::{Message, WebSocket},
Path, State, WebSocketUpgrade,
},
response::IntoResponse,
};
use std::sync::Arc;
use tmai_core::api::TmaiCore;
pub async fn ws_terminal(
ws: WebSocketUpgrade,
Path(id): Path<String>,
State(core): State<Arc<TmaiCore>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| handle_ws(socket, id, core))
}
async fn handle_ws(socket: WebSocket, session_id: String, core: Arc<TmaiCore>) {
let session = match core.pty_registry().get(&session_id) {
Some(s) => s,
None => {
tracing::warn!("WS: PTY session not found: {}", session_id);
return;
}
};
tracing::debug!("WS: connected to PTY session {}", session_id);
let mut output_rx = session.subscribe();
let (mut ws_tx, mut ws_rx) = socket.split();
use futures_util::{SinkExt, StreamExt};
let snapshot = session.scrollback_snapshot();
if !snapshot.is_empty() {
tracing::debug!(
"WS: replaying {} bytes of scrollback for session {}",
snapshot.len(),
session_id
);
if ws_tx
.send(Message::Binary(snapshot.to_vec().into()))
.await
.is_err()
{
return;
}
}
loop {
tokio::select! {
result = output_rx.recv() => {
match result {
Ok(data) => {
if ws_tx.send(Message::Binary(data.to_vec().into())).await.is_err() {
break; }
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::debug!("WS: lagged {} messages for session {}", n, session_id);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
tracing::debug!("WS: PTY output channel closed for session {}", session_id);
break;
}
}
}
result = ws_rx.next() => {
match result {
Some(Ok(Message::Binary(data))) => {
if let Err(e) = session.write_input(&data) {
tracing::warn!("WS: PTY write error: {}", e);
break;
}
}
Some(Ok(Message::Text(text))) => {
if let Err(e) = handle_control_message(&text, &session) {
tracing::debug!("WS: control message error: {}", e);
}
}
Some(Ok(Message::Close(_))) | None => {
break; }
Some(Ok(Message::Ping(data))) => {
let _ = ws_tx.send(Message::Pong(data)).await;
}
Some(Ok(_)) => {} Some(Err(e)) => {
tracing::debug!("WS: receive error: {}", e);
break;
}
}
}
}
}
tracing::debug!("WS: disconnected from PTY session {}", session_id);
}
#[derive(serde::Deserialize)]
struct ControlMessage {
#[serde(rename = "type")]
msg_type: String,
#[serde(default)]
cols: u16,
#[serde(default)]
rows: u16,
}
fn handle_control_message(text: &str, session: &tmai_core::pty::PtySession) -> anyhow::Result<()> {
let msg: ControlMessage = serde_json::from_str(text)?;
match msg.msg_type.as_str() {
"resize" => {
if msg.cols > 0 && msg.rows > 0 {
session.resize(msg.rows, msg.cols)?;
tracing::debug!("WS: resized PTY to {}x{}", msg.cols, msg.rows);
}
}
other => {
tracing::debug!("WS: unknown control message type: {}", other);
}
}
Ok(())
}