use std::net::IpAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use axum::Router;
use axum::extract::ConnectInfo;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::response::sse::{Event, Sse};
use axum::routing::get;
use futures_core::Stream;
use notify::{RecursiveMode, Watcher};
use tokio::sync::broadcast;
const CSS_RELOAD_SCRIPT: &str =
"document.querySelector('link[data-blixt-css]').href='/static/css/output.css?v='+Date.now()";
const DEBOUNCE_DURATION: Duration = Duration::from_millis(200);
pub struct CssHmrBroadcaster {
sender: broadcast::Sender<()>,
_watcher: notify::RecommendedWatcher,
}
impl CssHmrBroadcaster {
pub fn new(css_path: PathBuf) -> crate::error::Result<Arc<Self>> {
let (sender, _) = broadcast::channel(16);
let watcher = start_watcher(css_path, sender.clone())?;
Ok(Arc::new(Self {
sender,
_watcher: watcher,
}))
}
pub fn subscribe(&self) -> broadcast::Receiver<()> {
self.sender.subscribe()
}
}
fn start_watcher(
css_path: PathBuf,
sender: broadcast::Sender<()>,
) -> crate::error::Result<notify::RecommendedWatcher> {
let debounce_sender = sender.clone();
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::channel::<()>(1);
tokio::spawn(async move {
while let Some(()) = notify_rx.recv().await {
tokio::time::sleep(DEBOUNCE_DURATION).await;
while notify_rx.try_recv().is_ok() {}
let _ = debounce_sender.send(());
tracing::debug!("CSS change detected, broadcasting reload");
}
});
let mut watcher = notify::recommended_watcher(move |res: notify::Result<notify::Event>| {
if let Ok(event) = res
&& is_modify_event(&event.kind)
{
let _ = notify_tx.blocking_send(());
}
})
.map_err(|err| crate::error::Error::Internal(format!("File watcher init failed: {err}")))?;
watcher
.watch(&css_path, RecursiveMode::NonRecursive)
.map_err(|err| {
crate::error::Error::Internal(format!("Failed to watch {}: {err}", css_path.display()))
})?;
tracing::info!(path = %css_path.display(), "CSS HMR watcher started");
Ok(watcher)
}
fn is_modify_event(kind: ¬ify::EventKind) -> bool {
matches!(
kind,
notify::EventKind::Modify(_) | notify::EventKind::Create(_)
)
}
fn is_loopback(addr: &IpAddr) -> bool {
match addr {
IpAddr::V4(v4) => v4.is_loopback(),
IpAddr::V6(v6) => v6.is_loopback(),
}
}
pub async fn hmr_handler(
ConnectInfo(addr): ConnectInfo<std::net::SocketAddr>,
state: axum::extract::State<Arc<CssHmrBroadcaster>>,
) -> impl IntoResponse {
if !is_loopback(&addr.ip()) {
return StatusCode::FORBIDDEN.into_response();
}
let receiver = state.subscribe();
let stream = HmrEventStream::new(receiver);
Sse::new(stream).into_response()
}
fn reload_event() -> Event {
Event::default()
.event("datastar-execute-script")
.data(CSS_RELOAD_SCRIPT)
}
fn bridge_broadcast(mut broadcast_rx: broadcast::Receiver<()>) -> tokio::sync::mpsc::Receiver<()> {
let (mpsc_tx, mpsc_rx) = tokio::sync::mpsc::channel(16);
tokio::spawn(async move {
while let Ok(()) | Err(broadcast::error::RecvError::Lagged(_)) = broadcast_rx.recv().await {
if mpsc_tx.send(()).await.is_err() {
break;
}
}
});
mpsc_rx
}
struct HmrEventStream {
receiver: tokio::sync::mpsc::Receiver<()>,
}
impl HmrEventStream {
fn new(broadcast_rx: broadcast::Receiver<()>) -> Self {
Self {
receiver: bridge_broadcast(broadcast_rx),
}
}
}
impl Stream for HmrEventStream {
type Item = Result<Event, std::convert::Infallible>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(())) => Poll::Ready(Some(Ok(reload_event()))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
pub fn hmr_route(broadcaster: Arc<CssHmrBroadcaster>) -> Router {
Router::new()
.route("/__blixt_hmr", get(hmr_handler))
.with_state(broadcaster)
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
#[test]
fn loopback_ipv4_detected() {
let addr = IpAddr::V4(Ipv4Addr::LOCALHOST);
assert!(is_loopback(&addr));
}
#[test]
fn loopback_ipv6_detected() {
let addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
assert!(is_loopback(&addr));
}
#[test]
fn non_loopback_ipv4_rejected() {
let addr = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1));
assert!(!is_loopback(&addr));
}
#[test]
fn non_loopback_ipv6_rejected() {
let addr = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
assert!(!is_loopback(&addr));
}
#[test]
fn unspecified_ipv4_is_not_loopback() {
let addr = IpAddr::V4(Ipv4Addr::UNSPECIFIED);
assert!(!is_loopback(&addr));
}
#[test]
fn unspecified_ipv6_is_not_loopback() {
let addr = IpAddr::V6(Ipv6Addr::UNSPECIFIED);
assert!(!is_loopback(&addr));
}
#[test]
fn css_reload_script_is_static_constant() {
assert!(!CSS_RELOAD_SCRIPT.contains("${"));
assert!(!CSS_RELOAD_SCRIPT.contains("eval("));
assert!(!CSS_RELOAD_SCRIPT.contains("innerHTML"));
assert!(!CSS_RELOAD_SCRIPT.contains("document.write"));
assert!(!CSS_RELOAD_SCRIPT.contains("Function("));
assert!(CSS_RELOAD_SCRIPT.contains("data-blixt-css"));
assert!(CSS_RELOAD_SCRIPT.contains("output.css"));
}
#[test]
fn css_reload_script_no_user_input_interpolation() {
assert!(!CSS_RELOAD_SCRIPT.contains("fetch("));
assert!(!CSS_RELOAD_SCRIPT.contains("XMLHttpRequest"));
assert!(!CSS_RELOAD_SCRIPT.contains("import("));
}
#[tokio::test]
async fn hmr_handler_returns_event_stream_content_type() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let (sender, _) = broadcast::channel(16);
let broadcaster = Arc::new(CssHmrBroadcaster {
sender,
_watcher: create_null_watcher(),
});
let app = Router::new()
.route("/__blixt_hmr", get(hmr_handler))
.with_state(broadcaster);
let app = app.into_make_service_with_connect_info::<SocketAddr>();
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let bound_addr = listener.local_addr().expect("local addr");
let server = tokio::spawn(async move {
axum::serve(listener, app).await.expect("serve");
});
let mut stream = tokio::net::TcpStream::connect(bound_addr)
.await
.expect("connect");
let request = "GET /__blixt_hmr HTTP/1.1\r\nHost: localhost\r\n\r\n";
stream.write_all(request.as_bytes()).await.expect("write");
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.expect("read");
let response = String::from_utf8_lossy(&buf[..n]);
assert!(
response.contains("text/event-stream"),
"expected text/event-stream in response headers, got:\n{response}"
);
server.abort();
}
#[tokio::test]
async fn hmr_handler_rejects_non_loopback() {
let external = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
assert!(!is_loopback(&external));
}
#[test]
fn modify_event_detected() {
let kind = notify::EventKind::Modify(notify::event::ModifyKind::Data(
notify::event::DataChange::Content,
));
assert!(is_modify_event(&kind));
}
#[test]
fn create_event_detected() {
let kind = notify::EventKind::Create(notify::event::CreateKind::File);
assert!(is_modify_event(&kind));
}
#[test]
fn remove_event_not_detected_as_modify() {
let kind = notify::EventKind::Remove(notify::event::RemoveKind::File);
assert!(!is_modify_event(&kind));
}
fn create_null_watcher() -> notify::RecommendedWatcher {
notify::recommended_watcher(|_: notify::Result<notify::Event>| {}).expect("null watcher")
}
}