Skip to main content

cvkg_cli/
ws_server.rs

1//! WebSocket Server
2//! Multiplexed WebSocket server for runtime communication, DevTools, hot reload, and agent streams
3
4use axum::{
5    Router,
6    extract::State,
7    extract::ws::{Message, WebSocket, WebSocketUpgrade},
8    response::IntoResponse,
9    routing::get,
10};
11use futures_util::StreamExt;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use tokio::sync::broadcast;
15use tracing::{debug, error, info, warn};
16
17use serde::{Deserialize, Serialize};
18
19use crate::patch_engine::{PatchEngine, RuntimePatch};
20
21/// Shared application state for the WebSocket server
22#[derive(Clone)]
23pub struct AppState {
24    pub patch_tx: broadcast::Sender<WsMessage>,
25    pub patch_engine: Arc<std::sync::Mutex<PatchEngine>>,
26}
27
28/// WebSocket message protocol between CLI dev server and connected clients.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(tag = "type", rename_all = "snake_case")]
31pub enum WsMessage {
32    /// Apply a hot-reload patch to the running application.
33    Patch(RuntimePatch),
34    /// Full state snapshot request/response.
35    State(crate::dev_runtime::RuntimeStateSnapshot),
36    /// Agent or runtime event.
37    Event(crate::dev_runtime::RuntimeEvent),
38    /// DevTools message (bidirectional).
39    Devtools(DevtoolsMessage),
40    /// Handshake response sent to new clients.
41    Handshake {
42        client: String,
43        capabilities: Vec<String>,
44    },
45}
46
47/// DevTool command types (bidirectional: client → server and server → client).
48#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(untagged)]
50pub enum DevtoolsMessage {
51    /// Client-side command.
52    Command(DevtoolsCommand),
53    /// Server-side response/event.
54    Response(serde_json::Value),
55}
56
57/// DevTools command types (client → server).
58#[derive(Debug, Clone, Serialize, Deserialize)]
59#[serde(tag = "command", rename_all = "snake_case")]
60pub enum DevtoolsCommand {
61    /// Request current performance metrics.
62    QueryMetrics,
63    /// Toggle the error overlay.
64    ToggleOverlay { show: bool },
65    /// Request the current scene graph.
66    QueryGraph,
67    /// Query accessibility properties for a given component path.
68    QueryAccessibility {
69        /// Dot-separated component path (e.g., "root.main.content.button-1").
70        path: String,
71    },
72    /// Echo for health checking.
73    Ping,
74}
75
76/// WebSocket handler for runtime communication
77async fn runtime_ws(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
78    ws.on_upgrade(move |socket| handle_runtime_socket(socket, state))
79}
80
81/// WebSocket handler for DevTools
82async fn devtools_ws(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
83    ws.on_upgrade(move |socket| handle_devtools_socket(socket, state))
84}
85
86/// WebSocket handler for hot reload
87async fn hotreload_ws(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
88    ws.on_upgrade(move |socket| handle_hotreload_socket(socket, state))
89}
90
91/// WebSocket handler for agent streams
92async fn agent_ws(State(state): State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
93    ws.on_upgrade(move |socket| handle_agent_socket(socket, state))
94}
95
96/// Send a JSON message over the WebSocket, logging errors.
97async fn send_ws(ws: &mut WebSocket, msg: &WsMessage) {
98    match serde_json::to_string(msg) {
99        Ok(json) => {
100            if let Err(e) = ws.send(Message::Text(json)).await {
101                error!("Failed to send WS message: {}", e);
102            }
103        }
104        Err(e) => error!("Failed to serialize WS message: {}", e),
105    }
106}
107
108/// Handle runtime WebSocket connection.
109///
110/// Processes incoming RuntimePatch, State, and Event messages from the runtime client.
111/// Forwards patches through the broadcast channel so hot-reload clients receive them.
112async fn handle_runtime_socket(mut ws: WebSocket, state: AppState) {
113    info!("Runtime WebSocket client connected");
114
115    // Send initial handshake
116    send_ws(
117        &mut ws,
118        &WsMessage::Handshake {
119            client: "runtime".to_string(),
120            capabilities: vec!["patch".into(), "state".into(), "event".into()],
121        },
122    )
123    .await;
124
125    while let Some(result) = ws.next().await {
126        match result {
127            Ok(Message::Text(text)) => {
128                match serde_json::from_str::<WsMessage>(&text) {
129                    Ok(WsMessage::Patch(patch)) => {
130                        info!(
131                            "Runtime patch received: {:?}",
132                            std::mem::discriminant(&patch)
133                        );
134                        // Forward patch to all hot-reload subscribers
135                        let _ = state.patch_tx.send(WsMessage::Patch(patch));
136                    }
137                    Ok(WsMessage::Event(event)) => {
138                        info!("Runtime event received: {:?}", event);
139                        let _ = state.patch_tx.send(WsMessage::Event(event));
140                    }
141                    Ok(WsMessage::State(_snapshot)) => {
142                        info!("Runtime state snapshot received");
143                    }
144                    Ok(other) => {
145                        warn!("Unexpected message type on runtime WS: {:?}", other);
146                    }
147                    Err(e) => {
148                        warn!("Failed to parse runtime message: {}", e);
149                    }
150                }
151            }
152            Ok(Message::Binary(bin)) => {
153                info!(
154                    "Received binary message of {} bytes on runtime WS",
155                    bin.len()
156                );
157            }
158            Ok(Message::Close(_)) => {
159                info!("Runtime WebSocket client disconnected");
160                break;
161            }
162            Err(e) => {
163                error!("Runtime WebSocket error: {}", e);
164                break;
165            }
166            _ => {}
167        }
168    }
169}
170
171/// Handle DevTools WebSocket connection.
172///
173/// Processes DevTools commands (QueryMetrics, ToggleOverlay, QueryGraph, Ping)
174/// and sends back appropriate responses.
175async fn handle_devtools_socket(mut ws: WebSocket, _state: AppState) {
176    info!("DevTools WebSocket client connected");
177
178    // Send initial handshake
179    send_ws(
180        &mut ws,
181        &WsMessage::Handshake {
182            client: "devtools".to_string(),
183            capabilities: vec!["metrics".into(), "overlay".into(), "graph".into()],
184        },
185    )
186    .await;
187
188    while let Some(result) = ws.next().await {
189        match result {
190            Ok(Message::Text(text)) => {
191                match serde_json::from_str::<DevtoolsCommand>(&text) {
192                    Ok(DevtoolsCommand::QueryMetrics) => {
193                        let metrics = crate::devtools::capture_metrics();
194                        let response = serde_json::json!({
195                            "type": "metrics",
196                            "fps": metrics.fps,
197                            "frame_time_ms": metrics.frame_time_ms,
198                            "node_count": metrics.node_count,
199                            "edge_count": metrics.edge_count,
200                            "gpu_memory_mb": metrics.gpu_memory_mb,
201                        });
202                        send_ws(
203                            &mut ws,
204                            &WsMessage::Devtools(DevtoolsMessage::Response(response)),
205                        )
206                        .await;
207                    }
208                    Ok(DevtoolsCommand::ToggleOverlay { show }) => {
209                        info!("DevTools overlay toggled: {}", show);
210                        let response = serde_json::json!({
211                            "type": "overlay_toggled",
212                            "show": show,
213                        });
214                        send_ws(
215                            &mut ws,
216                            &WsMessage::Devtools(DevtoolsMessage::Response(response)),
217                        )
218                        .await;
219                    }
220                    Ok(DevtoolsCommand::QueryGraph) => {
221                        // Return empty graph for now — populated by build pipeline
222                        let response = serde_json::json!({
223                            "type": "graph",
224                            "nodes": [],
225                            "edges": [],
226                        });
227                        send_ws(
228                            &mut ws,
229                            &WsMessage::Devtools(DevtoolsMessage::Response(response)),
230                        )
231                        .await;
232                    }
233                    Ok(DevtoolsCommand::QueryAccessibility { path }) => {
234                        // Query accessibility properties for the given component path.
235                        // In a real implementation, this would traverse the component tree
236                        // and return the AriaProperties for the matched component.
237                        let response = serde_json::json!({
238                            "type": "accessibility",
239                            "path": path,
240                            "properties": {
241                                "role": "button",
242                                "label": "Sample Button",
243                                "description": None::<String>,
244                                "disabled": false,
245                                "checked": None::<bool>,
246                                "expanded": None::<bool>,
247                                "hidden": false,
248                                "shortcut": None::<String>,
249                            },
250                        });
251                        send_ws(
252                            &mut ws,
253                            &WsMessage::Devtools(DevtoolsMessage::Response(response)),
254                        )
255                        .await;
256                    }
257                    Ok(DevtoolsCommand::Ping) => {
258                        let response = serde_json::json!({ "type": "pong" });
259                        send_ws(
260                            &mut ws,
261                            &WsMessage::Devtools(DevtoolsMessage::Response(response)),
262                        )
263                        .await;
264                    }
265                    Err(e) => {
266                        warn!("Failed to parse DevTools message: {}", e);
267                        let error = serde_json::json!({
268                            "type": "error",
269                            "message": format!("Invalid command: {}", e),
270                        });
271                        send_ws(
272                            &mut ws,
273                            &WsMessage::Devtools(DevtoolsMessage::Response(error)),
274                        )
275                        .await;
276                    }
277                }
278            }
279            Ok(Message::Close(_)) => {
280                info!("DevTools WebSocket client disconnected");
281                break;
282            }
283            Err(e) => {
284                error!("DevTools WebSocket error: {}", e);
285                break;
286            }
287            _ => {}
288        }
289    }
290}
291
292/// Handle hot reload WebSocket connection.
293///
294/// Broadcasts patches from the build pipeline to connected clients.
295async fn handle_hotreload_socket(mut ws: WebSocket, state: AppState) {
296    info!("Hot reload WebSocket client connected");
297
298    let mut patch_rx = state.patch_tx.subscribe();
299
300    // Send initial handshake
301    send_ws(
302        &mut ws,
303        &WsMessage::Handshake {
304            client: "hotreload".to_string(),
305            capabilities: vec!["patch".into()],
306        },
307    )
308    .await;
309
310    loop {
311        tokio::select! {
312            Ok(msg) = patch_rx.recv() => {
313                send_ws(&mut ws, &msg).await;
314            }
315            Some(result) = ws.next() => {
316                match result {
317                    Ok(Message::Close(_)) => {
318                        info!("Hot reload WebSocket client disconnected");
319                        break;
320                    }
321                    Err(e) => {
322                        error!("Hot reload WebSocket error: {}", e);
323                        break;
324                    }
325                    _ => {}
326                }
327            }
328        }
329    }
330}
331
332/// Handle agent stream WebSocket connection.
333///
334/// Receives AgentEvent messages and forwards them through the broadcast channel.
335async fn handle_agent_socket(mut ws: WebSocket, state: AppState) {
336    info!("Agent stream WebSocket client connected");
337
338    // Send initial handshake
339    send_ws(
340        &mut ws,
341        &WsMessage::Handshake {
342            client: "agent".to_string(),
343            capabilities: vec!["event".into()],
344        },
345    )
346    .await;
347
348    while let Some(result) = ws.next().await {
349        match result {
350            Ok(Message::Text(text)) => {
351                match serde_json::from_str::<crate::dev_runtime::AgentEvent>(&text) {
352                    Ok(event) => {
353                        let runtime_event = crate::dev_runtime::RuntimeEvent::Agent(event);
354                        let _ = state.patch_tx.send(WsMessage::Event(runtime_event));
355                    }
356                    Err(e) => {
357                        // Try parsing as a raw RuntimeEvent
358                        match serde_json::from_str::<crate::dev_runtime::RuntimeEvent>(&text) {
359                            Ok(event) => {
360                                let _ = state.patch_tx.send(WsMessage::Event(event));
361                            }
362                            Err(e2) => {
363                                warn!(
364                                    "Failed to parse agent message as AgentEvent ({}) or RuntimeEvent ({})",
365                                    e, e2
366                                );
367                            }
368                        }
369                    }
370                }
371            }
372            Ok(Message::Close(_)) => {
373                info!("Agent stream WebSocket client disconnected");
374                break;
375            }
376            Err(e) => {
377                error!("Agent stream WebSocket error: {}", e);
378                break;
379            }
380            _ => {}
381        }
382    }
383}
384
385/// Create the WebSocket router with all endpoints
386pub fn create_router(state: AppState) -> Router {
387    Router::new()
388        .route("/ws/runtime", get(runtime_ws))
389        .route("/ws/devtools", get(devtools_ws))
390        .route("/ws/hotreload", get(hotreload_ws))
391        .route("/ws/agent", get(agent_ws))
392        .route("/health", get(|| async { "OK" }))
393        .route("/", get(serve_shell))
394        .layer(tower_http::trace::TraceLayer::new_for_http())
395        .with_state(state)
396}
397
398/// Serve a minimal HTML shell that connects back via WebSocket.
399async fn serve_shell() -> impl IntoResponse {
400    axum::response::Html(
401        r#"<!DOCTYPE html>
402<html lang="en">
403<head>
404    <meta charset="UTF-8">
405    <meta name="viewport" content="width=device-width, initial-scale=1.0">
406    <title>CVKG Dev Server</title>
407    <style>
408        body { margin: 0; background: #0b0b14; color: #c0c0c8; font-family: 'JetBrains Mono', monospace; display: flex; align-items: center; justify-content: center; height: 100vh; }
409        .status { text-align: center; }
410        .status h1 { font-size: 24px; color: #00cccc; margin-bottom: 8px; }
411        .status p { font-size: 14px; color: #6a6a8a; }
412        .status .indicator { display: inline-block; width: 8px; height: 8px; border-radius: 50%; background: #4a8a4a; margin-right: 6px; }
413    </style>
414</head>
415<body>
416    <div class="status">
417        <h1>⚡ CVKG Dev Server</h1>
418        <p><span class="indicator"></span>Connected — WebSocket hot reload active</p>
419        <p style="margin-top: 16px; font-size: 12px;">Waiting for changes...</p>
420    </div>
421</body>
422</html>"#,
423    )
424}
425
426/// Path for the hot-reload state file.
427const HOT_RELOAD_STATE_PATH: &str = ".cvkg/hot_reload_state.json";
428
429/// Shared dashboard state, populated by the dev server and file watcher.
430pub type DashboardState = Arc<std::sync::Mutex<crate::devtools_dashboard::GraphState>>;
431
432/// Starts the file watcher and returns a broadcast sender for patches.
433pub fn start_file_watcher(
434    path: &str,
435    patch_engine: Arc<std::sync::Mutex<crate::patch_engine::PatchEngine>>,
436) -> broadcast::Sender<WsMessage> {
437    use crate::build_pipeline::BuildPipeline;
438
439    let (tx, _) = broadcast::channel(100);
440    let tx_clone = tx.clone();
441    let patch_engine = Arc::clone(&patch_engine);
442    // Ensure the .cvkg directory exists for state persistence
443    let _ = std::fs::create_dir_all(".cvkg");
444
445    BuildPipeline::watch_changes(path, move |artifact| {
446        // Update live metrics for the dashboard from the shared state
447        if let Some(ds) = crate::devtools_dashboard::dashboard_state() {
448            let guard = ds.lock().unwrap_or_else(|e| e.into_inner());
449            crate::devtools::update_metrics(crate::devtools::PerfMetrics {
450                frame_time_ms: guard.frame_time_ms,
451                fps: if guard.frame_time_ms > 0.0 {
452                    1000.0 / guard.frame_time_ms
453                } else {
454                    0.0
455                },
456                node_count: guard.nodes.len(),
457                edge_count: guard.edges.len(),
458                gpu_memory_mb: guard.gpu_memory_mb,
459            });
460        }
461
462        // Save hot-reload state before applying the patch
463        let state = crate::dev_runtime::HotReloadState {
464            theme_mode: "dark".to_string(),
465            window_size: (1200.0, 800.0),
466            scroll_positions: std::collections::HashMap::new(),
467            input_text: std::collections::HashMap::new(),
468            expanded_nodes: std::collections::HashMap::new(),
469            saved_at: std::time::SystemTime::now()
470                .duration_since(std::time::UNIX_EPOCH)
471                .unwrap_or_default()
472                .as_secs_f64(),
473        };
474        if let Err(e) = state.save(std::path::Path::new(HOT_RELOAD_STATE_PATH)) {
475            warn!("Failed to save hot-reload state: {}", e);
476        }
477
478        let mut engine = match patch_engine.lock() {
479            Ok(guard) => guard,
480            Err(poisoned) => poisoned.into_inner(),
481        };
482        let patch = engine.generate_patch(artifact);
483        let _ = tx_clone.send(WsMessage::Patch(patch));
484    });
485
486    // Attempt to load any previously saved state
487    if std::path::Path::new(HOT_RELOAD_STATE_PATH).exists() {
488        match crate::dev_runtime::HotReloadState::load(std::path::Path::new(HOT_RELOAD_STATE_PATH))
489        {
490            Ok(state) => {
491                info!(
492                    "Loaded hot-reload state from {} (theme: {}, saved_at: {})",
493                    HOT_RELOAD_STATE_PATH, state.theme_mode, state.saved_at
494                );
495            }
496            Err(e) => {
497                debug!("No previous hot-reload state found: {}", e);
498            }
499        }
500    }
501
502    tx
503}
504
505/// Start the WebSocket server with graceful shutdown.
506pub async fn start_server(addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
507    let patch_engine = Arc::new(std::sync::Mutex::new(PatchEngine::new()));
508    let patch_tx = start_file_watcher(".", Arc::clone(&patch_engine));
509
510    let state = AppState {
511        patch_tx: patch_tx.clone(),
512        patch_engine: Arc::clone(&patch_engine),
513    };
514
515    let app = create_router(state);
516    info!("Starting WebSocket server on {} (Ctrl+C to stop)", addr);
517
518    // Spawn animation tick task
519    let animation_handle = tokio::spawn(async move {
520        let mut interval = tokio::time::interval(std::time::Duration::from_millis(16)); // ~60fps
521        let mut solver =
522            cvkg_anim::SleipnirSolver::new(cvkg_anim::SleipnirParams::default(), 0.0, 0.0);
523        let mut physics_world =
524            cvkg_physics::PhysicsWorld::new(cvkg_physics::WorldConfig::default());
525        loop {
526            interval.tick().await;
527            let dt = 0.016;
528            // Tick the animation solver
529            let _value = solver.tick(dt);
530            // Tick the physics world
531            physics_world.step(dt);
532        }
533    });
534
535    let listener = tokio::net::TcpListener::bind(addr).await?;
536    axum::serve(listener, app)
537        .with_graceful_shutdown(shutdown_signal())
538        .await?;
539
540    animation_handle.abort();
541    info!("CVKG dev server shut down gracefully.");
542    Ok(())
543}
544
545/// Wait for Ctrl+C or SIGTERM.
546async fn shutdown_signal() {
547    let ctrl_c = async {
548        tokio::signal::ctrl_c()
549            .await
550            .expect("failed to install Ctrl+C handler");
551    };
552
553    #[cfg(unix)]
554    let terminate = async {
555        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
556            .expect("failed to install signal handler")
557            .recv()
558            .await;
559    };
560
561    #[cfg(not(unix))]
562    let terminate = std::future::pending::<()>();
563
564    tokio::select! {
565        _ = ctrl_c => {
566            info!("Ctrl+C received, shutting down gracefully...");
567        },
568        _ = terminate => {
569            info!("SIGTERM received, shutting down gracefully...");
570        },
571    }
572}