1use axum::{
5 extract::ws::{WebSocketUpgrade, WebSocket, Message},
6 response::IntoResponse, routing::get,
7 Router,
8};
9use futures_util::StreamExt;
10use std::net::SocketAddr;
11use tracing::{info, error};
12
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15use tokio::sync::broadcast;
16
17#[derive(Clone)]
19pub struct AppState {
20 pub patch_tx: broadcast::Sender<WsMessage>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub enum WsMessage {
26 Patch(super::patch_engine::RuntimePatch),
27 State(super::dev_runtime::RuntimeStateSnapshot),
28 Event(super::dev_runtime::RuntimeEvent),
29 Devtools(serde_json::Value),
30}
31
32async fn runtime_ws(axum::extract::State(_state): axum::extract::State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
34 ws.on_upgrade(handle_runtime_socket)
35}
36
37async fn devtools_ws(axum::extract::State(_state): axum::extract::State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
39 ws.on_upgrade(handle_devtools_socket)
40}
41
42async fn hotreload_ws(axum::extract::State(state): axum::extract::State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
44 ws.on_upgrade(move |socket| handle_hotreload_socket(socket, state))
45}
46
47async fn agent_ws(axum::extract::State(_state): axum::extract::State<AppState>, ws: WebSocketUpgrade) -> impl IntoResponse {
49 ws.on_upgrade(handle_agent_socket)
50}
51
52async fn handle_runtime_socket(mut ws: WebSocket) {
54 info!("Runtime WebSocket client connected");
55
56 let _ = ws
58 .send(Message::Text(
59 serde_json::to_string(&serde_json::json!({
60 "type": "handshake",
61 "payload": {
62 "client": "runtime",
63 "capabilities": ["patch", "state", "event"]
64 }
65 })).unwrap().into(),
66 ))
67 .await;
68
69 while let Some(result) = ws.next().await {
70 match result {
71 Ok(Message::Text(text)) => {
72 if let Ok(message) = serde_json::from_str::<serde_json::Value>(&text) {
74 info!("Received runtime message: {}", message);
76 }
77 }
78 Ok(Message::Binary(bin)) => {
79 info!("Received binary message of {} bytes", bin.len());
81 }
82 Ok(Message::Close(_)) => {
83 info!("Runtime WebSocket client disconnected");
84 break;
85 }
86 Err(e) => {
87 error!("WebSocket error: {}", e);
88 break;
89 }
90 _ => {}
91 }
92 }
93}
94
95async fn handle_devtools_socket(mut ws: WebSocket) {
97 info!("DevTools WebSocket client connected");
98
99 while let Some(result) = ws.next().await {
100 match result {
101 Ok(Message::Text(text)) => {
102 if let Ok(message) = serde_json::from_str::<serde_json::Value>(&text) {
104 info!("Received DevTools message: {}", message);
105 }
106 }
107 Ok(Message::Close(_)) => {
108 info!("DevTools WebSocket client disconnected");
109 break;
110 }
111 Err(e) => {
112 error!("DevTools WebSocket error: {}", e);
113 break;
114 }
115 _ => {}
116 }
117 }
118}
119
120async fn handle_hotreload_socket(mut ws: WebSocket, state: AppState) {
122 info!("Hot reload WebSocket client connected");
123
124 let mut patch_rx = state.patch_tx.subscribe();
125
126 loop {
127 tokio::select! {
128 Ok(msg) = patch_rx.recv() => {
130 if let Ok(serialized) = serde_json::to_string(&msg) {
131 if let Err(e) = ws.send(Message::Text(serialized.into())).await {
132 error!("Failed to send patch to client: {}", e);
133 break;
134 }
135 }
136 }
137 Some(result) = ws.next() => {
139 match result {
140 Ok(Message::Close(_)) => {
141 info!("Hot reload WebSocket client disconnected");
142 break;
143 }
144 Err(e) => {
145 error!("Hot reload WebSocket error: {}", e);
146 break;
147 }
148 _ => {}
149 }
150 }
151 }
152 }
153}
154
155async fn handle_agent_socket(mut ws: WebSocket) {
157 info!("Agent stream WebSocket client connected");
158
159 while let Some(result) = ws.next().await {
160 match result {
161 Ok(Message::Text(text)) => {
162 if let Ok(message) = serde_json::from_str::<serde_json::Value>(&text) {
164 info!("Received agent message: {}", message);
165 }
166 }
167 Ok(Message::Close(_)) => {
168 info!("Agent stream WebSocket client disconnected");
169 break;
170 }
171 Err(e) => {
172 error!("Agent stream WebSocket error: {}", e);
173 break;
174 }
175 _ => {}
176 }
177 }
178}
179
180pub fn create_router(state: AppState) -> Router {
182 Router::new()
183 .route("/ws/runtime", get(runtime_ws))
184 .route("/ws/devtools", get(devtools_ws))
185 .route("/ws/hotreload", get(hotreload_ws))
186 .route("/ws/agent", get(agent_ws))
187 .with_state(state)
188}
189
190pub async fn start_server(addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
192 let (tx, _) = broadcast::channel(100);
193 let state = AppState { patch_tx: tx.clone() };
194
195 let tx_clone = tx.clone();
197 let patch_engine = Arc::new(tokio::sync::Mutex::new(super::patch_engine::PatchEngine::new()));
198
199 super::build_pipeline::BuildPipeline::watch_changes(".", move |artifact| {
200 let mut engine = patch_engine.blocking_lock();
201 let patch = engine.generate_patch(artifact);
202 let _ = tx_clone.send(WsMessage::Patch(patch));
203 });
204
205 let app = create_router(state);
206 info!("Starting WebSocket server on {}", addr);
207
208 let listener = tokio::net::TcpListener::bind(addr).await?;
209 axum::serve(listener, app).await?;
210
211 Ok(())
212}