use crate::cc_tasks::{CCTask, CCTaskChangeEvent, CCTasksOverview, CCTasksWatcher, WatcherEvent};
use crate::pty::{PTYManager, SessionEvent, SessionState};
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex as StdMutex};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, Mutex};
use tokio_tungstenite::tungstenite::handshake::server::{Request as WsRequest, Response as WsResponse};
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
use tokio_tungstenite::{accept_hdr_async, tungstenite::Message};
use tracing::{error, info, warn};
pub struct WSServerOptions {
pub port: u16,
pub pty_manager: Option<Arc<PTYManager>>,
pub cc_tasks_watcher: Option<Arc<Mutex<CCTasksWatcher>>>,
}
pub struct PTYWebSocketServer {
port: u16,
pty_manager: Option<Arc<PTYManager>>,
cc_tasks_watcher: Option<Arc<Mutex<CCTasksWatcher>>>,
shutdown_tx: Option<broadcast::Sender<()>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Route<'a> {
Pty { slot_id: &'a str },
Tasks,
Invalid,
}
fn parse_route(path: &str) -> Route<'_> {
if path == "/tasks" {
return Route::Tasks;
}
if let Some(slot_id) = path.strip_prefix("/pty/") {
if !slot_id.is_empty() && !slot_id.contains('/') {
return Route::Pty { slot_id };
}
}
Route::Invalid
}
fn close_frame(code: u16, reason: impl Into<String>) -> CloseFrame<'static> {
CloseFrame {
code: CloseCode::from(code),
reason: reason.into().into(),
}
}
async fn send_json<S: Serialize>(
ws_tx: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<TcpStream>, Message>,
msg: &S,
) -> anyhow::Result<()> {
let text = serde_json::to_string(msg)?;
ws_tx.send(Message::Text(text)).await?;
Ok(())
}
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "lowercase")]
enum PtyOutMessage {
Screen { data: String },
Data { data: String },
State {
state: SessionState,
#[serde(rename = "prevState")]
prev_state: SessionState,
},
Exit { code: i32 },
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
enum PtyInMessage {
Input { data: String },
}
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum TasksEventMessage {
CcTasksOverview { payload: CCTasksOverview },
CcTasksChanged { payload: CCTaskChangeEvent },
CcTaskStarted { payload: TaskEventPayload },
CcTaskCompleted { payload: TaskEventPayload },
CcSessionActive { payload: SessionEventPayload },
CcSessionInactive { payload: SessionEventPayload },
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct TaskEventPayload {
session_id: String,
project_name: String,
task: CCTask,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct SessionEventPayload {
session_id: String,
project_name: String,
#[serde(skip_serializing_if = "Option::is_none")]
summary: Option<String>,
}
impl PTYWebSocketServer {
pub fn new(options: WSServerOptions) -> Self {
Self {
port: options.port,
pty_manager: options.pty_manager,
cc_tasks_watcher: options.cc_tasks_watcher,
shutdown_tx: None,
}
}
pub async fn start(&mut self) -> anyhow::Result<()> {
let addr = format!("0.0.0.0:{}", self.port);
let listener = TcpListener::bind(&addr).await?;
info!(port = self.port, "PTY WebSocket server started");
let (shutdown_tx, _) = broadcast::channel::<()>(1);
self.shutdown_tx = Some(shutdown_tx.clone());
let pty_manager = self.pty_manager.clone();
let cc_tasks_watcher = self.cc_tasks_watcher.clone();
tokio::spawn(async move {
let mut shutdown_rx = shutdown_tx.subscribe();
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, addr)) => {
let pty_manager = pty_manager.clone();
let cc_tasks_watcher = cc_tasks_watcher.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(stream, addr, pty_manager, cc_tasks_watcher).await {
error!(?e, ?addr, "WebSocket connection error");
}
});
}
Err(e) => {
error!(?e, "Failed to accept connection");
}
}
}
_ = shutdown_rx.recv() => {
info!("WebSocket server shutting down");
break;
}
}
}
});
Ok(())
}
pub async fn stop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
info!("PTY WebSocket server stopped");
}
async fn handle_connection(
stream: TcpStream,
addr: SocketAddr,
pty_manager: Option<Arc<PTYManager>>,
cc_tasks_watcher: Option<Arc<Mutex<CCTasksWatcher>>>,
) -> anyhow::Result<()> {
let path_cell = Arc::new(StdMutex::new(String::new()));
let path_cell2 = Arc::clone(&path_cell);
let ws_stream = accept_hdr_async(stream, move |req: &WsRequest, resp: WsResponse| {
if let Ok(mut path) = path_cell2.lock() {
*path = req.uri().path().to_string();
}
Ok(resp)
})
.await?;
let path = path_cell
.lock()
.map(|p| p.clone())
.unwrap_or_else(|_| "/".to_string());
match parse_route(&path) {
Route::Tasks => Self::handle_tasks_subscription(addr, ws_stream, cc_tasks_watcher).await,
Route::Pty { slot_id } => {
Self::handle_pty_subscription(addr, ws_stream, pty_manager, slot_id).await
}
Route::Invalid => {
let (mut ws_tx, _ws_rx) = ws_stream.split();
let _ = ws_tx
.send(Message::Close(Some(close_frame(
4000,
"Invalid URL. Use /pty/<slotId> or /tasks",
))))
.await;
warn!(?addr, %path, "Invalid WebSocket URL");
Ok(())
}
}
}
async fn handle_pty_subscription(
addr: SocketAddr,
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
pty_manager: Option<Arc<PTYManager>>,
slot_id: &str,
) -> anyhow::Result<()> {
let pty_manager = match pty_manager {
Some(pm) => pm,
None => {
let (mut ws_tx, _ws_rx) = ws_stream.split();
let _ = ws_tx
.send(Message::Close(Some(close_frame(
4000,
"PTY manager not available",
))))
.await;
warn!(?addr, "PTY manager not available");
return Ok(());
}
};
let status = pty_manager.get_status(slot_id).await;
if status
.as_ref()
.map(|s| s.state == SessionState::Exited)
.unwrap_or(true)
{
let (mut ws_tx, _ws_rx) = ws_stream.split();
let _ = ws_tx
.send(Message::Close(Some(close_frame(
4001,
format!("PTY session not found: {}", slot_id),
))))
.await;
warn!(?addr, slot_id, "PTY session not found");
return Ok(());
}
let (mut ws_tx, mut ws_rx) = ws_stream.split();
info!(?addr, slot_id, "Client attached to PTY");
if let Ok(screen) = pty_manager.get_screen(slot_id).await {
let msg = PtyOutMessage::Screen { data: screen };
let _ = send_json(&mut ws_tx, &msg).await;
}
let mut session_rx = match pty_manager.subscribe_session(slot_id).await {
Ok(rx) => rx,
Err(e) => {
warn!(?addr, slot_id, error = %e, "Cannot subscribe to PTY events");
let _ = ws_tx
.send(Message::Close(Some(close_frame(
4002,
format!("Cannot attach to PTY: {}", slot_id),
))))
.await;
return Ok(());
}
};
loop {
tokio::select! {
evt = session_rx.recv() => {
let evt = match evt {
Ok(e) => e,
Err(_) => break,
};
match evt {
SessionEvent::Data(bytes) => {
let data = String::from_utf8_lossy(&bytes).to_string();
let msg = PtyOutMessage::Data { data };
if send_json(&mut ws_tx, &msg).await.is_err() {
break;
}
}
SessionEvent::StateChange { new_state, prev_state } => {
let msg = PtyOutMessage::State { state: new_state, prev_state };
if send_json(&mut ws_tx, &msg).await.is_err() {
break;
}
}
SessionEvent::Exit(code) => {
let msg = PtyOutMessage::Exit { code };
let _ = send_json(&mut ws_tx, &msg).await;
let _ = ws_tx.send(Message::Close(Some(close_frame(4003, format!("PTY exited with code {}", code))))).await;
break;
}
_ => {}
}
}
msg = ws_rx.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
if let Ok(input) = serde_json::from_str::<PtyInMessage>(&text) {
match input {
PtyInMessage::Input { data } => {
let _ = pty_manager.write(slot_id, &data).await;
}
}
} else {
let _ = pty_manager.write(slot_id, &text).await;
}
}
Some(Ok(Message::Binary(data))) => {
let text = String::from_utf8_lossy(&data).to_string();
let _ = pty_manager.write(slot_id, &text).await;
}
Some(Ok(Message::Close(_))) | None => break,
Some(Err(e)) => {
warn!(?addr, slot_id, error = %e, "WebSocket error");
break;
}
_ => {}
}
}
}
}
info!(?addr, slot_id, "Client disconnected from PTY");
Ok(())
}
async fn handle_tasks_subscription(
addr: SocketAddr,
ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
cc_tasks_watcher: Option<Arc<Mutex<CCTasksWatcher>>>,
) -> anyhow::Result<()> {
let watcher = match cc_tasks_watcher {
Some(w) => w,
None => {
let (mut ws_tx, _ws_rx) = ws_stream.split();
let _ = ws_tx
.send(Message::Close(Some(close_frame(
4000,
"CC Tasks watcher not available",
))))
.await;
warn!(?addr, "CC Tasks watcher not available");
return Ok(());
}
};
let (mut ws_tx, mut ws_rx) = ws_stream.split();
info!(?addr, "Client subscribing to Tasks events");
let overview = watcher.lock().await.get_overview().await;
let msg = TasksEventMessage::CcTasksOverview { payload: overview };
let _ = send_json(&mut ws_tx, &msg).await;
let mut events_rx = watcher.lock().await.subscribe();
loop {
tokio::select! {
event = events_rx.recv() => {
let event = match event {
Ok(e) => e,
Err(_) => break,
};
let msg = match event {
WatcherEvent::TasksChanged(e) => TasksEventMessage::CcTasksChanged { payload: e },
WatcherEvent::TaskStarted { session, task } => TasksEventMessage::CcTaskStarted {
payload: TaskEventPayload {
session_id: session.session_id,
project_name: session.project_name,
task,
}
},
WatcherEvent::TaskCompleted { session, task } => TasksEventMessage::CcTaskCompleted {
payload: TaskEventPayload {
session_id: session.session_id,
project_name: session.project_name,
task,
}
},
WatcherEvent::SessionActive(session) => TasksEventMessage::CcSessionActive {
payload: SessionEventPayload {
session_id: session.session_id,
project_name: session.project_name,
summary: Some(session.summary),
}
},
WatcherEvent::SessionInactive(session) => TasksEventMessage::CcSessionInactive {
payload: SessionEventPayload {
session_id: session.session_id,
project_name: session.project_name,
summary: None,
}
},
};
if send_json(&mut ws_tx, &msg).await.is_err() {
break;
}
}
msg = ws_rx.next() => {
match msg {
Some(Ok(Message::Close(_))) | None => break,
Some(Err(e)) => {
warn!(?addr, error = %e, "WebSocket error");
break;
}
_ => {}
}
}
}
}
info!(?addr, "Client unsubscribed from Tasks events");
Ok(())
}
}