use axum::{
extract::{
Query, State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
http::StatusCode,
response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
use serde::Deserialize;
use std::io::{Read, Write};
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::{error, info, warn};
use crate::api::AppState;
#[derive(Debug, Deserialize)]
pub struct WsAuthQuery {
pub token: Option<String>,
}
pub async fn terminal_ws(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Query(query): Query<WsAuthQuery>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let token = query.token.ok_or((
StatusCode::UNAUTHORIZED,
"Missing token parameter".to_string(),
))?;
let _claims = state
.jwt
.validate_access_token(&token)
.map_err(|e| (StatusCode::UNAUTHORIZED, format!("Invalid token: {}", e)))?;
Ok(ws.on_upgrade(handle_terminal))
}
async fn handle_terminal(socket: WebSocket) {
let (mut ws_sender, mut ws_receiver) = socket.split();
let pty_system = native_pty_system();
let pair = match pty_system.openpty(PtySize {
rows: 24,
cols: 80,
pixel_width: 0,
pixel_height: 0,
}) {
Ok(pair) => pair,
Err(e) => {
error!("Failed to create PTY: {}", e);
let _ = ws_sender
.send(Message::Text(format!(
"\r\nError: Failed to create PTY: {}\r\n",
e
)))
.await;
return;
}
};
let shell = if cfg!(target_os = "windows") {
"powershell.exe".to_string()
} else {
std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string())
};
info!("Starting PTY terminal session with shell: {}", shell);
let mut cmd = CommandBuilder::new(&shell);
cmd.env("TERM", "xterm-256color");
let mut child = match pair.slave.spawn_command(cmd) {
Ok(child) => child,
Err(e) => {
error!("Failed to spawn shell: {}", e);
let _ = ws_sender
.send(Message::Text(format!(
"\r\nError: Failed to spawn shell: {}\r\n",
e
)))
.await;
return;
}
};
let mut reader = match pair.master.try_clone_reader() {
Ok(r) => r,
Err(e) => {
error!("Failed to get PTY reader: {}", e);
return;
}
};
let pty_writer: Box<dyn Write + Send> = match pair.master.take_writer() {
Ok(w) => w,
Err(e) => {
error!("Failed to get PTY writer: {}", e);
return;
}
};
let master = Arc::new(std::sync::Mutex::new(pair.master));
let writer = Arc::new(std::sync::Mutex::new(pty_writer));
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(256);
let read_handle = std::thread::spawn(move || {
let mut buffer = [0u8; 4096];
loop {
match reader.read(&mut buffer) {
Ok(0) => {
break;
}
Ok(n) => {
if tx.blocking_send(buffer[..n].to_vec()).is_err() {
break;
}
}
Err(e) => {
if e.kind() != std::io::ErrorKind::Other {
warn!("PTY read error: {}", e);
}
break;
}
}
}
});
let send_task = tokio::spawn(async move {
while let Some(data) = rx.recv().await {
let text = String::from_utf8_lossy(&data).to_string();
if ws_sender.send(Message::Text(text)).await.is_err() {
break;
}
}
});
let master_clone = master.clone();
let writer_clone = writer.clone();
let recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = ws_receiver.next().await {
match msg {
Message::Text(text) => {
if text.starts_with("\x1b[8;") {
if let Some(size) = parse_resize_sequence(&text) {
if let Ok(master) = master_clone.lock() {
let _ = master.resize(size);
}
continue;
}
}
if let Ok(mut pty_writer) = writer_clone.lock() {
if pty_writer.write_all(text.as_bytes()).is_err() {
break;
}
}
}
Message::Binary(data) => {
if let Ok(mut pty_writer) = writer_clone.lock() {
if pty_writer.write_all(&data).is_err() {
break;
}
}
}
Message::Close(_) => break,
_ => {}
}
}
});
tokio::select! {
_ = send_task => {
info!("Send task ended");
}
_ = recv_task => {
info!("Recv task ended");
}
_ = tokio::task::spawn_blocking(move || {
let _ = child.wait();
}) => {
info!("Shell process ended");
}
}
drop(writer);
drop(master);
let _ = read_handle.join();
info!("Terminal session ended");
}
fn parse_resize_sequence(s: &str) -> Option<PtySize> {
if !s.starts_with("\x1b[8;") || !s.ends_with('t') {
return None;
}
let inner = &s[4..s.len() - 1]; let parts: Vec<&str> = inner.split(';').collect();
if parts.len() != 2 {
return None;
}
let rows: u16 = parts[0].parse().ok()?;
let cols: u16 = parts[1].parse().ok()?;
Some(PtySize {
rows,
cols,
pixel_width: 0,
pixel_height: 0,
})
}
pub async fn terminal_info(
State(state): State<AppState>,
Query(query): Query<WsAuthQuery>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let token = query.token.ok_or((
StatusCode::UNAUTHORIZED,
"Missing token parameter".to_string(),
))?;
let _claims = state
.jwt
.validate_access_token(&token)
.map_err(|e| (StatusCode::UNAUTHORIZED, format!("Invalid token: {}", e)))?;
let shell = if cfg!(target_os = "windows") {
"powershell.exe".to_string()
} else {
std::env::var("SHELL").unwrap_or_else(|_| "/bin/bash".to_string())
};
Ok(axum::Json(serde_json::json!({
"available": true,
"shell": shell,
"features": ["pty", "resize", "colors"],
})))
}