1use axum::{
6 Router,
7 body::Body,
8 extract::{
9 State, WebSocketUpgrade,
10 ws::{Message, WebSocket},
11 },
12 http::{Request, header},
13 response::{IntoResponse, Response},
14 routing::get,
15};
16use std::net::SocketAddr;
17use std::path::PathBuf;
18use std::sync::Arc;
19use tokio::sync::broadcast;
20use tower_http::services::ServeDir;
21
22#[derive(Debug, Clone)]
24pub enum ReloadMessage {
25 Reload,
27 CssReload(String),
29}
30
31pub struct ServerState {
33 pub output_dir: PathBuf,
35 pub reload_tx: broadcast::Sender<ReloadMessage>,
37}
38
39const LIVE_RELOAD_SCRIPT: &str = r#"
41<script>
42(function() {
43 var reconnectInterval = 1000;
44 var maxReconnectInterval = 5000;
45 var reconnecting = false;
46 var isConnecting = false;
47
48 function connect() {
49 if (isConnecting) return;
50 isConnecting = true;
51
52 var ws;
53 try {
54 ws = new WebSocket('ws://' + location.host + '/__rs_web_live_reload');
55 } catch (e) {
56 isConnecting = false;
57 scheduleReconnect();
58 return;
59 }
60
61 ws.onopen = function() {
62 console.log('[rs-web] Live reload connected');
63 isConnecting = false;
64 reconnectInterval = 1000;
65 if (reconnecting) {
66 // Server is back - verify page is ready then reload
67 fetch(location.href, { method: 'HEAD', cache: 'no-store' })
68 .then(function(resp) {
69 if (resp.ok) {
70 location.reload();
71 } else {
72 scheduleReconnect();
73 }
74 })
75 .catch(function() {
76 scheduleReconnect();
77 });
78 }
79 };
80
81 ws.onmessage = function(event) {
82 console.log('[rs-web] Received:', event.data);
83 var msg = JSON.parse(event.data);
84 if (msg.type === 'reload') {
85 console.log('[rs-web] Reloading page...');
86 location.reload();
87 } else if (msg.type === 'css') {
88 // Hot reload CSS
89 var links = document.querySelectorAll('link[rel="stylesheet"]');
90 links.forEach(function(link) {
91 var href = link.getAttribute('href');
92 if (href) {
93 var url = new URL(href, location.href);
94 url.searchParams.set('_reload', Date.now());
95 link.setAttribute('href', url.toString());
96 }
97 });
98 }
99 };
100
101 ws.onclose = function() {
102 isConnecting = false;
103 if (!reconnecting) {
104 console.log('[rs-web] Live reload disconnected');
105 }
106 reconnecting = true;
107 scheduleReconnect();
108 };
109
110 ws.onerror = function() {
111 // Let onclose handle reconnection
112 };
113 }
114
115 function scheduleReconnect() {
116 setTimeout(function() {
117 reconnectInterval = Math.min(reconnectInterval * 1.5, maxReconnectInterval);
118 connect();
119 }, reconnectInterval);
120 }
121
122 connect();
123})();
124</script>
125"#;
126
127pub fn create_router(state: Arc<ServerState>) -> Router {
129 let serve_dir = ServeDir::new(&state.output_dir);
131
132 Router::new()
133 .route("/__rs_web_live_reload", get(websocket_handler))
134 .fallback_service(serve_dir)
135 .with_state(state)
136 .layer(axum::middleware::from_fn(inject_live_reload))
137}
138
139async fn websocket_handler(
141 ws: WebSocketUpgrade,
142 State(state): State<Arc<ServerState>>,
143) -> impl IntoResponse {
144 ws.on_upgrade(|socket| handle_socket(socket, state))
145}
146
147async fn handle_socket(mut socket: WebSocket, state: Arc<ServerState>) {
149 let mut rx = state.reload_tx.subscribe();
150
151 loop {
152 tokio::select! {
153 Ok(msg) = rx.recv() => {
155 let json = match msg {
156 ReloadMessage::Reload => r#"{"type":"reload"}"#.to_string(),
157 ReloadMessage::CssReload(path) => {
158 format!(r#"{{"type":"css","path":"{}"}}"#, path)
159 }
160 };
161 if socket.send(Message::Text(json.into())).await.is_err() {
162 break;
163 }
164 }
165 Some(Ok(msg)) = socket.recv() => {
167 match msg {
168 Message::Ping(data) => {
169 if socket.send(Message::Pong(data)).await.is_err() {
170 break;
171 }
172 }
173 Message::Close(_) => break,
174 _ => {}
175 }
176 }
177 else => break,
178 }
179 }
180}
181
182async fn inject_live_reload(request: Request<Body>, next: axum::middleware::Next) -> Response {
184 let response = next.run(request).await;
185
186 let is_html = response
188 .headers()
189 .get(header::CONTENT_TYPE)
190 .and_then(|v| v.to_str().ok())
191 .map(|ct| ct.starts_with("text/html"))
192 .unwrap_or(false);
193
194 if !is_html {
195 return response;
196 }
197
198 let (mut parts, body) = response.into_parts();
200 let bytes = match axum::body::to_bytes(body, usize::MAX).await {
201 Ok(b) => b,
202 Err(_) => return Response::from_parts(parts, Body::empty()),
203 };
204
205 let html = String::from_utf8_lossy(&bytes);
206 let modified = if html.contains("</body>") {
207 html.replace("</body>", &format!("{}</body>", LIVE_RELOAD_SCRIPT))
208 } else if html.contains("</html>") {
209 html.replace("</html>", &format!("{}</html>", LIVE_RELOAD_SCRIPT))
210 } else {
211 format!("{}{}", html, LIVE_RELOAD_SCRIPT)
212 };
213
214 let new_len = modified.len();
216 parts.headers.remove(header::CONTENT_LENGTH);
217 parts.headers.insert(
218 header::CONTENT_LENGTH,
219 header::HeaderValue::from_str(&new_len.to_string()).unwrap(),
220 );
221
222 Response::from_parts(parts, Body::from(modified))
223}
224
225pub struct ServerConfig {
227 pub port: u16,
228 pub host: String,
229 pub output_dir: PathBuf,
230}
231
232pub async fn run_server(config: ServerConfig) -> anyhow::Result<broadcast::Sender<ReloadMessage>> {
234 let (reload_tx, _) = broadcast::channel::<ReloadMessage>(16);
235
236 let state = Arc::new(ServerState {
237 output_dir: config.output_dir.clone(),
238 reload_tx: reload_tx.clone(),
239 });
240
241 let app = create_router(state);
242
243 let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
244
245 println!(
246 "Development server running at http://{}:{}",
247 config.host, config.port
248 );
249 println!("Serving: {}", config.output_dir.display());
250 println!("Live reload: enabled");
251 println!();
252
253 let listener = tokio::net::TcpListener::bind(addr).await?;
254
255 tokio::spawn(async move {
256 axum::serve(listener, app).await.ok();
257 });
258
259 Ok(reload_tx)
260}
261
262pub fn notify_reload(tx: &broadcast::Sender<ReloadMessage>, message: ReloadMessage) {
264 let receivers = tx.receiver_count();
265 log::debug!("Sending reload message to {} connected clients", receivers);
266 match tx.send(message) {
267 Ok(n) => log::debug!("Reload message sent to {} receivers", n),
268 Err(e) => log::warn!("Failed to send reload message: {}", e),
269 }
270}