rs_web/
server.rs

1//! Development server with WebSocket live reload
2//!
3//! Provides a static file server with automatic reload when files change.
4
5use axum::{
6    Router,
7    body::Body,
8    extract::{
9        State, WebSocketUpgrade,
10        ws::{Message, WebSocket},
11    },
12    http::{Request, header},
13    response::{IntoResponse, Response},
14    routing::get,
15};
16use std::net::SocketAddr;
17use std::path::PathBuf;
18use std::sync::Arc;
19use tokio::sync::broadcast;
20use tower_http::services::ServeDir;
21
22/// Reload message sent to connected clients
23#[derive(Debug, Clone)]
24pub enum ReloadMessage {
25    /// Full page reload
26    Reload,
27    /// CSS-only reload (hot reload)
28    CssReload(String),
29}
30
31/// Server state shared across handlers
32pub struct ServerState {
33    /// Output directory to serve
34    pub output_dir: PathBuf,
35    /// Broadcast channel for reload notifications
36    pub reload_tx: broadcast::Sender<ReloadMessage>,
37}
38
39/// Live reload JavaScript injected into HTML pages
40const LIVE_RELOAD_SCRIPT: &str = r#"
41<script>
42(function() {
43    var ws = new WebSocket('ws://' + location.host + '/__rs_web_live_reload');
44    ws.onmessage = function(event) {
45        var msg = JSON.parse(event.data);
46        if (msg.type === 'reload') {
47            location.reload();
48        } else if (msg.type === 'css') {
49            // Hot reload CSS
50            var links = document.querySelectorAll('link[rel="stylesheet"]');
51            links.forEach(function(link) {
52                var href = link.getAttribute('href');
53                if (href) {
54                    var url = new URL(href, location.href);
55                    url.searchParams.set('_reload', Date.now());
56                    link.setAttribute('href', url.toString());
57                }
58            });
59        }
60    };
61    ws.onclose = function() {
62        console.log('[rs-web] Live reload disconnected. Attempting reconnect...');
63        setTimeout(function() { location.reload(); }, 1000);
64    };
65    ws.onerror = function() {
66        console.log('[rs-web] Live reload connection error');
67    };
68})();
69</script>
70"#;
71
72/// Create the server router
73pub fn create_router(state: Arc<ServerState>) -> Router {
74    // Static file serving with live reload injection
75    let serve_dir = ServeDir::new(&state.output_dir);
76
77    Router::new()
78        .route("/__rs_web_live_reload", get(websocket_handler))
79        .fallback_service(serve_dir)
80        .with_state(state)
81        .layer(axum::middleware::from_fn(inject_live_reload))
82}
83
84/// WebSocket handler for live reload connections
85async fn websocket_handler(
86    ws: WebSocketUpgrade,
87    State(state): State<Arc<ServerState>>,
88) -> impl IntoResponse {
89    ws.on_upgrade(|socket| handle_socket(socket, state))
90}
91
92/// Handle WebSocket connection
93async fn handle_socket(mut socket: WebSocket, state: Arc<ServerState>) {
94    let mut rx = state.reload_tx.subscribe();
95
96    loop {
97        tokio::select! {
98            // Receive reload notifications
99            Ok(msg) = rx.recv() => {
100                let json = match msg {
101                    ReloadMessage::Reload => r#"{"type":"reload"}"#.to_string(),
102                    ReloadMessage::CssReload(path) => {
103                        format!(r#"{{"type":"css","path":"{}"}}"#, path)
104                    }
105                };
106                if socket.send(Message::Text(json.into())).await.is_err() {
107                    break;
108                }
109            }
110            // Handle incoming messages (ping/pong)
111            Some(Ok(msg)) = socket.recv() => {
112                match msg {
113                    Message::Ping(data) => {
114                        if socket.send(Message::Pong(data)).await.is_err() {
115                            break;
116                        }
117                    }
118                    Message::Close(_) => break,
119                    _ => {}
120                }
121            }
122            else => break,
123        }
124    }
125}
126
127/// Middleware to inject live reload script into HTML responses
128async fn inject_live_reload(request: Request<Body>, next: axum::middleware::Next) -> Response {
129    let response = next.run(request).await;
130
131    // Check if response is HTML
132    let is_html = response
133        .headers()
134        .get(header::CONTENT_TYPE)
135        .and_then(|v| v.to_str().ok())
136        .map(|ct| ct.starts_with("text/html"))
137        .unwrap_or(false);
138
139    if !is_html {
140        return response;
141    }
142
143    // Extract body and inject script
144    let (mut parts, body) = response.into_parts();
145    let bytes = match axum::body::to_bytes(body, usize::MAX).await {
146        Ok(b) => b,
147        Err(_) => return Response::from_parts(parts, Body::empty()),
148    };
149
150    let html = String::from_utf8_lossy(&bytes);
151    let modified = if html.contains("</body>") {
152        html.replace("</body>", &format!("{}</body>", LIVE_RELOAD_SCRIPT))
153    } else if html.contains("</html>") {
154        html.replace("</html>", &format!("{}</html>", LIVE_RELOAD_SCRIPT))
155    } else {
156        format!("{}{}", html, LIVE_RELOAD_SCRIPT)
157    };
158
159    // Update Content-Length header to match new body
160    let new_len = modified.len();
161    parts.headers.remove(header::CONTENT_LENGTH);
162    parts.headers.insert(
163        header::CONTENT_LENGTH,
164        header::HeaderValue::from_str(&new_len.to_string()).unwrap(),
165    );
166
167    Response::from_parts(parts, Body::from(modified))
168}
169
170/// Server configuration
171pub struct ServerConfig {
172    pub port: u16,
173    pub host: String,
174    pub output_dir: PathBuf,
175}
176
177/// Run the development server
178pub async fn run_server(config: ServerConfig) -> anyhow::Result<broadcast::Sender<ReloadMessage>> {
179    let (reload_tx, _) = broadcast::channel::<ReloadMessage>(16);
180
181    let state = Arc::new(ServerState {
182        output_dir: config.output_dir.clone(),
183        reload_tx: reload_tx.clone(),
184    });
185
186    let app = create_router(state);
187
188    let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
189
190    println!(
191        "Development server running at http://{}:{}",
192        config.host, config.port
193    );
194    println!("Serving: {}", config.output_dir.display());
195    println!("Live reload: enabled");
196    println!();
197
198    let listener = tokio::net::TcpListener::bind(addr).await?;
199
200    tokio::spawn(async move {
201        axum::serve(listener, app).await.ok();
202    });
203
204    Ok(reload_tx)
205}
206
207/// Notify clients to reload
208pub fn notify_reload(tx: &broadcast::Sender<ReloadMessage>, message: ReloadMessage) {
209    let _ = tx.send(message);
210}