1use axum::{
6 Router,
7 body::Body,
8 extract::{
9 State, WebSocketUpgrade,
10 ws::{Message, WebSocket},
11 },
12 http::{Request, header},
13 response::{IntoResponse, Response},
14 routing::get,
15};
16use std::net::SocketAddr;
17use std::path::PathBuf;
18use std::sync::Arc;
19use tokio::sync::broadcast;
20use tower_http::services::ServeDir;
21
22#[derive(Debug, Clone)]
24pub enum ReloadMessage {
25 Reload,
27 CssReload(String),
29}
30
31pub struct ServerState {
33 pub output_dir: PathBuf,
35 pub reload_tx: broadcast::Sender<ReloadMessage>,
37}
38
39const LIVE_RELOAD_SCRIPT: &str = r#"
41<script>
42(function() {
43 var ws = new WebSocket('ws://' + location.host + '/__rs_web_live_reload');
44 ws.onmessage = function(event) {
45 var msg = JSON.parse(event.data);
46 if (msg.type === 'reload') {
47 location.reload();
48 } else if (msg.type === 'css') {
49 // Hot reload CSS
50 var links = document.querySelectorAll('link[rel="stylesheet"]');
51 links.forEach(function(link) {
52 var href = link.getAttribute('href');
53 if (href) {
54 var url = new URL(href, location.href);
55 url.searchParams.set('_reload', Date.now());
56 link.setAttribute('href', url.toString());
57 }
58 });
59 }
60 };
61 ws.onclose = function() {
62 console.log('[rs-web] Live reload disconnected. Attempting reconnect...');
63 setTimeout(function() { location.reload(); }, 1000);
64 };
65 ws.onerror = function() {
66 console.log('[rs-web] Live reload connection error');
67 };
68})();
69</script>
70"#;
71
72pub fn create_router(state: Arc<ServerState>) -> Router {
74 let serve_dir = ServeDir::new(&state.output_dir);
76
77 Router::new()
78 .route("/__rs_web_live_reload", get(websocket_handler))
79 .fallback_service(serve_dir)
80 .with_state(state)
81 .layer(axum::middleware::from_fn(inject_live_reload))
82}
83
84async fn websocket_handler(
86 ws: WebSocketUpgrade,
87 State(state): State<Arc<ServerState>>,
88) -> impl IntoResponse {
89 ws.on_upgrade(|socket| handle_socket(socket, state))
90}
91
92async fn handle_socket(mut socket: WebSocket, state: Arc<ServerState>) {
94 let mut rx = state.reload_tx.subscribe();
95
96 loop {
97 tokio::select! {
98 Ok(msg) = rx.recv() => {
100 let json = match msg {
101 ReloadMessage::Reload => r#"{"type":"reload"}"#.to_string(),
102 ReloadMessage::CssReload(path) => {
103 format!(r#"{{"type":"css","path":"{}"}}"#, path)
104 }
105 };
106 if socket.send(Message::Text(json.into())).await.is_err() {
107 break;
108 }
109 }
110 Some(Ok(msg)) = socket.recv() => {
112 match msg {
113 Message::Ping(data) => {
114 if socket.send(Message::Pong(data)).await.is_err() {
115 break;
116 }
117 }
118 Message::Close(_) => break,
119 _ => {}
120 }
121 }
122 else => break,
123 }
124 }
125}
126
127async fn inject_live_reload(request: Request<Body>, next: axum::middleware::Next) -> Response {
129 let response = next.run(request).await;
130
131 let is_html = response
133 .headers()
134 .get(header::CONTENT_TYPE)
135 .and_then(|v| v.to_str().ok())
136 .map(|ct| ct.starts_with("text/html"))
137 .unwrap_or(false);
138
139 if !is_html {
140 return response;
141 }
142
143 let (mut parts, body) = response.into_parts();
145 let bytes = match axum::body::to_bytes(body, usize::MAX).await {
146 Ok(b) => b,
147 Err(_) => return Response::from_parts(parts, Body::empty()),
148 };
149
150 let html = String::from_utf8_lossy(&bytes);
151 let modified = if html.contains("</body>") {
152 html.replace("</body>", &format!("{}</body>", LIVE_RELOAD_SCRIPT))
153 } else if html.contains("</html>") {
154 html.replace("</html>", &format!("{}</html>", LIVE_RELOAD_SCRIPT))
155 } else {
156 format!("{}{}", html, LIVE_RELOAD_SCRIPT)
157 };
158
159 let new_len = modified.len();
161 parts.headers.remove(header::CONTENT_LENGTH);
162 parts.headers.insert(
163 header::CONTENT_LENGTH,
164 header::HeaderValue::from_str(&new_len.to_string()).unwrap(),
165 );
166
167 Response::from_parts(parts, Body::from(modified))
168}
169
170pub struct ServerConfig {
172 pub port: u16,
173 pub host: String,
174 pub output_dir: PathBuf,
175}
176
177pub async fn run_server(config: ServerConfig) -> anyhow::Result<broadcast::Sender<ReloadMessage>> {
179 let (reload_tx, _) = broadcast::channel::<ReloadMessage>(16);
180
181 let state = Arc::new(ServerState {
182 output_dir: config.output_dir.clone(),
183 reload_tx: reload_tx.clone(),
184 });
185
186 let app = create_router(state);
187
188 let addr: SocketAddr = format!("{}:{}", config.host, config.port).parse()?;
189
190 println!(
191 "Development server running at http://{}:{}",
192 config.host, config.port
193 );
194 println!("Serving: {}", config.output_dir.display());
195 println!("Live reload: enabled");
196 println!();
197
198 let listener = tokio::net::TcpListener::bind(addr).await?;
199
200 tokio::spawn(async move {
201 axum::serve(listener, app).await.ok();
202 });
203
204 Ok(reload_tx)
205}
206
207pub fn notify_reload(tx: &broadcast::Sender<ReloadMessage>, message: ReloadMessage) {
209 let _ = tx.send(message);
210}