1use axum::{
6 Router,
7 body::Body,
8 extract::{
9 State, WebSocketUpgrade,
10 ws::{Message, WebSocket},
11 },
12 http::{Request, header},
13 response::{IntoResponse, Response},
14 routing::get,
15};
16use std::net::SocketAddr;
17use std::path::PathBuf;
18use std::sync::Arc;
19use tokio::sync::broadcast;
20use tower_http::services::ServeDir;
21
22#[derive(Debug, Clone)]
24pub enum ReloadMessage {
25 Reload,
27 CssReload(String),
29}
30
31pub struct ServerState {
33 pub output_dir: PathBuf,
35 pub reload_tx: broadcast::Sender<ReloadMessage>,
37}
38
39const LIVE_RELOAD_SCRIPT: &str = r#"
41<script>
42(function() {
43 var reconnectInterval = 1000;
44 var maxReconnectInterval = 5000;
45 var reconnecting = false;
46 var isConnecting = false;
47
48 function connect() {
49 if (isConnecting) return;
50 isConnecting = true;
51
52 var ws;
53 try {
54 ws = new WebSocket('ws://' + location.host + '/__rs_web_live_reload');
55 } catch (e) {
56 isConnecting = false;
57 scheduleReconnect();
58 return;
59 }
60
61 ws.onopen = function() {
62 console.log('[rs-web] Live reload connected');
63 isConnecting = false;
64 reconnectInterval = 1000;
65 if (reconnecting) {
66 // Server is back - verify page is ready then reload
67 fetch(location.href, { method: 'HEAD', cache: 'no-store' })
68 .then(function(resp) {
69 if (resp.ok) {
70 location.reload();
71 } else {
72 scheduleReconnect();
73 }
74 })
75 .catch(function() {
76 scheduleReconnect();
77 });
78 }
79 };
80
81 ws.onmessage = function(event) {
82 console.log('[rs-web] Received:', event.data);
83 var msg = JSON.parse(event.data);
84 if (msg.type === 'reload') {
85 console.log('[rs-web] Reloading page...');
86 location.reload();
87 } else if (msg.type === 'css') {
88 // Hot reload CSS
89 var links = document.querySelectorAll('link[rel="stylesheet"]');
90 links.forEach(function(link) {
91 var href = link.getAttribute('href');
92 if (href) {
93 var url = new URL(href, location.href);
94 url.searchParams.set('_reload', Date.now());
95 link.setAttribute('href', url.toString());
96 }
97 });
98 }
99 };
100
101 ws.onclose = function() {
102 isConnecting = false;
103 if (!reconnecting) {
104 console.log('[rs-web] Live reload disconnected');
105 }
106 reconnecting = true;
107 scheduleReconnect();
108 };
109
110 ws.onerror = function() {
111 // Let onclose handle reconnection
112 };
113 }
114
115 function scheduleReconnect() {
116 setTimeout(function() {
117 reconnectInterval = Math.min(reconnectInterval * 1.5, maxReconnectInterval);
118 connect();
119 }, reconnectInterval);
120 }
121
122 connect();
123})();
124</script>
125"#;
126
127pub fn create_router(state: Arc<ServerState>) -> Router {
129 let serve_dir = ServeDir::new(&state.output_dir);
131
132 Router::new()
133 .route("/__rs_web_live_reload", get(websocket_handler))
134 .fallback_service(serve_dir)
135 .with_state(state)
136 .layer(axum::middleware::from_fn(inject_live_reload))
137}
138
139async fn websocket_handler(
141 ws: WebSocketUpgrade,
142 State(state): State<Arc<ServerState>>,
143) -> impl IntoResponse {
144 ws.on_upgrade(|socket| handle_socket(socket, state))
145}
146
147async fn handle_socket(mut socket: WebSocket, state: Arc<ServerState>) {
149 let mut rx = state.reload_tx.subscribe();
150 log::debug!(
151 "[WS] Client connected. Total receivers: {}",
152 state.reload_tx.receiver_count()
153 );
154
155 loop {
156 tokio::select! {
157 biased;
158
159 msg = socket.recv() => {
161 match msg {
162 Some(Ok(Message::Ping(data))) => {
163 if socket.send(Message::Pong(data)).await.is_err() {
164 break;
165 }
166 }
167 Some(Ok(Message::Pong(_))) => {}
168 Some(Ok(Message::Close(_))) => break,
169 Some(Ok(_)) => {}
170 Some(Err(e)) => {
171 log::debug!("[WS] Receive error: {}", e);
172 break;
173 }
174 None => {
175 log::debug!("[WS] Connection closed by client");
176 break;
177 }
178 }
179 }
180
181 result = rx.recv() => {
183 match result {
184 Ok(msg) => {
185 let json = match msg {
186 ReloadMessage::Reload => r#"{"type":"reload"}"#.to_string(),
187 ReloadMessage::CssReload(path) => {
188 format!(r#"{{"type":"css","path":"{}"}}"#, path)
189 }
190 };
191 log::debug!("[WS] Sending: {}", json);
192 if socket.send(Message::Text(json.into())).await.is_err() {
193 break;
194 }
195 }
196 Err(e) => {
197 log::debug!("[WS] Broadcast recv error: {}", e);
198 }
199 }
200 }
201 }
202 }
203 log::debug!(
204 "[WS] Client disconnected. Remaining receivers: {}",
205 state.reload_tx.receiver_count()
206 );
207}
208
209async fn inject_live_reload(request: Request<Body>, next: axum::middleware::Next) -> Response {
211 if request.headers().contains_key(header::UPGRADE) {
213 return next.run(request).await;
214 }
215
216 let response = next.run(request).await;
217
218 let is_html = response
220 .headers()
221 .get(header::CONTENT_TYPE)
222 .and_then(|v| v.to_str().ok())
223 .map(|ct| ct.starts_with("text/html"))
224 .unwrap_or(false);
225
226 if !is_html {
227 return response;
228 }
229
230 let (mut parts, body) = response.into_parts();
232 let bytes = match axum::body::to_bytes(body, usize::MAX).await {
233 Ok(b) => b,
234 Err(_) => return Response::from_parts(parts, Body::empty()),
235 };
236
237 let html = String::from_utf8_lossy(&bytes);
238 let modified = if html.contains("</body>") {
239 html.replace("</body>", &format!("{}</body>", LIVE_RELOAD_SCRIPT))
240 } else if html.contains("</html>") {
241 html.replace("</html>", &format!("{}</html>", LIVE_RELOAD_SCRIPT))
242 } else {
243 format!("{}{}", html, LIVE_RELOAD_SCRIPT)
244 };
245
246 let new_len = modified.len();
248 parts.headers.remove(header::CONTENT_LENGTH);
249 parts.headers.insert(
250 header::CONTENT_LENGTH,
251 header::HeaderValue::from_str(&new_len.to_string()).unwrap(),
252 );
253
254 Response::from_parts(parts, Body::from(modified))
255}
256
257pub struct ServerConfig {
259 pub port: u16,
260 pub host: String,
261 pub output_dir: PathBuf,
262}
263
264async fn try_bind(
266 host: &str,
267 start_port: u16,
268 max_attempts: u16,
269) -> anyhow::Result<(tokio::net::TcpListener, u16)> {
270 for offset in 0..max_attempts {
271 let port = start_port + offset;
272 let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
273
274 match tokio::net::TcpListener::bind(addr).await {
275 Ok(listener) => {
276 if offset > 0 {
277 rs_print!(
278 "⚠ Port {} in use, using port {} instead (another rs-web may be running)",
279 start_port,
280 port
281 );
282 }
283 return Ok((listener, port));
284 }
285 Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
286 continue;
287 }
288 Err(e) => {
289 return Err(e.into());
290 }
291 }
292 }
293
294 anyhow::bail!(
295 "Could not find available port (tried {} to {})",
296 start_port,
297 start_port + max_attempts - 1
298 )
299}
300
301pub async fn run_server(config: ServerConfig) -> anyhow::Result<broadcast::Sender<ReloadMessage>> {
303 let (reload_tx, _) = broadcast::channel::<ReloadMessage>(16);
304
305 let state = Arc::new(ServerState {
306 output_dir: config.output_dir.clone(),
307 reload_tx: reload_tx.clone(),
308 });
309
310 let app = create_router(state);
311
312 let (listener, actual_port) = try_bind(&config.host, config.port, 10).await?;
314
315 rs_print!(
316 "Development server running at http://{}:{}",
317 config.host,
318 actual_port
319 );
320 rs_print!("Serving: {}", config.output_dir.display());
321 rs_print!("Live reload: enabled");
322
323 tokio::spawn(async move {
324 axum::serve(listener, app).await.ok();
325 });
326
327 Ok(reload_tx)
328}
329
330pub fn notify_reload(tx: &broadcast::Sender<ReloadMessage>, message: ReloadMessage) {
332 let receivers = tx.receiver_count();
333 log::debug!("Sending reload to {} receivers", receivers);
334 match tx.send(message) {
335 Ok(n) => log::debug!("Sent to {} receivers", n),
336 Err(e) => log::debug!("No receivers for reload message: {}", e),
337 }
338}