rs_web/
server.rs

1//! Development server with WebSocket live reload
2//!
3//! Provides a static file server with automatic reload when files change.
4
5use axum::{
6    Router,
7    body::Body,
8    extract::{
9        State, WebSocketUpgrade,
10        ws::{Message, WebSocket},
11    },
12    http::{Request, header},
13    response::{IntoResponse, Response},
14    routing::get,
15};
16use std::net::SocketAddr;
17use std::path::PathBuf;
18use std::sync::Arc;
19use tokio::sync::broadcast;
20use tower_http::services::ServeDir;
21
22/// Reload message sent to connected clients
23#[derive(Debug, Clone)]
24pub enum ReloadMessage {
25    /// Full page reload
26    Reload,
27    /// CSS-only reload (hot reload)
28    CssReload(String),
29}
30
31/// Server state shared across handlers
32pub struct ServerState {
33    /// Output directory to serve
34    pub output_dir: PathBuf,
35    /// Broadcast channel for reload notifications
36    pub reload_tx: broadcast::Sender<ReloadMessage>,
37}
38
39/// Live reload JavaScript injected into HTML pages
40const LIVE_RELOAD_SCRIPT: &str = r#"
41<script>
42(function() {
43    var reconnectInterval = 1000;
44    var maxReconnectInterval = 5000;
45    var reconnecting = false;
46    var isConnecting = false;
47
48    function connect() {
49        if (isConnecting) return;
50        isConnecting = true;
51
52        var ws;
53        try {
54            ws = new WebSocket('ws://' + location.host + '/__rs_web_live_reload');
55        } catch (e) {
56            isConnecting = false;
57            scheduleReconnect();
58            return;
59        }
60
61        ws.onopen = function() {
62            console.log('[rs-web] Live reload connected');
63            isConnecting = false;
64            reconnectInterval = 1000;
65            if (reconnecting) {
66                // Server is back - verify page is ready then reload
67                fetch(location.href, { method: 'HEAD', cache: 'no-store' })
68                    .then(function(resp) {
69                        if (resp.ok) {
70                            location.reload();
71                        } else {
72                            scheduleReconnect();
73                        }
74                    })
75                    .catch(function() {
76                        scheduleReconnect();
77                    });
78            }
79        };
80
81        ws.onmessage = function(event) {
82            console.log('[rs-web] Received:', event.data);
83            var msg = JSON.parse(event.data);
84            if (msg.type === 'reload') {
85                console.log('[rs-web] Reloading page...');
86                location.reload();
87            } else if (msg.type === 'css') {
88                // Hot reload CSS
89                var links = document.querySelectorAll('link[rel="stylesheet"]');
90                links.forEach(function(link) {
91                    var href = link.getAttribute('href');
92                    if (href) {
93                        var url = new URL(href, location.href);
94                        url.searchParams.set('_reload', Date.now());
95                        link.setAttribute('href', url.toString());
96                    }
97                });
98            }
99        };
100
101        ws.onclose = function() {
102            isConnecting = false;
103            if (!reconnecting) {
104                console.log('[rs-web] Live reload disconnected');
105            }
106            reconnecting = true;
107            scheduleReconnect();
108        };
109
110        ws.onerror = function() {
111            // Let onclose handle reconnection
112        };
113    }
114
115    function scheduleReconnect() {
116        setTimeout(function() {
117            reconnectInterval = Math.min(reconnectInterval * 1.5, maxReconnectInterval);
118            connect();
119        }, reconnectInterval);
120    }
121
122    connect();
123})();
124</script>
125"#;
126
127/// Create the server router
128pub fn create_router(state: Arc<ServerState>) -> Router {
129    // Static file serving with live reload injection
130    let serve_dir = ServeDir::new(&state.output_dir);
131
132    Router::new()
133        .route("/__rs_web_live_reload", get(websocket_handler))
134        .fallback_service(serve_dir)
135        .with_state(state)
136        .layer(axum::middleware::from_fn(inject_live_reload))
137}
138
139/// WebSocket handler for live reload connections
140async fn websocket_handler(
141    ws: WebSocketUpgrade,
142    State(state): State<Arc<ServerState>>,
143) -> impl IntoResponse {
144    ws.on_upgrade(|socket| handle_socket(socket, state))
145}
146
147/// Handle WebSocket connection
148async fn handle_socket(mut socket: WebSocket, state: Arc<ServerState>) {
149    let mut rx = state.reload_tx.subscribe();
150    log::debug!(
151        "[WS] Client connected. Total receivers: {}",
152        state.reload_tx.receiver_count()
153    );
154
155    loop {
156        tokio::select! {
157            biased;
158
159            // Handle incoming messages (ping/pong) - check this first
160            msg = socket.recv() => {
161                match msg {
162                    Some(Ok(Message::Ping(data))) => {
163                        if socket.send(Message::Pong(data)).await.is_err() {
164                            break;
165                        }
166                    }
167                    Some(Ok(Message::Pong(_))) => {}
168                    Some(Ok(Message::Close(_))) => break,
169                    Some(Ok(_)) => {}
170                    Some(Err(e)) => {
171                        log::debug!("[WS] Receive error: {}", e);
172                        break;
173                    }
174                    None => {
175                        log::debug!("[WS] Connection closed by client");
176                        break;
177                    }
178                }
179            }
180
181            // Receive reload notifications
182            result = rx.recv() => {
183                match result {
184                    Ok(msg) => {
185                        let json = match msg {
186                            ReloadMessage::Reload => r#"{"type":"reload"}"#.to_string(),
187                            ReloadMessage::CssReload(path) => {
188                                format!(r#"{{"type":"css","path":"{}"}}"#, path)
189                            }
190                        };
191                        log::debug!("[WS] Sending: {}", json);
192                        if socket.send(Message::Text(json.into())).await.is_err() {
193                            break;
194                        }
195                    }
196                    Err(e) => {
197                        log::debug!("[WS] Broadcast recv error: {}", e);
198                    }
199                }
200            }
201        }
202    }
203    log::debug!(
204        "[WS] Client disconnected. Remaining receivers: {}",
205        state.reload_tx.receiver_count()
206    );
207}
208
209/// Middleware to inject live reload script into HTML responses
210async fn inject_live_reload(request: Request<Body>, next: axum::middleware::Next) -> Response {
211    // Skip for WebSocket upgrade requests
212    if request.headers().contains_key(header::UPGRADE) {
213        return next.run(request).await;
214    }
215
216    let response = next.run(request).await;
217
218    // Check if response is HTML
219    let is_html = response
220        .headers()
221        .get(header::CONTENT_TYPE)
222        .and_then(|v| v.to_str().ok())
223        .map(|ct| ct.starts_with("text/html"))
224        .unwrap_or(false);
225
226    if !is_html {
227        return response;
228    }
229
230    // Extract body and inject script
231    let (mut parts, body) = response.into_parts();
232    let bytes = match axum::body::to_bytes(body, usize::MAX).await {
233        Ok(b) => b,
234        Err(_) => return Response::from_parts(parts, Body::empty()),
235    };
236
237    let html = String::from_utf8_lossy(&bytes);
238    let modified = if html.contains("</body>") {
239        html.replace("</body>", &format!("{}</body>", LIVE_RELOAD_SCRIPT))
240    } else if html.contains("</html>") {
241        html.replace("</html>", &format!("{}</html>", LIVE_RELOAD_SCRIPT))
242    } else {
243        format!("{}{}", html, LIVE_RELOAD_SCRIPT)
244    };
245
246    // Update Content-Length header to match new body
247    let new_len = modified.len();
248    parts.headers.remove(header::CONTENT_LENGTH);
249    parts.headers.insert(
250        header::CONTENT_LENGTH,
251        header::HeaderValue::from_str(&new_len.to_string()).unwrap(),
252    );
253
254    Response::from_parts(parts, Body::from(modified))
255}
256
257/// Server configuration
258pub struct ServerConfig {
259    pub port: u16,
260    pub host: String,
261    pub output_dir: PathBuf,
262}
263
264/// Try to bind to a port, returns the listener and actual port used
265async fn try_bind(
266    host: &str,
267    start_port: u16,
268    max_attempts: u16,
269) -> anyhow::Result<(tokio::net::TcpListener, u16)> {
270    for offset in 0..max_attempts {
271        let port = start_port + offset;
272        let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
273
274        match tokio::net::TcpListener::bind(addr).await {
275            Ok(listener) => {
276                if offset > 0 {
277                    rs_print!(
278                        "⚠ Port {} in use, using port {} instead (another rs-web may be running)",
279                        start_port,
280                        port
281                    );
282                }
283                return Ok((listener, port));
284            }
285            Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
286                continue;
287            }
288            Err(e) => {
289                return Err(e.into());
290            }
291        }
292    }
293
294    anyhow::bail!(
295        "Could not find available port (tried {} to {})",
296        start_port,
297        start_port + max_attempts - 1
298    )
299}
300
301/// Run the development server
302pub async fn run_server(config: ServerConfig) -> anyhow::Result<broadcast::Sender<ReloadMessage>> {
303    let (reload_tx, _) = broadcast::channel::<ReloadMessage>(16);
304
305    let state = Arc::new(ServerState {
306        output_dir: config.output_dir.clone(),
307        reload_tx: reload_tx.clone(),
308    });
309
310    let app = create_router(state);
311
312    // Try to find an available port (up to 10 attempts)
313    let (listener, actual_port) = try_bind(&config.host, config.port, 10).await?;
314
315    rs_print!(
316        "Development server running at http://{}:{}",
317        config.host,
318        actual_port
319    );
320    rs_print!("Serving: {}", config.output_dir.display());
321    rs_print!("Live reload: enabled");
322
323    tokio::spawn(async move {
324        axum::serve(listener, app).await.ok();
325    });
326
327    Ok(reload_tx)
328}
329
330/// Notify clients to reload
331pub fn notify_reload(tx: &broadcast::Sender<ReloadMessage>, message: ReloadMessage) {
332    let receivers = tx.receiver_count();
333    log::debug!("Sending reload to {} receivers", receivers);
334    match tx.send(message) {
335        Ok(n) => log::debug!("Sent to {} receivers", n),
336        Err(e) => log::debug!("No receivers for reload message: {}", e),
337    }
338}