use std::sync::Arc;
use axum::extract::State;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use tokio::sync::broadcast;
use super::AppState;
use crate::html::insert_before_closing_tag;
pub(crate) async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
) -> Response {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
async fn handle_socket(mut socket: WebSocket, state: Arc<AppState>) {
let mut rx = state.reload_tx.subscribe();
let shutdown = state.shutdown.notified();
tokio::pin!(shutdown);
loop {
tokio::select! {
biased;
_ = &mut shutdown => break,
result = rx.recv() => {
match result {
Ok(()) | Err(broadcast::error::RecvError::Lagged(_)) => {
if socket.send(Message::Text("reload".into())).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
msg = socket.recv() => {
if msg.is_none() {
break;
}
}
}
}
}
pub const LIVE_RELOAD_SCRIPT: &str = r#"<script>(function(){var ws=new WebSocket("ws://"+location.host+"/ws");ws.onmessage=function(){location.reload()};ws.onclose=function(){setTimeout(function(){location.reload()},1000)}})()</script>"#;
pub fn inject_live_reload(html: &str) -> String {
let mut result = String::with_capacity(html.len() + LIVE_RELOAD_SCRIPT.len());
result.push_str(html);
if !insert_before_closing_tag(&mut result, "</body>", LIVE_RELOAD_SCRIPT) {
result.push_str(LIVE_RELOAD_SCRIPT);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn injects_before_closing_body() {
let html = "<html><body><p>Hello</p></body></html>";
let result = inject_live_reload(html);
assert!(result.contains(&format!("{LIVE_RELOAD_SCRIPT}</body>")));
assert!(result.contains("<p>Hello</p>"));
}
#[test]
fn appends_when_no_body_tag() {
let html = "<p>No body tag</p>";
let result = inject_live_reload(html);
assert!(result.ends_with(LIVE_RELOAD_SCRIPT));
}
#[test]
fn script_contains_websocket() {
assert!(LIVE_RELOAD_SCRIPT.contains("WebSocket"));
assert!(LIVE_RELOAD_SCRIPT.contains("location.reload()"));
}
}