live_server/http_layer/
server.rs

1use axum::{
2    Router,
3    body::Body,
4    extract::{
5        Request, State, WebSocketUpgrade,
6        ws::{Message, WebSocket},
7    },
8    http::{HeaderMap, HeaderValue, StatusCode, header},
9    routing::get,
10};
11use futures::{sink::SinkExt, stream::StreamExt};
12use mime_guess::mime;
13use std::{
14    fs,
15    io::ErrorKind,
16    path::{Path, PathBuf},
17    sync::Arc,
18};
19use tokio::{net::TcpListener, sync::broadcast};
20
21use crate::{
22    http_layer::template::{error_html, index_html},
23    utils::is_ignored,
24};
25
26/// JS script containing a function that takes in the address and connects to the websocket.
27const WEBSOCKET_FUNCTION: &str = include_str!("../templates/websocket.js");
28
29/// JS script to inject to the HTML on reload so the client
30/// knows it's a successful reload.
31const RELOAD_PAYLOAD: &str = include_str!("../templates/reload.js");
32
33pub(crate) async fn serve(tcp_listener: TcpListener, router: Router) {
34    axum::serve(tcp_listener, router).await.unwrap();
35}
36
37pub struct Options {
38    /// Always hard reload the page instead of hot-reload
39    pub hard_reload: bool,
40    /// Show page list of the current URL if `index.html` does not exist
41    pub index_listing: bool,
42    /// Ignore hidden and ignored files
43    pub auto_ignore: bool,
44}
45
46pub(crate) struct AppState {
47    /// Always hard reload the page instead of hot-reload
48    pub(crate) hard_reload: bool,
49    /// Show page list of the current URL if `index.html` does not exist
50    pub(crate) index_listing: bool,
51    /// Ignore hidden and ignored files
52    pub(crate) auto_ignore: bool,
53    pub(crate) tx: Arc<broadcast::Sender<()>>,
54    pub(crate) root: PathBuf,
55}
56
57impl Default for Options {
58    fn default() -> Self {
59        Self {
60            hard_reload: false,
61            index_listing: true,
62            auto_ignore: false,
63        }
64    }
65}
66
67pub(crate) fn create_server(state: AppState) -> Router {
68    let tx = state.tx.clone();
69    Router::new()
70        .route("/", get(static_assets))
71        .route("/*path", get(static_assets))
72        .route(
73            "/live-server-ws",
74            get(|ws: WebSocketUpgrade| async move {
75                ws.on_failed_upgrade(|error| {
76                    log::error!("Failed to upgrade websocket: {error}");
77                })
78                .on_upgrade(|socket: WebSocket| on_websocket_upgrade(socket, tx))
79            }),
80        )
81        .with_state(Arc::new(state))
82}
83
84async fn on_websocket_upgrade(socket: WebSocket, tx: Arc<broadcast::Sender<()>>) {
85    let (mut sender, mut receiver) = socket.split();
86    let mut rx = tx.subscribe();
87    let mut send_task = tokio::spawn(async move {
88        while rx.recv().await.is_ok() {
89            sender.send(Message::Text(String::new())).await.unwrap();
90        }
91    });
92    let mut recv_task =
93        tokio::spawn(async move { while let Some(Ok(_)) = receiver.next().await {} });
94    tokio::select! {
95        _ = (&mut send_task) => recv_task.abort(),
96        _ = (&mut recv_task) => send_task.abort(),
97    };
98}
99
100fn get_index_listing(uri_path: &str, root: &Path, auto_ignore: bool) -> String {
101    let is_root = uri_path == "/";
102    let path = root.join(&uri_path[1..]);
103    let entries = fs::read_dir(path).unwrap();
104    let mut entry_names = entries
105        .into_iter()
106        .filter_map(|e| {
107            if let Ok(entry) = e {
108                if auto_ignore {
109                    match is_ignored(root, &entry.path()) {
110                        Ok(ignored) => {
111                            if ignored {
112                                return None;
113                            }
114                        }
115                        Err(err) => {
116                            log::error!("Failed to check ignore files: {err}");
117                            // Do nothing if we cannot know if it's an ignored entry
118                            return None;
119                        }
120                    }
121                }
122                let is_dir = entry.metadata().ok()?.is_dir();
123                let trailing = if is_dir { "/" } else { "" };
124                entry
125                    .file_name()
126                    .to_str()
127                    .map(|name| format!("{name}{trailing}"))
128            } else {
129                None
130            }
131        })
132        .collect::<Vec<String>>();
133    entry_names.sort();
134    if !is_root {
135        entry_names.insert(0, "..".to_string());
136    }
137    entry_names
138        .into_iter()
139        .map(|en| format!("<li><a href=\"{en}\">{en}</a></li>"))
140        .collect::<Vec<String>>()
141        .join("\n")
142}
143
144async fn static_assets(
145    state: State<Arc<AppState>>,
146    req: Request<Body>,
147) -> (StatusCode, HeaderMap, Body) {
148    let is_reload = req.uri().query().is_some_and(|x| x == "reload");
149
150    // Get the path and mime of the static file.
151    let uri_path = req.uri().path();
152    // Avoid [directory traversal attack](https://en.wikipedia.org/wiki/Directory_traversal_attack).
153    if uri_path.starts_with("//") {
154        let redirect = format!("/{}", uri_path.trim_start_matches("/"));
155        let mut headers = HeaderMap::new();
156        headers.append(header::LOCATION, HeaderValue::from_str(&redirect).unwrap());
157        return (StatusCode::TEMPORARY_REDIRECT, headers, Body::empty());
158    }
159    let mut path = state.root.join(&uri_path[1..]);
160    let is_accessing_dir = path.is_dir();
161    if is_accessing_dir {
162        if !uri_path.ends_with('/') {
163            // redirect so parent links work correctly
164            let redirect = format!("{uri_path}/");
165            let mut headers = HeaderMap::new();
166            headers.append(header::LOCATION, HeaderValue::from_str(&redirect).unwrap());
167            return (StatusCode::TEMPORARY_REDIRECT, headers, Body::empty());
168        }
169        path.push("index.html");
170    }
171    let mime = mime_guess::from_path(&path).first_or_text_plain();
172
173    let mut headers = HeaderMap::new();
174    let content_type = if mime.type_() == mime::TEXT {
175        format!("{}; charset=utf-8", mime.as_ref())
176    } else {
177        mime.as_ref().to_string()
178    };
179    headers.append(
180        header::CONTENT_TYPE,
181        HeaderValue::from_str(&content_type).unwrap(),
182    );
183
184    if state.auto_ignore {
185        match is_ignored(&state.root, &path) {
186            Ok(ignored) => {
187                if ignored {
188                    let err_msg =
189                        "Unable to access ignored or hidden file, because `--ignore` is enabled";
190                    let body = generate_error_body(err_msg, state.hard_reload, is_reload);
191
192                    return (StatusCode::FORBIDDEN, HeaderMap::new(), body);
193                }
194            }
195            Err(err) => {
196                let err_msg = format!("Failed to check ignore files: {err}");
197                let body = generate_error_body(&err_msg, state.hard_reload, is_reload);
198                log::error!("{err_msg}");
199
200                return (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new(), body);
201            }
202        }
203    }
204
205    // Read the file.
206    let mut file = match fs::read(&path) {
207        Ok(file) => file,
208        Err(err) => {
209            let status_code = match err.kind() {
210                ErrorKind::NotFound => {
211                    if state.index_listing && is_accessing_dir {
212                        let script = format_script(state.hard_reload, is_reload, false);
213                        let html = index_html(
214                            uri_path,
215                            &script,
216                            &get_index_listing(uri_path, &state.root, state.auto_ignore),
217                        );
218                        return (StatusCode::OK, headers, html);
219                    }
220                    StatusCode::NOT_FOUND
221                }
222                _ => StatusCode::INTERNAL_SERVER_ERROR,
223            };
224            match path.to_str() {
225                Some(path) => log::warn!("Failed to read \"{path}\": {err}"),
226                None => log::warn!("Failed to read file with invalid path: {err}"),
227            }
228            return (
229                status_code,
230                headers,
231                if mime == "text/html" {
232                    generate_error_body(&err.to_string(), state.hard_reload, is_reload)
233                } else {
234                    Body::from(err.to_string())
235                },
236            );
237        }
238    };
239
240    // Construct the response.
241    if mime == "text/html" {
242        let text = match String::from_utf8(file) {
243            Ok(text) => text,
244            Err(err) => {
245                log::error!("Failed to read {path:?} as utf-8: {err}");
246                let html = generate_error_body(&err.to_string(), state.hard_reload, is_reload);
247                return (StatusCode::INTERNAL_SERVER_ERROR, headers, html);
248            }
249        };
250        let script = format_script(state.hard_reload, is_reload, false);
251        file = format!("{text}{script}").into_bytes();
252    } else if state.hard_reload {
253        // allow client to cache assets for a smoother reload.
254        // client handles preloading to refresh cache before reloading.
255        headers.append(
256            header::CACHE_CONTROL,
257            HeaderValue::from_str("max-age=30").unwrap(),
258        );
259    }
260
261    (StatusCode::OK, headers, Body::from(file))
262}
263
264/// Inject the address into the websocket script and wrap it in a script tag
265fn format_script(hard_reload: bool, is_reload: bool, is_error: bool) -> String {
266    match (is_reload, is_error) {
267        // successful reload, inject the reload payload
268        (true, false) => format!("<script>{RELOAD_PAYLOAD}</script>"),
269        // failed reload, don't inject anything so the client polls again
270        (true, true) => String::new(),
271        // normal connection, inject the websocket client
272        _ => {
273            let hard = if hard_reload { "true" } else { "false" };
274            format!(r#"<script>{WEBSOCKET_FUNCTION}({hard})</script>"#)
275        }
276    }
277}
278
279fn generate_error_body(err_msg: &str, hard_reload: bool, is_reload: bool) -> Body {
280    let script = format_script(hard_reload, is_reload, true);
281    error_html(&script, err_msg)
282}