1use crate::cc_tasks::{CCTask, CCTaskChangeEvent, CCTasksOverview, CCTasksWatcher, WatcherEvent};
14use crate::pty::{PTYManager, SessionEvent, SessionState};
15use futures_util::{SinkExt, StreamExt};
16use serde::{Deserialize, Serialize};
17use std::net::SocketAddr;
18use std::sync::{Arc, Mutex as StdMutex};
19use tokio::net::{TcpListener, TcpStream};
20use tokio::sync::{broadcast, Mutex};
21use tokio_tungstenite::tungstenite::handshake::server::{Request as WsRequest, Response as WsResponse};
22use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
23use tokio_tungstenite::tungstenite::protocol::CloseFrame;
24use tokio_tungstenite::{accept_hdr_async, tungstenite::Message};
25use tracing::{error, info, warn};
26
27pub struct WSServerOptions {
29 pub port: u16,
31 pub pty_manager: Option<Arc<PTYManager>>,
33 pub cc_tasks_watcher: Option<Arc<Mutex<CCTasksWatcher>>>,
35}
36
37pub struct PTYWebSocketServer {
39 port: u16,
40 pty_manager: Option<Arc<PTYManager>>,
41 cc_tasks_watcher: Option<Arc<Mutex<CCTasksWatcher>>>,
42 shutdown_tx: Option<broadcast::Sender<()>>,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46enum Route<'a> {
47 Pty { slot_id: &'a str },
48 Tasks,
49 Invalid,
50}
51
52fn parse_route(path: &str) -> Route<'_> {
53 if path == "/tasks" {
54 return Route::Tasks;
55 }
56 if let Some(slot_id) = path.strip_prefix("/pty/") {
57 if !slot_id.is_empty() && !slot_id.contains('/') {
58 return Route::Pty { slot_id };
59 }
60 }
61 Route::Invalid
62}
63
64fn close_frame(code: u16, reason: impl Into<String>) -> CloseFrame<'static> {
65 CloseFrame {
66 code: CloseCode::from(code),
67 reason: reason.into().into(),
68 }
69}
70
71async fn send_json<S: Serialize>(
72 ws_tx: &mut futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<TcpStream>, Message>,
73 msg: &S,
74) -> anyhow::Result<()> {
75 let text = serde_json::to_string(msg)?;
76 ws_tx.send(Message::Text(text)).await?;
77 Ok(())
78}
79
80#[derive(Debug, Serialize)]
82#[serde(tag = "type", rename_all = "lowercase")]
83enum PtyOutMessage {
84 Screen { data: String },
85 Data { data: String },
86 State {
87 state: SessionState,
88 #[serde(rename = "prevState")]
89 prev_state: SessionState,
90 },
91 Exit { code: i32 },
92}
93
94#[derive(Debug, Deserialize)]
96#[serde(tag = "type", rename_all = "lowercase")]
97enum PtyInMessage {
98 Input { data: String },
99}
100
101#[derive(Debug, Serialize)]
103#[serde(tag = "type", rename_all = "snake_case")]
104enum TasksEventMessage {
105 CcTasksOverview { payload: CCTasksOverview },
106 CcTasksChanged { payload: CCTaskChangeEvent },
107 CcTaskStarted { payload: TaskEventPayload },
108 CcTaskCompleted { payload: TaskEventPayload },
109 CcSessionActive { payload: SessionEventPayload },
110 CcSessionInactive { payload: SessionEventPayload },
111}
112
113#[derive(Debug, Serialize)]
114#[serde(rename_all = "camelCase")]
115struct TaskEventPayload {
116 session_id: String,
117 project_name: String,
118 task: CCTask,
119}
120
121#[derive(Debug, Serialize)]
122#[serde(rename_all = "camelCase")]
123struct SessionEventPayload {
124 session_id: String,
125 project_name: String,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 summary: Option<String>,
128}
129
130impl PTYWebSocketServer {
131 pub fn new(options: WSServerOptions) -> Self {
133 Self {
134 port: options.port,
135 pty_manager: options.pty_manager,
136 cc_tasks_watcher: options.cc_tasks_watcher,
137 shutdown_tx: None,
138 }
139 }
140
141 pub async fn start(&mut self) -> anyhow::Result<()> {
143 let addr = format!("0.0.0.0:{}", self.port);
144 let listener = TcpListener::bind(&addr).await?;
145
146 info!(port = self.port, "PTY WebSocket server started");
147
148 let (shutdown_tx, _) = broadcast::channel::<()>(1);
149 self.shutdown_tx = Some(shutdown_tx.clone());
150
151 let pty_manager = self.pty_manager.clone();
152 let cc_tasks_watcher = self.cc_tasks_watcher.clone();
153
154 tokio::spawn(async move {
155 let mut shutdown_rx = shutdown_tx.subscribe();
156 loop {
157 tokio::select! {
158 result = listener.accept() => {
159 match result {
160 Ok((stream, addr)) => {
161 let pty_manager = pty_manager.clone();
162 let cc_tasks_watcher = cc_tasks_watcher.clone();
163 tokio::spawn(async move {
164 if let Err(e) = Self::handle_connection(stream, addr, pty_manager, cc_tasks_watcher).await {
165 error!(?e, ?addr, "WebSocket connection error");
166 }
167 });
168 }
169 Err(e) => {
170 error!(?e, "Failed to accept connection");
171 }
172 }
173 }
174 _ = shutdown_rx.recv() => {
175 info!("WebSocket server shutting down");
176 break;
177 }
178 }
179 }
180 });
181
182 Ok(())
183 }
184
185 pub async fn stop(&mut self) {
187 if let Some(tx) = self.shutdown_tx.take() {
188 let _ = tx.send(());
189 }
190 info!("PTY WebSocket server stopped");
191 }
192
193 async fn handle_connection(
194 stream: TcpStream,
195 addr: SocketAddr,
196 pty_manager: Option<Arc<PTYManager>>,
197 cc_tasks_watcher: Option<Arc<Mutex<CCTasksWatcher>>>,
198 ) -> anyhow::Result<()> {
199 let path_cell = Arc::new(StdMutex::new(String::new()));
201 let path_cell2 = Arc::clone(&path_cell);
202
203 let ws_stream = accept_hdr_async(stream, move |req: &WsRequest, resp: WsResponse| {
204 if let Ok(mut path) = path_cell2.lock() {
205 *path = req.uri().path().to_string();
206 }
207 Ok(resp)
208 })
209 .await?;
210
211 let path = path_cell
212 .lock()
213 .map(|p| p.clone())
214 .unwrap_or_else(|_| "/".to_string());
215
216 match parse_route(&path) {
217 Route::Tasks => Self::handle_tasks_subscription(addr, ws_stream, cc_tasks_watcher).await,
218 Route::Pty { slot_id } => {
219 Self::handle_pty_subscription(addr, ws_stream, pty_manager, slot_id).await
220 }
221 Route::Invalid => {
222 let (mut ws_tx, _ws_rx) = ws_stream.split();
223 let _ = ws_tx
224 .send(Message::Close(Some(close_frame(
225 4000,
226 "Invalid URL. Use /pty/<slotId> or /tasks",
227 ))))
228 .await;
229 warn!(?addr, %path, "Invalid WebSocket URL");
230 Ok(())
231 }
232 }
233 }
234
235 async fn handle_pty_subscription(
236 addr: SocketAddr,
237 ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
238 pty_manager: Option<Arc<PTYManager>>,
239 slot_id: &str,
240 ) -> anyhow::Result<()> {
241 let pty_manager = match pty_manager {
242 Some(pm) => pm,
243 None => {
244 let (mut ws_tx, _ws_rx) = ws_stream.split();
245 let _ = ws_tx
246 .send(Message::Close(Some(close_frame(
247 4000,
248 "PTY manager not available",
249 ))))
250 .await;
251 warn!(?addr, "PTY manager not available");
252 return Ok(());
253 }
254 };
255
256 let status = pty_manager.get_status(slot_id).await;
258 if status
259 .as_ref()
260 .map(|s| s.state == SessionState::Exited)
261 .unwrap_or(true)
262 {
263 let (mut ws_tx, _ws_rx) = ws_stream.split();
264 let _ = ws_tx
265 .send(Message::Close(Some(close_frame(
266 4001,
267 format!("PTY session not found: {}", slot_id),
268 ))))
269 .await;
270 warn!(?addr, slot_id, "PTY session not found");
271 return Ok(());
272 }
273
274 let (mut ws_tx, mut ws_rx) = ws_stream.split();
275
276 info!(?addr, slot_id, "Client attached to PTY");
277
278 if let Ok(screen) = pty_manager.get_screen(slot_id).await {
280 let msg = PtyOutMessage::Screen { data: screen };
281 let _ = send_json(&mut ws_tx, &msg).await;
282 }
283
284 let mut session_rx = match pty_manager.subscribe_session(slot_id).await {
286 Ok(rx) => rx,
287 Err(e) => {
288 warn!(?addr, slot_id, error = %e, "Cannot subscribe to PTY events");
289 let _ = ws_tx
290 .send(Message::Close(Some(close_frame(
291 4002,
292 format!("Cannot attach to PTY: {}", slot_id),
293 ))))
294 .await;
295 return Ok(());
296 }
297 };
298
299 loop {
300 tokio::select! {
301 evt = session_rx.recv() => {
303 let evt = match evt {
304 Ok(e) => e,
305 Err(_) => break,
306 };
307
308 match evt {
309 SessionEvent::Data(bytes) => {
310 let data = String::from_utf8_lossy(&bytes).to_string();
311 let msg = PtyOutMessage::Data { data };
312 if send_json(&mut ws_tx, &msg).await.is_err() {
313 break;
314 }
315 }
316 SessionEvent::StateChange { new_state, prev_state } => {
317 let msg = PtyOutMessage::State { state: new_state, prev_state };
318 if send_json(&mut ws_tx, &msg).await.is_err() {
319 break;
320 }
321 }
322 SessionEvent::Exit(code) => {
323 let msg = PtyOutMessage::Exit { code };
324 let _ = send_json(&mut ws_tx, &msg).await;
325 let _ = ws_tx.send(Message::Close(Some(close_frame(4003, format!("PTY exited with code {}", code))))).await;
326 break;
327 }
328 _ => {}
329 }
330 }
331
332 msg = ws_rx.next() => {
334 match msg {
335 Some(Ok(Message::Text(text))) => {
336 if let Ok(input) = serde_json::from_str::<PtyInMessage>(&text) {
337 match input {
338 PtyInMessage::Input { data } => {
339 let _ = pty_manager.write(slot_id, &data).await;
340 }
341 }
342 } else {
343 let _ = pty_manager.write(slot_id, &text).await;
345 }
346 }
347 Some(Ok(Message::Binary(data))) => {
348 let text = String::from_utf8_lossy(&data).to_string();
349 let _ = pty_manager.write(slot_id, &text).await;
350 }
351 Some(Ok(Message::Close(_))) | None => break,
352 Some(Err(e)) => {
353 warn!(?addr, slot_id, error = %e, "WebSocket error");
354 break;
355 }
356 _ => {}
357 }
358 }
359 }
360 }
361
362 info!(?addr, slot_id, "Client disconnected from PTY");
363 Ok(())
364 }
365
366 async fn handle_tasks_subscription(
367 addr: SocketAddr,
368 ws_stream: tokio_tungstenite::WebSocketStream<TcpStream>,
369 cc_tasks_watcher: Option<Arc<Mutex<CCTasksWatcher>>>,
370 ) -> anyhow::Result<()> {
371 let watcher = match cc_tasks_watcher {
372 Some(w) => w,
373 None => {
374 let (mut ws_tx, _ws_rx) = ws_stream.split();
375 let _ = ws_tx
376 .send(Message::Close(Some(close_frame(
377 4000,
378 "CC Tasks watcher not available",
379 ))))
380 .await;
381 warn!(?addr, "CC Tasks watcher not available");
382 return Ok(());
383 }
384 };
385
386 let (mut ws_tx, mut ws_rx) = ws_stream.split();
387
388 info!(?addr, "Client subscribing to Tasks events");
389
390 let overview = watcher.lock().await.get_overview().await;
392 let msg = TasksEventMessage::CcTasksOverview { payload: overview };
393 let _ = send_json(&mut ws_tx, &msg).await;
394
395 let mut events_rx = watcher.lock().await.subscribe();
397
398 loop {
399 tokio::select! {
400 event = events_rx.recv() => {
401 let event = match event {
402 Ok(e) => e,
403 Err(_) => break,
404 };
405
406 let msg = match event {
407 WatcherEvent::TasksChanged(e) => TasksEventMessage::CcTasksChanged { payload: e },
408 WatcherEvent::TaskStarted { session, task } => TasksEventMessage::CcTaskStarted {
409 payload: TaskEventPayload {
410 session_id: session.session_id,
411 project_name: session.project_name,
412 task,
413 }
414 },
415 WatcherEvent::TaskCompleted { session, task } => TasksEventMessage::CcTaskCompleted {
416 payload: TaskEventPayload {
417 session_id: session.session_id,
418 project_name: session.project_name,
419 task,
420 }
421 },
422 WatcherEvent::SessionActive(session) => TasksEventMessage::CcSessionActive {
423 payload: SessionEventPayload {
424 session_id: session.session_id,
425 project_name: session.project_name,
426 summary: Some(session.summary),
427 }
428 },
429 WatcherEvent::SessionInactive(session) => TasksEventMessage::CcSessionInactive {
430 payload: SessionEventPayload {
431 session_id: session.session_id,
432 project_name: session.project_name,
433 summary: None,
434 }
435 },
436 };
437
438 if send_json(&mut ws_tx, &msg).await.is_err() {
439 break;
440 }
441 }
442
443 msg = ws_rx.next() => {
444 match msg {
445 Some(Ok(Message::Close(_))) | None => break,
446 Some(Err(e)) => {
447 warn!(?addr, error = %e, "WebSocket error");
448 break;
449 }
450 _ => {}
451 }
452 }
453 }
454 }
455
456 info!(?addr, "Client unsubscribed from Tasks events");
457 Ok(())
458 }
459}