#![forbid(unsafe_code, unused_unsafe)]
#![warn(clippy::all, missing_docs, nonstandard_style, future_incompatible)]
#![allow(clippy::type_complexity)]
mod inject;
mod overlay;
pub mod predicate;
mod sse;
use std::{convert::Infallible, sync::Arc, time::Duration};
use http::{header, Request, Response, StatusCode};
use tokio::sync::Notify;
use tower::{Layer, Service};
use crate::{
inject::InjectService,
overlay::OverlayService,
predicate::{Always, ContentTypeStartsWith, Predicate},
sse::ReloadEventsBody,
};
const DEFAULT_PREFIX: &str = "/_tower-livereload";
#[derive(Clone, Debug)]
pub struct Reloader {
sender: Arc<Notify>,
}
impl Reloader {
pub fn new() -> Self {
Self {
sender: Arc::new(Notify::new()),
}
}
pub fn reload(&self) {
self.sender.notify_waiters();
}
}
impl Default for Reloader {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct LiveReloadLayer<ReqPred = Always, ResPred = ContentTypeStartsWith<&'static str>> {
custom_prefix: Option<String>,
reloader: Reloader,
req_predicate: ReqPred,
res_predicate: ResPred,
reload_interval: Duration,
}
impl LiveReloadLayer {
pub fn new() -> Self {
Self {
custom_prefix: None,
reloader: Reloader::new(),
req_predicate: Always,
res_predicate: ContentTypeStartsWith::new("text/html"),
reload_interval: Duration::from_secs(1),
}
}
}
impl<ReqPred, ResPred> LiveReloadLayer<ReqPred, ResPred> {
pub fn custom_prefix<P: Into<String>>(self, prefix: P) -> Self {
Self {
custom_prefix: Some(prefix.into()),
..self
}
}
pub fn request_predicate<Body, P: Predicate<Request<Body>>>(
self,
predicate: P,
) -> LiveReloadLayer<P, ResPred> {
LiveReloadLayer {
custom_prefix: self.custom_prefix,
reloader: self.reloader,
req_predicate: predicate,
res_predicate: self.res_predicate,
reload_interval: self.reload_interval,
}
}
pub fn response_predicate<Body, P: Predicate<Response<Body>>>(
self,
predicate: P,
) -> LiveReloadLayer<ReqPred, P> {
LiveReloadLayer {
custom_prefix: self.custom_prefix,
reloader: self.reloader,
req_predicate: self.req_predicate,
res_predicate: predicate,
reload_interval: self.reload_interval,
}
}
pub fn reload_interval(self, interval: Duration) -> Self {
Self {
reload_interval: interval,
..self
}
}
pub fn reloader(&self) -> Reloader {
self.reloader.clone()
}
}
impl Default for LiveReloadLayer {
fn default() -> Self {
Self::new()
}
}
impl<S, ReqPred: Copy, ResPred: Copy> Layer<S> for LiveReloadLayer<ReqPred, ResPred> {
type Service = LiveReload<S, ReqPred, ResPred>;
fn layer(&self, inner: S) -> Self::Service {
LiveReload::new(
inner,
self.reloader.clone(),
self.req_predicate,
self.res_predicate,
self.reload_interval,
self.custom_prefix
.clone()
.unwrap_or_else(|| DEFAULT_PREFIX.to_owned()),
)
}
}
type InnerService<S, ReqPred, ResPred> =
OverlayService<ReloadEventsBody, Infallible, InjectService<S, ReqPred, ResPred>>;
#[derive(Clone, Debug)]
pub struct LiveReload<S, ReqPred = Always, ResPred = ContentTypeStartsWith<&'static str>> {
service: InnerService<S, ReqPred, ResPred>,
}
impl<S, ReqPred, ResPred> LiveReload<S, ReqPred, ResPred> {
fn new<P: AsRef<str>>(
service: S,
reloader: Reloader,
req_predicate: ReqPred,
res_predicate: ResPred,
reload_interval: Duration,
prefix: P,
) -> Self {
let event_stream_path = format!("{}/event-stream", prefix.as_ref());
let inject = InjectService::new(
service,
format!(
r#"<script data-event-stream="{path}">{code}</script>"#,
path = event_stream_path,
code = include_str!("../assets/sse_reload.js"),
)
.into(),
req_predicate,
res_predicate,
);
let overlay = OverlayService::new(inject, move |parts| {
if parts.uri.path() == event_stream_path {
return Some(
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.body(ReloadEventsBody::new(
reloader.sender.clone(),
reload_interval,
))
.map_err(|_| unreachable!()),
);
}
None
});
LiveReload { service: overlay }
}
}
impl<ReqBody, ResBody, S, ReqPred, ResPred> Service<Request<ReqBody>>
for LiveReload<S, ReqPred, ResPred>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
ResBody: http_body::Body,
ReqPred: Predicate<Request<ReqBody>>,
ResPred: Predicate<Response<ResBody>>,
{
type Response = <InnerService<S, ReqPred, ResPred> as Service<Request<ReqBody>>>::Response;
type Error = <InnerService<S, ReqPred, ResPred> as Service<Request<ReqBody>>>::Error;
type Future = <InnerService<S, ReqPred, ResPred> as Service<Request<ReqBody>>>::Future;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
self.service.call(req)
}
}